1use crate::{
2 curve25519::x25519,
3 mlkem::{self, MlKemError},
4 sha3::{Sha3_256, Shake256},
5};
6
7pub const SECRET_KEY_SIZE: usize = 32;
8pub const PUBLIC_KEY_SIZE: usize = mlkem::PUBLIC_KEY_SIZE_768 + x25519::KEY_SIZE; pub const CIPHERTEXT_SIZE: usize = mlkem::CIPHERTEXT_SIZE_768 + x25519::SHARED_SECRET_SIZE; pub const SHARED_SECRET_SIZE: usize = 32;
11
12const XWING_LABEL: &[u8; 6] = b"\\.//^\\";
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum XWingError {
16 MlKem(MlKemError),
17}
18
19impl From<MlKemError> for XWingError {
20 fn from(err: MlKemError) -> Self {
21 XWingError::MlKem(err)
22 }
23}
24
25impl core::fmt::Display for XWingError {
26 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
27 match self {
28 XWingError::MlKem(err) => write!(f, "ML-KEM error: {err}"),
29 }
30 }
31}
32
33#[derive(Clone, Debug, PartialEq, Eq)]
35pub struct SecretKey {
36 bytes: [u8; SECRET_KEY_SIZE],
37 x25519_secret_key: x25519::SecretKey,
38 x25519_public_key_bytes: [u8; x25519::KEY_SIZE],
39 mlkem_secret_key: mlkem::SecretKey768,
40}
41
42impl SecretKey {
43 pub fn to_bytes(&self) -> [u8; SECRET_KEY_SIZE] {
44 self.bytes
45 }
46
47 pub fn decapsulate(&self, ct: &[u8; CIPHERTEXT_SIZE]) -> Result<[u8; SHARED_SECRET_SIZE], XWingError> {
48 let ct_m = &ct[..mlkem::CIPHERTEXT_SIZE_768].try_into().unwrap();
49 let ct_x = x25519::PublicKey::from_bytes(&ct[mlkem::CIPHERTEXT_SIZE_768..].try_into().unwrap());
50
51 let ss_m = self.mlkem_secret_key.decapsulate(&ct_m)?;
52 let ss_x = self.x25519_secret_key.ecdh(&ct_x);
53
54 Ok(combiner(&ss_m, &ss_x, &ct_x.to_bytes(), &self.x25519_public_key_bytes))
55 }
56}
57
58#[derive(Clone, Debug, PartialEq, Eq)]
60pub struct PublicKey {
61 mlkem_public_key: mlkem::PublicKey768,
62 x25519_public_key: x25519::PublicKey,
63}
64
65impl PublicKey {
66 pub fn to_bytes(&self) -> [u8; PUBLIC_KEY_SIZE] {
67 let mut bytes = [0u8; PUBLIC_KEY_SIZE];
68 bytes[..mlkem::PUBLIC_KEY_SIZE_768].copy_from_slice(&self.mlkem_public_key.to_bytes());
69 bytes[mlkem::PUBLIC_KEY_SIZE_768..].copy_from_slice(&self.x25519_public_key.to_bytes());
70 bytes
71 }
72
73 pub fn encapsulate(&self) -> ([u8; SHARED_SECRET_SIZE], [u8; CIPHERTEXT_SIZE]) {
74 let eseed: [u8; 64] = rand::random();
75 self.encapsulate_derand(&eseed)
76 }
77
78 fn encapsulate_derand(&self, eseed: &[u8; 64]) -> ([u8; SHARED_SECRET_SIZE], [u8; CIPHERTEXT_SIZE]) {
79 let ek_x = x25519::SecretKey::from_bytes(&eseed[32..64].try_into().unwrap());
80 let ct_x = ek_x.public_key();
81 let ss_x = ek_x.ecdh(&self.x25519_public_key);
82
83 let m = &eseed[..32].try_into().unwrap();
84 let (ct_m, ss_m) = self.mlkem_public_key.encapsulate_derand(&m);
85
86 let ss = combiner(&ss_m, &ss_x, &ct_x.to_bytes(), &self.x25519_public_key.to_bytes());
87
88 let mut ct = [0u8; CIPHERTEXT_SIZE];
89 ct[..mlkem::CIPHERTEXT_SIZE_768].copy_from_slice(&ct_m);
90 ct[mlkem::CIPHERTEXT_SIZE_768..].copy_from_slice(&ct_x.to_bytes());
91
92 (ss, ct)
93 }
94}
95
96pub fn generate_keypair() -> (SecretKey, PublicKey) {
97 let seed: [u8; SECRET_KEY_SIZE] = rand::random();
98 generate_keypair_derand(&seed)
99}
100
101fn generate_keypair_derand(secret_key: &[u8; SECRET_KEY_SIZE]) -> (SecretKey, PublicKey) {
104 let (mlkem_sk, x25519_sk, mlkem_pk, x25519_pk) = expand_decapsulation_key(secret_key);
105
106 let secret_key = SecretKey {
107 bytes: *secret_key,
108 x25519_secret_key: x25519_sk,
109 x25519_public_key_bytes: x25519_pk.to_bytes(),
110 mlkem_secret_key: mlkem_sk,
111 };
112
113 let public_key = PublicKey {
114 mlkem_public_key: mlkem_pk,
115 x25519_public_key: x25519_pk,
116 };
117
118 (secret_key, public_key)
119}
120
121fn expand_decapsulation_key(
122 secret_key: &[u8; 32],
123) -> (mlkem::SecretKey768, x25519::SecretKey, mlkem::PublicKey768, x25519::PublicKey) {
124 let mut expanded_secret_key = [0u8; 96];
125 Shake256::hash(secret_key, &mut expanded_secret_key);
126
127 let (sk_m, pk_m) = derive_mlkeem_keys(&expanded_secret_key);
128
129 let sk_x = x25519::SecretKey::from_bytes(&expanded_secret_key[64..96].try_into().unwrap());
130 let pk_x = sk_x.public_key();
131
132 (sk_m, sk_x, pk_m, pk_x)
133}
134
135fn derive_mlkeem_keys(expnded_secret_key: &[u8; 96]) -> (mlkem::SecretKey768, mlkem::PublicKey768) {
136 mlkem::generate_keypair_768_derand(&expnded_secret_key[..64].try_into().unwrap())
137}
138
139fn combiner(
140 ss_m: &[u8; mlkem::SHARED_SECRET_SIZE],
141 ss_x: &[u8; x25519::KEY_SIZE],
142 ct_x: &[u8; x25519::KEY_SIZE],
143 pk_x: &[u8; x25519::KEY_SIZE],
144) -> [u8; SHARED_SECRET_SIZE] {
145 let mut hasher = Sha3_256::new();
146 hasher.write(ss_m);
147 hasher.write(ss_x);
148 hasher.write(ct_x);
149 hasher.write(pk_x);
150 hasher.write(XWING_LABEL);
151 hasher.sum()
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 fn hex_to_array<const N: usize>(hex_str: &str) -> [u8; N] {
159 let bytes = hex::decode(hex_str).unwrap();
160 return bytes.try_into().unwrap();
161 }
162
163 #[test]
164 fn constants() {
165 assert!(PUBLIC_KEY_SIZE == 1216);
166 assert!(CIPHERTEXT_SIZE == 1120);
167 }
168
169 struct TestVector {
170 seed: &'static str,
171 eseed: &'static str,
172 ss: &'static str,
173 }
174
175 const TEST_VECTORS: [TestVector; 3] = [
176 TestVector {
177 seed: "7f9c2ba4e88f827d616045507605853ed73b8093f6efbc88eb1a6eacfa66ef26",
178 eseed: "3cb1eea988004b93103cfb0aeefd2a686e01fa4a58e8a3639ca8a1e3f9ae57e235b8cc873c23dc62b8d260169afa2f75ab916a58d974918835d25e6a435085b2",
179 ss: "d2df0522128f09dd8e2c92b1e905c793d8f57a54c3da25861f10bf4ca613e384",
180 },
181 TestVector {
182 seed: "badfd6dfaac359a5efbb7bcc4b59d538df9a04302e10c8bc1cbf1a0b3a5120ea",
183 eseed: "17cda7cfad765f5623474d368ccca8af0007cd9f5e4c849f167a580b14aabdefaee7eef47cb0fca9767be1fda69419dfb927e9df07348b196691abaeb580b32d",
184 ss: "f2e86241c64d60f6649fbc6c5b7d17180b780a3f34355e64a85749949c45f150",
185 },
186 TestVector {
187 seed: "ef58538b8d23f87732ea63b02b4fa0f4873360e2841928cd60dd4cee8cc0d4c9",
188 eseed: "22a96188d032675c8ac850933c7aff1533b94c834adbb69c6115bad4692d8619f90b0cdf8a7b9c264029ac185b70b83f2801f2f4b3f70c593ea3aeeb613a7f1b",
189 ss: "953f7f4e8c5b5049bdc771d1dffada0dd961477d1a2ae0988baa7ea6898d893f",
190 },
191 ];
192
193 #[test]
194 fn test_vectors_from_draft() {
195 for (i, tv) in TEST_VECTORS.iter().enumerate() {
196 let seed: [u8; 32] = hex_to_array(tv.seed);
197 let eseed: [u8; 64] = hex_to_array(tv.eseed);
198 let expected_ss: [u8; 32] = hex_to_array(tv.ss);
199
200 let (secret_key, pk) = generate_keypair_derand(&seed);
201 assert_eq!(secret_key.to_bytes(), seed, "vector {i}: sk mismatch");
202
203 let (ss, ct) = pk.encapsulate_derand(&eseed);
204 assert_eq!(ss, expected_ss, "vector {i}: encaps ss mismatch");
205
206 let decapsulated_ss = secret_key.decapsulate(&ct).unwrap();
207 assert_eq!(decapsulated_ss, expected_ss, "vector {i}: decaps ss mismatch");
208 }
209 }
210
211 #[test]
212 fn round_trip() {
213 let (secret_key, public_key) = generate_keypair();
214 let (ss, ct) = public_key.encapsulate();
215 let decapsulated = secret_key.decapsulate(&ct).unwrap();
216 assert_eq!(ss, decapsulated);
217 }
218
219 #[test]
220 fn round_trip_many() {
221 for _ in 0..10 {
222 let (secret_key, public_key) = generate_keypair();
223 let (ss, ct) = public_key.encapsulate();
224 let decapsulated = secret_key.decapsulate(&ct).unwrap();
225 assert_eq!(ss, decapsulated);
226 }
227 }
228
229 #[test]
230 fn decapsulation_with_wrong_key_produces_different_secret() {
231 let (_, pk_a) = generate_keypair();
232 let (sk_b, _) = generate_keypair();
233
234 let (ss_a, ct) = pk_a.encapsulate();
235 let ss_b = sk_b.decapsulate(&ct).unwrap();
236 assert_ne!(ss_a, ss_b);
237 }
238
239 #[test]
240 fn tampered_ciphertext_produces_different_secret() {
241 let (secret_key, public_key) = generate_keypair();
242 let (ss, mut ct) = public_key.encapsulate();
243
244 ct[0] ^= 0x80;
245
246 let tampered_ss = secret_key.decapsulate(&ct).unwrap();
247 assert_ne!(ss, tampered_ss);
248 }
249
250 #[test]
251 fn derandomized_keygen_is_deterministic() {
252 let seed: [u8; 32] = hex_to_array("7f9c2ba4e88f827d616045507605853ed73b8093f6efbc88eb1a6eacfa66ef26");
253 let (sk1, pk1) = generate_keypair_derand(&seed);
254 let (sk2, pk2) = generate_keypair_derand(&seed);
255 assert_eq!(sk1.to_bytes(), sk2.to_bytes());
256 assert_eq!(pk1.to_bytes(), pk2.to_bytes());
257 }
258
259 #[test]
260 fn derandomized_encaps_is_deterministic() {
261 let seed: [u8; 32] = hex_to_array("7f9c2ba4e88f827d616045507605853ed73b8093f6efbc88eb1a6eacfa66ef26");
262 let eseed: [u8; 64] = hex_to_array(
263 "3cb1eea988004b93103cfb0aeefd2a686e01fa4a58e8a3639ca8a1e3f9ae57e235b8cc873c23dc62b8d260169afa2f75ab916a58d974918835d25e6a435085b2",
264 );
265 let (_, pk) = generate_keypair_derand(&seed);
266
267 let (ss1, ct1) = pk.encapsulate_derand(&eseed);
268 let (ss2, ct2) = pk.encapsulate_derand(&eseed);
269 assert_eq!(ct1, ct2);
270 assert_eq!(ss1, ss2);
271 }
272
273 #[test]
274 fn xwing_label_is_correct() {
275 assert_eq!(XWING_LABEL.len(), 6);
276 assert_eq!(hex::encode(XWING_LABEL), "5c2e2f2f5e5c");
277 }
278
279 #[test]
280 fn expand_decapsulation_key_is_deterministic() {
281 let seed: [u8; 32] = hex_to_array("7f9c2ba4e88f827d616045507605853ed73b8093f6efbc88eb1a6eacfa66ef26");
282
283 let (sk_m1, sk_x1, pk_m1, pk_x1) = expand_decapsulation_key(&seed);
284 let (sk_m2, sk_x2, pk_m2, pk_x2) = expand_decapsulation_key(&seed);
285 assert_eq!(sk_m1, sk_m2);
286 assert_eq!(sk_x1, sk_x2);
287 assert_eq!(pk_m1, pk_m2);
288 assert_eq!(pk_x1, pk_x2);
289 }
290
291 #[test]
292 fn combiner_is_deterministic() {
293 let ss_m = [0x01u8; 32];
294 let ss_x = [0x02u8; 32];
295 let ct_x = [0x03u8; 32];
296 let pk_x = [0x04u8; 32];
297
298 let result1 = combiner(&ss_m, &ss_x, &ct_x, &pk_x);
299 let result2 = combiner(&ss_m, &ss_x, &ct_x, &pk_x);
300 assert_eq!(result1, result2);
301 }
302}