Skip to main content

crypto/mlkem/
mlkem.rs

1use constant_time_eq::constant_time_eq;
2#[cfg(feature = "zeroize")]
3use zeroize::{Zeroize, ZeroizeOnDrop};
4
5use crate::{
6    Xof,
7    sha3::{Sha3_256, Sha3_512, Shake128, Shake256},
8};
9
10pub const SHARED_SECRET_SIZE: usize = 32;
11
12pub(crate) const N: usize = 256;
13pub(crate) const Q: i16 = 3329;
14const SYMBYTES: usize = 32;
15const POLY_BYTES: usize = 384;
16const SHAKE128_RATE: usize = 168;
17const QINV: i16 = -3327;
18const MONT_SQUARED_DIV_N: i16 = 1441;
19const ZETAS: [i16; 128] = [
20    -1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577, 182, 962, -1202, -1474, 1468, 573, -1325, 264,
21    383, -829, 1458, -1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, 1223, 652, -552, 1015, -1293, 1491,
22    -282, -1544, 516, -8, -320, -666, -1618, -1162, 126, 1469, -853, -90, -271, 830, 107, -1421, -247, -951, -398, 961,
23    -1508, -725, 448, -1065, 677, -1275, -1103, 430, 555, 843, -1251, 871, 1550, 105, 422, 587, 177, -235, -291, -460,
24    1574, 1653, -246, 778, 1159, -147, -777, 1483, -602, 1119, -1590, 644, -872, 349, 418, 329, -156, -75, 817, 1097,
25    603, 610, 1322, -1285, -1465, 384, -1215, -136, 1218, -1335, -874, 220, -1187, -1659, -1185, -1530, -1278, 794,
26    -1510, -854, -870, 478, -108, -308, 996, 991, 958, -1460, 1522, 1628,
27];
28
29pub(crate) const ML_KEM_768: MlKemParams<3> = MlKemParams {
30    eta1: 2,
31    polycompressedbytes: 128,
32    polyveccompressedbytes: 960,
33};
34pub(crate) const ML_KEM_1024: MlKemParams<4> = MlKemParams {
35    eta1: 2,
36    polycompressedbytes: 160,
37    polyveccompressedbytes: 1408,
38};
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum MlKemError {
42    InvalidKey,
43}
44
45impl core::fmt::Display for MlKemError {
46    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
47        match self {
48            MlKemError::InvalidKey => write!(f, "key is not valid"),
49        }
50    }
51}
52
53#[derive(Clone, Copy)]
54pub(crate) struct MlKemParams<const K: usize> {
55    pub(crate) eta1: usize,
56    pub(crate) polycompressedbytes: usize,
57    pub(crate) polyveccompressedbytes: usize,
58}
59
60#[derive(Clone, Debug, PartialEq, Eq)]
61#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
62pub(crate) struct Poly {
63    pub(crate) coeffs: [i16; N],
64}
65
66impl Default for Poly {
67    #[inline]
68    fn default() -> Self {
69        Self {
70            coeffs: [0; N],
71        }
72    }
73}
74
75#[derive(Clone, Debug, PartialEq, Eq)]
76#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
77pub(crate) struct PolyVec<const K: usize> {
78    pub(crate) vec: [Poly; K],
79}
80
81impl<const K: usize> Default for PolyVec<K> {
82    #[inline]
83    fn default() -> Self {
84        Self {
85            vec: core::array::from_fn(|_| Poly::default()),
86        }
87    }
88}
89
90#[inline]
91pub(crate) fn crypto_kem_keypair_derand<const K: usize, const SECRET_KEY_SIZE: usize, const PUBLIC_KEY_SIZE: usize>(
92    params: &MlKemParams<K>,
93    coins: &[u8; 64],
94) -> ([u8; SECRET_KEY_SIZE], [u8; PUBLIC_KEY_SIZE]) {
95    let mut public_key = [0u8; PUBLIC_KEY_SIZE];
96    let mut secret_key = [0u8; SECRET_KEY_SIZE];
97
98    indcpa_keypair_derand::<K>(
99        params,
100        &mut public_key,
101        &mut secret_key[..indcpa_secret_key_bytes::<K>()],
102        &coins[..32],
103    );
104    secret_key[indcpa_secret_key_bytes::<K>()..indcpa_secret_key_bytes::<K>() + PUBLIC_KEY_SIZE]
105        .copy_from_slice(&public_key);
106
107    let public_key_hash = hash_h(&public_key);
108    secret_key[SECRET_KEY_SIZE - 64..SECRET_KEY_SIZE - 32].copy_from_slice(&public_key_hash);
109    secret_key[SECRET_KEY_SIZE - 32..].copy_from_slice(&coins[32..]);
110
111    (secret_key, public_key)
112}
113
114#[inline]
115pub(crate) fn crypto_kem_enc_derand<const K: usize, const PUBLIC_KEY_SIZE: usize, const CIPHERTEXT_SIZE: usize>(
116    params: &MlKemParams<K>,
117    public_key: &[u8; PUBLIC_KEY_SIZE],
118    coins: &[u8; 32],
119) -> ([u8; CIPHERTEXT_SIZE], [u8; SHARED_SECRET_SIZE]) {
120    let mut ciphertext = [0u8; CIPHERTEXT_SIZE];
121    let mut buf = [0u8; 64];
122    let mut kr = [0u8; 64];
123
124    buf[..32].copy_from_slice(coins);
125    buf[32..].copy_from_slice(&hash_h(public_key));
126    kr.copy_from_slice(&hash_g(&buf));
127
128    indcpa_enc::<K>(params, &mut ciphertext, &buf[..32], public_key, array_ref_32(&kr[32..64]));
129
130    let mut shared_secret = [0u8; SHARED_SECRET_SIZE];
131    shared_secret.copy_from_slice(&kr[..32]);
132    (ciphertext, shared_secret)
133}
134
135#[inline]
136pub(crate) fn crypto_kem_dec<const K: usize, const SECRET_KEY_SIZE: usize, const CIPHERTEXT_SIZE: usize>(
137    params: &MlKemParams<K>,
138    secret_key: &[u8; SECRET_KEY_SIZE],
139    ciphertext: &[u8; CIPHERTEXT_SIZE],
140) -> Result<[u8; SHARED_SECRET_SIZE], MlKemError> {
141    let public_key_offset = indcpa_secret_key_bytes::<K>();
142    let public_key_size = public_key_bytes::<K>();
143    if SECRET_KEY_SIZE != secret_key_size::<K>() {
144        return Err(MlKemError::InvalidKey);
145    }
146
147    let public_key = &secret_key[public_key_offset..public_key_offset + public_key_size];
148    let mut message_and_hash = [0u8; 64];
149    let mut kr = [0u8; 64];
150    let mut cmp = [0u8; CIPHERTEXT_SIZE];
151
152    indcpa_dec::<K>(
153        params,
154        &mut message_and_hash[..32],
155        ciphertext,
156        &secret_key[..public_key_offset],
157    );
158    message_and_hash[32..].copy_from_slice(&secret_key[SECRET_KEY_SIZE - 64..SECRET_KEY_SIZE - 32]);
159    kr.copy_from_slice(&hash_g(&message_and_hash));
160
161    indcpa_enc::<K>(params, &mut cmp, &message_and_hash[..32], public_key, array_ref_32(&kr[32..64]));
162
163    let mut shared_secret = rkprf(array_ref_32(&secret_key[SECRET_KEY_SIZE - 32..]), ciphertext);
164    cmov(&mut shared_secret, array_ref_32(&kr[..32]), constant_time_eq(ciphertext, &cmp));
165    Ok(shared_secret)
166}
167
168#[inline]
169pub(crate) fn indcpa_keypair_derand<const K: usize>(
170    params: &MlKemParams<K>,
171    public_key: &mut [u8],
172    secret_key: &mut [u8],
173    coins: &[u8],
174) {
175    debug_assert_eq!(public_key.len(), public_key_bytes::<K>());
176    debug_assert_eq!(secret_key.len(), indcpa_secret_key_bytes::<K>());
177    debug_assert_eq!(coins.len(), 32);
178
179    let mut g_input = [0u8; 33];
180    g_input[..32].copy_from_slice(coins);
181    g_input[32] = K as u8;
182    let seed_output = hash_g(&g_input);
183    let public_seed = array_ref_32(&seed_output[..32]);
184    let noise_seed = array_ref_32(&seed_output[32..64]);
185    let matrix = gen_matrix::<K>(public_seed, false);
186
187    let mut skpv = PolyVec::<K>::default();
188    let mut e = PolyVec::<K>::default();
189    for (index, poly) in skpv.vec.iter_mut().enumerate() {
190        *poly = poly_getnoise(noise_seed, index as u8, params.eta1);
191    }
192    for (index, poly) in e.vec.iter_mut().enumerate() {
193        *poly = poly_getnoise(noise_seed, (K + index) as u8, params.eta1);
194    }
195
196    polyvec_ntt(&mut skpv);
197    polyvec_ntt(&mut e);
198
199    let mut pkpv = PolyVec::<K>::default();
200    for i in 0..K {
201        pkpv.vec[i] = polyvec_basemul_acc_montgomery(&matrix[i], &skpv);
202        poly_tomont(&mut pkpv.vec[i]);
203    }
204
205    polyvec_add(&mut pkpv, &e);
206    polyvec_reduce(&mut pkpv);
207
208    pack_sk(secret_key, &skpv);
209    pack_pk(public_key, &pkpv, public_seed);
210}
211
212#[inline]
213pub(crate) fn indcpa_enc<const K: usize>(
214    params: &MlKemParams<K>,
215    ciphertext: &mut [u8],
216    message: &[u8],
217    public_key: &[u8],
218    coins: &[u8; 32],
219) {
220    debug_assert_eq!(ciphertext.len(), ciphertext_bytes(params));
221    debug_assert_eq!(message.len(), 32);
222    debug_assert_eq!(public_key.len(), public_key_bytes::<K>());
223
224    let (pkpv, seed) = unpack_pk::<K>(public_key);
225    let at = gen_matrix::<K>(&seed, true);
226    let k = poly_frommsg(message);
227
228    let mut sp = PolyVec::<K>::default();
229    let mut ep = PolyVec::<K>::default();
230    for (index, poly) in sp.vec.iter_mut().enumerate() {
231        *poly = poly_getnoise(coins, index as u8, params.eta1);
232    }
233    let ep_nonce_offset = sp.vec.len();
234    for (index, poly) in ep.vec.iter_mut().enumerate() {
235        *poly = poly_getnoise(coins, (ep_nonce_offset + index) as u8, 2);
236    }
237    let epp = poly_getnoise(coins, (sp.vec.len() + ep.vec.len()) as u8, 2);
238
239    polyvec_ntt(&mut sp);
240
241    let mut b = PolyVec::<K>::default();
242    for i in 0..K {
243        b.vec[i] = polyvec_basemul_acc_montgomery(&at[i], &sp);
244    }
245    let mut v = polyvec_basemul_acc_montgomery(&pkpv, &sp);
246
247    polyvec_invntt_tomont(&mut b);
248    poly_invntt_tomont(&mut v);
249
250    polyvec_add(&mut b, &ep);
251    poly_add(&mut v, &epp);
252    poly_add(&mut v, &k);
253    polyvec_reduce(&mut b);
254    poly_reduce(&mut v);
255
256    pack_ciphertext(params, ciphertext, &b, &v);
257}
258
259#[inline]
260pub(crate) fn indcpa_dec<const K: usize>(
261    params: &MlKemParams<K>,
262    message: &mut [u8],
263    ciphertext: &[u8],
264    secret_key: &[u8],
265) {
266    debug_assert_eq!(message.len(), 32);
267    debug_assert_eq!(ciphertext.len(), ciphertext_bytes(params));
268    debug_assert_eq!(secret_key.len(), indcpa_secret_key_bytes::<K>());
269
270    let (mut b, v) = unpack_ciphertext::<K>(params, ciphertext);
271    let skpv = unpack_sk::<K>(secret_key);
272
273    polyvec_ntt(&mut b);
274    let mut mp = polyvec_basemul_acc_montgomery(&skpv, &b);
275    poly_invntt_tomont(&mut mp);
276    let product = mp.clone();
277    poly_sub(&mut mp, &v, &product);
278    poly_reduce(&mut mp);
279
280    message.copy_from_slice(&poly_tomsg(&mp));
281}
282
283#[inline]
284fn pack_pk<const K: usize>(out: &mut [u8], pk: &PolyVec<K>, seed: &[u8; 32]) {
285    let polyvec_bytes = polyvec_bytes::<K>();
286    polyvec_tobytes(&mut out[..polyvec_bytes], pk);
287    out[polyvec_bytes..polyvec_bytes + 32].copy_from_slice(seed);
288}
289
290#[inline]
291fn unpack_pk<const K: usize>(packed: &[u8]) -> (PolyVec<K>, [u8; 32]) {
292    let polyvec_bytes = polyvec_bytes::<K>();
293    let pk = polyvec_frombytes::<K>(&packed[..polyvec_bytes]);
294    let mut seed = [0u8; 32];
295    seed.copy_from_slice(&packed[polyvec_bytes..polyvec_bytes + 32]);
296    (pk, seed)
297}
298
299#[inline]
300fn pack_sk<const K: usize>(out: &mut [u8], sk: &PolyVec<K>) {
301    polyvec_tobytes(out, sk);
302}
303
304#[inline]
305fn unpack_sk<const K: usize>(packed: &[u8]) -> PolyVec<K> {
306    polyvec_frombytes(packed)
307}
308
309#[inline]
310fn pack_ciphertext<const K: usize>(params: &MlKemParams<K>, out: &mut [u8], b: &PolyVec<K>, v: &Poly) {
311    let split = params.polyveccompressedbytes;
312    polyvec_compress(params, &mut out[..split], b);
313    poly_compress(params, &mut out[split..split + params.polycompressedbytes], v);
314}
315
316#[inline]
317fn unpack_ciphertext<const K: usize>(params: &MlKemParams<K>, packed: &[u8]) -> (PolyVec<K>, Poly) {
318    let split = params.polyveccompressedbytes;
319    (
320        polyvec_decompress(params, &packed[..split]),
321        poly_decompress(params, &packed[split..split + params.polycompressedbytes]),
322    )
323}
324
325#[inline]
326pub(crate) fn gen_matrix<const K: usize>(seed: &[u8; 32], transpose: bool) -> [PolyVec<K>; K] {
327    let mut matrix = core::array::from_fn(|_| PolyVec::<K>::default());
328    for i in 0..K {
329        for j in 0..K {
330            let (x, y) = if transpose {
331                (i as u8, j as u8)
332            } else {
333                (j as u8, i as u8)
334            };
335            matrix[i].vec[j] = uniform_poly(seed, x, y);
336        }
337    }
338    matrix
339}
340
341#[inline]
342fn uniform_poly(seed: &[u8; 32], x: u8, y: u8) -> Poly {
343    let mut shake = Shake128::new();
344    shake.absorb(seed);
345    shake.absorb(&[x, y]);
346
347    let mut poly = Poly::default();
348    let mut ctr = 0usize;
349    let mut block = [0u8; SHAKE128_RATE];
350    while ctr < N {
351        shake.squeeze(&mut block);
352        ctr += rej_uniform(&mut poly.coeffs[ctr..], &block);
353    }
354    poly
355}
356
357#[inline]
358fn rej_uniform(out: &mut [i16], buf: &[u8]) -> usize {
359    let mut ctr = 0usize;
360    let mut pos = 0usize;
361    while ctr < out.len() && pos + 3 <= buf.len() {
362        let val0 = (((buf[pos] as u16) | ((buf[pos + 1] as u16) << 8)) & 0x0fff) as i16;
363        let val1 = ((((buf[pos + 1] as u16) >> 4) | ((buf[pos + 2] as u16) << 4)) & 0x0fff) as i16;
364        pos += 3;
365
366        if val0 < Q {
367            out[ctr] = val0;
368            ctr += 1;
369        }
370        if ctr < out.len() && val1 < Q {
371            out[ctr] = val1;
372            ctr += 1;
373        }
374    }
375    ctr
376}
377
378#[inline]
379pub(crate) fn poly_getnoise(seed: &[u8; 32], nonce: u8, eta: usize) -> Poly {
380    debug_assert_eq!(eta, 2);
381    let mut input = [0u8; 33];
382    input[..32].copy_from_slice(seed);
383    input[32] = nonce;
384    let mut buf = [0u8; 128];
385    Shake256::hash(&input, &mut buf);
386    cbd2(&buf)
387}
388
389#[inline]
390fn cbd2(buf: &[u8; 128]) -> Poly {
391    let mut poly = Poly::default();
392    for i in 0..(N / 8) {
393        let t = load32(&buf[4 * i..4 * i + 4]);
394        let mut d = t & 0x5555_5555;
395        d += (t >> 1) & 0x5555_5555;
396        for j in 0..8 {
397            let a = ((d >> (4 * j)) & 0x3) as i16;
398            let b = ((d >> (4 * j + 2)) & 0x3) as i16;
399            poly.coeffs[8 * i + j] = a - b;
400        }
401    }
402    poly
403}
404
405#[inline]
406pub(crate) fn polyvec_compress<const K: usize>(params: &MlKemParams<K>, out: &mut [u8], a: &PolyVec<K>) {
407    match params.polyveccompressedbytes {
408        960 => {
409            let mut offset = 0usize;
410            for poly in &a.vec {
411                for chunk in poly.coeffs.chunks_exact(4) {
412                    let mut t = [0u16; 4];
413                    for (dst, coeff) in t.iter_mut().zip(chunk.iter()) {
414                        let mut u = *coeff as i32;
415                        u += (u >> 15) & Q as i32;
416                        let mut d0 = u as u64;
417                        d0 <<= 10;
418                        d0 += 1665;
419                        d0 *= 1_290_167;
420                        d0 >>= 32;
421                        *dst = (d0 as u16) & 0x03ff;
422                    }
423                    out[offset] = t[0] as u8;
424                    out[offset + 1] = ((t[0] >> 8) as u8) | ((t[1] << 2) as u8);
425                    out[offset + 2] = ((t[1] >> 6) as u8) | ((t[2] << 4) as u8);
426                    out[offset + 3] = ((t[2] >> 4) as u8) | ((t[3] << 6) as u8);
427                    out[offset + 4] = (t[3] >> 2) as u8;
428                    offset += 5;
429                }
430            }
431        }
432        1408 => {
433            let mut offset = 0usize;
434            for poly in &a.vec {
435                for chunk in poly.coeffs.chunks_exact(8) {
436                    let mut t = [0u16; 8];
437                    for (dst, coeff) in t.iter_mut().zip(chunk.iter()) {
438                        let mut u = *coeff as i32;
439                        u += (u >> 15) & Q as i32;
440                        let mut d0 = u as u64;
441                        d0 <<= 11;
442                        d0 += 1664;
443                        d0 *= 645_084;
444                        d0 >>= 31;
445                        *dst = (d0 as u16) & 0x07ff;
446                    }
447                    out[offset] = t[0] as u8;
448                    out[offset + 1] = ((t[0] >> 8) as u8) | ((t[1] << 3) as u8);
449                    out[offset + 2] = ((t[1] >> 5) as u8) | ((t[2] << 6) as u8);
450                    out[offset + 3] = (t[2] >> 2) as u8;
451                    out[offset + 4] = ((t[2] >> 10) as u8) | ((t[3] << 1) as u8);
452                    out[offset + 5] = ((t[3] >> 7) as u8) | ((t[4] << 4) as u8);
453                    out[offset + 6] = ((t[4] >> 4) as u8) | ((t[5] << 7) as u8);
454                    out[offset + 7] = (t[5] >> 1) as u8;
455                    out[offset + 8] = ((t[5] >> 9) as u8) | ((t[6] << 2) as u8);
456                    out[offset + 9] = ((t[6] >> 6) as u8) | ((t[7] << 5) as u8);
457                    out[offset + 10] = (t[7] >> 3) as u8;
458                    offset += 11;
459                }
460            }
461        }
462        _ => unreachable!(),
463    }
464}
465
466#[inline]
467pub(crate) fn polyvec_decompress<const K: usize>(params: &MlKemParams<K>, input: &[u8]) -> PolyVec<K> {
468    let mut out = PolyVec::<K>::default();
469    match params.polyveccompressedbytes {
470        960 => {
471            let mut offset = 0usize;
472            for poly in &mut out.vec {
473                for j in 0..(N / 4) {
474                    let t0 = (input[offset] as u16) | ((input[offset + 1] as u16) << 8);
475                    let t1 = ((input[offset + 1] as u16) >> 2) | ((input[offset + 2] as u16) << 6);
476                    let t2 = ((input[offset + 2] as u16) >> 4) | ((input[offset + 3] as u16) << 4);
477                    let t3 = ((input[offset + 3] as u16) >> 6) | ((input[offset + 4] as u16) << 2);
478                    offset += 5;
479                    poly.coeffs[4 * j] = ((((t0 & 0x03ff) as u32) * Q as u32 + 512) >> 10) as i16;
480                    poly.coeffs[4 * j + 1] = ((((t1 & 0x03ff) as u32) * Q as u32 + 512) >> 10) as i16;
481                    poly.coeffs[4 * j + 2] = ((((t2 & 0x03ff) as u32) * Q as u32 + 512) >> 10) as i16;
482                    poly.coeffs[4 * j + 3] = ((((t3 & 0x03ff) as u32) * Q as u32 + 512) >> 10) as i16;
483                }
484            }
485        }
486        1408 => {
487            let mut offset = 0usize;
488            for poly in &mut out.vec {
489                for j in 0..(N / 8) {
490                    let t0 = (input[offset] as u16) | ((input[offset + 1] as u16) << 8);
491                    let t1 = ((input[offset + 1] as u16) >> 3) | ((input[offset + 2] as u16) << 5);
492                    let t2 = ((input[offset + 2] as u16) >> 6)
493                        | ((input[offset + 3] as u16) << 2)
494                        | ((input[offset + 4] as u16) << 10);
495                    let t3 = ((input[offset + 4] as u16) >> 1) | ((input[offset + 5] as u16) << 7);
496                    let t4 = ((input[offset + 5] as u16) >> 4) | ((input[offset + 6] as u16) << 4);
497                    let t5 = ((input[offset + 6] as u16) >> 7)
498                        | ((input[offset + 7] as u16) << 1)
499                        | ((input[offset + 8] as u16) << 9);
500                    let t6 = ((input[offset + 8] as u16) >> 2) | ((input[offset + 9] as u16) << 6);
501                    let t7 = ((input[offset + 9] as u16) >> 5) | ((input[offset + 10] as u16) << 3);
502                    offset += 11;
503                    let values = [t0, t1, t2, t3, t4, t5, t6, t7];
504                    for (k, value) in values.into_iter().enumerate() {
505                        poly.coeffs[8 * j + k] = ((((value & 0x07ff) as u32) * Q as u32 + 1024) >> 11) as i16;
506                    }
507                }
508            }
509        }
510        _ => unreachable!(),
511    }
512    out
513}
514
515#[inline]
516fn polyvec_tobytes<const K: usize>(out: &mut [u8], polyvec: &PolyVec<K>) {
517    for (i, poly) in polyvec.vec.iter().enumerate() {
518        poly_tobytes(&mut out[i * POLY_BYTES..(i + 1) * POLY_BYTES], poly);
519    }
520}
521
522#[inline]
523fn polyvec_frombytes<const K: usize>(input: &[u8]) -> PolyVec<K> {
524    let mut out = PolyVec::<K>::default();
525    for (i, poly) in out.vec.iter_mut().enumerate() {
526        *poly = poly_frombytes(&input[i * POLY_BYTES..(i + 1) * POLY_BYTES]);
527    }
528    out
529}
530
531#[inline]
532fn polyvec_ntt<const K: usize>(polyvec: &mut PolyVec<K>) {
533    for poly in &mut polyvec.vec {
534        poly_ntt(poly);
535    }
536}
537
538#[inline]
539fn polyvec_invntt_tomont<const K: usize>(polyvec: &mut PolyVec<K>) {
540    for poly in &mut polyvec.vec {
541        poly_invntt_tomont(poly);
542    }
543}
544
545#[inline]
546fn polyvec_basemul_acc_montgomery<const K: usize>(a: &PolyVec<K>, b: &PolyVec<K>) -> Poly {
547    let mut out = poly_basemul_montgomery(&a.vec[0], &b.vec[0]);
548    for i in 1..K {
549        let t = poly_basemul_montgomery(&a.vec[i], &b.vec[i]);
550        poly_add(&mut out, &t);
551    }
552    poly_reduce(&mut out);
553    out
554}
555
556#[inline]
557fn polyvec_reduce<const K: usize>(polyvec: &mut PolyVec<K>) {
558    for poly in &mut polyvec.vec {
559        poly_reduce(poly);
560    }
561}
562
563#[inline]
564fn polyvec_add<const K: usize>(left: &mut PolyVec<K>, right: &PolyVec<K>) {
565    for i in 0..K {
566        poly_add(&mut left.vec[i], &right.vec[i]);
567    }
568}
569
570#[inline]
571fn poly_compress<const K: usize>(params: &MlKemParams<K>, out: &mut [u8], poly: &Poly) {
572    match params.polycompressedbytes {
573        128 => {
574            let mut offset = 0usize;
575            for chunk in poly.coeffs.chunks_exact(8) {
576                let mut t = [0u8; 8];
577                for (dst, coeff) in t.iter_mut().zip(chunk.iter()) {
578                    let mut u = *coeff as i32;
579                    u += (u >> 15) & Q as i32;
580                    let mut d0 = ((u as u32) << 4) as u64;
581                    d0 += 1665;
582                    d0 *= 80_635;
583                    d0 >>= 28;
584                    *dst = (d0 as u8) & 0x0f;
585                }
586                out[offset] = t[0] | (t[1] << 4);
587                out[offset + 1] = t[2] | (t[3] << 4);
588                out[offset + 2] = t[4] | (t[5] << 4);
589                out[offset + 3] = t[6] | (t[7] << 4);
590                offset += 4;
591            }
592        }
593        160 => {
594            let mut offset = 0usize;
595            for chunk in poly.coeffs.chunks_exact(8) {
596                let mut t = [0u8; 8];
597                for (dst, coeff) in t.iter_mut().zip(chunk.iter()) {
598                    let mut u = *coeff as i32;
599                    u += (u >> 15) & Q as i32;
600                    let mut d0 = ((u as u32) << 5) as u64;
601                    d0 += 1664;
602                    d0 *= 40_318;
603                    d0 >>= 27;
604                    *dst = (d0 as u8) & 0x1f;
605                }
606                out[offset] = t[0] | (t[1] << 5);
607                out[offset + 1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7);
608                out[offset + 2] = (t[3] >> 1) | (t[4] << 4);
609                out[offset + 3] = (t[4] >> 4) | (t[5] << 1) | (t[6] << 6);
610                out[offset + 4] = (t[6] >> 2) | (t[7] << 3);
611                offset += 5;
612            }
613        }
614        _ => unreachable!(),
615    }
616}
617
618#[inline]
619fn poly_decompress<const K: usize>(params: &MlKemParams<K>, input: &[u8]) -> Poly {
620    let mut out = Poly::default();
621    match params.polycompressedbytes {
622        128 => {
623            for i in 0..(N / 2) {
624                out.coeffs[2 * i] = ((((input[i] & 0x0f) as u16) * Q as u16 + 8) >> 4) as i16;
625                out.coeffs[2 * i + 1] = ((((input[i] >> 4) as u16) * Q as u16 + 8) >> 4) as i16;
626            }
627        }
628        160 => {
629            let mut offset = 0usize;
630            for i in 0..(N / 8) {
631                let t0 = input[offset] >> 0;
632                let t1 = (input[offset] >> 5) | (input[offset + 1] << 3);
633                let t2 = input[offset + 1] >> 2;
634                let t3 = (input[offset + 1] >> 7) | (input[offset + 2] << 1);
635                let t4 = (input[offset + 2] >> 4) | (input[offset + 3] << 4);
636                let t5 = input[offset + 3] >> 1;
637                let t6 = (input[offset + 3] >> 6) | (input[offset + 4] << 2);
638                let t7 = input[offset + 4] >> 3;
639                offset += 5;
640                let values = [t0, t1, t2, t3, t4, t5, t6, t7];
641                for (j, value) in values.into_iter().enumerate() {
642                    out.coeffs[8 * i + j] = (((value as u32 & 31) * Q as u32 + 16) >> 5) as i16;
643                }
644            }
645        }
646        _ => unreachable!(),
647    }
648    out
649}
650
651#[inline]
652fn poly_tobytes(out: &mut [u8], poly: &Poly) {
653    for i in 0..(N / 2) {
654        let mut t0 = poly.coeffs[2 * i] as i32;
655        t0 += (t0 >> 15) & Q as i32;
656        let mut t1 = poly.coeffs[2 * i + 1] as i32;
657        t1 += (t1 >> 15) & Q as i32;
658        out[3 * i] = t0 as u8;
659        out[3 * i + 1] = ((t0 >> 8) as u8) | ((t1 << 4) as u8);
660        out[3 * i + 2] = (t1 >> 4) as u8;
661    }
662}
663
664#[inline]
665fn poly_frombytes(input: &[u8]) -> Poly {
666    let mut out = Poly::default();
667    for i in 0..(N / 2) {
668        out.coeffs[2 * i] = (((input[3 * i] as u16) | ((input[3 * i + 1] as u16) << 8)) & 0x0fff) as i16;
669        out.coeffs[2 * i + 1] = ((((input[3 * i + 1] as u16) >> 4) | ((input[3 * i + 2] as u16) << 4)) & 0x0fff) as i16;
670    }
671    out
672}
673
674#[inline]
675pub(crate) fn poly_frommsg(msg: &[u8]) -> Poly {
676    let mut out = Poly::default();
677    let half_q: i16 = ((Q + 1) / 2) as i16;
678    for i in 0..(N / 8) {
679        for j in 0..8 {
680            let bit = ((msg[i] >> j) & 1) as i16;
681            out.coeffs[8 * i + j] = (-bit) & half_q;
682        }
683    }
684    out
685}
686
687#[inline]
688pub(crate) fn poly_tomsg(poly: &Poly) -> [u8; 32] {
689    let mut msg = [0u8; 32];
690    for i in 0..(N / 8) {
691        for j in 0..8 {
692            let mut t = poly.coeffs[8 * i + j] as i32;
693            t <<= 1;
694            t += 1665;
695            t *= 80_635;
696            t >>= 28;
697            msg[i] |= ((t & 1) as u8) << j;
698        }
699    }
700    msg
701}
702
703#[inline]
704fn poly_ntt(poly: &mut Poly) {
705    ntt(&mut poly.coeffs);
706    poly_reduce(poly);
707}
708
709#[inline]
710fn poly_invntt_tomont(poly: &mut Poly) {
711    invntt(&mut poly.coeffs);
712}
713
714#[inline]
715fn poly_basemul_montgomery(a: &Poly, b: &Poly) -> Poly {
716    let mut out = Poly::default();
717    for i in 0..(N / 4) {
718        let r0 = basemul(
719            [a.coeffs[4 * i], a.coeffs[4 * i + 1]],
720            [b.coeffs[4 * i], b.coeffs[4 * i + 1]],
721            ZETAS[64 + i],
722        );
723        out.coeffs[4 * i] = r0[0];
724        out.coeffs[4 * i + 1] = r0[1];
725
726        let r1 = basemul(
727            [a.coeffs[4 * i + 2], a.coeffs[4 * i + 3]],
728            [b.coeffs[4 * i + 2], b.coeffs[4 * i + 3]],
729            -ZETAS[64 + i],
730        );
731        out.coeffs[4 * i + 2] = r1[0];
732        out.coeffs[4 * i + 3] = r1[1];
733    }
734    out
735}
736
737#[inline]
738fn poly_tomont(poly: &mut Poly) {
739    for coeff in &mut poly.coeffs {
740        *coeff = montgomery_reduce(*coeff as i32 * 1353);
741    }
742}
743
744#[inline]
745fn poly_reduce(poly: &mut Poly) {
746    for coeff in &mut poly.coeffs {
747        *coeff = barrett_reduce(*coeff);
748    }
749}
750
751#[inline]
752fn poly_add(left: &mut Poly, right: &Poly) {
753    for i in 0..N {
754        left.coeffs[i] = (left.coeffs[i] as i32 + right.coeffs[i] as i32) as i16;
755    }
756}
757
758#[inline]
759fn poly_sub(out: &mut Poly, left: &Poly, right: &Poly) {
760    for i in 0..N {
761        out.coeffs[i] = (left.coeffs[i] as i32 - right.coeffs[i] as i32) as i16;
762    }
763}
764
765#[inline]
766fn ntt(r: &mut [i16; N]) {
767    let mut k = 1usize;
768    let mut len = 128usize;
769    while len >= 2 {
770        let mut start = 0usize;
771        while start < N {
772            let zeta = ZETAS[k];
773            k += 1;
774            for j in start..start + len {
775                let t = fqmul(zeta, r[j + len]);
776                let rj = r[j] as i32;
777                r[j + len] = (rj - t as i32) as i16;
778                r[j] = (rj + t as i32) as i16;
779            }
780            start += 2 * len;
781        }
782        len >>= 1;
783    }
784}
785
786#[inline]
787fn invntt(r: &mut [i16; N]) {
788    let mut k = 127usize;
789    let mut len = 2usize;
790    while len <= 128 {
791        let mut start = 0usize;
792        while start < N {
793            let zeta = ZETAS[k];
794            k -= 1;
795            for j in start..start + len {
796                let t = r[j];
797                r[j] = barrett_reduce((t as i32 + r[j + len] as i32) as i16);
798                r[j + len] = fqmul(zeta, (r[j + len] as i32 - t as i32) as i16);
799            }
800            start += 2 * len;
801        }
802        len <<= 1;
803    }
804
805    for coeff in r.iter_mut() {
806        *coeff = fqmul(*coeff, MONT_SQUARED_DIV_N);
807    }
808}
809
810#[inline]
811fn basemul(a: [i16; 2], b: [i16; 2], zeta: i16) -> [i16; 2] {
812    let mut out = [0i16; 2];
813    out[0] = fqmul(a[1], b[1]);
814    out[0] = fqmul(out[0], zeta);
815    out[0] = (out[0] as i32 + fqmul(a[0], b[0]) as i32) as i16;
816    out[1] = (fqmul(a[0], b[1]) as i32 + fqmul(a[1], b[0]) as i32) as i16;
817    out
818}
819
820#[inline]
821fn fqmul(a: i16, b: i16) -> i16 {
822    montgomery_reduce(a as i32 * b as i32)
823}
824
825#[inline]
826fn montgomery_reduce(a: i32) -> i16 {
827    let t = (a as i16).wrapping_mul(QINV) as i32;
828    ((a - t * Q as i32) >> 16) as i16
829}
830
831#[inline]
832fn barrett_reduce(a: i16) -> i16 {
833    const V: i32 = ((1 << 26) + (Q as i32 / 2)) / Q as i32;
834    let t = ((V * a as i32 + (1 << 25)) >> 26) * Q as i32;
835    (a as i32 - t) as i16
836}
837
838#[inline]
839fn hash_h(data: &[u8]) -> [u8; 32] {
840    let mut hasher = Sha3_256::new();
841    hasher.write(data);
842    hasher.sum()
843}
844
845#[inline]
846fn hash_g(data: &[u8]) -> [u8; 64] {
847    let mut hasher = Sha3_512::new();
848    hasher.write(data);
849    hasher.sum()
850}
851
852#[inline]
853fn rkprf(cipher_key: &[u8; 32], ciphertext: &[u8]) -> [u8; 32] {
854    let mut shake = Shake256::new();
855    shake.absorb(cipher_key);
856    shake.absorb(ciphertext);
857    let mut out = [0u8; 32];
858    shake.squeeze(&mut out);
859    out
860}
861
862/// Constant-time conditional move: if `cond` is true, copies `value` into `out`.
863/// Uses a compiler barrier on the mask to prevent the optimizer from turning this
864/// into a branch (which would leak timing information in the FO transform).
865#[inline]
866fn cmov(out: &mut [u8; 32], value: &[u8; 32], cond: bool) {
867    let mask = ct_mask_u8(cond);
868    for i in 0..32 {
869        out[i] ^= mask & (out[i] ^ value[i]);
870    }
871}
872
873/// Converts a boolean condition to a constant-time mask (0x00 or 0xFF) with a compiler
874/// barrier to prevent optimization into a branch.
875#[inline]
876fn ct_mask_u8(cond: bool) -> u8 {
877    let mask = 0u8.wrapping_sub(cond as u8);
878    // Prevent the compiler from reasoning about the mask value and potentially
879    // converting downstream code into a conditional branch.
880    ct_barrier_u8(mask)
881}
882
883#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
884#[inline]
885fn ct_barrier_u8(mut value: u8) -> u8 {
886    // SAFETY: the inline asm is a no-op that forces the compiler to treat `value`
887    // as an opaque value, preventing branch-based optimizations.
888    unsafe {
889        core::arch::asm!("/* {0} */", inout(reg_byte) value, options(pure, nomem, nostack, preserves_flags));
890    }
891    value
892}
893
894#[cfg(any(
895    target_arch = "aarch64",
896    target_arch = "arm",
897    target_arch = "riscv32",
898    target_arch = "riscv64"
899))]
900#[inline]
901#[allow(asm_sub_register)]
902fn ct_barrier_u8(mut value: u8) -> u8 {
903    unsafe {
904        core::arch::asm!("/* {0} */", inout(reg) value, options(pure, nomem, nostack, preserves_flags));
905    }
906    value
907}
908
909#[cfg(not(any(
910    target_arch = "x86",
911    target_arch = "x86_64",
912    target_arch = "aarch64",
913    target_arch = "arm",
914    target_arch = "riscv32",
915    target_arch = "riscv64"
916)))]
917#[inline(never)]
918fn ct_barrier_u8(value: u8) -> u8 {
919    core::hint::black_box(value)
920}
921
922#[inline]
923fn load32(input: &[u8]) -> u32 {
924    (input[0] as u32) | ((input[1] as u32) << 8) | ((input[2] as u32) << 16) | ((input[3] as u32) << 24)
925}
926
927#[inline]
928pub(crate) fn public_key_bytes<const K: usize>() -> usize {
929    polyvec_bytes::<K>() + SYMBYTES
930}
931
932#[inline]
933pub(crate) fn indcpa_secret_key_bytes<const K: usize>() -> usize {
934    polyvec_bytes::<K>()
935}
936
937#[inline]
938fn polyvec_bytes<const K: usize>() -> usize {
939    K * POLY_BYTES
940}
941
942#[inline]
943pub(crate) fn secret_key_size<const K: usize>() -> usize {
944    indcpa_secret_key_bytes::<K>() + public_key_bytes::<K>() + 2 * SYMBYTES
945}
946
947#[inline]
948fn ciphertext_bytes<const K: usize>(params: &MlKemParams<K>) -> usize {
949    params.polyveccompressedbytes + params.polycompressedbytes
950}
951
952#[inline]
953fn array_ref_32(input: &[u8]) -> &[u8; 32] {
954    input.try_into().expect("slice length should be 32")
955}
956
957#[cfg(test)]
958pub(crate) fn decode_hex_array<const N: usize>(s: &str) -> [u8; N] {
959    let bytes = hex::decode(s).expect("valid hex");
960    assert_eq!(bytes.len(), N);
961    let mut out = [0u8; N];
962    out.copy_from_slice(&bytes);
963    out
964}
965
966#[cfg(test)]
967pub(crate) fn sha3_256_hex(data: &[u8]) -> String {
968    let mut hasher = Sha3_256::new();
969    hasher.write(data);
970    hex::encode(hasher.sum())
971}
972
973#[cfg(test)]
974mod tests {
975    use super::*;
976
977    #[test]
978    fn poly_frommsg_tomsg_roundtrip() {
979        for pattern in 0..=255u16 {
980            let mut msg = [0u8; 32];
981            msg[0] = pattern as u8;
982            msg[1] = (pattern >> 8) as u8;
983            let poly = poly_frommsg(&msg);
984            let recovered = poly_tomsg(&poly);
985            assert_eq!(msg, recovered, "roundtrip failed for pattern {pattern:#06x}");
986        }
987    }
988
989    #[test]
990    fn poly_frommsg_constant_time_produces_expected_values() {
991        let half_q = ((Q + 1) / 2) as i16;
992        let mut msg = [0u8; 32];
993        msg[0] = 0b1010_1010;
994        msg[1] = 0b0101_0101;
995        let poly = poly_frommsg(&msg);
996        assert_eq!(poly.coeffs[0], 0);
997        assert_eq!(poly.coeffs[1], half_q);
998        assert_eq!(poly.coeffs[2], 0);
999        assert_eq!(poly.coeffs[3], half_q);
1000        assert_eq!(poly.coeffs[8], half_q);
1001        assert_eq!(poly.coeffs[9], 0);
1002        assert_eq!(poly.coeffs[10], half_q);
1003        assert_eq!(poly.coeffs[11], 0);
1004    }
1005
1006    #[test]
1007    fn cmov_selects_correctly() {
1008        let mut out = [0xAAu8; 32];
1009        let value = [0xBBu8; 32];
1010        cmov(&mut out, &value, false);
1011        assert_eq!(out, [0xAAu8; 32], "cmov with false should not modify output");
1012
1013        cmov(&mut out, &value, true);
1014        assert_eq!(out, [0xBBu8; 32], "cmov with true should copy value");
1015    }
1016
1017    #[test]
1018    fn cmov_is_idempotent() {
1019        let mut out = [0x42u8; 32];
1020        let value = [0x42u8; 32];
1021        cmov(&mut out, &value, true);
1022        assert_eq!(out, [0x42u8; 32]);
1023        cmov(&mut out, &value, false);
1024        assert_eq!(out, [0x42u8; 32]);
1025    }
1026
1027    #[test]
1028    fn barrett_reduce_produces_values_in_range() {
1029        // Barrett reduce should map any i16 to the range [-(Q-1)/2, (Q-1)/2] approximately
1030        for val in [0i16, 1, -1, Q - 1, -(Q - 1), Q, -Q, 3000, -3000, i16::MAX, i16::MIN] {
1031            let reduced = barrett_reduce(val);
1032            // The reduced value should be congruent to val mod Q
1033            let diff = (val as i32 - reduced as i32).rem_euclid(Q as i32);
1034            assert!(diff == 0, "barrett_reduce({val}) = {reduced} not congruent mod Q");
1035        }
1036    }
1037
1038    #[test]
1039    fn montgomery_reduce_correctness() {
1040        // Montgomery reduce: given a, return a * R^(-1) mod Q where R = 2^16
1041        // Verify: montgomery_reduce(a * R) == a mod Q for small a
1042        let r_mod_q: i32 = (1i32 << 16) % Q as i32; // R mod Q = 65536 mod 3329 = 2285
1043        for val in [0i16, 1, -1, 100, -100, Q - 1, -(Q - 1)] {
1044            let product = val as i32 * r_mod_q;
1045            let result = montgomery_reduce(product);
1046            // result should be congruent to val mod Q
1047            let diff = (val as i32 - result as i32).rem_euclid(Q as i32);
1048            assert!(
1049                diff == 0,
1050                "montgomery_reduce({val} * R) = {result}, expected congruent to {val} mod Q"
1051            );
1052        }
1053    }
1054
1055    #[test]
1056    fn ntt_invntt_preserves_polynomial_structure() {
1057        // NTT->InvNTT roundtrip preserves polynomial relationships.
1058        // The full KEM roundtrip tests already validate NTT correctness,
1059        // but this verifies that two distinct inputs remain distinct after transform.
1060        let mut poly_a = Poly::default();
1061        let mut poly_b = Poly::default();
1062        for i in 0..N {
1063            poly_a.coeffs[i] = (i as i16 * 7 + 3) % Q;
1064            poly_b.coeffs[i] = (i as i16 * 11 + 5) % Q;
1065        }
1066        poly_ntt(&mut poly_a);
1067        poly_ntt(&mut poly_b);
1068        // NTT outputs should be different for different inputs
1069        assert_ne!(poly_a.coeffs, poly_b.coeffs);
1070
1071        poly_invntt_tomont(&mut poly_a);
1072        poly_invntt_tomont(&mut poly_b);
1073        // After roundtrip, they should still be different
1074        assert_ne!(poly_a.coeffs, poly_b.coeffs);
1075    }
1076
1077    #[test]
1078    fn poly_compress_decompress_roundtrip_4bit() {
1079        // For 4-bit compression (ML-KEM-768)
1080        let params = &ML_KEM_768;
1081        let mut poly = Poly::default();
1082        for i in 0..N {
1083            poly.coeffs[i] = ((i * 13) % Q as usize) as i16;
1084        }
1085        let mut compressed = [0u8; 128];
1086        poly_compress::<3>(params, &mut compressed, &poly);
1087        let decompressed = poly_decompress::<3>(params, &compressed);
1088        // Compression is lossy but within rounding error
1089        for i in 0..N {
1090            let orig = poly.coeffs[i] as i32;
1091            let dec = decompressed.coeffs[i] as i32;
1092            // Maximum rounding error for d-bit compression: Q / (2^(d+1))
1093            // For 4 bits: Q/32 ≈ 104
1094            let error = ((orig - dec).rem_euclid(Q as i32)).min(((dec - orig).rem_euclid(Q as i32)));
1095            assert!(
1096                error <= Q as i32 / 32 + 1,
1097                "4-bit compress/decompress error too large at index {i}: orig={orig}, dec={dec}, error={error}"
1098            );
1099        }
1100    }
1101
1102    #[test]
1103    fn poly_compress_decompress_roundtrip_5bit() {
1104        // For 5-bit compression (ML-KEM-1024)
1105        let params = &ML_KEM_1024;
1106        let mut poly = Poly::default();
1107        for i in 0..N {
1108            poly.coeffs[i] = ((i * 13) % Q as usize) as i16;
1109        }
1110        let mut compressed = [0u8; 160];
1111        poly_compress::<4>(params, &mut compressed, &poly);
1112        let decompressed = poly_decompress::<4>(params, &compressed);
1113        for i in 0..N {
1114            let orig = poly.coeffs[i] as i32;
1115            let dec = decompressed.coeffs[i] as i32;
1116            let error = ((orig - dec).rem_euclid(Q as i32)).min(((dec - orig).rem_euclid(Q as i32)));
1117            assert!(
1118                error <= Q as i32 / 64 + 1,
1119                "5-bit compress/decompress error too large at index {i}: orig={orig}, dec={dec}, error={error}"
1120            );
1121        }
1122    }
1123
1124    #[test]
1125    fn polyvec_compress_decompress_roundtrip_10bit() {
1126        let params = &ML_KEM_768;
1127        let mut pv = PolyVec::<3>::default();
1128        for k in 0..3 {
1129            for i in 0..N {
1130                pv.vec[k].coeffs[i] = ((k * 97 + i * 13) % Q as usize) as i16;
1131            }
1132        }
1133        let mut compressed = [0u8; 960];
1134        polyvec_compress(params, &mut compressed, &pv);
1135        let decompressed = polyvec_decompress::<3>(params, &compressed);
1136        for k in 0..3 {
1137            for i in 0..N {
1138                let orig = pv.vec[k].coeffs[i] as i32;
1139                let dec = decompressed.vec[k].coeffs[i] as i32;
1140                let error = ((orig - dec).rem_euclid(Q as i32)).min(((dec - orig).rem_euclid(Q as i32)));
1141                assert!(
1142                    error <= Q as i32 / 2048 + 1,
1143                    "10-bit compress/decompress error at [{k}][{i}]: orig={orig}, dec={dec}, error={error}"
1144                );
1145            }
1146        }
1147    }
1148
1149    #[test]
1150    fn polyvec_compress_decompress_roundtrip_11bit() {
1151        let params = &ML_KEM_1024;
1152        let mut pv = PolyVec::<4>::default();
1153        for k in 0..4 {
1154            for i in 0..N {
1155                pv.vec[k].coeffs[i] = ((k * 97 + i * 13) % Q as usize) as i16;
1156            }
1157        }
1158        let mut compressed = [0u8; 1408];
1159        polyvec_compress(params, &mut compressed, &pv);
1160        let decompressed = polyvec_decompress::<4>(params, &compressed);
1161        for k in 0..4 {
1162            for i in 0..N {
1163                let orig = pv.vec[k].coeffs[i] as i32;
1164                let dec = decompressed.vec[k].coeffs[i] as i32;
1165                let error = ((orig - dec).rem_euclid(Q as i32)).min(((dec - orig).rem_euclid(Q as i32)));
1166                assert!(
1167                    error <= Q as i32 / 4096 + 1,
1168                    "11-bit compress/decompress error at [{k}][{i}]: orig={orig}, dec={dec}, error={error}"
1169                );
1170            }
1171        }
1172    }
1173
1174    #[test]
1175    fn poly_tobytes_frombytes_roundtrip() {
1176        let mut poly = Poly::default();
1177        for i in 0..N {
1178            poly.coeffs[i] = (i as i16 * 13) % Q;
1179        }
1180        let mut buf = [0u8; POLY_BYTES];
1181        poly_tobytes(&mut buf, &poly);
1182        let recovered = poly_frombytes(&buf);
1183        assert_eq!(poly.coeffs, recovered.coeffs);
1184    }
1185
1186    #[test]
1187    fn gen_matrix_transpose_relationship() {
1188        let seed = [42u8; 32];
1189        let matrix = gen_matrix::<3>(&seed, false);
1190        let transposed = gen_matrix::<3>(&seed, true);
1191        for i in 0..3 {
1192            for j in 0..3 {
1193                assert_eq!(
1194                    matrix[i].vec[j].coeffs, transposed[j].vec[i].coeffs,
1195                    "A[{i}][{j}] != A^T[{j}][{i}]"
1196                );
1197            }
1198        }
1199    }
1200
1201    #[test]
1202    fn cbd2_produces_values_in_correct_range() {
1203        // CBD with eta=2 should produce coefficients in [-2, 2]
1204        let mut buf = [0u8; 128];
1205        for i in 0..128 {
1206            buf[i] = (i as u8).wrapping_mul(0x37);
1207        }
1208        let poly = cbd2(&buf);
1209        for (i, &coeff) in poly.coeffs.iter().enumerate() {
1210            assert!((-2..=2).contains(&coeff), "CBD2 coeff[{i}] = {coeff} out of range [-2, 2]");
1211        }
1212    }
1213
1214    #[test]
1215    fn rej_uniform_only_accepts_values_less_than_q() {
1216        // Craft input where val0 = Q (3329 = 0xD01) should be rejected
1217        // rej_uniform parses 3 bytes into 2 12-bit values:
1218        // val0 = (buf[0] | buf[1]<<8) & 0x0fff
1219        // val1 = ((buf[1]>>4) | buf[2]<<4) & 0x0fff
1220        let buf = [
1221            0x01, 0x0D,
1222            0x00, // val0 = 0xD01 = 3329 = Q (rejected), val1 = (0x0D>>4 | 0x00<<4) & 0xfff = 0 (accepted)
1223            0x00, 0x0D,
1224            0xD0, // val0 = 0xD00 = 3328 (accepted), val1 = (0x0D>>4 | 0xD0<<4) & 0xfff = 0xD00 = 3328 (accepted)
1225        ];
1226        let mut out = [0i16; 256];
1227        let count = rej_uniform(&mut out, &buf);
1228        // val0=Q rejected, val1=0 accepted, val0=3328 accepted, val1=3328 accepted
1229        assert_eq!(count, 3);
1230        assert_eq!(out[0], 0); // first accepted: val1 from first triple
1231        assert_eq!(out[1], 3328); // second accepted: val0 from second triple
1232        assert_eq!(out[2], 3328); // third accepted: val1 from second triple
1233    }
1234
1235    #[test]
1236    fn nist_acvp_ml_kem_768_full_vector() {
1237        // Verify against NIST FIPS 203 intermediate test vector (ML-KEM-768.txt)
1238        // These values come from the NIST test file and are validated by the CCTV tests
1239        let d: [u8; 32] = decode_hex_array("f688563f7c66a5da2d8bdb5a5f3e07bd8dce6f7efcec7f41298d79863459f7cd");
1240        let z: [u8; 32] = decode_hex_array("d1d49a515250dbceb9f6e3fcc1c7d5306918964b21ddb22207e03e57f0600da8");
1241        let m: [u8; 32] = decode_hex_array("3dc27ca0a6594b0e56320457c45a0f76bb8a213ea4a76d442186a0aefadbcdb9");
1242
1243        let mut coins = [0u8; 64];
1244        coins[..32].copy_from_slice(&d);
1245        coins[32..].copy_from_slice(&z);
1246
1247        let (dk, ek) = crypto_kem_keypair_derand::<3, 2400, 1184>(&ML_KEM_768, &coins);
1248        let (ct, k) = crypto_kem_enc_derand::<3, 1184, 1088>(&ML_KEM_768, &ek, &m);
1249
1250        // Verify public key hash matches NIST vector
1251        assert_eq!(
1252            sha3_256_hex(&ek),
1253            "42d930a50dfd1f0541ca45c4598daebb4f51cd10d711a001bd9bb87d5c87a4bf"
1254        );
1255        // Verify secret key hash
1256        assert_eq!(
1257            sha3_256_hex(&dk),
1258            "db563aebd9fdc875e88563693edad1e5e359cc37b0f685d2d0a3723b37253192"
1259        );
1260        // Verify ciphertext hash
1261        assert_eq!(
1262            sha3_256_hex(&ct),
1263            "9d6e358208c4d583050becb319050b7f916de47caad1d589a1d01fea43fe1750"
1264        );
1265        // Verify shared secret
1266        assert_eq!(
1267            hex::encode(k),
1268            "ae726da2df66601c6648a7565c02b203a089276ac30f6cc226d048f93fafd78c"
1269        );
1270
1271        // Verify decapsulation produces the same shared secret
1272        let k_dec = crypto_kem_dec::<3, 2400, 1088>(&ML_KEM_768, &dk, &ct).unwrap();
1273        assert_eq!(k, k_dec, "decapsulation mismatch against NIST vector");
1274    }
1275
1276    #[test]
1277    fn nist_acvp_ml_kem_1024_full_vector() {
1278        // Verify against NIST FIPS 203 intermediate test vector (ML-KEM-1024.txt)
1279        let d: [u8; 32] = decode_hex_array("2a62c39ef4fc499f2d132716f480bb7521a49558ae84ee80d9352e66daf1e3a8");
1280        let z: [u8; 32] = decode_hex_array("5f574ef7f013d4336801fed022178c3ed91d0b6d51325315fc1dcabf4770a2ea");
1281        let m: [u8; 32] = decode_hex_array("e07d685ed308e609c9c7842026e35732f6ffc6e2fee10f0afd348f2b42a8acb4");
1282
1283        let mut coins = [0u8; 64];
1284        coins[..32].copy_from_slice(&d);
1285        coins[32..].copy_from_slice(&z);
1286
1287        let (dk, ek) = crypto_kem_keypair_derand::<4, 3168, 1568>(&ML_KEM_1024, &coins);
1288        let (ct, k) = crypto_kem_enc_derand::<4, 1568, 1568>(&ML_KEM_1024, &ek, &m);
1289
1290        assert_eq!(
1291            sha3_256_hex(&ek),
1292            "3b308d1344ed70366b84d790acb705b86cd3dfd471fff171969aaa338f26dca5"
1293        );
1294        assert_eq!(
1295            sha3_256_hex(&dk),
1296            "aa63a9e0c035ada6635e7938b71856b24917ff9b3ebca1a4d205a83b502a415a"
1297        );
1298        assert_eq!(
1299            sha3_256_hex(&ct),
1300            "8caba02733421f12a7ba9a2bcbe4de7c9853156a0637df5a7a0f9127c81da943"
1301        );
1302        assert_eq!(
1303            hex::encode(k),
1304            "d53825c3ff666bb2881215dbec04a8bdce9099b2a3680938c2f199b54d505953"
1305        );
1306
1307        let k_dec = crypto_kem_dec::<4, 3168, 1568>(&ML_KEM_1024, &dk, &ct).unwrap();
1308        assert_eq!(k, k_dec, "decapsulation mismatch against NIST vector");
1309    }
1310
1311    #[test]
1312    fn compression_constant_time_no_division() {
1313        // Verify that the compression constants avoid division at runtime.
1314        // This test exercises boundary values where a naive division would
1315        // produce different rounding behavior than the multiplication trick.
1316        let params_768 = &ML_KEM_768;
1317        let params_1024 = &ML_KEM_1024;
1318
1319        // Test boundary values for poly_compress (4-bit)
1320        let mut poly = Poly::default();
1321        poly.coeffs[0] = 0;
1322        poly.coeffs[1] = (Q - 1) as i16;
1323        poly.coeffs[2] = (Q / 2) as i16;
1324        poly.coeffs[3] = (Q / 2 + 1) as i16;
1325        let mut buf4 = [0u8; 128];
1326        poly_compress::<3>(params_768, &mut buf4, &poly);
1327        let dec = poly_decompress::<3>(params_768, &buf4);
1328        // Verify round-trip for boundary values
1329        assert_eq!(dec.coeffs[0], 0); // 0 should compress/decompress to 0
1330
1331        // Test boundary values for poly_compress (5-bit)
1332        let mut buf5 = [0u8; 160];
1333        poly_compress::<4>(params_1024, &mut buf5, &poly);
1334        let dec5 = poly_decompress::<4>(params_1024, &buf5);
1335        assert_eq!(dec5.coeffs[0], 0);
1336    }
1337
1338    #[test]
1339    fn poly_tomsg_boundary_values() {
1340        // Test poly_tomsg at the decision boundary: Q/4 and 3Q/4
1341        let mut poly = Poly::default();
1342        // Value 0 should produce bit 0
1343        poly.coeffs[0] = 0;
1344        // Value Q/2 (1665) should produce bit 1
1345        poly.coeffs[1] = (Q / 2) as i16;
1346        // Value Q/4 (832) is at the boundary
1347        poly.coeffs[2] = (Q / 4) as i16;
1348        // Value 3Q/4 (2497) is at the other boundary
1349        poly.coeffs[3] = (3 * Q as i32 / 4) as i16;
1350
1351        let msg = poly_tomsg(&poly);
1352        // bit 0: value 0 -> 0
1353        assert_eq!(msg[0] & 1, 0);
1354        // bit 1: value Q/2 -> 1
1355        assert_eq!((msg[0] >> 1) & 1, 1);
1356    }
1357}