1/* SPDX-License-Identifier: GPL-2.0-only */
2
3#ifndef __NET_PSP_HELPERS_H
4#define __NET_PSP_HELPERS_H
5
6#include <linux/skbuff.h>
7#include <linux/rcupdate.h>
8#include <linux/udp.h>
9#include <net/sock.h>
10#include <net/tcp.h>
11#include <net/psp/types.h>
12
13struct inet_timewait_sock;
14
15/* Driver-facing API */
16struct psp_dev *
17psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops,
18 struct psp_dev_caps *psd_caps, void *priv_ptr);
19void psp_dev_unregister(struct psp_dev *psd);
20bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
21 u8 ver, __be16 sport);
22int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv);
23
24/* Kernel-facing API */
25void psp_assoc_put(struct psp_assoc *pas);
26
27static inline void *psp_assoc_drv_data(struct psp_assoc *pas)
28{
29 return pas->drv_data;
30}
31
32#if IS_ENABLED(CONFIG_INET_PSP)
33unsigned int psp_key_size(u32 version);
34void psp_sk_assoc_free(struct sock *sk);
35void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk);
36void psp_twsk_assoc_free(struct inet_timewait_sock *tw);
37void psp_reply_set_decrypted(const struct sock *sk, struct sk_buff *skb);
38
39static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
40{
41 return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk));
42}
43
44static inline void
45psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb)
46{
47 struct psp_assoc *pas;
48
49 pas = psp_sk_assoc(sk);
50 if (pas && pas->tx.spi)
51 skb->decrypted = 1;
52}
53
54static inline unsigned long
55__psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
56 unsigned long diffs)
57{
58 struct psp_skb_ext *a, *b;
59
60 a = skb_ext_find(one, SKB_EXT_PSP);
61 b = skb_ext_find(two, SKB_EXT_PSP);
62
63 diffs |= (!!a) ^ (!!b);
64 if (!diffs && unlikely(a))
65 diffs |= memcmp(a, b, sizeof(*a));
66 return diffs;
67}
68
69static inline bool
70psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas)
71{
72 bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN);
73 u32 end_seq = TCP_SKB_CB(skb)->end_seq;
74 u32 seq = TCP_SKB_CB(skb)->seq;
75 bool pure_fin;
76
77 pure_fin = fin && end_seq - seq == 1;
78
79 return seq == end_seq || (pure_fin && seq == pas->upgrade_seq);
80}
81
82static inline bool
83psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas)
84{
85 return pse && pas->rx.spi == pse->spi &&
86 pas->generation == pse->generation &&
87 pas->version == pse->version &&
88 pas->dev_id == pse->dev_id;
89}
90
91static inline enum skb_drop_reason
92__psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas)
93{
94 struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP);
95
96 if (!pas)
97 return pse ? SKB_DROP_REASON_PSP_INPUT : 0;
98
99 if (likely(psp_pse_matches_pas(pse, pas))) {
100 if (unlikely(!pas->peer_tx))
101 pas->peer_tx = 1;
102
103 return 0;
104 }
105
106 if (!pse) {
107 if (!pas->tx.spi ||
108 (!pas->peer_tx && psp_is_allowed_nondata(skb, pas)))
109 return 0;
110 }
111
112 return SKB_DROP_REASON_PSP_INPUT;
113}
114
115static inline enum skb_drop_reason
116psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
117{
118 return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk));
119}
120
121static inline enum skb_drop_reason
122psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
123{
124 return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc));
125}
126
127static inline struct psp_assoc *psp_sk_get_assoc_rcu(const struct sock *sk)
128{
129 struct psp_assoc *pas;
130 int state;
131
132 state = READ_ONCE(sk->sk_state);
133 if (!sk_is_inet(sk) || state == TCP_NEW_SYN_RECV)
134 return NULL;
135
136 pas = state == TCP_TIME_WAIT ?
137 rcu_dereference(inet_twsk(sk)->psp_assoc) :
138 rcu_dereference(sk->psp_assoc);
139 return pas;
140}
141
142static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
143{
144 if (!skb->decrypted || !skb->sk)
145 return NULL;
146
147 return psp_sk_get_assoc_rcu(skb->sk);
148}
149
150static inline unsigned int psp_sk_overhead(const struct sock *sk)
151{
152 int psp_encap = sizeof(struct udphdr) + PSP_HDR_SIZE + PSP_TRL_SIZE;
153 bool has_psp = rcu_access_pointer(sk->psp_assoc);
154
155 return has_psp ? psp_encap : 0;
156}
157#else
158static inline void psp_sk_assoc_free(struct sock *sk) { }
159static inline void
160psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { }
161static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { }
162static inline void
163psp_reply_set_decrypted(const struct sock *sk, struct sk_buff *skb) { }
164
165static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
166{
167 return NULL;
168}
169
170static inline void
171psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { }
172
173static inline unsigned long
174__psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
175 unsigned long diffs)
176{
177 return diffs;
178}
179
180static inline enum skb_drop_reason
181psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
182{
183 return 0;
184}
185
186static inline enum skb_drop_reason
187psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
188{
189 return 0;
190}
191
192static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
193{
194 return NULL;
195}
196
197static inline unsigned int psp_sk_overhead(const struct sock *sk)
198{
199 return 0;
200}
201#endif
202
203static inline unsigned long
204psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two)
205{
206 return __psp_skb_coalesce_diff(one, two, diffs: 0);
207}
208
209#endif /* __NET_PSP_HELPERS_H */
210