Skip to main content

crypto/
xwing.rs

1use crate::{
2    curve25519::x25519,
3    mlkem::{self, MlKemError},
4    sha3::{Sha3_256, Shake256},
5};
6
7pub const SECRET_KEY_SIZE: usize = 32;
8pub const PUBLIC_KEY_SIZE: usize = mlkem::PUBLIC_KEY_SIZE_768 + x25519::KEY_SIZE; // 1216
9pub const CIPHERTEXT_SIZE: usize = mlkem::CIPHERTEXT_SIZE_768 + x25519::SHARED_SECRET_SIZE; // 1120
10pub const SHARED_SECRET_SIZE: usize = 32;
11
12const XWING_LABEL: &[u8; 6] = b"\\.//^\\";
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum XWingError {
16    MlKem(MlKemError),
17}
18
19impl From<MlKemError> for XWingError {
20    fn from(err: MlKemError) -> Self {
21        XWingError::MlKem(err)
22    }
23}
24
25impl core::fmt::Display for XWingError {
26    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
27        match self {
28            XWingError::MlKem(err) => write!(f, "ML-KEM error: {err}"),
29        }
30    }
31}
32
33/// The X-Wing decapsulation (private) key
34#[derive(Clone, Debug, PartialEq, Eq)]
35pub struct SecretKey {
36    bytes: [u8; SECRET_KEY_SIZE],
37    x25519_secret_key: x25519::SecretKey,
38    x25519_public_key_bytes: [u8; x25519::KEY_SIZE],
39    mlkem_secret_key: mlkem::SecretKey768,
40}
41
42impl SecretKey {
43    pub fn to_bytes(&self) -> [u8; SECRET_KEY_SIZE] {
44        self.bytes
45    }
46
47    pub fn decapsulate(&self, ct: &[u8; CIPHERTEXT_SIZE]) -> Result<[u8; SHARED_SECRET_SIZE], XWingError> {
48        let ct_m = &ct[..mlkem::CIPHERTEXT_SIZE_768].try_into().unwrap();
49        let ct_x = x25519::PublicKey::from_bytes(&ct[mlkem::CIPHERTEXT_SIZE_768..].try_into().unwrap());
50
51        let ss_m = self.mlkem_secret_key.decapsulate(&ct_m)?;
52        let ss_x = self.x25519_secret_key.ecdh(&ct_x);
53
54        Ok(combiner(&ss_m, &ss_x, &ct_x.to_bytes(), &self.x25519_public_key_bytes))
55    }
56}
57
58/// The X-Wing encapsulation (public) key
59#[derive(Clone, Debug, PartialEq, Eq)]
60pub struct PublicKey {
61    mlkem_public_key: mlkem::PublicKey768,
62    x25519_public_key: x25519::PublicKey,
63}
64
65impl PublicKey {
66    pub fn to_bytes(&self) -> [u8; PUBLIC_KEY_SIZE] {
67        let mut bytes = [0u8; PUBLIC_KEY_SIZE];
68        bytes[..mlkem::PUBLIC_KEY_SIZE_768].copy_from_slice(&self.mlkem_public_key.to_bytes());
69        bytes[mlkem::PUBLIC_KEY_SIZE_768..].copy_from_slice(&self.x25519_public_key.to_bytes());
70        bytes
71    }
72
73    pub fn encapsulate(&self) -> ([u8; SHARED_SECRET_SIZE], [u8; CIPHERTEXT_SIZE]) {
74        let eseed: [u8; 64] = rand::random();
75        self.encapsulate_derand(&eseed)
76    }
77
78    fn encapsulate_derand(&self, eseed: &[u8; 64]) -> ([u8; SHARED_SECRET_SIZE], [u8; CIPHERTEXT_SIZE]) {
79        let ek_x = x25519::SecretKey::from_bytes(&eseed[32..64].try_into().unwrap());
80        let ct_x = ek_x.public_key();
81        let ss_x = ek_x.ecdh(&self.x25519_public_key);
82
83        let m = &eseed[..32].try_into().unwrap();
84        let (ct_m, ss_m) = self.mlkem_public_key.encapsulate_derand(&m);
85
86        let ss = combiner(&ss_m, &ss_x, &ct_x.to_bytes(), &self.x25519_public_key.to_bytes());
87
88        let mut ct = [0u8; CIPHERTEXT_SIZE];
89        ct[..mlkem::CIPHERTEXT_SIZE_768].copy_from_slice(&ct_m);
90        ct[mlkem::CIPHERTEXT_SIZE_768..].copy_from_slice(&ct_x.to_bytes());
91
92        (ss, ct)
93    }
94}
95
96pub fn generate_keypair() -> (SecretKey, PublicKey) {
97    let seed: [u8; SECRET_KEY_SIZE] = rand::random();
98    generate_keypair_derand(&seed)
99}
100
101/// Generate a deterministic keypair from the given seed. This function is not public because it
102/// should be used exclusively for testing.
103fn generate_keypair_derand(secret_key: &[u8; SECRET_KEY_SIZE]) -> (SecretKey, PublicKey) {
104    let (mlkem_sk, x25519_sk, mlkem_pk, x25519_pk) = expand_decapsulation_key(secret_key);
105
106    let secret_key = SecretKey {
107        bytes: *secret_key,
108        x25519_secret_key: x25519_sk,
109        x25519_public_key_bytes: x25519_pk.to_bytes(),
110        mlkem_secret_key: mlkem_sk,
111    };
112
113    let public_key = PublicKey {
114        mlkem_public_key: mlkem_pk,
115        x25519_public_key: x25519_pk,
116    };
117
118    (secret_key, public_key)
119}
120
121fn expand_decapsulation_key(
122    secret_key: &[u8; 32],
123) -> (mlkem::SecretKey768, x25519::SecretKey, mlkem::PublicKey768, x25519::PublicKey) {
124    let mut expanded_secret_key = [0u8; 96];
125    Shake256::hash(secret_key, &mut expanded_secret_key);
126
127    let (sk_m, pk_m) = derive_mlkeem_keys(&expanded_secret_key);
128
129    let sk_x = x25519::SecretKey::from_bytes(&expanded_secret_key[64..96].try_into().unwrap());
130    let pk_x = sk_x.public_key();
131
132    (sk_m, sk_x, pk_m, pk_x)
133}
134
135fn derive_mlkeem_keys(expnded_secret_key: &[u8; 96]) -> (mlkem::SecretKey768, mlkem::PublicKey768) {
136    mlkem::generate_keypair_768_derand(&expnded_secret_key[..64].try_into().unwrap())
137}
138
139fn combiner(
140    ss_m: &[u8; mlkem::SHARED_SECRET_SIZE],
141    ss_x: &[u8; x25519::KEY_SIZE],
142    ct_x: &[u8; x25519::KEY_SIZE],
143    pk_x: &[u8; x25519::KEY_SIZE],
144) -> [u8; SHARED_SECRET_SIZE] {
145    let mut hasher = Sha3_256::new();
146    hasher.write(ss_m);
147    hasher.write(ss_x);
148    hasher.write(ct_x);
149    hasher.write(pk_x);
150    hasher.write(XWING_LABEL);
151    hasher.sum()
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    fn hex_to_array<const N: usize>(hex_str: &str) -> [u8; N] {
159        let bytes = hex::decode(hex_str).unwrap();
160        return bytes.try_into().unwrap();
161    }
162
163    #[test]
164    fn constants() {
165        assert!(PUBLIC_KEY_SIZE == 1216);
166        assert!(CIPHERTEXT_SIZE == 1120);
167    }
168
169    struct TestVector {
170        seed: &'static str,
171        eseed: &'static str,
172        ss: &'static str,
173    }
174
175    const TEST_VECTORS: [TestVector; 3] = [
176        TestVector {
177            seed: "7f9c2ba4e88f827d616045507605853ed73b8093f6efbc88eb1a6eacfa66ef26",
178            eseed: "3cb1eea988004b93103cfb0aeefd2a686e01fa4a58e8a3639ca8a1e3f9ae57e235b8cc873c23dc62b8d260169afa2f75ab916a58d974918835d25e6a435085b2",
179            ss: "d2df0522128f09dd8e2c92b1e905c793d8f57a54c3da25861f10bf4ca613e384",
180        },
181        TestVector {
182            seed: "badfd6dfaac359a5efbb7bcc4b59d538df9a04302e10c8bc1cbf1a0b3a5120ea",
183            eseed: "17cda7cfad765f5623474d368ccca8af0007cd9f5e4c849f167a580b14aabdefaee7eef47cb0fca9767be1fda69419dfb927e9df07348b196691abaeb580b32d",
184            ss: "f2e86241c64d60f6649fbc6c5b7d17180b780a3f34355e64a85749949c45f150",
185        },
186        TestVector {
187            seed: "ef58538b8d23f87732ea63b02b4fa0f4873360e2841928cd60dd4cee8cc0d4c9",
188            eseed: "22a96188d032675c8ac850933c7aff1533b94c834adbb69c6115bad4692d8619f90b0cdf8a7b9c264029ac185b70b83f2801f2f4b3f70c593ea3aeeb613a7f1b",
189            ss: "953f7f4e8c5b5049bdc771d1dffada0dd961477d1a2ae0988baa7ea6898d893f",
190        },
191    ];
192
193    #[test]
194    fn test_vectors_from_draft() {
195        for (i, tv) in TEST_VECTORS.iter().enumerate() {
196            let seed: [u8; 32] = hex_to_array(tv.seed);
197            let eseed: [u8; 64] = hex_to_array(tv.eseed);
198            let expected_ss: [u8; 32] = hex_to_array(tv.ss);
199
200            let (secret_key, pk) = generate_keypair_derand(&seed);
201            assert_eq!(secret_key.to_bytes(), seed, "vector {i}: sk mismatch");
202
203            let (ss, ct) = pk.encapsulate_derand(&eseed);
204            assert_eq!(ss, expected_ss, "vector {i}: encaps ss mismatch");
205
206            let decapsulated_ss = secret_key.decapsulate(&ct).unwrap();
207            assert_eq!(decapsulated_ss, expected_ss, "vector {i}: decaps ss mismatch");
208        }
209    }
210
211    #[test]
212    fn round_trip() {
213        let (secret_key, public_key) = generate_keypair();
214        let (ss, ct) = public_key.encapsulate();
215        let decapsulated = secret_key.decapsulate(&ct).unwrap();
216        assert_eq!(ss, decapsulated);
217    }
218
219    #[test]
220    fn round_trip_many() {
221        for _ in 0..10 {
222            let (secret_key, public_key) = generate_keypair();
223            let (ss, ct) = public_key.encapsulate();
224            let decapsulated = secret_key.decapsulate(&ct).unwrap();
225            assert_eq!(ss, decapsulated);
226        }
227    }
228
229    #[test]
230    fn decapsulation_with_wrong_key_produces_different_secret() {
231        let (_, pk_a) = generate_keypair();
232        let (sk_b, _) = generate_keypair();
233
234        let (ss_a, ct) = pk_a.encapsulate();
235        let ss_b = sk_b.decapsulate(&ct).unwrap();
236        assert_ne!(ss_a, ss_b);
237    }
238
239    #[test]
240    fn tampered_ciphertext_produces_different_secret() {
241        let (secret_key, public_key) = generate_keypair();
242        let (ss, mut ct) = public_key.encapsulate();
243
244        ct[0] ^= 0x80;
245
246        let tampered_ss = secret_key.decapsulate(&ct).unwrap();
247        assert_ne!(ss, tampered_ss);
248    }
249
250    #[test]
251    fn derandomized_keygen_is_deterministic() {
252        let seed: [u8; 32] = hex_to_array("7f9c2ba4e88f827d616045507605853ed73b8093f6efbc88eb1a6eacfa66ef26");
253        let (sk1, pk1) = generate_keypair_derand(&seed);
254        let (sk2, pk2) = generate_keypair_derand(&seed);
255        assert_eq!(sk1.to_bytes(), sk2.to_bytes());
256        assert_eq!(pk1.to_bytes(), pk2.to_bytes());
257    }
258
259    #[test]
260    fn derandomized_encaps_is_deterministic() {
261        let seed: [u8; 32] = hex_to_array("7f9c2ba4e88f827d616045507605853ed73b8093f6efbc88eb1a6eacfa66ef26");
262        let eseed: [u8; 64] = hex_to_array(
263            "3cb1eea988004b93103cfb0aeefd2a686e01fa4a58e8a3639ca8a1e3f9ae57e235b8cc873c23dc62b8d260169afa2f75ab916a58d974918835d25e6a435085b2",
264        );
265        let (_, pk) = generate_keypair_derand(&seed);
266
267        let (ss1, ct1) = pk.encapsulate_derand(&eseed);
268        let (ss2, ct2) = pk.encapsulate_derand(&eseed);
269        assert_eq!(ct1, ct2);
270        assert_eq!(ss1, ss2);
271    }
272
273    #[test]
274    fn xwing_label_is_correct() {
275        assert_eq!(XWING_LABEL.len(), 6);
276        assert_eq!(hex::encode(XWING_LABEL), "5c2e2f2f5e5c");
277    }
278
279    #[test]
280    fn expand_decapsulation_key_is_deterministic() {
281        let seed: [u8; 32] = hex_to_array("7f9c2ba4e88f827d616045507605853ed73b8093f6efbc88eb1a6eacfa66ef26");
282
283        let (sk_m1, sk_x1, pk_m1, pk_x1) = expand_decapsulation_key(&seed);
284        let (sk_m2, sk_x2, pk_m2, pk_x2) = expand_decapsulation_key(&seed);
285        assert_eq!(sk_m1, sk_m2);
286        assert_eq!(sk_x1, sk_x2);
287        assert_eq!(pk_m1, pk_m2);
288        assert_eq!(pk_x1, pk_x2);
289    }
290
291    #[test]
292    fn combiner_is_deterministic() {
293        let ss_m = [0x01u8; 32];
294        let ss_x = [0x02u8; 32];
295        let ct_x = [0x03u8; 32];
296        let pk_x = [0x04u8; 32];
297
298        let result1 = combiner(&ss_m, &ss_x, &ct_x, &pk_x);
299        let result2 = combiner(&ss_m, &ss_x, &ct_x, &pk_x);
300        assert_eq!(result1, result2);
301    }
302}