1// SPDX-License-Identifier: GPL-2.0-or-later
2/*
3 * Cryptographic API.
4 *
5 * Cipher operations.
6 *
7 * Copyright (c) 2002 James Morris <jmorris@intercode.com.au>
8 * 2002 Adam J. Richter <adam@yggdrasil.com>
9 * 2004 Jean-Luc Cooke <jlcooke@certainkey.com>
10 */
11
12#include <crypto/scatterwalk.h>
13#include <linux/crypto.h>
14#include <linux/errno.h>
15#include <linux/kernel.h>
16#include <linux/mm.h>
17#include <linux/module.h>
18#include <linux/scatterlist.h>
19#include <linux/slab.h>
20
21enum {
22 SKCIPHER_WALK_SLOW = 1 << 0,
23 SKCIPHER_WALK_COPY = 1 << 1,
24 SKCIPHER_WALK_DIFF = 1 << 2,
25 SKCIPHER_WALK_SLEEP = 1 << 3,
26};
27
28static inline gfp_t skcipher_walk_gfp(struct skcipher_walk *walk)
29{
30 return walk->flags & SKCIPHER_WALK_SLEEP ? GFP_KERNEL : GFP_ATOMIC;
31}
32
33void scatterwalk_skip(struct scatter_walk *walk, unsigned int nbytes)
34{
35 struct scatterlist *sg = walk->sg;
36
37 nbytes += walk->offset - sg->offset;
38
39 while (nbytes > sg->length) {
40 nbytes -= sg->length;
41 sg = sg_next(sg);
42 }
43 walk->sg = sg;
44 walk->offset = sg->offset + nbytes;
45}
46EXPORT_SYMBOL_GPL(scatterwalk_skip);
47
48inline void memcpy_from_scatterwalk(void *buf, struct scatter_walk *walk,
49 unsigned int nbytes)
50{
51 do {
52 unsigned int to_copy;
53
54 to_copy = scatterwalk_next(walk, total: nbytes);
55 memcpy(to: buf, from: walk->addr, len: to_copy);
56 scatterwalk_done_src(walk, nbytes: to_copy);
57 buf += to_copy;
58 nbytes -= to_copy;
59 } while (nbytes);
60}
61EXPORT_SYMBOL_GPL(memcpy_from_scatterwalk);
62
63inline void memcpy_to_scatterwalk(struct scatter_walk *walk, const void *buf,
64 unsigned int nbytes)
65{
66 do {
67 unsigned int to_copy;
68
69 to_copy = scatterwalk_next(walk, total: nbytes);
70 memcpy(to: walk->addr, from: buf, len: to_copy);
71 scatterwalk_done_dst(walk, nbytes: to_copy);
72 buf += to_copy;
73 nbytes -= to_copy;
74 } while (nbytes);
75}
76EXPORT_SYMBOL_GPL(memcpy_to_scatterwalk);
77
78void memcpy_from_sglist(void *buf, struct scatterlist *sg,
79 unsigned int start, unsigned int nbytes)
80{
81 struct scatter_walk walk;
82
83 if (unlikely(nbytes == 0)) /* in case sg == NULL */
84 return;
85
86 scatterwalk_start_at_pos(walk: &walk, sg, pos: start);
87 memcpy_from_scatterwalk(buf, &walk, nbytes);
88}
89EXPORT_SYMBOL_GPL(memcpy_from_sglist);
90
91void memcpy_to_sglist(struct scatterlist *sg, unsigned int start,
92 const void *buf, unsigned int nbytes)
93{
94 struct scatter_walk walk;
95
96 if (unlikely(nbytes == 0)) /* in case sg == NULL */
97 return;
98
99 scatterwalk_start_at_pos(walk: &walk, sg, pos: start);
100 memcpy_to_scatterwalk(&walk, buf, nbytes);
101}
102EXPORT_SYMBOL_GPL(memcpy_to_sglist);
103
104void memcpy_sglist(struct scatterlist *dst, struct scatterlist *src,
105 unsigned int nbytes)
106{
107 struct skcipher_walk walk = {};
108
109 if (unlikely(nbytes == 0)) /* in case sg == NULL */
110 return;
111
112 walk.total = nbytes;
113
114 scatterwalk_start(walk: &walk.in, sg: src);
115 scatterwalk_start(walk: &walk.out, sg: dst);
116
117 skcipher_walk_first(walk: &walk, atomic: true);
118 do {
119 if (walk.src.virt.addr != walk.dst.virt.addr)
120 memcpy(to: walk.dst.virt.addr, from: walk.src.virt.addr,
121 len: walk.nbytes);
122 skcipher_walk_done(walk: &walk, res: 0);
123 } while (walk.nbytes);
124}
125EXPORT_SYMBOL_GPL(memcpy_sglist);
126
127struct scatterlist *scatterwalk_ffwd(struct scatterlist dst[2],
128 struct scatterlist *src,
129 unsigned int len)
130{
131 for (;;) {
132 if (!len)
133 return src;
134
135 if (src->length > len)
136 break;
137
138 len -= src->length;
139 src = sg_next(sg: src);
140 }
141
142 sg_init_table(dst, 2);
143 sg_set_page(sg: dst, page: sg_page(sg: src), len: src->length - len, offset: src->offset + len);
144 scatterwalk_crypto_chain(head: dst, sg: sg_next(sg: src), num: 2);
145
146 return dst;
147}
148EXPORT_SYMBOL_GPL(scatterwalk_ffwd);
149
150static int skcipher_next_slow(struct skcipher_walk *walk, unsigned int bsize)
151{
152 unsigned alignmask = walk->alignmask;
153 unsigned n;
154 void *buffer;
155
156 if (!walk->buffer)
157 walk->buffer = walk->page;
158 buffer = walk->buffer;
159 if (!buffer) {
160 /* Min size for a buffer of bsize bytes aligned to alignmask */
161 n = bsize + (alignmask & ~(crypto_tfm_ctx_alignment() - 1));
162
163 buffer = kzalloc(n, skcipher_walk_gfp(walk));
164 if (!buffer)
165 return skcipher_walk_done(walk, res: -ENOMEM);
166 walk->buffer = buffer;
167 }
168
169 buffer = PTR_ALIGN(buffer, alignmask + 1);
170 memcpy_from_scatterwalk(buffer, &walk->in, bsize);
171 walk->out.__addr = buffer;
172 walk->in.__addr = walk->out.addr;
173
174 walk->nbytes = bsize;
175 walk->flags |= SKCIPHER_WALK_SLOW;
176
177 return 0;
178}
179
180static int skcipher_next_copy(struct skcipher_walk *walk)
181{
182 void *tmp = walk->page;
183
184 scatterwalk_map(walk: &walk->in);
185 memcpy(to: tmp, from: walk->in.addr, len: walk->nbytes);
186 scatterwalk_unmap(walk: &walk->in);
187 /*
188 * walk->in is advanced later when the number of bytes actually
189 * processed (which might be less than walk->nbytes) is known.
190 */
191
192 walk->in.__addr = tmp;
193 walk->out.__addr = tmp;
194 return 0;
195}
196
197static int skcipher_next_fast(struct skcipher_walk *walk)
198{
199 unsigned long diff;
200
201 diff = offset_in_page(walk->in.offset) -
202 offset_in_page(walk->out.offset);
203 diff |= (u8 *)(sg_page(sg: walk->in.sg) + (walk->in.offset >> PAGE_SHIFT)) -
204 (u8 *)(sg_page(sg: walk->out.sg) + (walk->out.offset >> PAGE_SHIFT));
205
206 scatterwalk_map(walk: &walk->out);
207 walk->in.__addr = walk->out.__addr;
208
209 if (diff) {
210 walk->flags |= SKCIPHER_WALK_DIFF;
211 scatterwalk_map(walk: &walk->in);
212 }
213
214 return 0;
215}
216
217static int skcipher_walk_next(struct skcipher_walk *walk)
218{
219 unsigned int bsize;
220 unsigned int n;
221
222 n = walk->total;
223 bsize = min(walk->stride, max(n, walk->blocksize));
224 n = scatterwalk_clamp(walk: &walk->in, nbytes: n);
225 n = scatterwalk_clamp(walk: &walk->out, nbytes: n);
226
227 if (unlikely(n < bsize)) {
228 if (unlikely(walk->total < walk->blocksize))
229 return skcipher_walk_done(walk, res: -EINVAL);
230
231slow_path:
232 return skcipher_next_slow(walk, bsize);
233 }
234 walk->nbytes = n;
235
236 if (unlikely((walk->in.offset | walk->out.offset) & walk->alignmask)) {
237 if (!walk->page) {
238 gfp_t gfp = skcipher_walk_gfp(walk);
239
240 walk->page = (void *)__get_free_page(gfp);
241 if (!walk->page)
242 goto slow_path;
243 }
244 walk->flags |= SKCIPHER_WALK_COPY;
245 return skcipher_next_copy(walk);
246 }
247
248 return skcipher_next_fast(walk);
249}
250
251static int skcipher_copy_iv(struct skcipher_walk *walk)
252{
253 unsigned alignmask = walk->alignmask;
254 unsigned ivsize = walk->ivsize;
255 unsigned aligned_stride = ALIGN(walk->stride, alignmask + 1);
256 unsigned size;
257 u8 *iv;
258
259 /* Min size for a buffer of stride + ivsize, aligned to alignmask */
260 size = aligned_stride + ivsize +
261 (alignmask & ~(crypto_tfm_ctx_alignment() - 1));
262
263 walk->buffer = kmalloc(size, skcipher_walk_gfp(walk));
264 if (!walk->buffer)
265 return -ENOMEM;
266
267 iv = PTR_ALIGN(walk->buffer, alignmask + 1) + aligned_stride;
268
269 walk->iv = memcpy(to: iv, from: walk->iv, len: walk->ivsize);
270 return 0;
271}
272
273int skcipher_walk_first(struct skcipher_walk *walk, bool atomic)
274{
275 if (WARN_ON_ONCE(in_hardirq()))
276 return -EDEADLK;
277
278 walk->flags = atomic ? 0 : SKCIPHER_WALK_SLEEP;
279
280 walk->buffer = NULL;
281 if (unlikely(((unsigned long)walk->iv & walk->alignmask))) {
282 int err = skcipher_copy_iv(walk);
283 if (err)
284 return err;
285 }
286
287 walk->page = NULL;
288
289 return skcipher_walk_next(walk);
290}
291EXPORT_SYMBOL_GPL(skcipher_walk_first);
292
293/**
294 * skcipher_walk_done() - finish one step of a skcipher_walk
295 * @walk: the skcipher_walk
296 * @res: number of bytes *not* processed (>= 0) from walk->nbytes,
297 * or a -errno value to terminate the walk due to an error
298 *
299 * This function cleans up after one step of walking through the source and
300 * destination scatterlists, and advances to the next step if applicable.
301 * walk->nbytes is set to the number of bytes available in the next step,
302 * walk->total is set to the new total number of bytes remaining, and
303 * walk->{src,dst}.virt.addr is set to the next pair of data pointers. If there
304 * is no more data, or if an error occurred (i.e. -errno return), then
305 * walk->nbytes and walk->total are set to 0 and all resources owned by the
306 * skcipher_walk are freed.
307 *
308 * Return: 0 or a -errno value. If @res was a -errno value then it will be
309 * returned, but other errors may occur too.
310 */
311int skcipher_walk_done(struct skcipher_walk *walk, int res)
312{
313 unsigned int n = walk->nbytes; /* num bytes processed this step */
314 unsigned int total = 0; /* new total remaining */
315
316 if (!n)
317 goto finish;
318
319 if (likely(res >= 0)) {
320 n -= res; /* subtract num bytes *not* processed */
321 total = walk->total - n;
322 }
323
324 if (likely(!(walk->flags & (SKCIPHER_WALK_SLOW |
325 SKCIPHER_WALK_COPY |
326 SKCIPHER_WALK_DIFF)))) {
327 scatterwalk_advance(walk: &walk->in, nbytes: n);
328 } else if (walk->flags & SKCIPHER_WALK_DIFF) {
329 scatterwalk_done_src(walk: &walk->in, nbytes: n);
330 } else if (walk->flags & SKCIPHER_WALK_COPY) {
331 scatterwalk_advance(walk: &walk->in, nbytes: n);
332 scatterwalk_map(walk: &walk->out);
333 memcpy(to: walk->out.addr, from: walk->page, len: n);
334 } else { /* SKCIPHER_WALK_SLOW */
335 if (res > 0) {
336 /*
337 * Didn't process all bytes. Either the algorithm is
338 * broken, or this was the last step and it turned out
339 * the message wasn't evenly divisible into blocks but
340 * the algorithm requires it.
341 */
342 res = -EINVAL;
343 total = 0;
344 } else
345 memcpy_to_scatterwalk(&walk->out, walk->out.addr, n);
346 goto dst_done;
347 }
348
349 scatterwalk_done_dst(walk: &walk->out, nbytes: n);
350dst_done:
351
352 if (res > 0)
353 res = 0;
354
355 walk->total = total;
356 walk->nbytes = 0;
357
358 if (total) {
359 if (walk->flags & SKCIPHER_WALK_SLEEP)
360 cond_resched();
361 walk->flags &= ~(SKCIPHER_WALK_SLOW | SKCIPHER_WALK_COPY |
362 SKCIPHER_WALK_DIFF);
363 return skcipher_walk_next(walk);
364 }
365
366finish:
367 /* Short-circuit for the common/fast path. */
368 if (!((unsigned long)walk->buffer | (unsigned long)walk->page))
369 goto out;
370
371 if (walk->iv != walk->oiv)
372 memcpy(to: walk->oiv, from: walk->iv, len: walk->ivsize);
373 if (walk->buffer != walk->page)
374 kfree(objp: walk->buffer);
375 if (walk->page)
376 free_page((unsigned long)walk->page);
377
378out:
379 return res;
380}
381EXPORT_SYMBOL_GPL(skcipher_walk_done);
382