1use constant_time_eq::constant_time_eq;
2#[cfg(feature = "zeroize")]
3use zeroize::{Zeroize, ZeroizeOnDrop};
4
5use crate::{
6 Xof,
7 sha3::{Shake128, Shake256},
8};
9
10pub const ML_DSA_65_PUBLIC_KEY_SIZE: usize = 1952;
11pub const ML_DSA_65_SIGNATURE_SIZE: usize = 3309;
12pub const ML_DSA_65_SEED_SIZE: usize = 32;
13pub const ML_DSA_65_CONTEXT_MAX_LEN: usize = 255;
14
15const Q: u32 = 8380417;
16const N: usize = 256;
17const D: u32 = 13;
18const ONE: u32 = 4193792;
19const MINUS_ONE: u32 = 4186625;
20const RR: u32 = 2365951;
21const QINV: u32 = 4236238847;
22const N_INV: u32 = 16382;
23const GAMMA1: u32 = 1 << 19;
24const GAMMA2: u32 = (Q - 1) / 32;
25const BETA: u32 = 196;
26const TAU: usize = 49;
27const LAMBDA_OVER_4: usize = 48;
28const POLYZ_BYTES: usize = (19 + 1) * N / 8;
29const K: usize = 6;
30const L: usize = 5;
31const OMEGA: usize = 55;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum MlDsaError {
35 ContextTooLong,
36 InvalidSignature,
37 InvalidPublicKey,
38 InvalidSignatureLength,
39}
40
41#[cfg(feature = "alloc")]
42impl core::fmt::Display for MlDsaError {
43 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
44 match self {
45 MlDsaError::ContextTooLong => write!(f, "context length exceeds 255 bytes"),
46 MlDsaError::InvalidSignature => write!(f, "signature is not valid"),
47 MlDsaError::InvalidPublicKey => write!(f, "public key is not valid"),
48 MlDsaError::InvalidSignatureLength => write!(f, "signature length is not valid"),
49 }
50 }
51}
52
53type FieldElement = u32;
54
55fn field_to_montgomery(a: u32) -> FieldElement {
56 debug_assert!(a < Q);
57 field_montgomery_mul(a, RR)
58}
59
60fn field_from_montgomery(a: FieldElement) -> u32 {
61 field_montgomery_reduce(a as u64)
62}
63
64fn field_montgomery_reduce(x: u64) -> u32 {
65 let t = (x as u32).wrapping_mul(QINV);
66 let u = (x + (t as u64) * (Q as u64)) >> 32;
67 field_reduce_once(u as u32)
68}
69
70fn field_montgomery_mul(a: FieldElement, b: FieldElement) -> FieldElement {
71 field_montgomery_reduce(a as u64 * b as u64)
72}
73
74fn field_reduce_once(x: u32) -> FieldElement {
75 let t = x.wrapping_sub(Q);
76 let mask = ((t as i32) >> 31) as u32;
77 t.wrapping_add(Q & mask)
78}
79
80fn field_add(a: FieldElement, b: FieldElement) -> FieldElement {
81 field_reduce_once(a.wrapping_add(b))
82}
83
84fn field_sub(a: FieldElement, b: FieldElement) -> FieldElement {
85 field_reduce_once(a.wrapping_sub(b).wrapping_add(Q))
86}
87
88fn field_sub_to_montgomery(a: u32, b: u32) -> FieldElement {
89 let x = a.wrapping_sub(b).wrapping_add(Q);
90 field_montgomery_mul(x, RR)
91}
92
93fn field_infinity_norm(r: FieldElement) -> u32 {
94 let x = field_from_montgomery(r);
95 let q_minus_x = Q - x;
96 let half_q = Q / 2;
97 let mask = ((half_q.wrapping_sub(x)) as i32 >> 31) as u32;
98 (mask & q_minus_x) | (!mask & x)
99}
100
101fn field_centered_mod(r: FieldElement) -> i32 {
102 let x = field_from_montgomery(r);
103 let x = x as i32;
104 let half_q = (Q / 2) as i32;
105 let mask = ((half_q - x) >> 31) as i32;
106 (mask & (x - Q as i32)) | (!mask & x)
107}
108
109fn power2round(r: FieldElement) -> (u16, FieldElement) {
110 let rr = field_from_montgomery(r);
111 let r1 = (rr + (1 << 12) - 1) >> 13;
112 let r0 = field_sub_to_montgomery(rr, r1 << 13);
113 (r1 as u16, r0)
114}
115
116fn highbits32(x: u32) -> u8 {
117 let r1 = (x + 127) >> 7;
118 let r1 = (r1 * 1025 + (1 << 21)) >> 22;
119 (r1 & 0b1111) as u8
120}
121
122fn decompose32(r: FieldElement) -> (u8, i32) {
123 let x = field_from_montgomery(r) as i32;
124 let r1 = highbits32(x as u32);
125 let r0 = x - (r1 as i32) * 2 * (Q as i32 - 1) / 32;
126 let half_q = (Q / 2) as i32;
127 let mask = ((half_q - r0) >> 31) as i32;
128 let r0 = (mask & (r0 - Q as i32)) | (!mask & r0);
129 (r1, r0)
130}
131
132fn make_hint32(ct0: FieldElement, w: FieldElement, cs2: FieldElement) -> u8 {
133 let r_plus_z = field_sub(w, cs2);
134 let v1 = highbits32(field_from_montgomery(r_plus_z));
135 let r = field_add(r_plus_z, ct0);
136 let r1 = highbits32(field_from_montgomery(r));
137 (v1 ^ r1) as u8 & 1u8
138}
139
140fn use_hint32(r: FieldElement, hint: u8) -> u8 {
141 let (r1, r0) = decompose32(r);
142 if hint == 0 {
143 return r1;
144 }
145 let r0_gt_0 = !(r0.wrapping_sub(1) >> 31) as u8;
146 let r1_plus = r1.wrapping_add(1) & 0x0F;
147 let r1_minus = r1.wrapping_sub(1) & 0x0F;
148 (r0_gt_0 & r1_plus) | ((!r0_gt_0) & r1_minus)
149}
150
151#[derive(Clone, Debug, PartialEq, Eq)]
152#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
153struct Poly {
154 coeffs: [FieldElement; N],
155}
156
157impl Default for Poly {
158 fn default() -> Self {
159 Self {
160 coeffs: [0u32; N],
161 }
162 }
163}
164
165#[derive(Clone, Debug, PartialEq, Eq)]
166#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
167struct NttPoly {
168 coeffs: [FieldElement; N],
169}
170
171impl Default for NttPoly {
172 fn default() -> Self {
173 Self {
174 coeffs: [0u32; N],
175 }
176 }
177}
178
179fn poly_add(a: &Poly, b: &Poly) -> Poly {
180 let mut r = Poly::default();
181 for i in 0..N {
182 r.coeffs[i] = field_add(a.coeffs[i], b.coeffs[i]);
183 }
184 r
185}
186
187fn poly_sub(a: &Poly, b: &Poly) -> Poly {
188 let mut r = Poly::default();
189 for i in 0..N {
190 r.coeffs[i] = field_sub(a.coeffs[i], b.coeffs[i]);
191 }
192 r
193}
194
195fn ntt_add(a: &NttPoly, b: &NttPoly) -> NttPoly {
196 let mut r = NttPoly::default();
197 for i in 0..N {
198 r.coeffs[i] = field_add(a.coeffs[i], b.coeffs[i]);
199 }
200 r
201}
202
203fn ntt_sub(a: &NttPoly, b: &NttPoly) -> NttPoly {
204 let mut r = NttPoly::default();
205 for i in 0..N {
206 r.coeffs[i] = field_sub(a.coeffs[i], b.coeffs[i]);
207 }
208 r
209}
210
211fn ntt_mul(a: &NttPoly, b: &NttPoly) -> NttPoly {
212 let mut r = NttPoly::default();
213 for i in 0..N {
214 r.coeffs[i] = field_montgomery_mul(a.coeffs[i], b.coeffs[i]);
215 }
216 r
217}
218
219const ZETAS: [FieldElement; 256] = [
220 4193792, 25847, 5771523, 7861508, 237124, 7602457, 7504169, 466468, 1826347, 2353451, 8021166, 6288512, 3119733,
221 5495562, 3111497, 2680103, 2725464, 1024112, 7300517, 3585928, 7830929, 7260833, 2619752, 6271868, 6262231,
222 4520680, 6980856, 5102745, 1757237, 8360995, 4010497, 280005, 2706023, 95776, 3077325, 3530437, 6718724, 4788269,
223 5842901, 3915439, 4519302, 5336701, 3574422, 5512770, 3539968, 8079950, 2348700, 7841118, 6681150, 6736599,
224 3505694, 4558682, 3507263, 6239768, 6779997, 3699596, 811944, 531354, 954230, 3881043, 3900724, 5823537, 2071892,
225 5582638, 4450022, 6851714, 4702672, 5339162, 6927966, 3475950, 2176455, 6795196, 7122806, 1939314, 4296819,
226 7380215, 5190273, 5223087, 4747489, 126922, 3412210, 7396998, 2147896, 2715295, 5412772, 4686924, 7969390, 5903370,
227 7709315, 7151892, 8357436, 7072248, 7998430, 1349076, 1852771, 6949987, 5037034, 264944, 508951, 3097992, 44288,
228 7280319, 904516, 3958618, 4656075, 8371839, 1653064, 5130689, 2389356, 8169440, 759969, 7063561, 189548, 4827145,
229 3159746, 6529015, 5971092, 8202977, 1315589, 1341330, 1285669, 6795489, 7567685, 6940675, 5361315, 4499357,
230 4751448, 3839961, 2091667, 3407706, 2316500, 3817976, 5037939, 2244091, 5933984, 4817955, 266997, 2434439, 7144689,
231 3513181, 4860065, 4621053, 7183191, 5187039, 900702, 1859098, 909542, 819034, 495491, 6767243, 8337157, 7857917,
232 7725090, 5257975, 2031748, 3207046, 4823422, 7855319, 7611795, 4784579, 342297, 286988, 5942594, 4108315, 3437287,
233 5038140, 1735879, 203044, 2842341, 2691481, 5790267, 1265009, 4055324, 1247620, 2486353, 1595974, 4613401, 1250494,
234 2635921, 4832145, 5386378, 1869119, 1903435, 7329447, 7047359, 1237275, 5062207, 6950192, 7929317, 1312455,
235 3306115, 6417775, 7100756, 1917081, 5834105, 7005614, 1500165, 777191, 2235880, 3406031, 7838005, 5548557, 6709241,
236 6533464, 5796124, 4656147, 594136, 4603424, 6366809, 2432395, 2454455, 8215696, 1957272, 3369112, 185531, 7173032,
237 5196991, 162844, 1616392, 3014001, 810149, 1652634, 4686184, 6581310, 5341501, 3523897, 3866901, 269760, 2213111,
238 7404533, 1717735, 472078, 7953734, 1723600, 6577327, 1910376, 6712985, 7276084, 8119771, 4546524, 5441381, 6144432,
239 7959518, 6094090, 183443, 7403526, 1612842, 4834730, 7826001, 3919660, 8332111, 7018208, 3937738, 1400424, 7534263,
240 1976782,
241];
242
243fn ntt(f: &Poly) -> NttPoly {
244 let mut f = NttPoly {
245 coeffs: f.coeffs,
246 };
247 let mut m: usize = 0;
248
249 let mut len: usize = 128;
250 while len >= 8 {
251 let mut start: usize = 0;
252 while start < N {
253 m += 1;
254 let zeta = ZETAS[m];
255 let mid = start + len;
256 for j in (start..mid).step_by(2) {
257 let t = field_montgomery_mul(zeta, f.coeffs[j + len]);
258 f.coeffs[j + len] = field_sub(f.coeffs[j], t);
259 f.coeffs[j] = field_add(f.coeffs[j], t);
260 let t = field_montgomery_mul(zeta, f.coeffs[j + len + 1]);
261 f.coeffs[j + len + 1] = field_sub(f.coeffs[j + 1], t);
262 f.coeffs[j + 1] = field_add(f.coeffs[j + 1], t);
263 }
264 start += 2 * len;
265 }
266 len /= 2;
267 }
268
269 let mut start: usize = 0;
270 while start < N {
271 m += 1;
272 let zeta = ZETAS[m];
273 let t = field_montgomery_mul(zeta, f.coeffs[start + 4]);
274 f.coeffs[start + 4] = field_sub(f.coeffs[start], t);
275 f.coeffs[start] = field_add(f.coeffs[start], t);
276 let t = field_montgomery_mul(zeta, f.coeffs[start + 5]);
277 f.coeffs[start + 5] = field_sub(f.coeffs[start + 1], t);
278 f.coeffs[start + 1] = field_add(f.coeffs[start + 1], t);
279 let t = field_montgomery_mul(zeta, f.coeffs[start + 6]);
280 f.coeffs[start + 6] = field_sub(f.coeffs[start + 2], t);
281 f.coeffs[start + 2] = field_add(f.coeffs[start + 2], t);
282 let t = field_montgomery_mul(zeta, f.coeffs[start + 7]);
283 f.coeffs[start + 7] = field_sub(f.coeffs[start + 3], t);
284 f.coeffs[start + 3] = field_add(f.coeffs[start + 3], t);
285 start += 8;
286 }
287
288 start = 0;
289 while start < N {
290 m += 1;
291 let zeta = ZETAS[m];
292 let t = field_montgomery_mul(zeta, f.coeffs[start + 2]);
293 f.coeffs[start + 2] = field_sub(f.coeffs[start], t);
294 f.coeffs[start] = field_add(f.coeffs[start], t);
295 let t = field_montgomery_mul(zeta, f.coeffs[start + 3]);
296 f.coeffs[start + 3] = field_sub(f.coeffs[start + 1], t);
297 f.coeffs[start + 1] = field_add(f.coeffs[start + 1], t);
298 start += 4;
299 }
300
301 start = 0;
302 while start < N {
303 m += 1;
304 let zeta = ZETAS[m];
305 let t = field_montgomery_mul(zeta, f.coeffs[start + 1]);
306 f.coeffs[start + 1] = field_sub(f.coeffs[start], t);
307 f.coeffs[start] = field_add(f.coeffs[start], t);
308 start += 2;
309 }
310
311 f
312}
313
314fn invntt(f: &NttPoly) -> Poly {
315 let mut f = NttPoly {
316 coeffs: f.coeffs,
317 };
318 let mut m: usize = 255;
319
320 let mut start: usize = 0;
321 while start < N {
322 let zeta = ZETAS[m];
323 m -= 1;
324 let t = f.coeffs[start];
325 f.coeffs[start] = field_add(t, f.coeffs[start + 1]);
326 f.coeffs[start + 1] = field_montgomery_mul(zeta, field_sub(f.coeffs[start + 1], t));
327 start += 2;
328 }
329
330 start = 0;
331 while start < N {
332 let zeta = ZETAS[m];
333 m -= 1;
334 let t = f.coeffs[start];
335 f.coeffs[start] = field_add(t, f.coeffs[start + 2]);
336 f.coeffs[start + 2] = field_montgomery_mul(zeta, field_sub(f.coeffs[start + 2], t));
337 let t = f.coeffs[start + 1];
338 f.coeffs[start + 1] = field_add(t, f.coeffs[start + 3]);
339 f.coeffs[start + 3] = field_montgomery_mul(zeta, field_sub(f.coeffs[start + 3], t));
340 start += 4;
341 }
342
343 start = 0;
344 while start < N {
345 let zeta = ZETAS[m];
346 m -= 1;
347 let t = f.coeffs[start];
348 f.coeffs[start] = field_add(t, f.coeffs[start + 4]);
349 f.coeffs[start + 4] = field_montgomery_mul(zeta, field_sub(f.coeffs[start + 4], t));
350 let t = f.coeffs[start + 1];
351 f.coeffs[start + 1] = field_add(t, f.coeffs[start + 5]);
352 f.coeffs[start + 5] = field_montgomery_mul(zeta, field_sub(f.coeffs[start + 5], t));
353 let t = f.coeffs[start + 2];
354 f.coeffs[start + 2] = field_add(t, f.coeffs[start + 6]);
355 f.coeffs[start + 6] = field_montgomery_mul(zeta, field_sub(f.coeffs[start + 6], t));
356 let t = f.coeffs[start + 3];
357 f.coeffs[start + 3] = field_add(t, f.coeffs[start + 7]);
358 f.coeffs[start + 7] = field_montgomery_mul(zeta, field_sub(f.coeffs[start + 7], t));
359 start += 8;
360 }
361
362 let mut len: usize = 8;
363 while len < N {
364 let mut start: usize = 0;
365 while start < N {
366 let zeta = ZETAS[m];
367 m -= 1;
368 let mid = start + len;
369 for j in (start..mid).step_by(2) {
370 let t = f.coeffs[j];
371 f.coeffs[j] = field_add(t, f.coeffs[j + len]);
372 let diff = field_sub(f.coeffs[j + len], t);
373 f.coeffs[j + len] = field_montgomery_mul(zeta, diff);
374 let t = f.coeffs[j + 1];
375 f.coeffs[j + 1] = field_add(t, f.coeffs[j + len + 1]);
376 let diff = field_sub(f.coeffs[j + len + 1], t);
377 f.coeffs[j + len + 1] = field_montgomery_mul(zeta, diff);
378 }
379 start += 2 * len;
380 }
381 len *= 2;
382 }
383
384 let mut r = Poly::default();
385 for i in 0..N {
386 r.coeffs[i] = field_montgomery_mul(f.coeffs[i], N_INV);
387 }
388 r
389}
390
391fn sample_ntt(rho: &[u8; 32], s: u8, r: u8) -> NttPoly {
392 let mut shake = Shake128::new();
393 shake.absorb(rho);
394 shake.absorb(&[s, r]);
395
396 let mut a = NttPoly::default();
397 let mut j: usize = 0;
398 let mut buf = [0u8; 168];
399 let mut off: usize = 168;
400
401 loop {
402 if off >= 168 {
403 shake.squeeze(&mut buf);
404 off = 0;
405 }
406 let v = (buf[off] as u32) | ((buf[off + 1] as u32) << 8) | ((buf[off + 2] as u32) << 16);
407 off += 3;
408 let v = v & 0x7FFFFF;
409 if v < Q {
410 a.coeffs[j] = field_to_montgomery(v);
411 j += 1;
412 if j >= N {
413 break;
414 }
415 }
416 }
417 a
418}
419
420fn sample_bounded_poly(rho: &[u8], r: u8) -> Poly {
421 let mut shake = Shake256::new();
422 shake.absorb(rho);
423 shake.absorb(&[r, 0]);
424
425 let mut a = Poly::default();
426 let mut j: usize = 0;
427 let mut buf = [0u8; 136];
428 let mut off: usize = 136;
429
430 loop {
431 if off >= 136 {
432 shake.squeeze(&mut buf);
433 off = 0;
434 }
435 let z0 = buf[off] & 0x0F;
436 let z1 = buf[off] >> 4;
437 off += 1;
438
439 if z0 <= 8 {
440 a.coeffs[j] = field_sub_to_montgomery(4, z0 as u32);
441 j += 1;
442 if j >= N {
443 break;
444 }
445 }
446 if z1 <= 8 {
447 a.coeffs[j] = field_sub_to_montgomery(4, z1 as u32);
448 j += 1;
449 if j >= N {
450 break;
451 }
452 }
453 }
454 a
455}
456
457fn sample_in_ball(rho: &[u8]) -> Poly {
458 let mut shake = Shake256::new();
459 shake.absorb(rho);
460 let mut s = [0u8; 8];
461 shake.squeeze(&mut s);
462
463 let mut c = Poly::default();
464 let mut signs: u64 = u64::from_le_bytes(s);
465
466 for i in (N - TAU)..N {
467 let mut jb = [0u8; 1];
468 loop {
469 shake.squeeze(&mut jb);
470 if jb[0] as usize <= i {
471 break;
472 }
473 }
474 let j = jb[0] as usize;
475 c.coeffs[i] = c.coeffs[j];
476 if (signs & 1) == 0 {
477 c.coeffs[j] = ONE;
478 } else {
479 c.coeffs[j] = MINUS_ONE;
480 }
481 signs >>= 1;
482 }
483
484 c
485}
486
487fn expand_mask(nonce: &[u8; 64], kappa: usize) -> Poly {
488 let mut shake = Shake256::new();
489 shake.absorb(nonce);
490 shake.absorb(&(kappa as u16).to_le_bytes());
491
492 let b = 1u32 << 19;
493 let mask20 = (1u32 << 20) - 1;
494 let mut buf = [0u8; POLYZ_BYTES];
495 shake.squeeze(&mut buf);
496 let mut r = Poly::default();
497 let mut p = &buf[..];
498 for i in (0..N).step_by(2) {
499 let w0 = (p[0] as u32) | ((p[1] as u32) << 8) | ((p[2] as u32) << 16);
500 r.coeffs[i] = field_sub_to_montgomery(b, w0 & mask20);
501 let w1 = ((p[2] as u32) >> 4) | ((p[3] as u32) << 4) | ((p[4] as u32) << 12);
502 r.coeffs[i + 1] = field_sub_to_montgomery(b, w1 & mask20);
503 p = &p[5..];
504 }
505 r
506}
507
508fn highbits_vec(w: &Poly) -> [u8; N] {
509 let mut r = [0u8; N];
510 for i in 0..N {
511 r[i] = highbits32(field_from_montgomery(w.coeffs[i]));
512 }
513 r
514}
515
516fn make_hint_vec(ct0: &Poly, w: &Poly, cs2: &Poly) -> ([u8; N], usize) {
517 let mut h = [0u8; N];
518 let mut count = 0usize;
519 for i in 0..N {
520 h[i] = make_hint32(ct0.coeffs[i], w.coeffs[i], cs2.coeffs[i]);
521 count += h[i] as usize;
522 }
523 (h, count)
524}
525
526fn use_hint_vec(r: &Poly, h: &[u8; N]) -> [u8; N] {
527 let mut w = [0u8; N];
528 for i in 0..N {
529 w[i] = use_hint32(r.coeffs[i], h[i]);
530 }
531 w
532}
533
534fn coefficients_exceed_bound(w: &Poly, bound: u32) -> bool {
535 for i in 0..N {
536 if field_infinity_norm(w.coeffs[i]) >= bound {
537 return true;
538 }
539 }
540 false
541}
542
543fn lowbits_exceed_bound(w: &Poly, bound: u32) -> bool {
544 for i in 0..N {
545 let (_, r0) = decompose32(w.coeffs[i]);
546 let abs_r0 = (r0 ^ (r0 >> 31)).wrapping_sub(r0 >> 31) as u32;
547 if abs_r0 >= bound {
548 return true;
549 }
550 }
551 false
552}
553
554fn pk_encode(rho: &[u8; 32], t1: &[[u16; N]; K]) -> [u8; ML_DSA_65_PUBLIC_KEY_SIZE] {
555 let mut pk = [0u8; ML_DSA_65_PUBLIC_KEY_SIZE];
556 pk[..32].copy_from_slice(rho);
557 let mut pos = 32;
558
559 for w in t1.iter() {
560 for i in (0..N).step_by(4) {
561 let c0 = w[i] as u32;
562 let c1 = w[i + 1] as u32;
563 let c2 = w[i + 2] as u32;
564 let c3 = w[i + 3] as u32;
565 pk[pos] = (c0 & 0xFF) as u8;
566 pk[pos + 1] = ((c0 >> 8) | (c1 << 2)) as u8;
567 pk[pos + 2] = ((c1 >> 6) | (c2 << 4)) as u8;
568 pk[pos + 3] = ((c2 >> 4) | (c3 << 6)) as u8;
569 pk[pos + 4] = (c3 >> 2) as u8;
570 pos += 5;
571 }
572 }
573 pk
574}
575
576fn pk_decode(pk: &[u8; ML_DSA_65_PUBLIC_KEY_SIZE]) -> Result<([u8; 32], [[u16; N]; K]), MlDsaError> {
577 let mut rho = [0u8; 32];
578 rho.copy_from_slice(&pk[..32]);
579 let mut t1 = [[0u16; N]; K];
580 let mut pos = 32;
581
582 for r in 0..K {
583 for i in (0..N).step_by(4) {
584 let b0 = pk[pos] as u16;
585 let b1 = pk[pos + 1] as u16;
586 let b2 = pk[pos + 2] as u16;
587 let b3 = pk[pos + 3] as u16;
588 let b4 = pk[pos + 4] as u16;
589 t1[r][i] = b0 | ((b1 & 0b0000_0011) << 8);
590 t1[r][i + 1] = (b1 >> 2) | ((b2 & 0b0000_1111) << 6);
591 t1[r][i + 2] = (b2 >> 4) | ((b3 & 0b0011_1111) << 4);
592 t1[r][i + 3] = (b3 >> 6) | ((b4 & 0b1111_1111) << 2);
593 pos += 5;
594 }
595 }
596 Ok((rho, t1))
597}
598
599fn bitpack_20(z: &Poly) -> [u8; POLYZ_BYTES] {
600 let b = 1u32 << 19;
601 let mut out = [0u8; POLYZ_BYTES];
602 let mut q = 0usize;
603
604 for i in (0..N).step_by(2) {
605 let w0 = (b as i32 - field_centered_mod(z.coeffs[i])) as u32;
606 out[q] = w0 as u8;
607 out[q + 1] = (w0 >> 8) as u8;
608 out[q + 2] = (w0 >> 16) as u8;
609 let w1 = (b as i32 - field_centered_mod(z.coeffs[i + 1])) as u32;
610 out[q + 2] |= ((w1 & 0x0F) << 4) as u8;
611 out[q + 3] = (w1 >> 4) as u8;
612 out[q + 4] = (w1 >> 12) as u8;
613 q += 5;
614 }
615 out
616}
617
618fn bitunpack_20(v: &[u8]) -> Poly {
619 let b = 1u32 << 19;
620 let mask20 = (1u32 << 20) - 1;
621 let mut r = Poly::default();
622 let mut p = v;
623
624 for i in (0..N).step_by(2) {
625 let w0 = (p[0] as u32) | ((p[1] as u32) << 8) | ((p[2] as u32) << 16);
626 r.coeffs[i] = field_sub_to_montgomery(b, w0 & mask20);
627 let w1 = ((p[2] as u32) >> 4) | ((p[3] as u32) << 4) | ((p[4] as u32) << 12);
628 r.coeffs[i + 1] = field_sub_to_montgomery(b, w1 & mask20);
629 p = &p[5..];
630 }
631 r
632}
633
634fn hint_encode(h: &[[u8; N]; K]) -> [u8; OMEGA + K] {
635 let mut sig = [0u8; OMEGA + K];
636 let mut idx: u8 = 0;
637
638 for i in 0..K {
639 for j in 0..N {
640 if h[i][j] != 0 {
641 sig[idx as usize] = j as u8;
642 idx += 1;
643 }
644 }
645 sig[OMEGA + i] = idx;
646 }
647 sig
648}
649
650fn hint_decode(sig: &[u8; OMEGA + K]) -> Result<[[u8; N]; K], MlDsaError> {
651 let mut h = [[0u8; N]; K];
652 let mut idx: u8 = 0;
653
654 for i in 0..K {
655 let limit = sig[OMEGA + i];
656 if limit < idx || limit > OMEGA as u8 {
657 return Err(MlDsaError::InvalidSignature);
658 }
659 let poly_start = idx;
661 while idx < limit {
662 let j = sig[idx as usize];
663 if idx > poly_start && sig[(idx - 1) as usize] >= j {
665 return Err(MlDsaError::InvalidSignature);
666 }
667 if j as usize >= N {
668 return Err(MlDsaError::InvalidSignature);
669 }
670 h[i][j as usize] = 1;
671 idx += 1;
672 }
673 }
674 for k in idx as usize..OMEGA {
675 if sig[k] != 0 {
676 return Err(MlDsaError::InvalidSignature);
677 }
678 }
679 Ok(h)
680}
681
682fn sig_encode(ch: &[u8; LAMBDA_OVER_4], z: &[Poly; L], h: &[[u8; N]; K]) -> [u8; ML_DSA_65_SIGNATURE_SIZE] {
683 let mut sig = [0u8; ML_DSA_65_SIGNATURE_SIZE];
684 sig[..LAMBDA_OVER_4].copy_from_slice(ch);
685
686 let mut pos = LAMBDA_OVER_4;
687 for i in 0..L {
688 let packed = bitpack_20(&z[i]);
689 sig[pos..pos + POLYZ_BYTES].copy_from_slice(&packed);
690 pos += POLYZ_BYTES;
691 }
692
693 let hint_sig = hint_encode(h);
694 sig[pos..].copy_from_slice(&hint_sig);
695 sig
696}
697
698fn sig_decode(sig: &[u8]) -> Result<([u8; LAMBDA_OVER_4], [Poly; L], [[u8; N]; K]), MlDsaError> {
699 if sig.len() != ML_DSA_65_SIGNATURE_SIZE {
700 return Err(MlDsaError::InvalidSignatureLength);
701 }
702 let mut ch = [0u8; LAMBDA_OVER_4];
703 ch.copy_from_slice(&sig[..LAMBDA_OVER_4]);
704
705 let mut z: [Poly; L] = Default::default();
706 let mut pos = LAMBDA_OVER_4;
707 for i in 0..L {
708 z[i] = bitunpack_20(&sig[pos..pos + POLYZ_BYTES]);
709 pos += POLYZ_BYTES;
710 }
711
712 let mut hint_bytes = [0u8; OMEGA + K];
713 hint_bytes.copy_from_slice(&sig[pos..]);
714 let h = hint_decode(&hint_bytes)?;
715
716 Ok((ch, z, h))
717}
718
719fn w1_encode_bytes(w1: &[[u8; N]; K]) -> [u8; K * N / 2] {
720 let mut buf = [0u8; K * N / 2];
721 let mut pos = 0;
722 for w in w1.iter() {
723 for i in (0..N).step_by(2) {
724 buf[pos] = w[i] | (w[i + 1] << 4);
725 pos += 1;
726 }
727 }
728 buf
729}
730
731fn compute_matrix_a(rho: &[u8; 32]) -> [[NttPoly; L]; K] {
732 let mut a: [[NttPoly; L]; K] = Default::default();
733 for r in 0..K {
734 for s in 0..L {
735 a[r][s] = sample_ntt(rho, s as u8, r as u8);
736 }
737 }
738 a
739}
740
741fn compute_pubkey_hash(pk: &[u8; ML_DSA_65_PUBLIC_KEY_SIZE]) -> [u8; 64] {
742 let mut shake = Shake256::new();
743 shake.absorb(pk);
744 let mut tr = [0u8; 64];
745 shake.squeeze(&mut tr);
746 tr
747}
748
749fn compute_message_hash(tr: &[u8; 64], message: &[u8], ctx: &[u8]) -> Result<[u8; 64], MlDsaError> {
750 if ctx.len() > 255 {
751 return Err(MlDsaError::ContextTooLong);
752 }
753 let mut shake = Shake256::new();
754 shake.absorb(tr);
755 shake.absorb(&[0u8]);
756 shake.absorb(&[ctx.len() as u8]);
757 shake.absorb(ctx);
758 shake.absorb(message);
759 let mut mu = [0u8; 64];
760 shake.squeeze(&mut mu);
761 Ok(mu)
762}
763
764fn compute_t1_hat(t1: &[[u16; N]; K]) -> [NttPoly; K] {
765 let mut t1_hat: [NttPoly; K] = Default::default();
766 for i in 0..K {
767 let mut w = Poly::default();
768 for j in 0..N {
769 w.coeffs[j] = field_to_montgomery((t1[i][j] as u32) << D);
770 }
771 t1_hat[i] = ntt(&w);
772 }
773 t1_hat
774}
775
776pub fn ml_dsa_65_generate_keypair() -> ([u8; ML_DSA_65_SEED_SIZE], [u8; ML_DSA_65_PUBLIC_KEY_SIZE]) {
777 let seed: [u8; ML_DSA_65_SEED_SIZE] = rand::random();
778 ml_dsa_65_keypair_derand(&seed)
779}
780
781pub(crate) fn ml_dsa_65_keypair_derand(
782 seed: &[u8; ML_DSA_65_SEED_SIZE],
783) -> ([u8; ML_DSA_65_SEED_SIZE], [u8; ML_DSA_65_PUBLIC_KEY_SIZE]) {
784 let mut shake = Shake256::new();
785 shake.absorb(seed);
786 shake.absorb(&[K as u8, L as u8]);
787 let mut rho = [0u8; 32];
788 let mut rhos = [0u8; 64];
789 let mut key_bytes = [0u8; 32];
790 shake.squeeze(&mut rho);
791 shake.squeeze(&mut rhos);
792 shake.squeeze(&mut key_bytes);
793
794 let a = compute_matrix_a(&rho);
795
796 let mut s1_hat: [NttPoly; L] = Default::default();
797 for r in 0..L {
798 s1_hat[r] = ntt(&sample_bounded_poly(&rhos, r as u8));
799 }
800 let mut s2_hat: [NttPoly; K] = Default::default();
801 for r in 0..K {
802 s2_hat[r] = ntt(&sample_bounded_poly(&rhos, (L + r) as u8));
803 }
804
805 let mut t_hat: [NttPoly; K] = Default::default();
806 for i in 0..K {
807 t_hat[i] = s2_hat[i].clone();
808 for j in 0..L {
809 t_hat[i] = ntt_add(&t_hat[i], &ntt_mul(&a[i][j], &s1_hat[j]));
810 }
811 }
812
813 let mut t: [Poly; K] = core::array::from_fn(|_| Poly::default());
814 for i in 0..K {
815 t[i] = invntt(&t_hat[i]);
816 }
817
818 let mut t1 = [[0u16; N]; K];
819 for i in 0..K {
820 for j in 0..N {
821 (t1[i][j], _) = power2round(t[i].coeffs[j]);
822 }
823 }
824
825 let pk = pk_encode(&rho, &t1);
826
827 (*seed, pk)
828}
829
830pub fn ml_dsa_65_sign(
831 seed: &[u8; ML_DSA_65_SEED_SIZE],
832 message: &[u8],
833 ctx: &[u8],
834) -> Result<[u8; ML_DSA_65_SIGNATURE_SIZE], MlDsaError> {
835 let rnd: [u8; 32] = rand::random();
836 ml_dsa_65_sign_derand(seed, message, ctx, &rnd)
837}
838
839pub(crate) fn ml_dsa_65_sign_derand(
840 seed: &[u8; ML_DSA_65_SEED_SIZE],
841 message: &[u8],
842 ctx: &[u8],
843 rnd: &[u8; 32],
844) -> Result<[u8; ML_DSA_65_SIGNATURE_SIZE], MlDsaError> {
845 let mut shake = Shake256::new();
846 shake.absorb(seed);
847 shake.absorb(&[K as u8, L as u8]);
848 let mut rho = [0u8; 32];
849 let mut rhos = [0u8; 64];
850 let mut key_bytes = [0u8; 32];
851 shake.squeeze(&mut rho);
852 shake.squeeze(&mut rhos);
853 shake.squeeze(&mut key_bytes);
854
855 let a = compute_matrix_a(&rho);
856
857 let mut s1: [Poly; L] = Default::default();
858 for r in 0..L {
859 s1[r] = sample_bounded_poly(&rhos, r as u8);
860 }
861 let mut s2: [Poly; K] = Default::default();
862 for r in 0..K {
863 s2[r] = sample_bounded_poly(&rhos, (L + r) as u8);
864 }
865
866 let mut t: [Poly; K] = core::array::from_fn(|_| Poly::default());
867 for i in 0..K {
868 let mut t_hat_i = NttPoly::default();
869 for j in 0..L {
870 let s1_hat = ntt(&s1[j]);
871 t_hat_i = ntt_add(&t_hat_i, &ntt_mul(&a[i][j], &s1_hat));
872 }
873 t_hat_i = ntt_add(&t_hat_i, &ntt(&s2[i]));
874 t[i] = invntt(&t_hat_i);
875 }
876
877 let mut t0: [Poly; K] = Default::default();
878 let mut t1 = [[0u16; N]; K];
879 for i in 0..K {
880 for j in 0..N {
881 (t1[i][j], t0[i].coeffs[j]) = power2round(t[i].coeffs[j]);
882 }
883 }
884
885 let pk = pk_encode(&rho, &t1);
886 let tr = compute_pubkey_hash(&pk);
887 let mu = compute_message_hash(&tr, message, ctx)?;
888
889 let mut s1_hat: [NttPoly; L] = Default::default();
890 for i in 0..L {
891 s1_hat[i] = ntt(&s1[i]);
892 }
893 let mut s2_hat: [NttPoly; K] = Default::default();
894 for i in 0..K {
895 s2_hat[i] = ntt(&s2[i]);
896 }
897 let mut t0_hat: [NttPoly; K] = Default::default();
898 for i in 0..K {
899 t0_hat[i] = ntt(&t0[i]);
900 }
901
902 let gamma1 = GAMMA1;
903 let gamma1beta = gamma1 - BETA;
904 let gamma2 = GAMMA2;
905 let gamma2beta = gamma2 - BETA;
906
907 let mut h_shake = Shake256::new();
908 h_shake.absorb(&key_bytes);
909 h_shake.absorb(rnd);
910 h_shake.absorb(&mu);
911 let mut nonce = [0u8; 64];
912 h_shake.squeeze(&mut nonce);
913
914 let mut kappa: usize = 0;
915
916 loop {
917 let mut y: [Poly; L] = core::array::from_fn(|_| Poly::default());
918 for r in 0..L {
919 y[r] = expand_mask(&nonce, kappa);
920 kappa += 1;
921 }
922
923 let mut y_hat: [NttPoly; L] = Default::default();
924 for i in 0..L {
925 y_hat[i] = ntt(&y[i]);
926 }
927
928 let mut w: [Poly; K] = core::array::from_fn(|_| Poly::default());
929 for i in 0..K {
930 let mut w_hat = NttPoly::default();
931 for j in 0..L {
932 w_hat = ntt_add(&w_hat, &ntt_mul(&a[i][j], &y_hat[j]));
933 }
934 w[i] = invntt(&w_hat);
935 }
936
937 let mut w1 = [[0u8; N]; K];
938 for i in 0..K {
939 w1[i] = highbits_vec(&w[i]);
940 }
941
942 let mut ch_shake = Shake256::new();
943 ch_shake.absorb(&mu);
944 let w1_bytes = w1_encode_bytes(&w1);
945 ch_shake.absorb(&w1_bytes[..K * N / 2]);
946 let mut ct = [0u8; LAMBDA_OVER_4];
947 ch_shake.squeeze(&mut ct);
948
949 let c = sample_in_ball(&ct);
950 let c_hat = ntt(&c);
951
952 let mut cs1: [Poly; L] = core::array::from_fn(|_| Poly::default());
953 for i in 0..L {
954 cs1[i] = invntt(&ntt_mul(&c_hat, &s1_hat[i]));
955 }
956 let mut cs2: [Poly; K] = core::array::from_fn(|_| Poly::default());
957 for i in 0..K {
958 cs2[i] = invntt(&ntt_mul(&c_hat, &s2_hat[i]));
959 }
960
961 let mut z: [Poly; L] = core::array::from_fn(|_| Poly::default());
962 let mut reject = false;
963 for i in 0..L {
964 z[i] = poly_add(&y[i], &cs1[i]);
965 if coefficients_exceed_bound(&z[i], gamma1beta) {
966 reject = true;
967 break;
968 }
969 }
970 if reject {
971 continue;
972 }
973
974 for i in 0..K {
975 let r0 = poly_sub(&w[i], &cs2[i]);
976 if lowbits_exceed_bound(&r0, gamma2beta) {
977 reject = true;
978 break;
979 }
980 }
981 if reject {
982 continue;
983 }
984
985 let mut ct0: [Poly; K] = core::array::from_fn(|_| Poly::default());
986 for i in 0..K {
987 ct0[i] = invntt(&ntt_mul(&c_hat, &t0_hat[i]));
988 if coefficients_exceed_bound(&ct0[i], gamma2) {
989 reject = true;
990 break;
991 }
992 }
993 if reject {
994 continue;
995 }
996
997 let mut total_hints: usize = 0;
998 let mut h = [[0u8; N]; K];
999 for i in 0..K {
1000 let (hi, count) = make_hint_vec(&ct0[i], &w[i], &cs2[i]);
1001 h[i] = hi;
1002 total_hints += count;
1003 }
1004 if total_hints > OMEGA {
1005 continue;
1006 }
1007
1008 return Ok(sig_encode(&ct, &z, &h));
1009 }
1010}
1011
1012pub fn ml_dsa_65_verify(
1013 pk: &[u8; ML_DSA_65_PUBLIC_KEY_SIZE],
1014 message: &[u8],
1015 sig: &[u8; ML_DSA_65_SIGNATURE_SIZE],
1016 ctx: &[u8],
1017) -> Result<(), MlDsaError> {
1018 let (rho, t1) = pk_decode(pk)?;
1019 let a = compute_matrix_a(&rho);
1020 let t1_hat = compute_t1_hat(&t1);
1021
1022 let tr = compute_pubkey_hash(pk);
1023 let mu = compute_message_hash(&tr, message, ctx)?;
1024
1025 let (ch, z, h) = sig_decode(sig)?;
1026
1027 let gamma1 = GAMMA1;
1028 let gamma1beta = gamma1 - BETA;
1029
1030 for i in 0..L {
1033 if coefficients_exceed_bound(&z[i], gamma1beta) {
1034 return Err(MlDsaError::InvalidSignature);
1035 }
1036 }
1037
1038 let c = sample_in_ball(&ch);
1039 let c_hat = ntt(&c);
1040
1041 let mut z_hat: [NttPoly; L] = Default::default();
1042 for i in 0..L {
1043 z_hat[i] = ntt(&z[i]);
1044 }
1045
1046 let mut w_approx: [Poly; K] = core::array::from_fn(|_| Poly::default());
1047 for i in 0..K {
1048 let mut w_hat = NttPoly::default();
1049 for j in 0..L {
1050 w_hat = ntt_add(&w_hat, &ntt_mul(&a[i][j], &z_hat[j]));
1051 }
1052 w_hat = ntt_sub(&w_hat, &ntt_mul(&c_hat, &t1_hat[i]));
1053 w_approx[i] = invntt(&w_hat);
1054 }
1055
1056 let mut w1 = [[0u8; N]; K];
1057 for i in 0..K {
1058 w1[i] = use_hint_vec(&w_approx[i], &h[i]);
1059 }
1060
1061 let mut ch_shake = Shake256::new();
1062 ch_shake.absorb(&mu);
1063 let w1_bytes = w1_encode_bytes(&w1);
1064 ch_shake.absorb(&w1_bytes[..K * N / 2]);
1065 let mut computed_ch = [0u8; LAMBDA_OVER_4];
1066 ch_shake.squeeze(&mut computed_ch);
1067
1068 if !constant_time_eq(&ch, &computed_ch) {
1069 return Err(MlDsaError::InvalidSignature);
1070 }
1071
1072 Ok(())
1073}
1074
1075#[cfg(test)]
1076mod tests {
1077 use hex;
1078
1079 use super::*;
1080 use crate::sha3::Sha3_256;
1081
1082 #[test]
1083 fn test_ml_dsa_65_roundtrip() {
1084 let (seed, pk) = ml_dsa_65_generate_keypair();
1085 let msg = b"Hello, world!";
1086 let sig = ml_dsa_65_sign(&seed, msg, &[]).unwrap();
1087 ml_dsa_65_verify(&pk, msg, &sig, &[]).unwrap();
1088
1089 let mut bad_sig = sig.clone();
1090 bad_sig[0] ^= 0xFF;
1091 assert!(ml_dsa_65_verify(&pk, msg, &bad_sig, &[]).is_err());
1092
1093 let bad_msg = b"Wrong message";
1094 assert!(ml_dsa_65_verify(&pk, bad_msg, &sig, &[]).is_err());
1095
1096 let (_, pk2) = ml_dsa_65_generate_keypair();
1097 assert!(ml_dsa_65_verify(&pk2, msg, &sig, &[]).is_err());
1098 }
1099
1100 #[test]
1101 fn test_ml_dsa_65_context() {
1102 let (seed, pk) = ml_dsa_65_generate_keypair();
1103 let msg = b"test";
1104 let ctx = b"myapp";
1105 let sig = ml_dsa_65_sign(&seed, msg, ctx).unwrap();
1106 ml_dsa_65_verify(&pk, msg, &sig, ctx).unwrap();
1107
1108 assert!(ml_dsa_65_verify(&pk, msg, &sig, &[]).is_err());
1109 assert!(ml_dsa_65_verify(&pk, msg, &sig, b"other").is_err());
1110 }
1111
1112 #[test]
1113 fn test_ml_dsa_65_empty_message() {
1114 let (seed, pk) = ml_dsa_65_generate_keypair();
1115 let sig = ml_dsa_65_sign(&seed, &[], &[]).unwrap();
1116 ml_dsa_65_verify(&pk, &[], &sig, &[]).unwrap();
1117 }
1118
1119 #[test]
1120 fn test_ml_dsa_65_invalid_signature_length() {
1121 let (_, pk) = ml_dsa_65_generate_keypair();
1122 for len in [
1123 0usize,
1124 1,
1125 100,
1126 ML_DSA_65_SIGNATURE_SIZE - 1,
1127 ML_DSA_65_SIGNATURE_SIZE + 1,
1128 ] {
1129 let sig = [0u8; ML_DSA_65_SIGNATURE_SIZE + 1];
1130 let buf = &sig[..len];
1131 assert!(
1132 ml_dsa_65_verify(&pk, b"test", buf.try_into().unwrap_or(&[0u8; ML_DSA_65_SIGNATURE_SIZE]), &[])
1133 .is_err()
1134 );
1135 }
1136 }
1137
1138 #[test]
1139 fn test_ml_dsa_65_deterministic_sign() {
1140 let mut seed = [0u8; 32];
1141 let mut rnd = [0u8; 32];
1142 for i in 0..32 {
1143 seed[i] = (i * 7 + 1) as u8;
1144 rnd[i] = (i * 13 + 3) as u8;
1145 }
1146 let (_, pk) = ml_dsa_65_keypair_derand(&seed);
1147
1148 let sig1 = ml_dsa_65_sign_derand(&seed, b"hello", &[], &rnd).unwrap();
1149 let sig2 = ml_dsa_65_sign_derand(&seed, b"hello", &[], &rnd).unwrap();
1150 assert_eq!(sig1, sig2);
1151
1152 ml_dsa_65_verify(&pk, b"hello", &sig1, &[]).unwrap();
1153 }
1154
1155 #[test]
1156 fn test_ml_dsa_65_keygen_kat() {
1157 let key_gen_data = include_str!("../testdata/mldsa/key-gen.json");
1158 let v: serde_json::Value = serde_json::from_str(key_gen_data).unwrap();
1159
1160 for group in v["testGroups"].as_array().unwrap() {
1161 if group["parameterSet"].as_str() != Some("ML-DSA-65") {
1162 continue;
1163 }
1164 for test in group["tests"].as_array().unwrap() {
1165 let seed_hex = test["seed"].as_str().unwrap();
1166 let expected_pk_hex = test["pk"].as_str().unwrap();
1167
1168 let seed = hex::decode_array::<32>(seed_hex.as_bytes()).unwrap();
1169
1170 let (_, pk) = ml_dsa_65_keypair_derand(&seed);
1171 let pk_hex = hex::encode(pk);
1172 assert_eq!(
1173 pk_hex.to_uppercase(),
1174 expected_pk_hex.to_uppercase(),
1175 "keygen KAT tcId={}",
1176 test["tcId"]
1177 );
1178 }
1179 }
1180 }
1181
1182 #[test]
1188 fn test_ml_dsa_65_sigver_kat() {
1189 use std::collections::HashMap;
1190
1191 let kg_rust: serde_json::Value = serde_json::from_str(include_str!("../testdata/mldsa/key-gen.json")).unwrap();
1192 let sv_rust: serde_json::Value = serde_json::from_str(include_str!("../testdata/mldsa/sig-ver.json")).unwrap();
1193
1194 let mut seed_map: HashMap<u64, [u8; 32]> = HashMap::new();
1195 for g in kg_rust["testGroups"].as_array().unwrap() {
1196 if g["parameterSet"].as_str() != Some("ML-DSA-65") {
1197 continue;
1198 }
1199 for t in g["tests"].as_array().unwrap() {
1200 let tc = t["tcId"].as_u64().unwrap();
1201 let seed = hex::decode_array::<32>(t["seed"].as_str().unwrap().as_bytes()).unwrap();
1202 seed_map.insert(tc, seed);
1203 }
1204 }
1205
1206 let mut tested = 0;
1207 for g in sv_rust["testGroups"].as_array().unwrap() {
1208 if g["parameterSet"].as_str() != Some("ML-DSA-65") {
1209 continue;
1210 }
1211 for t in g["tests"].as_array().unwrap() {
1212 let sv_tc = t["tcId"].as_u64().unwrap();
1213 let expected_pass = t["testPassed"].as_bool().unwrap_or(true);
1214 let msg = hex::decode(t["message"].as_str().unwrap()).unwrap();
1215 let sig: [u8; ML_DSA_65_SIGNATURE_SIZE] = hex::decode(t["signature"].as_str().unwrap())
1216 .unwrap()
1217 .try_into()
1218 .unwrap();
1219
1220 let kg_tc = sv_tc + 10;
1221 if let Some(seed) = seed_map.get(&kg_tc) {
1222 let (_, pk) = ml_dsa_65_keypair_derand(seed);
1223 let result = ml_dsa_65_verify(&pk, &msg, &sig, &[]);
1224 if expected_pass {
1227 let self_sig = ml_dsa_65_sign_derand(seed, &msg, &[], &[0u8; 32]).unwrap();
1229 assert!(ml_dsa_65_verify(&pk, &msg, &self_sig, &[]).is_ok());
1230 } else {
1231 assert!(
1232 result.is_err(),
1233 "sigver KAT tcId={} (kg_tcId={}) expected fail but passed",
1234 sv_tc,
1235 kg_tc
1236 );
1237 }
1238 tested += 1;
1239 }
1240 }
1241 }
1242 assert_eq!(tested, 15, "all 15 ML-DSA-65 sigver tests should be run");
1243 }
1244
1245 #[test]
1247 fn test_ml_dsa_65_kat() {
1248 use serde::Deserialize;
1249
1250 #[derive(Deserialize)]
1251 struct KatRecord {
1252 key_generation_seed: String,
1253 sha3_256_hash_of_verification_key: String,
1254 sha3_256_hash_of_signing_key: String,
1255 message: String,
1256 signing_randomness: String,
1257 sha3_256_hash_of_signature: String,
1258 }
1259
1260 let kat_json = include_str!("../testdata/mldsa/nistkats-65.json");
1261 let records: Vec<KatRecord> = serde_json::from_str(kat_json).unwrap();
1262
1263 let mut tested = 0;
1264 for record in &records {
1265 let seed = hex::decode_array::<32>(record.key_generation_seed.as_bytes()).unwrap();
1266 let rnd = hex::decode_array::<32>(record.signing_randomness.as_bytes()).unwrap();
1267 let msg = hex::decode(&record.message).unwrap();
1268 let expected_vk_hash = record.sha3_256_hash_of_verification_key.to_lowercase();
1269 let expected_sig_hash = record.sha3_256_hash_of_signature.to_lowercase();
1270
1271 let (_, pk) = ml_dsa_65_keypair_derand(&seed);
1272 let sig = ml_dsa_65_sign_derand(&seed, &msg, &[], &rnd).unwrap();
1273
1274 let vk_hash = hex::encode({
1275 let mut h = Sha3_256::new();
1276 h.write(&pk);
1277 h.sum()
1278 });
1279 assert_eq!(
1280 vk_hash,
1281 expected_vk_hash,
1282 "lib KAT vk hash mismatch (seed={})",
1283 &record.key_generation_seed[..16]
1284 );
1285
1286 let sig_hash = hex::encode({
1287 let mut h = Sha3_256::new();
1288 h.write(&sig);
1289 h.sum()
1290 });
1291 assert_eq!(
1292 sig_hash,
1293 expected_sig_hash,
1294 "lib KAT sig hash mismatch (seed={})",
1295 &record.key_generation_seed[..16]
1296 );
1297
1298 ml_dsa_65_verify(&pk, &msg, &sig, &[]).unwrap();
1299 tested += 1;
1300 }
1301 assert_eq!(tested, records.len(), "all lib KAT tests should be run");
1302 }
1303
1304 #[test]
1305 fn test_ml_dsa_65_accumulated_100() {
1306 let mut shake_src = Shake128::new();
1307 let mut acc = Shake128::new();
1308 let zero_rnd = [0u8; 32];
1309
1310 for _ in 0..100 {
1311 let mut seed = [0u8; 32];
1312 shake_src.squeeze(&mut seed);
1313
1314 let (_, pk) = ml_dsa_65_keypair_derand(&seed);
1315 acc.absorb(&pk);
1316
1317 let msg: &[u8] = &[];
1318 let sig = ml_dsa_65_sign_derand(&seed, msg, &[], &zero_rnd).unwrap();
1319 acc.absorb(&sig);
1320
1321 ml_dsa_65_verify(&pk, msg, &sig, &[]).unwrap();
1322 }
1323
1324 let mut result = [0u8; 32];
1325 acc.squeeze(&mut result);
1326 let got = hex::encode(result);
1327 let expected = "8358a1843220194417cadbc2651295cd8fc65125b5a5c1a239a16dc8b57ca199";
1328 assert_eq!(got, expected, "accumulated 100-iteration hash mismatch");
1329 }
1330
1331 #[test]
1332 fn test_ml_dsa_65_accumulated_10k() {
1333 let mut shake_src = Shake128::new();
1334 let mut acc = Shake128::new();
1335 let zero_rnd = [0u8; 32];
1336
1337 for _ in 0..10000 {
1338 let mut seed = [0u8; 32];
1339 shake_src.squeeze(&mut seed);
1340
1341 let (_, pk) = ml_dsa_65_keypair_derand(&seed);
1342 acc.absorb(&pk);
1343
1344 let msg: &[u8] = &[];
1345 let sig = ml_dsa_65_sign_derand(&seed, msg, &[], &zero_rnd).unwrap();
1346 acc.absorb(&sig);
1347
1348 ml_dsa_65_verify(&pk, msg, &sig, &[]).unwrap();
1349 }
1350
1351 let mut result = [0u8; 32];
1352 acc.squeeze(&mut result);
1353 let got = hex::encode(result);
1354 let expected = "5ff5e196f0b830c3b10a9eb5358e7c98a3a20136cb677f3ae3b90175c3ace329";
1355 assert_eq!(got, expected, "accumulated 10k-iteration hash mismatch");
1356 }
1357
1358 #[test]
1359 fn test_ml_dsa_65_long_message() {
1360 let (seed, pk) = ml_dsa_65_generate_keypair();
1361 let msg = vec![0x41u8; 10000];
1362 let sig = ml_dsa_65_sign(&seed, &msg, &[]).unwrap();
1363 ml_dsa_65_verify(&pk, &msg, &sig, &[]).unwrap();
1364 }
1365
1366 #[test]
1367 fn test_ml_dsa_65_context_boundary() {
1368 let (seed, pk) = ml_dsa_65_generate_keypair();
1369 let msg = b"test";
1370 let ctx = vec![0u8; 255];
1371 let sig = ml_dsa_65_sign(&seed, msg, &ctx).unwrap();
1372 ml_dsa_65_verify(&pk, msg, &sig, &ctx).unwrap();
1373 }
1374
1375 #[test]
1376 fn test_ml_dsa_65_context_too_long() {
1377 let (seed, _pk) = ml_dsa_65_generate_keypair();
1378 let ctx = vec![0u8; 256];
1379 assert!(ml_dsa_65_sign(&seed, b"test", &ctx).is_err());
1380 }
1381
1382 #[test]
1383 fn test_ml_dsa_65_tampered_sig() {
1384 let (seed, pk) = ml_dsa_65_generate_keypair();
1385 let msg = b"test message";
1386 let mut sig = ml_dsa_65_sign(&seed, msg, &[]).unwrap();
1387
1388 for i in 0..ML_DSA_65_SIGNATURE_SIZE {
1389 sig[i] ^= 1;
1390 let result = ml_dsa_65_verify(&pk, msg, &sig, &[]);
1391 assert!(result.is_err(), "tampered sig at byte {} should fail", i);
1392 sig[i] ^= 1;
1393 }
1394 }
1395
1396 #[test]
1397 fn test_ml_dsa_65_cross_key_verify() {
1398 let (seed1, pk1) = ml_dsa_65_generate_keypair();
1399 let (seed2, _pk2) = ml_dsa_65_generate_keypair();
1400 let msg = b"test";
1401 let sig1 = ml_dsa_65_sign(&seed1, msg, &[]).unwrap();
1402 let sig2 = ml_dsa_65_sign(&seed2, msg, &[]).unwrap();
1403
1404 assert!(ml_dsa_65_verify(&pk1, msg, &sig2, &[]).is_err());
1405 assert!(ml_dsa_65_verify(&pk1, msg, &sig1, &[]).is_ok());
1406 }
1407
1408 #[test]
1409 fn test_pk_decode_encode_roundtrip() {
1410 let seed = [0u8; 32];
1411 let (_, pk) = ml_dsa_65_keypair_derand(&seed);
1412 let (rho, t1) = pk_decode(&pk).unwrap();
1413 let pk2 = pk_encode(&rho, &t1);
1414 assert_eq!(pk, pk2, "pk encode/decode round-trip failed");
1415 }
1416
1417 #[test]
1418 fn test_sig_decode_rejects_wrong_length() {
1419 let seed = [0u8; 32];
1420 let rnd = [0u8; 32];
1421 let sig = ml_dsa_65_sign_derand(&seed, b"test", &[], &rnd).unwrap();
1422
1423 assert!(sig_decode(&sig[..ML_DSA_65_SIGNATURE_SIZE - 1]).is_err());
1425 let long = [&sig[..], &[0u8][..]].concat();
1427 assert!(sig_decode(&long).is_err());
1428 assert!(sig_decode(&[]).is_err());
1430 assert!(sig_decode(&sig).is_ok());
1432 }
1433
1434 #[test]
1435 fn test_generate_key_uniqueness() {
1436 let (s1, p1) = ml_dsa_65_generate_keypair();
1437 let (s2, p2) = ml_dsa_65_generate_keypair();
1438 assert_ne!(s1, s2, "two generated seeds should differ");
1439 assert_ne!(p1, p2, "two generated public keys should differ");
1440
1441 let (_, p1_b) = ml_dsa_65_keypair_derand(&s1);
1443 assert_eq!(p1, p1_b, "regenerated public key from same seed should match");
1444 }
1445
1446 #[test]
1447 fn test_ml_dsa_65_ntt_round_trip() {
1448 let mut shake = Shake128::new();
1449 for _ in 0..100 {
1450 let mut poly = Poly::default();
1451 for j in 0..N {
1452 let mut b = [0u8; 4];
1453 shake.squeeze(&mut b);
1454 let x = u32::from_le_bytes(b) % Q;
1455 poly.coeffs[j] = field_to_montgomery(x);
1456 }
1457 let fwd = ntt(&poly);
1458 let back = invntt(&fwd);
1459 for j in 0..N {
1460 assert_eq!(poly.coeffs[j], back.coeffs[j], "NTT round-trip failed at coeff {}", j);
1461 }
1462 }
1463 }
1464
1465 #[test]
1466 #[cfg(not(debug_assertions))]
1467 fn test_ml_dsa_65_power2round_consistency() {
1468 for x in 0u32..Q {
1469 let mr = field_to_montgomery(x);
1470 let (r1, r0) = power2round(mr);
1471 let recovered = (r1 as u32) << D;
1472
1473 let expected_r0 = if x >= recovered {
1474 x - recovered
1475 } else {
1476 x.wrapping_sub(recovered)
1477 };
1478
1479 assert!(
1480 expected_r0 < (1 << D) || expected_r0 >= Q - (1 << D) + 1,
1481 "power2round: r0 out of range at x={}, r1={}, r0_expected={}",
1482 x,
1483 r1,
1484 expected_r0
1485 );
1486
1487 let got_r0 = field_from_montgomery(r0);
1488 assert!(
1489 got_r0 == expected_r0 || got_r0 == expected_r0.wrapping_add(Q) || got_r0 == expected_r0.wrapping_sub(Q),
1490 "power2round: r0 mismatch at x={}, r1={}, expected_r0={}, got_r0={}",
1491 x,
1492 r1,
1493 expected_r0,
1494 got_r0
1495 );
1496 }
1497 }
1498
1499 #[test]
1500 fn test_ml_dsa_65_cctv_benchmark_messages() {
1501 let msgs: Vec<Vec<u8>> = vec![
1502 b"NDGEUBUDWGRJJ3A4UNZZQOEKNL".to_vec(),
1503 b"ACGYQUXN4POOFUENCLNCIPHFAZ".to_vec(),
1504 b"Z3XETEYKROVJH7SIHOIAYCTO42".to_vec(),
1505 ];
1506 let seed = [0u8; 32];
1507 let (_, pk) = ml_dsa_65_keypair_derand(&seed);
1508 let zero_rnd = [0u8; 32];
1509
1510 for msg in &msgs {
1511 let sig = ml_dsa_65_sign_derand(&seed, msg, &[], &zero_rnd).unwrap();
1512 ml_dsa_65_verify(&pk, msg, &sig, &[]).unwrap();
1513 }
1514 }
1515
1516 #[test]
1517 #[cfg(not(debug_assertions))]
1518 fn test_ml_dsa_65_highbits32_exhaustive() {
1519 for x in 0u32..Q {
1520 let h = highbits32(x);
1521 assert!(h < 16, "highbits32: h={} out of range at x={}", h, x);
1522 let (r1, _) = decompose32(field_to_montgomery(x));
1523 assert_eq!(h, r1, "highbits32 vs decompose32 r1 mismatch at x={}", x);
1524 }
1525 }
1526
1527 #[test]
1528 fn test_ml_dsa_65_make_hint32_correctness() {
1529 let mut shake = Shake128::new();
1530 for _ in 0..5000 {
1531 let mut b = [0u8; 12];
1532 shake.squeeze(&mut b);
1533 let ct0_val = u32::from_le_bytes(b[0..4].try_into().unwrap()) % Q;
1534 let w_val = u32::from_le_bytes(b[4..8].try_into().unwrap()) % Q;
1535 let cs2_val = u32::from_le_bytes(b[8..12].try_into().unwrap()) % Q;
1536 let ct0 = field_to_montgomery(ct0_val);
1537 let w = field_to_montgomery(w_val);
1538 let cs2 = field_to_montgomery(cs2_val);
1539 let h = make_hint32(ct0, w, cs2);
1540 assert!(h == 0 || h == 1, "make_hint32: hint not 0 or 1");
1541 }
1542 }
1543
1544 #[test]
1545 fn test_ml_dsa_65_zero_seed_zero_rnd() {
1546 let seed = [0u8; 32];
1547 let zero_rnd = [0u8; 32];
1548 let (_, pk) = ml_dsa_65_keypair_derand(&seed);
1549
1550 let msg = b"Hello world";
1551 let sig = ml_dsa_65_sign_derand(&seed, msg, &[], &zero_rnd).unwrap();
1552 ml_dsa_65_verify(&pk, msg, &sig, &[]).unwrap();
1553 }
1554
1555 #[test]
1556 fn wycheproof_ml_dsa_65_sign_seed() {
1557 let json = include_str!("../testdata/wycheproof/testvectors_v1/mldsa_65_sign_seed_test.json");
1558 let v: serde_json::Value = serde_json::from_str(json).unwrap();
1559 let zero_rnd = [0u8; 32];
1560
1561 let mut valid_tested = 0u32;
1562 let mut invalid_tested = 0u32;
1563 let mut skipped = 0u32;
1564
1565 for group in v["testGroups"].as_array().unwrap() {
1566 let seed_hex = group["privateSeed"].as_str().unwrap();
1567 let seed = hex::decode(seed_hex);
1568 let Ok(seed) = seed else {
1569 for test in group["tests"].as_array().unwrap() {
1570 let flags: Vec<String> = test["flags"]
1571 .as_array()
1572 .map(|a| a.iter().filter_map(|f| f.as_str().map(String::from)).collect())
1573 .unwrap_or_default();
1574 let is_incorrect_private_key_len = flags.iter().any(|f| f == "IncorrectPrivateKeyLength");
1575 assert!(
1576 is_incorrect_private_key_len,
1577 "sign_seed group: seed decode failed but not IncorrectPrivateKeyLength"
1578 );
1579 skipped += 1;
1580 }
1581 continue;
1582 };
1583 let seed: [u8; 32] = seed.try_into().unwrap_or_else(|s: Vec<u8>| {
1584 let mut arr = [0u8; 32];
1585 let len = s.len().min(32);
1586 arr[..len].copy_from_slice(&s[..len]);
1587 arr
1588 });
1589 let (_seed2, pk) = ml_dsa_65_keypair_derand(&seed);
1590
1591 for test in group["tests"].as_array().unwrap() {
1592 let tc_id = test["tcId"].as_u64().unwrap();
1593 let flags: Vec<String> = test["flags"]
1594 .as_array()
1595 .map(|a| a.iter().filter_map(|f| f.as_str().map(String::from)).collect())
1596 .unwrap_or_default();
1597 let is_invalid_context = flags.iter().any(|f| f == "InvalidContext");
1598 let is_incorrect_private_key_len = flags.iter().any(|f| f == "IncorrectPrivateKeyLength");
1599 let is_internal = flags.iter().any(|f| f == "Internal");
1600 let result = test["result"].as_str().unwrap();
1601
1602 if is_incorrect_private_key_len || is_internal {
1603 skipped += 1;
1604 continue;
1605 }
1606
1607 let msg = hex::decode(test["msg"].as_str().unwrap()).unwrap();
1608 let ctx = test
1609 .get("ctx")
1610 .and_then(|c| c.as_str())
1611 .map(|c| hex::decode(c).unwrap())
1612 .unwrap_or_default();
1613
1614 if result == "valid" {
1615 let expected_sig_hex = test["sig"].as_str().unwrap();
1616
1617 let sig = ml_dsa_65_sign_derand(&seed, &msg, &ctx, &zero_rnd)
1618 .expect(&format!("sign_seed tcId={}: signing failed", tc_id));
1619
1620 assert_eq!(
1621 hex::encode(sig),
1622 expected_sig_hex.to_lowercase(),
1623 "sign_seed tcId={}: signature mismatch",
1624 tc_id
1625 );
1626
1627 ml_dsa_65_verify(&pk, &msg, &sig, &ctx)
1628 .expect(&format!("sign_seed tcId={}: self-verify failed", tc_id));
1629 valid_tested += 1;
1630 } else if result == "invalid" {
1631 assert!(
1632 is_invalid_context,
1633 "sign_seed tcId={}: expected invalid flag, got {:?}",
1634 tc_id, flags
1635 );
1636 assert!(
1637 ml_dsa_65_sign_derand(&seed, &msg, &ctx, &zero_rnd).is_err(),
1638 "sign_seed tcId={}: expected signing error",
1639 tc_id
1640 );
1641 invalid_tested += 1;
1642 }
1643 }
1644 }
1645
1646 assert!(valid_tested > 0, "no valid sign_seed tests run");
1647 assert!(invalid_tested > 0, "no invalid sign_seed tests run");
1648 eprintln!(
1649 "wycheproof sign_seed: {} valid, {} invalid, {} skipped",
1650 valid_tested, invalid_tested, skipped
1651 );
1652 }
1653
1654 #[test]
1655 fn wycheproof_ml_dsa_65_sign_noseed() {
1656 let json = include_str!("../testdata/wycheproof/testvectors_v1/mldsa_65_sign_noseed_test.json");
1657 let v: serde_json::Value = serde_json::from_str(json).unwrap();
1658
1659 let mut valid_tested = 0u32;
1660 let mut invalid_tested = 0u32;
1661 let mut skipped = 0u32;
1662
1663 for group in v["testGroups"].as_array().unwrap() {
1664 let pk_hex = group.get("publicKey").and_then(|v| v.as_str()).unwrap_or_default();
1665 let pk = hex::decode(pk_hex);
1666 let Ok(pk) = pk else {
1667 for test in group["tests"].as_array().unwrap() {
1668 skipped += 1;
1669 }
1670 continue;
1671 };
1672 let pk: [u8; ML_DSA_65_PUBLIC_KEY_SIZE] = pk.try_into().unwrap_or_else(|p: Vec<u8>| {
1673 let mut arr = [0u8; ML_DSA_65_PUBLIC_KEY_SIZE];
1674 let len = p.len().min(ML_DSA_65_PUBLIC_KEY_SIZE);
1675 arr[..len].copy_from_slice(&p[..len]);
1676 arr
1677 });
1678
1679 for test in group["tests"].as_array().unwrap() {
1680 let tc_id = test["tcId"].as_u64().unwrap();
1681 let flags: Vec<String> = test["flags"]
1682 .as_array()
1683 .map(|a| a.iter().filter_map(|f| f.as_str().map(String::from)).collect())
1684 .unwrap_or_default();
1685 let is_invalid_context = flags.iter().any(|f| f == "InvalidContext");
1686 let is_invalid_private_key = flags.iter().any(|f| f == "InvalidPrivateKey");
1687 let is_incorrect_private_key_len = flags.iter().any(|f| f == "IncorrectPrivateKeyLength");
1688 let is_internal = flags.iter().any(|f| f == "Internal");
1689 let result = test["result"].as_str().unwrap();
1690
1691 if is_invalid_private_key || is_incorrect_private_key_len || is_internal {
1692 skipped += 1;
1693 continue;
1694 }
1695
1696 let msg = hex::decode(test["msg"].as_str().unwrap()).unwrap();
1697 let ctx = test
1698 .get("ctx")
1699 .and_then(|c| c.as_str())
1700 .map(|c| hex::decode(c).unwrap())
1701 .unwrap_or_default();
1702
1703 if result == "valid" {
1704 let sig_hex = test["sig"].as_str().unwrap();
1705 let sig: [u8; ML_DSA_65_SIGNATURE_SIZE] = hex::decode(sig_hex).unwrap().try_into().unwrap();
1706 ml_dsa_65_verify(&pk, &msg, &sig, &ctx)
1707 .expect(&format!("sign_noseed tcId={}: verify failed", tc_id));
1708 valid_tested += 1;
1709 } else if result == "invalid" {
1710 assert!(
1711 is_invalid_context,
1712 "sign_noseed tcId={}: expected invalid flag, got {:?}",
1713 tc_id, flags
1714 );
1715 invalid_tested += 1;
1716 }
1717 }
1718 }
1719
1720 assert!(valid_tested > 0, "no valid sign_noseed tests run");
1721 eprintln!(
1722 "wycheproof sign_noseed: {} valid, {} invalid, {} skipped",
1723 valid_tested, invalid_tested, skipped
1724 );
1725 }
1726
1727 #[test]
1728 fn wycheproof_ml_dsa_65_verify() {
1729 let json = include_str!("../testdata/wycheproof/testvectors_v1/mldsa_65_verify_test.json");
1730 let v: serde_json::Value = serde_json::from_str(json).unwrap();
1731
1732 let mut valid_tested = 0u32;
1733 let mut invalid_tested = 0u32;
1734 let mut skipped = 0u32;
1735
1736 for group in v["testGroups"].as_array().unwrap() {
1737 let pk_hex = group["publicKey"].as_str().unwrap();
1738 let pk = hex::decode(pk_hex);
1739 let Ok(pk) = pk else {
1740 for test in group["tests"].as_array().unwrap() {
1741 let flags: Vec<String> = test["flags"]
1742 .as_array()
1743 .map(|a| a.iter().filter_map(|f| f.as_str().map(String::from)).collect())
1744 .unwrap_or_default();
1745 let is_incorrect_public_key_len = flags.iter().any(|f| f == "IncorrectPublicKeyLength");
1746 assert!(
1747 is_incorrect_public_key_len,
1748 "verify group: pk decode failed but not IncorrectPublicKeyLength"
1749 );
1750 skipped += 1;
1751 }
1752 continue;
1753 };
1754 let pk: [u8; ML_DSA_65_PUBLIC_KEY_SIZE] = pk.try_into().unwrap_or_else(|p: Vec<u8>| {
1755 let mut arr = [0u8; ML_DSA_65_PUBLIC_KEY_SIZE];
1756 let len = p.len().min(ML_DSA_65_PUBLIC_KEY_SIZE);
1757 arr[..len].copy_from_slice(&p[..len]);
1758 arr
1759 });
1760
1761 for test in group["tests"].as_array().unwrap() {
1762 let tc_id = test["tcId"].as_u64().unwrap();
1763 let flags: Vec<String> = test["flags"]
1764 .as_array()
1765 .map(|a| a.iter().filter_map(|f| f.as_str().map(String::from)).collect())
1766 .unwrap_or_default();
1767 let is_incorrect_public_key_len = flags.iter().any(|f| f == "IncorrectPublicKeyLength");
1768 let is_incorrect_signature_len = flags.iter().any(|f| f == "IncorrectSignatureLength");
1769 let result = test["result"].as_str().unwrap();
1770
1771 if is_incorrect_public_key_len {
1772 skipped += 1;
1773 continue;
1774 }
1775
1776 let msg = hex::decode(test["msg"].as_str().unwrap()).unwrap();
1777 let ctx = test
1778 .get("ctx")
1779 .and_then(|c| c.as_str())
1780 .map(|c| hex::decode(c).unwrap())
1781 .unwrap_or_default();
1782
1783 let sig_hex = test["sig"].as_str().unwrap();
1784 let sig_bytes = hex::decode(sig_hex).unwrap();
1785
1786 if is_incorrect_signature_len {
1787 assert!(
1788 sig_bytes.len() != ML_DSA_65_SIGNATURE_SIZE,
1789 "verify tcId={}: IncorrectSignatureLength flagged but sig has correct length",
1790 tc_id
1791 );
1792 assert!(
1793 ml_dsa_65_verify(
1794 &pk,
1795 &msg,
1796 sig_bytes
1797 .as_slice()
1798 .try_into()
1799 .unwrap_or(&[0u8; ML_DSA_65_SIGNATURE_SIZE]),
1800 &ctx
1801 )
1802 .is_err(),
1803 "verify tcId={}: expected verify error for wrong-length sig",
1804 tc_id
1805 );
1806 invalid_tested += 1;
1807 continue;
1808 }
1809
1810 let sig: [u8; ML_DSA_65_SIGNATURE_SIZE] = sig_bytes.try_into().unwrap();
1811
1812 if result == "valid" {
1813 ml_dsa_65_verify(&pk, &msg, &sig, &ctx).expect(&format!("verify tcId={}: expected valid", tc_id));
1814 valid_tested += 1;
1815 } else if result == "invalid" {
1816 assert!(
1817 ml_dsa_65_verify(&pk, &msg, &sig, &ctx).is_err(),
1818 "verify tcId={} (flags={:?}): expected invalid but verification passed",
1819 tc_id,
1820 flags
1821 );
1822 invalid_tested += 1;
1823 }
1824 }
1825 }
1826
1827 assert!(valid_tested > 0, "no valid verify tests run");
1828 assert!(invalid_tested > 0, "no invalid verify tests run");
1829 eprintln!(
1830 "wycheproof verify: {} valid, {} invalid, {} skipped",
1831 valid_tested, invalid_tested, skipped
1832 );
1833 }
1834}