1#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
3mod chacha_neon;
4
5#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
7mod chacha_wasm_simd128;
8
9#[cfg(any(
11 all(target_arch = "x86_64", feature = "std"),
12 all(target_arch = "x86_64", target_feature = "avx2")
13))]
14mod chacha_avx2;
15
16#[cfg(any(
18 all(target_arch = "x86_64", feature = "std"),
19 all(target_arch = "x86_64", target_feature = "avx512f")
20))]
21mod chacha_avx512;
22
23#[cfg(feature = "zeroize")]
24use zeroize::{Zeroize, ZeroizeOnDrop};
25
26use crate::StreamCipher;
27
28const STATE_WORDS: usize = 16;
30
31const BLOCK_SIZE: usize = 64;
33
34const CONSTANT: [u32; 4] = [
36 0x61707865, 0x3320646e, 0x79622d32, 0x6b206574, ];
41
42pub type ChaCha8Djb = ChaChaDjb<8>;
43pub type ChaCha12Djb = ChaChaDjb<12>;
44pub type ChaCha20Djb = ChaChaDjb<20>;
45
46#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
47pub struct ChaChaDjb<const ROUNDS: usize> {
48 state: [u32; STATE_WORDS],
49 last_keystream_block: [u8; BLOCK_SIZE],
64 last_keystream_block_index: usize,
65}
66
67impl<const ROUNDS: usize> ChaChaDjb<ROUNDS> {
68 pub fn new(key: &[u8; 32], nonce: &[u8; 8]) -> ChaChaDjb<ROUNDS> {
69 let mut state = [0u32; STATE_WORDS];
70
71 state[..4].copy_from_slice(&CONSTANT);
73
74 for (state_word, key_chunk) in state[4..12].iter_mut().zip(key.chunks_exact(4)) {
76 *state_word = u32::from_le_bytes(key_chunk.try_into().unwrap());
77 }
78
79 state[14] = u32::from_le_bytes(nonce[0..4].try_into().unwrap());
85 state[15] = u32::from_le_bytes(nonce[4..8].try_into().unwrap());
86
87 return ChaChaDjb {
88 state,
89 last_keystream_block: [0u8; BLOCK_SIZE],
90 last_keystream_block_index: 0,
91 };
92 }
93
94 pub fn set_counter(&mut self, counter: u64) {
97 inject_counter_into_state(&mut self.state, counter);
98 self.last_keystream_block_index = 0;
100 }
101}
102
103impl<const ROUNDS: usize> StreamCipher for ChaChaDjb<ROUNDS> {
104 fn xor_keystream(&mut self, mut in_out: &mut [u8]) {
106 if in_out.len() == 0 {
107 return;
108 }
109
110 if self.last_keystream_block_index != 0 {
112 let remaining_keystream = &self.last_keystream_block[self.last_keystream_block_index..];
113
114 in_out
115 .iter_mut()
116 .zip(remaining_keystream)
117 .for_each(|(plaintext, keystream)| *plaintext ^= *keystream);
118
119 if in_out.len() > remaining_keystream.len() {
120 in_out = &mut in_out[remaining_keystream.len()..];
121 } else if in_out.len() < remaining_keystream.len() {
122 self.last_keystream_block_index += in_out.len();
123 return;
124 } else {
125 self.last_keystream_block_index = 0;
127 return;
128 }
129 }
130 self.last_keystream_block_index = in_out.len() % BLOCK_SIZE;
131
132 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
134 if in_out.len() >= 128 {
135 use chacha_neon::chacha_neon;
136 unsafe {
138 chacha_neon::<ROUNDS>(&mut self.state, in_out, &mut self.last_keystream_block);
139 }
140 return;
141 }
142
143 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
145 if in_out.len() >= 128 {
146 use chacha_wasm_simd128::chacha_wasm_simd128;
147 chacha_wasm_simd128::<ROUNDS>(&mut self.state, in_out, &mut self.last_keystream_block);
148 return;
149 }
150
151 #[cfg(feature = "std")]
153 {
154 #[cfg(target_arch = "x86_64")]
155 if is_x86_feature_detected!("avx512f") && in_out.len() >= 128 {
156 use chacha_avx512::chacha_avx512;
157 unsafe {
158 chacha_avx512::<ROUNDS>(&mut self.state, in_out, &mut self.last_keystream_block);
159 }
160 return;
161 }
162
163 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
164 if is_x86_feature_detected!("avx2") && in_out.len() >= 128 {
165 use chacha_avx2::chacha_avx2;
166 unsafe {
167 chacha_avx2::<ROUNDS>(&mut self.state, in_out, &mut self.last_keystream_block);
168 }
169 return;
170 }
171 }
172
173 #[cfg(not(feature = "std"))]
175 {
176 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
177 if in_out.len() >= 128 {
178 use chacha_avx512::chacha_avx512;
179 chacha_avx512::<ROUNDS>(&mut self.state, in_out, &mut self.last_keystream_block);
180 return;
181 }
182
183 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx2"))]
184 if in_out.len() >= 128 {
185 use chacha_avx2::chacha_avx2;
186 chacha_avx2::<ROUNDS>(&mut self.state, in_out, &mut self.last_keystream_block);
187 return;
188 }
189 }
190
191 chacha_generic::<ROUNDS>(&mut self.state, &mut self.last_keystream_block, in_out);
193 }
194}
195
196#[inline]
197fn chacha_generic<const ROUNDS: usize>(
198 mut state: &mut [u32; STATE_WORDS],
199 last_keystream_block: &mut [u8; BLOCK_SIZE],
200 plaintext: &mut [u8],
201) {
202 let mut keystream = [0u8; BLOCK_SIZE];
203 let keystream_ptr = keystream.as_mut_ptr();
204 let mut counter = extract_counter_from_state(state);
205
206 for plaintext_block in plaintext.chunks_mut(BLOCK_SIZE) {
208 inject_counter_into_state(&mut state, counter);
209
210 let mut tmp_state = *state;
212
213 for _ in 0..(ROUNDS / 2) {
215 quarter_round(&mut tmp_state, 0, 4, 8, 12);
217 quarter_round(&mut tmp_state, 1, 5, 9, 13);
218 quarter_round(&mut tmp_state, 2, 6, 10, 14);
219 quarter_round(&mut tmp_state, 3, 7, 11, 15);
220
221 quarter_round(&mut tmp_state, 0, 5, 10, 15);
223 quarter_round(&mut tmp_state, 1, 6, 11, 12);
224 quarter_round(&mut tmp_state, 2, 7, 8, 13);
225 quarter_round(&mut tmp_state, 3, 4, 9, 14);
226 }
227
228 for word_index in 0..STATE_WORDS {
233 tmp_state[word_index] = tmp_state[word_index].wrapping_add(state[word_index]);
235
236 unsafe {
238 core::ptr::copy_nonoverlapping(
239 tmp_state[word_index].to_le_bytes().as_ptr(),
240 keystream_ptr.add(word_index * 4),
241 4,
242 );
243 }
244 }
245
246 plaintext_block
248 .iter_mut()
249 .zip(keystream)
250 .for_each(|(plaintext, keystream)| *plaintext ^= keystream);
251
252 counter = counter.wrapping_add(1);
253 }
254
255 inject_counter_into_state(state, counter);
256
257 if plaintext.len() % BLOCK_SIZE != 0 {
258 last_keystream_block.copy_from_slice(&keystream);
259 }
260}
261
262#[inline(always)]
263const fn quarter_round(state: &mut [u32; 16], a: usize, b: usize, c: usize, d: usize) {
264 state[a] = state[a].wrapping_add(state[b]);
266 state[d] ^= state[a];
267 state[d] = state[d].rotate_left(16);
268
269 state[c] = state[c].wrapping_add(state[d]);
271 state[b] ^= state[c];
272 state[b] = state[b].rotate_left(12);
273
274 state[a] = state[a].wrapping_add(state[b]);
276 state[d] ^= state[a];
277 state[d] = state[d].rotate_left(8);
278
279 state[c] = state[c].wrapping_add(state[d]);
281 state[b] ^= state[c];
282 state[b] = state[b].rotate_left(7);
283}
284
285#[inline(always)]
286fn extract_counter_from_state(state: &[u32; STATE_WORDS]) -> u64 {
287 return ((state[13] as u64) << 32) | (state[12] as u64);
288}
289
290#[inline(always)]
291fn inject_counter_into_state(state: &mut [u32; STATE_WORDS], counter: u64) {
292 state[12] = counter as u32;
293 state[13] = (counter >> 32) as u32;
294}
295
296#[cfg(test)]
297mod test {
298 use super::{ChaCha8Djb, ChaCha12Djb, ChaCha20Djb};
299 use crate::StreamCipher;
300
301 struct Test {
302 key: [u8; 32],
303 nonce: [u8; 8],
304 initial_counter: u64,
305 plaintext: Vec<u8>,
306 expected_ciphertext: Vec<u8>,
307 }
308
309 #[test]
310 fn chacha20_test_vectors() {
311 let tests = vec![
312 Test {
314 key: hex::decode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f")
315 .unwrap()
316 .try_into()
317 .unwrap(),
318 nonce: hex::decode("0000004a00000000").unwrap().try_into().unwrap(),
319 initial_counter: 1,
320 plaintext: hex::decode(
321 "4c616469657320616e642047656e746c\
322656d656e206f662074686520636c6173\
32373206f66202739393a20496620492063\
3246f756c64206f6666657220796f75206f\
3256e6c79206f6e652074697020666f7220\
326746865206675747572652c2073756e73\
327637265656e20776f756c642062652069\
328742e",
329 )
330 .unwrap(),
331 expected_ciphertext: hex::decode(
332 "6e2e359a2568f98041ba0728dd0d6981\
333e97e7aec1d4360c20a27afccfd9fae0b\
334f91b65c5524733ab8f593dabcd62b357\
3351639d624e65152ab8f530c359f0861d8\
33607ca0dbf500d6a6156a38e088a22b65e\
33752bc514d16ccf806818ce91ab7793736\
3385af90bbf74a35be6b40b8eedf2785e42\
339874d",
340 )
341 .unwrap(),
342 },
343 Test {
345 key: [0u8; 32],
346 nonce: [0u8; 8],
347 initial_counter: 0,
348 plaintext: [0u8; 64].to_vec(),
349 expected_ciphertext: hex::decode(
350 "76b8e0ada0f13d90405d6ae55386bd28\
351bdd219b8a08ded1aa836efcc8b770dc7\
352da41597c5157488d7724e03fb8d84a37\
3536a43b8f41518a11cc387b669b2ee6586",
354 )
355 .unwrap(),
356 },
357 Test {
359 key: hex::decode("0000000000000000000000000000000000000000000000000000000000000001")
360 .unwrap()
361 .try_into()
362 .unwrap(),
363 nonce: hex::decode("0000000000000002").unwrap().try_into().unwrap(),
364 initial_counter: 1,
365 plaintext: hex::decode(
366 "416e79207375626d697373696f6e2074\
3676f20746865204945544620696e74656e\
3686465642062792074686520436f6e7472\
369696275746f7220666f72207075626c69\
370636174696f6e20617320616c6c206f72\
3712070617274206f6620616e2049455446\
37220496e7465726e65742d447261667420\
3736f722052464320616e6420616e792073\
374746174656d656e74206d616465207769\
3757468696e2074686520636f6e74657874\
376206f6620616e20494554462061637469\
3777669747920697320636f6e7369646572\
378656420616e20224945544620436f6e74\
3797269627574696f6e222e205375636820\
38073746174656d656e747320696e636c75\
3816465206f72616c2073746174656d656e\
382747320696e2049455446207365737369\
3836f6e732c2061732077656c6c20617320\
3847772697474656e20616e6420656c6563\
38574726f6e696320636f6d6d756e696361\
38674696f6e73206d61646520617420616e\
387792074696d65206f7220706c6163652c\
38820776869636820617265206164647265\
3897373656420746f",
390 )
391 .unwrap(),
392 expected_ciphertext: hex::decode(
393 "a3fbf07df3fa2fde4f376ca23e827370\
39441605d9f4f4f57bd8cff2c1d4b7955ec\
3952a97948bd3722915c8f3d337f7d37005\
3960e9e96d647b7c39f56e031ca5eb6250d\
3974042e02785ececfa4b4bb5e8ead0440e\
39820b6e8db09d881a7c6132f420e527950\
39942bdfa7773d8a9051447b3291ce1411c\
400680465552aa6c405b7764d5e87bea85a\
401d00f8449ed8f72d0d662ab052691ca66\
402424bc86d2df80ea41f43abf937d3259d\
403c4b2d0dfb48a6c9139ddd7f76966e928\
404e635553ba76c5c879d7b35d49eb2e62b\
4050871cdac638939e25e8a1e0ef9d5280f\
406a8ca328b351c3c765989cbcf3daa8b6c\
407cc3aaf9f3979c92b3720fc88dc95ed84\
408a1be059c6499b9fda236e7e818b04b0b\
409c39c1e876b193bfe5569753f88128cc0\
4108aaa9b63d1a16f80ef2554d7189c411f\
4115869ca52c5b83fa36ff216b9c1d30062\
412bebcfd2dc5bce0911934fda79a86f6e6\
41398ced759c3ff9b6477338f3da4f9cd85\
41414ea9982ccafb341b2384dd902f3d1ab\
4157ac61dd29c6f21ba5b862f3730e37cfd\
416c4fd806c22f221",
417 )
418 .unwrap(),
419 },
420 Test {
422 key: hex::decode("1c9240a5eb55d38af333888604f6b5f0473917c1402b80099dca5cbc207075c0")
423 .unwrap()
424 .try_into()
425 .unwrap(),
426 nonce: hex::decode("0000000000000002").unwrap().try_into().unwrap(),
427 initial_counter: 42,
428 plaintext: hex::decode(
429 "2754776173206272696c6c69672c2061\
4306e642074686520736c6974687920746f\
4317665730a446964206779726520616e64\
4322067696d626c6520696e207468652077\
4336162653a0a416c6c206d696d73792077\
4346572652074686520626f726f676f7665\
435732c0a416e6420746865206d6f6d6520\
4367261746873206f757467726162652e",
437 )
438 .unwrap(),
439 expected_ciphertext: hex::decode(
440 "62e6347f95ed87a45ffae7426f27a1df\
4415fb69110044c0d73118effa95b01e5cf\
442166d3df2d721caf9b21e5fb14c616871\
443fd84c54f9d65b283196c7fe4f60553eb\
444f39c6402c42234e32a356b3e764312a6\
4451a5532055716ead6962568f87d3f3f77\
44604c6a8d1bcd1bf4d50d6154b6da731b1\
44787b58dfd728afa36757a797ac188d1",
448 )
449 .unwrap(),
450 },
451 ];
452
453 for (i, test) in tests.into_iter().enumerate() {
454 let mut cipher = ChaCha20Djb::new(&test.key, &test.nonce);
455 cipher.set_counter(test.initial_counter);
456
457 let mut plaintext = test.plaintext.clone();
458 cipher.xor_keystream(&mut plaintext);
459
460 assert_eq!(
461 plaintext,
462 test.expected_ciphertext,
463 "test [{i}] failed
464Got ciphertext: {}
465Expected ciphertext: {}",
466 hex::encode(&plaintext),
467 hex::encode(&test.expected_ciphertext),
468 );
469
470 let mut cipher = ChaCha20Djb::new(&test.key, &test.nonce);
471 cipher.set_counter(test.initial_counter);
472 cipher.xor_keystream(&mut plaintext);
473
474 assert_eq!(
475 plaintext,
476 test.plaintext,
477 "test [{i}] failed. Initial plaintext != decrypt(encrypt(plaintext))
478Got: {}
479Expected: {}",
480 hex::encode(&plaintext),
481 hex::encode(&test.plaintext),
482 );
483
484 let mut cipher = ChaCha20Djb::new(&test.key, &test.nonce);
493 cipher.xor_keystream(&mut plaintext);
494 for n in 0..10 {
495 let mut partial_plaintext: Vec<u8> = test.plaintext.clone();
496
497 let mut cipher = ChaCha20Djb::new(&test.key, &test.nonce);
498 cipher.xor_keystream(&mut partial_plaintext[..n]);
499 cipher.xor_keystream(&mut partial_plaintext[n..]);
500
501 assert_eq!(
502 plaintext,
503 partial_plaintext,
504 "test [{i}] failed. partial encryption is not valid for n = {n}
505 Got: {}
506 Expected: {}",
507 hex::encode(&partial_plaintext),
508 hex::encode(&plaintext),
509 )
510 }
511 }
512 }
513
514 #[test]
515 fn chacha12_case_1() {
516 let nonce: &[u8; 8] = &[0xdb, 0x4b, 0x4a, 0x41, 0xd8, 0xdf, 0x18, 0xaa];
517 let key: &[u8; 32] = &[
518 0x27, 0xfc, 0x12, 0x0b, 0x01, 0x3b, 0x82, 0x9f, 0x1f, 0xae, 0xef, 0xd1, 0xab, 0x41, 0x7e, 0x86, 0x62, 0xf4,
519 0x3e, 0x0d, 0x73, 0xf9, 0x8d, 0xe8, 0x66, 0xe3, 0x46, 0x35, 0x31, 0x80, 0xfd, 0xb7,
520 ];
521
522 let mut buffer = [0u8; 100];
523 ChaCha12Djb::new(key, nonce).xor_keystream(&mut buffer);
524
525 assert_eq!(
526 buffer,
527 [
528 0x5f, 0x3c, 0x8c, 0x19, 0x0a, 0x78, 0xab, 0x7f, 0xe8, 0x08, 0xca, 0xe9, 0xcb, 0xcb, 0x0a, 0x98, 0x37,
529 0xc8, 0x93, 0x49, 0x2d, 0x96, 0x3a, 0x1c, 0x2e, 0xda, 0x6c, 0x15, 0x58, 0xb0, 0x2c, 0x83, 0xfc, 0x02,
530 0xa4, 0x4c, 0xbb, 0xb7, 0xe6, 0x20, 0x4d, 0x51, 0xd1, 0xc2, 0x43, 0x0e, 0x9c, 0x0b, 0x58, 0xf2, 0x93,
531 0x7b, 0xf5, 0x93, 0x84, 0x0c, 0x85, 0x0b, 0xda, 0x90, 0x51, 0xa1, 0xf0, 0x51, 0xdd, 0xf0, 0x9d, 0x2a,
532 0x03, 0xeb, 0xf0, 0x9f, 0x01, 0xbd, 0xba, 0x9d, 0xa0, 0xb6, 0xda, 0x79, 0x1b, 0x2e, 0x64, 0x56, 0x41,
533 0x04, 0x7d, 0x11, 0xeb, 0xf8, 0x50, 0x87, 0xd4, 0xde, 0x5c, 0x01, 0x5f, 0xdd, 0xd0, 0x44,
534 ]
535 );
536 }
537
538 #[test]
539 fn chacha8_case_1() {
540 let key = &[
541 0x64, 0x1a, 0xea, 0xeb, 0x08, 0x03, 0x6b, 0x61, 0x7a, 0x42, 0xcf, 0x14, 0xe8, 0xc5, 0xd2, 0xd1, 0x15, 0xf8,
542 0xd7, 0xcb, 0x6e, 0xa5, 0xe2, 0x8b, 0x9b, 0xfa, 0xf8, 0x3e, 0x03, 0x84, 0x26, 0xa7,
543 ];
544 let nonce = &[0xa1, 0x4a, 0x11, 0x68, 0x27, 0x1d, 0x45, 0x9b];
545
546 let mut buffer = [0u8; 100];
547 ChaCha8Djb::new(key, nonce).xor_keystream(&mut buffer);
548
549 assert_eq!(
550 buffer,
551 [
552 0x17, 0x21, 0xc0, 0x44, 0xa8, 0xa6, 0x45, 0x35, 0x22, 0xdd, 0xdb, 0x31, 0x43, 0xd0, 0xbe, 0x35, 0x12,
553 0x63, 0x3c, 0xa3, 0xc7, 0x9b, 0xf8, 0xcc, 0xc3, 0x59, 0x4c, 0xb2, 0xc2, 0xf3, 0x10, 0xf7, 0xbd, 0x54,
554 0x4f, 0x55, 0xce, 0x0d, 0xb3, 0x81, 0x23, 0x41, 0x2d, 0x6c, 0x45, 0x20, 0x7d, 0x5c, 0xf9, 0xaf, 0x0c,
555 0x6c, 0x68, 0x0c, 0xce, 0x1f, 0x7e, 0x43, 0x38, 0x8d, 0x1b, 0x03, 0x46, 0xb7, 0x13, 0x3c, 0x59, 0xfd,
556 0x6a, 0xf4, 0xa5, 0xa5, 0x68, 0xaa, 0x33, 0x4c, 0xcd, 0xc3, 0x8a, 0xf5, 0xac, 0xe2, 0x01, 0xdf, 0x84,
557 0xd0, 0xa3, 0xca, 0x22, 0x54, 0x94, 0xca, 0x62, 0x09, 0x34, 0x5f, 0xcf, 0x30, 0x13, 0x2e,
558 ]
559 );
560 }
561}