1use super::mlkem::{
2 ML_KEM_768, MlKemError, SHARED_SECRET_SIZE, crypto_kem_dec, crypto_kem_enc_derand, crypto_kem_keypair_derand,
3 indcpa_secret_key_bytes,
4};
5
6pub const PUBLIC_KEY_SIZE_768: usize = 1184;
7pub const SECRET_KEY_SIZE_768: usize = 2400;
8pub const CIPHERTEXT_SIZE_768: usize = 1088;
9
10#[derive(Clone, Debug, PartialEq, Eq)]
12#[cfg_attr(feature = "zeroize", derive(zeroize::Zeroize, zeroize::ZeroizeOnDrop))]
13pub struct SecretKey768 {
14 bytes: [u8; SECRET_KEY_SIZE_768],
15}
16
17#[derive(Clone, Debug, PartialEq, Eq)]
19pub struct PublicKey768 {
20 bytes: [u8; PUBLIC_KEY_SIZE_768],
21}
22
23#[inline]
24pub fn generate_keypair_768() -> (SecretKey768, PublicKey768) {
25 SecretKey768::generate()
26}
27
28#[inline]
29pub(crate) fn generate_keypair_768_derand(coins: &[u8; 64]) -> (SecretKey768, PublicKey768) {
30 SecretKey768::generate_derand(coins)
31}
32
33impl SecretKey768 {
34 pub fn from_bytes(bytes: &[u8; SECRET_KEY_SIZE_768]) -> Self {
35 Self {
36 bytes: *bytes,
37 }
38 }
39
40 pub fn to_bytes(&self) -> [u8; SECRET_KEY_SIZE_768] {
41 self.bytes
42 }
43
44 pub fn generate() -> (Self, PublicKey768) {
45 let coins: [u8; 64] = rand::random();
46 Self::generate_derand(&coins)
47 }
48
49 pub(crate) fn generate_derand(coins: &[u8; 64]) -> (Self, PublicKey768) {
50 let (sk_bytes, pk_bytes) =
51 crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, coins);
52 (
53 Self {
54 bytes: sk_bytes,
55 },
56 PublicKey768 {
57 bytes: pk_bytes,
58 },
59 )
60 }
61
62 pub fn decapsulate(&self, ciphertext: &[u8; CIPHERTEXT_SIZE_768]) -> Result<[u8; SHARED_SECRET_SIZE], MlKemError> {
63 crypto_kem_dec::<3, SECRET_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &self.bytes, ciphertext)
64 }
65
66 pub fn public_key(&self) -> PublicKey768 {
67 let offset = indcpa_secret_key_bytes::<3>();
68 let mut pk_bytes = [0u8; PUBLIC_KEY_SIZE_768];
69 pk_bytes.copy_from_slice(&self.bytes[offset..offset + PUBLIC_KEY_SIZE_768]);
70 PublicKey768 {
71 bytes: pk_bytes,
72 }
73 }
74}
75
76impl From<&[u8; SECRET_KEY_SIZE_768]> for SecretKey768 {
77 fn from(bytes: &[u8; SECRET_KEY_SIZE_768]) -> Self {
78 Self::from_bytes(bytes)
79 }
80}
81
82impl TryFrom<&[u8]> for SecretKey768 {
83 type Error = MlKemError;
84
85 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
86 Ok(Self::from_bytes(bytes.try_into().map_err(|_| MlKemError::InvalidKey)?))
87 }
88}
89
90impl PublicKey768 {
91 pub fn from_bytes(bytes: &[u8; PUBLIC_KEY_SIZE_768]) -> Self {
92 Self {
93 bytes: *bytes,
94 }
95 }
96
97 pub fn to_bytes(&self) -> [u8; PUBLIC_KEY_SIZE_768] {
98 self.bytes
99 }
100
101 pub fn encapsulate(&self) -> ([u8; CIPHERTEXT_SIZE_768], [u8; SHARED_SECRET_SIZE]) {
102 let coins: [u8; 32] = rand::random();
103 self.encapsulate_derand(&coins)
104 }
105
106 pub(crate) fn encapsulate_derand(&self, coins: &[u8; 32]) -> ([u8; CIPHERTEXT_SIZE_768], [u8; SHARED_SECRET_SIZE]) {
107 crypto_kem_enc_derand::<3, PUBLIC_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &self.bytes, coins)
108 }
109}
110
111impl From<&[u8; PUBLIC_KEY_SIZE_768]> for PublicKey768 {
112 fn from(bytes: &[u8; PUBLIC_KEY_SIZE_768]) -> Self {
113 Self::from_bytes(bytes)
114 }
115}
116
117impl TryFrom<&[u8]> for PublicKey768 {
118 type Error = MlKemError;
119
120 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
121 Ok(Self::from_bytes(bytes.try_into().map_err(|_| MlKemError::InvalidKey)?))
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::{
128 super::mlkem::{
129 ML_KEM_768, crypto_kem_dec, crypto_kem_enc_derand, crypto_kem_keypair_derand, decode_hex_array,
130 sha3_256_hex,
131 },
132 *,
133 };
134
135 #[test]
136 fn ml_kem_768_round_trip() {
137 let (private_key, public_key) = generate_keypair_768();
138 let (ciphertext, encapsulated_secret) = public_key.encapsulate();
139 let decapsulated_secret = private_key.decapsulate(&ciphertext).unwrap();
140
141 assert_eq!(encapsulated_secret, decapsulated_secret);
142 }
143
144 #[test]
145 fn ml_kem_768_decapsulation_rejects_tampered_ciphertext() {
146 let (private_key, public_key) = generate_keypair_768();
147 let (mut ciphertext, encapsulated_secret) = public_key.encapsulate();
148
149 ciphertext[0] ^= 0x80;
150
151 let decapsulated_secret = private_key.decapsulate(&ciphertext).unwrap();
152
153 assert_ne!(encapsulated_secret, decapsulated_secret);
154 }
155
156 #[test]
157 fn ml_kem_768_deterministic_derand_vectors_are_stable() {
158 let key_coins = [7u8; 64];
159 let enc_coins = [9u8; 32];
160 let (secret_key, public_key) =
161 crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, &key_coins);
162 let (ciphertext, shared_secret) =
163 crypto_kem_enc_derand::<3, PUBLIC_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &public_key, &enc_coins);
164 let decapsulated =
165 crypto_kem_dec::<3, SECRET_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &secret_key, &ciphertext)
166 .unwrap();
167
168 assert_eq!(shared_secret, decapsulated);
169 assert_eq!(
170 hex::encode(&public_key[..32]),
171 "925a2700ad064ff778b4da4cf51457a48224a52751250a8ee10b251c818bafca"
172 );
173 assert_eq!(
174 hex::encode(&ciphertext[..32]),
175 "766c326c3483444c5b6d917cdddc3c07fbf935295c8f17c92a187a80dc4d15f2"
176 );
177 assert_eq!(
178 hex::encode(shared_secret),
179 "afcf18dfd6b710a09b5cf591d0eb8229d83aa10904934a3ca60a52da5ff36b96"
180 );
181 }
182
183 #[test]
184 fn ml_kem_768_cctv_accumulated_10k() {
185 use crate::{Xof, sha3::Shake128};
186
187 let mut rng = Shake128::new();
188 rng.absorb(&[]);
189
190 let mut acc = Shake128::new();
191
192 for _ in 0..10_000u32 {
193 let mut d = [0u8; 32];
194 let mut z = [0u8; 32];
195 let mut m = [0u8; 32];
196 let mut ct_random = [0u8; CIPHERTEXT_SIZE_768];
197
198 rng.squeeze(&mut d);
199 rng.squeeze(&mut z);
200 rng.squeeze(&mut m);
201 rng.squeeze(&mut ct_random);
202
203 let mut coins = [0u8; 64];
204 coins[..32].copy_from_slice(&d);
205 coins[32..].copy_from_slice(&z);
206
207 let (dk, ek) =
208 crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, &coins);
209 let (ct, k_encaps) =
210 crypto_kem_enc_derand::<3, PUBLIC_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &ek, &m);
211
212 let k_decaps =
213 crypto_kem_dec::<3, SECRET_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &dk, &ct).unwrap();
214 assert_eq!(k_encaps, k_decaps);
215
216 let k_decaps_random =
217 crypto_kem_dec::<3, SECRET_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &dk, &ct_random).unwrap();
218
219 acc.absorb(&ek);
220 acc.absorb(&dk);
221 acc.absorb(&ct);
222 acc.absorb(&k_encaps);
223 acc.absorb(&k_decaps_random);
224 }
225
226 let mut hash = [0u8; 32];
227 acc.squeeze(&mut hash);
228 assert_eq!(
229 hex::encode(hash),
230 "f959d18d3d1180121433bf0e05f11e7908cf9d03edc150b2b07cb90bef5bc1c1",
231 "ML-KEM-768 CCTV accumulated hash mismatch"
232 );
233 }
234
235 #[test]
236 fn ml_kem_768_cctv_intermediate_vector() {
237 let d: [u8; 32] = decode_hex_array("f688563f7c66a5da2d8bdb5a5f3e07bd8dce6f7efcec7f41298d79863459f7cd");
238 let z: [u8; 32] = decode_hex_array("d1d49a515250dbceb9f6e3fcc1c7d5306918964b21ddb22207e03e57f0600da8");
239 let m: [u8; 32] = decode_hex_array("3dc27ca0a6594b0e56320457c45a0f76bb8a213ea4a76d442186a0aefadbcdb9");
240
241 let mut coins = [0u8; 64];
242 coins[..32].copy_from_slice(&d);
243 coins[32..].copy_from_slice(&z);
244
245 let (dk, ek) = crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, &coins);
246 let (ct, k) = crypto_kem_enc_derand::<3, PUBLIC_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &ek, &m);
247
248 assert_eq!(
249 sha3_256_hex(&ek),
250 "42d930a50dfd1f0541ca45c4598daebb4f51cd10d711a001bd9bb87d5c87a4bf"
251 );
252 assert_eq!(
253 sha3_256_hex(&dk),
254 "db563aebd9fdc875e88563693edad1e5e359cc37b0f685d2d0a3723b37253192"
255 );
256 assert_eq!(
257 sha3_256_hex(&ct),
258 "9d6e358208c4d583050becb319050b7f916de47caad1d589a1d01fea43fe1750"
259 );
260 assert_eq!(
261 hex::encode(k),
262 "ae726da2df66601c6648a7565c02b203a089276ac30f6cc226d048f93fafd78c"
263 );
264 }
265
266 #[test]
267 fn ml_kem_768_decapsulation_with_wrong_key_rejects() {
268 let (_, alice_pk) = generate_keypair_768();
269 let (bob_sk, _bob_pk) = generate_keypair_768();
270 let (ct, _alice_ss) = alice_pk.encapsulate();
271
272 let wrong_ss = bob_sk.decapsulate(&ct).unwrap();
273 assert_ne!(_alice_ss, wrong_ss);
274 }
275
276 #[test]
277 fn ml_kem_768_round_trip_many() {
278 for _ in 0..100 {
279 let (sk, pk) = generate_keypair_768();
280 let (ct, ss_enc) = pk.encapsulate();
281 let ss_dec = sk.decapsulate(&ct).unwrap();
282 assert_eq!(ss_enc, ss_dec);
283 }
284 }
285
286 #[test]
287 fn ml_kem_768_all_zero_ciphertext_does_not_panic() {
288 let (sk, _pk) = generate_keypair_768();
289 let ct = [0u8; CIPHERTEXT_SIZE_768];
290 let _result = sk.decapsulate(&ct);
291 }
292
293 #[test]
294 fn ml_kem_768_all_ones_ciphertext_does_not_panic() {
295 let (sk, _pk) = generate_keypair_768();
296 let ct = [0xffu8; CIPHERTEXT_SIZE_768];
297 let _result = sk.decapsulate(&ct);
298 }
299
300 #[test]
301 fn ml_kem_768_derand_keygen_is_deterministic() {
302 let coins = [7u8; 64];
303 let (sk1, pk1) = crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, &coins);
304 let (sk2, pk2) = crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, &coins);
305 assert_eq!(sk1, sk2);
306 assert_eq!(pk1, pk2);
307 }
308
309 #[test]
310 fn ml_kem_768_key_sizes_are_correct() {
311 let (sk, pk) = generate_keypair_768();
312 let sk_bytes = sk.to_bytes();
313 let pk_bytes = pk.to_bytes();
314 assert_eq!(sk_bytes.len(), SECRET_KEY_SIZE_768);
315 assert_eq!(pk_bytes.len(), PUBLIC_KEY_SIZE_768);
316 let (ct, _) = pk.encapsulate();
317 assert_eq!(ct.len(), CIPHERTEXT_SIZE_768);
318 }
319
320 #[test]
321 fn ml_kem_768_encaps_is_deterministic_with_same_coins() {
322 let enc_coins = [9u8; 32];
323 let key_coins = [7u8; 64];
324 let (_sk, pk) =
325 crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, &key_coins);
326 let (ct1, ss1) =
327 crypto_kem_enc_derand::<3, PUBLIC_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &pk, &enc_coins);
328 let (ct2, ss2) =
329 crypto_kem_enc_derand::<3, PUBLIC_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &pk, &enc_coins);
330 assert_eq!(ct1, ct2);
331 assert_eq!(ss1, ss2);
332 }
333
334 #[test]
335 fn ml_kem_768_decapsulation_with_wrong_key_is_deterministic() {
336 let (_, pk_a) = generate_keypair_768();
337 let (sk_b, _pk_b) = generate_keypair_768();
338 let (ct, _) = pk_a.encapsulate();
339
340 let ss1 = sk_b.decapsulate(&ct).unwrap();
341 let ss2 = sk_b.decapsulate(&ct).unwrap();
342 assert_eq!(ss1, ss2, "implicit rejection must be deterministic");
343 }
344
345 #[test]
346 fn ml_kem_768_wycheproof_keygen() {
347 let data: serde_json::Value = serde_json::from_str(include_str!(
348 "../../testdata/wycheproof/testvectors_v1/mlkem_768_keygen_seed_test.json"
349 ))
350 .unwrap();
351 let mut tested = 0u64;
352 for group in data["testGroups"].as_array().unwrap() {
353 if group["parameterSet"].as_str() != Some("ML-KEM-768") {
354 continue;
355 }
356 for test in group["tests"].as_array().unwrap() {
357 let seed_hex = test["seed"].as_str().unwrap();
358 let expected_ek_hex = test["ek"].as_str().unwrap();
359 let expected_dk_hex = test["dk"].as_str().unwrap();
360 let result = test["result"].as_str().unwrap();
361
362 let seed = hex::decode_array::<64>(seed_hex.as_bytes()).unwrap();
363
364 let (dk, ek) =
365 crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, &seed);
366
367 let ek_hex = hex::encode(ek);
368 let dk_hex = hex::encode(dk);
369
370 if result == "valid" {
371 assert_eq!(
372 ek_hex, expected_ek_hex,
373 "wycheproof keygen KAT tcId={} ek mismatch",
374 test["tcId"]
375 );
376 assert_eq!(
377 dk_hex, expected_dk_hex,
378 "wycheproof keygen KAT tcId={} dk mismatch",
379 test["tcId"]
380 );
381 }
382 tested += 1;
383 }
384 }
385 assert!(tested > 0, "no ML-KEM-768 keygen tests were run");
386 }
387
388 fn wycheproof_kem_skip_invalid_lengths(seed_hex: &str, c_hex: &str, ct_size: usize) -> bool {
389 seed_hex.len() != 128 || c_hex.len() != ct_size * 2
390 }
391
392 #[test]
393 fn ml_kem_768_wycheproof_kem() {
394 let data: serde_json::Value =
395 serde_json::from_str(include_str!("../../testdata/wycheproof/testvectors_v1/mlkem_768_test.json")).unwrap();
396 let mut tested = 0u64;
397 for group in data["testGroups"].as_array().unwrap() {
398 if group["parameterSet"].as_str() != Some("ML-KEM-768") {
399 continue;
400 }
401 for test in group["tests"].as_array().unwrap() {
402 let seed_hex = test["seed"].as_str().unwrap();
403 let c_hex = test["c"].as_str().unwrap();
404 let expected_k_hex = test["K"].as_str().unwrap();
405 let result = test["result"].as_str().unwrap();
406
407 if wycheproof_kem_skip_invalid_lengths(seed_hex, c_hex, CIPHERTEXT_SIZE_768) {
408 tested += 1;
409 continue;
410 }
411
412 let seed = hex::decode_array::<64>(seed_hex.as_bytes()).unwrap();
413
414 let (dk, ek) =
415 crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, &seed);
416
417 if let Some(expected_ek_hex) = test.get("ek").and_then(|v| v.as_str()) {
418 let ek_hex = hex::encode(ek);
419 assert_eq!(ek_hex, expected_ek_hex, "wycheproof KEM KAT tcId={} ek mismatch", test["tcId"]);
420 }
421
422 let c = decode_hex_array::<CIPHERTEXT_SIZE_768>(c_hex);
423 let shared_secret = crypto_kem_dec::<3, SECRET_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &dk, &c);
424
425 if result == "valid" {
426 let k = shared_secret.unwrap();
427 let k_hex = hex::encode(k);
428 assert_eq!(k_hex, expected_k_hex, "wycheproof KEM KAT tcId={} K mismatch", test["tcId"]);
429 } else {
430 assert!(
431 shared_secret.is_ok(),
432 "wycheproof KEM KAT tcId={} unexpected error",
433 test["tcId"]
434 );
435 }
436 tested += 1;
437 }
438 }
439 assert!(tested > 0, "no ML-KEM-768 KEM tests were run");
440 }
441
442 #[test]
443 fn ml_kem_768_wycheproof_encaps() {
444 let data: serde_json::Value = serde_json::from_str(include_str!(
445 "../../testdata/wycheproof/testvectors_v1/mlkem_768_encaps_test.json"
446 ))
447 .unwrap();
448 let mut tested = 0u64;
449 for group in data["testGroups"].as_array().unwrap() {
450 if group["parameterSet"].as_str() != Some("ML-KEM-768") {
451 continue;
452 }
453 for test in group["tests"].as_array().unwrap() {
454 let ek_hex = test["ek"].as_str().unwrap();
455 let m_hex = test["m"].as_str().unwrap();
456 let expected_c_hex = test["c"].as_str().unwrap();
457 let expected_k_hex = test["K"].as_str().unwrap();
458 let result = test["result"].as_str().unwrap();
459
460 if ek_hex.len() != PUBLIC_KEY_SIZE_768 * 2 {
461 tested += 1;
462 continue;
463 }
464
465 let ek = decode_hex_array::<PUBLIC_KEY_SIZE_768>(ek_hex);
466
467 if result == "valid" {
468 let m = decode_hex_array::<32>(m_hex);
469 let (c, k) =
470 crypto_kem_enc_derand::<3, PUBLIC_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &ek, &m);
471 let c_hex_out = hex::encode(c);
472 let k_hex_out = hex::encode(k);
473 assert_eq!(
474 c_hex_out, expected_c_hex,
475 "wycheproof encaps KAT tcId={} c mismatch",
476 test["tcId"]
477 );
478 assert_eq!(
479 k_hex_out, expected_k_hex,
480 "wycheproof encaps KAT tcId={} K mismatch",
481 test["tcId"]
482 );
483 }
484 tested += 1;
485 }
486 }
487 assert!(tested > 0, "no ML-KEM-768 encaps tests were run");
488 }
489
490 #[test]
491 fn ml_kem_768_wycheproof_decaps_validation() {
492 let data: serde_json::Value = serde_json::from_str(include_str!(
493 "../../testdata/wycheproof/testvectors_v1/mlkem_768_semi_expanded_decaps_test.json"
494 ))
495 .unwrap();
496 let mut tested = 0u64;
497 for group in data["testGroups"].as_array().unwrap() {
498 if group["parameterSet"].as_str() != Some("ML-KEM-768") {
499 continue;
500 }
501 for test in group["tests"].as_array().unwrap() {
502 let flags: Vec<&str> = test["flags"]
503 .as_array()
504 .map(|a| a.iter().filter_map(|v| v.as_str()).collect())
505 .unwrap_or_default();
506 let dk_hex = test["dk"].as_str().unwrap();
507 let c_hex = test["c"].as_str().unwrap();
508
509 if flags.contains(&"IncorrectDecapsulationKeyLength") || flags.contains(&"IncorrectCiphertextLength") {
510 tested += 1;
511 continue;
512 }
513
514 let dk = decode_hex_array::<SECRET_KEY_SIZE_768>(dk_hex);
515 let c = decode_hex_array::<CIPHERTEXT_SIZE_768>(c_hex);
516
517 let result = crypto_kem_dec::<3, SECRET_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &dk, &c);
518
519 assert!(result.is_ok(), "wycheproof decaps tcId={} panicked", test["tcId"]);
520 tested += 1;
521 }
522 }
523 assert!(tested > 0, "no ML-KEM-768 decaps validation tests were run");
524 }
525
526 #[test]
527 fn ml_kem_768_cross_implementation_pqcrypto() {
528 let data: serde_json::Value =
531 serde_json::from_str(include_str!("../../testdata/mlkem/pqcrypto_768_vectors.json")).unwrap();
532 let vectors = data.as_array().unwrap();
533 assert!(vectors.len() >= 5, "not enough cross-impl vectors");
534
535 for (i, vector) in vectors.iter().enumerate() {
536 let sk_hex = vector["sk"].as_str().unwrap();
537 let ct_hex = vector["ct"].as_str().unwrap();
538 let expected_ss_hex = vector["ss"].as_str().unwrap();
539
540 let sk = decode_hex_array::<SECRET_KEY_SIZE_768>(sk_hex);
541 let ct = decode_hex_array::<CIPHERTEXT_SIZE_768>(ct_hex);
542
543 let ss = crypto_kem_dec::<3, SECRET_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &sk, &ct).unwrap();
544 assert_eq!(
545 hex::encode(ss),
546 expected_ss_hex,
547 "cross-impl pqcrypto vector {i} decapsulation mismatch"
548 );
549 }
550 }
551}