Skip to main content

crypto/
hkdf.rs

1use crate::{Hash, Hasher, hmac::Hmac};
2
3const DEFAULT_SALT: [u8; 64] = [0u8; 64];
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum HkdfError {
7    PrkIsTooShort(usize),
8    OutputIsTooLong,
9}
10
11#[cfg(feature = "alloc")]
12impl core::fmt::Display for HkdfError {
13    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
14        match self {
15            HkdfError::PrkIsTooShort(prk_size) => write!(f, "PRK must be at least {prk_size} bytes"),
16            HkdfError::OutputIsTooLong => {
17                write!(f, "HKDF output length exceeds RFC 5869 limit (255 * Hash's output size)")
18            }
19        }
20    }
21}
22
23/// Extract step: `PRK = HMAC-Hash(salt, IKM)`.
24///
25/// If `salt` is `None`, a string of `H::OUTPUT_SIZE` zero bytes is used.
26pub fn extract<H: Hasher>(salt: Option<&[u8]>, ikm: &[u8]) -> Hash {
27    let salt = salt.unwrap_or(&DEFAULT_SALT[..H::OUTPUT_SIZE]);
28    let mut mac = Hmac::<H>::new(salt);
29    mac.update(ikm);
30    return mac.finalize();
31}
32
33/// Expand step: `OKM = T(1) || T(2) || ...`, where
34/// `T(i) = HMAC-Hash(PRK, T(i-1) || info || i)`.
35///
36/// # Error
37///
38/// Returns an error if `N > 255 * H::OUTPUT_SIZE` or if `prk.len() < H::OUTPUT_SIZE`.
39pub fn expand<H: Hasher, const N: usize>(prk: &[u8], info: &[u8]) -> Result<[u8; N], HkdfError> {
40    if prk.len() < H::OUTPUT_SIZE {
41        return Err(HkdfError::PrkIsTooShort(H::OUTPUT_SIZE));
42    }
43
44    if N > 255 * H::OUTPUT_SIZE {
45        return Err(HkdfError::OutputIsTooLong);
46    }
47
48    let mut okm = [0u8; N];
49    if N == 0 {
50        return Ok(okm);
51    }
52
53    let mut t = [0u8; 64];
54    let mut t_len = 0usize;
55    let mut offset = 0usize;
56    let mut counter = 1u8;
57
58    while offset < N {
59        let mut mac = Hmac::<H>::new(&prk[..H::OUTPUT_SIZE]);
60        mac.update(&t[..t_len]);
61        mac.update(info);
62        mac.update(&[counter]);
63        let block = mac.finalize();
64        let block_bytes = block.as_ref();
65        let chunk_len = (N - offset).min(H::OUTPUT_SIZE);
66        okm[offset..offset + chunk_len].copy_from_slice(&block_bytes[..chunk_len]);
67        t[..H::OUTPUT_SIZE].copy_from_slice(block_bytes);
68        t_len = H::OUTPUT_SIZE;
69        offset += chunk_len;
70        counter = counter.wrapping_add(1);
71    }
72
73    return Ok(okm);
74}
75
76/// One-shot extract-then-expand.
77///
78/// # Error
79///
80/// Returns an error if if `N > 255 * H::OUTPUT_SIZE`.
81pub fn derive_key<H: Hasher, const N: usize>(
82    ikm: &[u8],
83    info: &[u8],
84    salt: Option<&[u8]>,
85) -> Result<[u8; N], HkdfError> {
86    let prk = extract::<H>(salt, ikm);
87    return expand::<H, N>(prk.as_ref(), info);
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use crate::sha2::{Sha256, Sha512};
94
95    struct TestVector {
96        ikm: &'static str,
97        salt: Option<&'static str>,
98        info: &'static str,
99        expected_prk: &'static str,
100        expected_okm: &'static str,
101    }
102
103    fn decode_hex(input: &str) -> Vec<u8> {
104        let input = input.replace(|c: char| c.is_whitespace(), "");
105        (0..input.len())
106            .step_by(2)
107            .map(|i| u8::from_str_radix(&input[i..i + 2], 16).unwrap())
108            .collect()
109    }
110
111    const SHA256_VECTORS: [TestVector; 4] = [
112        TestVector {
113            ikm: "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
114            salt: Some("000102030405060708090a0b0c"),
115            info: "f0f1f2f3f4f5f6f7f8f9",
116            expected_prk: "077709362c2e32df0ddc3f0dc47bba6390b6c73bb50f9c3122ec844ad7c2b3e5",
117            expected_okm: "3cb25f25faacd57a90434f64d0362f2a2d2d0a90cf1a5a4c5db02d56ecc4c5bf34007208d5b887185865",
118        },
119        TestVector {
120            ikm: "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f\
121                  202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f\
122                  404142434445464748494a4b4c4d4e4f",
123            salt: Some(
124                "606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f\
125                 808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f\
126                 a0a1a2a3a4a5a6a7a8a9aaabacadaeaf",
127            ),
128            info: "b0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecf\
129                  d0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeef\
130                  f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff",
131            expected_prk: "06a6b88c5853361a06104c9ceb35b45cef760014904671014a193f40c15fc244",
132            expected_okm: "b11e398dc80327a1c8e7f78c596a49344f012eda2d4efad8a050cc4c19afa97c59045a99cac7827271cb41c65e590e09da3275600c2f09b8367793a9aca3db71cc30c58179ec3e87c14c01d5c1f3434f1d87",
133        },
134        TestVector {
135            ikm: "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
136            salt: Some(""),
137            info: "",
138            expected_prk: "19ef24a32c717b167f33a91d6f648bdf96596776afdb6377ac434c1c293ccb04",
139            expected_okm: "8da4e775a563c18f715f802a063c5a31b8a11f5c5ee1879ec3454e5f3c738d2d9d201395faa4b61a96c8",
140        },
141        TestVector {
142            ikm: "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
143            salt: None,
144            info: "",
145            expected_prk: "19ef24a32c717b167f33a91d6f648bdf96596776afdb6377ac434c1c293ccb04",
146            expected_okm: "8da4e775a563c18f715f802a063c5a31b8a11f5c5ee1879ec3454e5f3c738d2d9d201395faa4b61a96c8",
147        },
148    ];
149
150    const SHA512_VECTORS: [TestVector; 4] = [
151        TestVector {
152            ikm: "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
153            salt: Some("000102030405060708090a0b0c"),
154            info: "f0f1f2f3f4f5f6f7f8f9",
155            expected_prk: "665799823737ded04a88e47e54a5890bb2c3d247c7a4254a8e61350723590a26c36238127d8661b88cf80ef802d57e2f7cebcf1e00e083848be19929c61b4237",
156            expected_okm: "832390086cda71fb47625bb5ceb168e4c8e26a1a16ed34d9fc7fe92c1481579338da362cb8d9f925d7cb",
157        },
158        TestVector {
159            ikm: "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f\
160                  202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f\
161                  404142434445464748494a4b4c4d4e4f",
162            salt: Some(
163                "606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f\
164                 808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f\
165                 a0a1a2a3a4a5a6a7a8a9aaabacadaeaf",
166            ),
167            info: "b0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecf\
168                  d0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeef\
169                  f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff",
170            expected_prk: "35672542907d4e142c00e84499e74e1de08be86535f924e022804ad775dde27ec86cd1e5b7d178c74489bdbeb30712beb82d4f97416c5a94ea81ebdf3e629e4a",
171            expected_okm: "ce6c97192805b346e6161e821ed165673b84f400a2b514b2fe23d84cd189ddf1b695b48cbd1c8388441137b3ce28f16aa64ba33ba466b24df6cfcb021ecff235f6a2056ce3af1de44d572097a8505d9e7a93",
172        },
173        TestVector {
174            ikm: "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
175            salt: Some(""),
176            info: "",
177            expected_prk: "fd200c4987ac491313bd4a2a13287121247239e11c9ef82802044b66ef357e5b194498d0682611382348572a7b1611de54764094286320578a863f36562b0df6",
178            expected_okm: "f5fa02b18298a72a8c23898a8703472c6eb179dc204c03425c970e3b164bf90fff22d04836d0e2343bac",
179        },
180        TestVector {
181            ikm: "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
182            salt: None,
183            info: "",
184            expected_prk: "fd200c4987ac491313bd4a2a13287121247239e11c9ef82802044b66ef357e5b194498d0682611382348572a7b1611de54764094286320578a863f36562b0df6",
185            expected_okm: "f5fa02b18298a72a8c23898a8703472c6eb179dc204c03425c970e3b164bf90fff22d04836d0e2343bac",
186        },
187    ];
188
189    #[test]
190    fn hkdf_sha256_vectors() {
191        for (i, vector) in SHA256_VECTORS.iter().enumerate() {
192            let ikm = decode_hex(vector.ikm);
193            let salt = vector.salt.map(decode_hex);
194            let info = decode_hex(vector.info);
195            let expected_prk = decode_hex(vector.expected_prk);
196            let expected_okm = decode_hex(vector.expected_okm);
197
198            let prk = extract::<Sha256>(salt.as_deref(), &ikm);
199            assert_eq!(prk.as_ref(), expected_prk.as_slice(), "vector {} PRK", i);
200
201            let okm = match expected_okm.len() {
202                42 => expand::<Sha256, 42>(prk.as_ref(), &info).unwrap().to_vec(),
203                82 => expand::<Sha256, 82>(prk.as_ref(), &info).unwrap().to_vec(),
204                _ => unreachable!(),
205            };
206            assert_eq!(okm, expected_okm, "vector {} OKM", i);
207
208            let derived = match expected_okm.len() {
209                42 => derive_key::<Sha256, 42>(&ikm, &info, salt.as_deref()).unwrap().to_vec(),
210                82 => derive_key::<Sha256, 82>(&ikm, &info, salt.as_deref()).unwrap().to_vec(),
211                _ => unreachable!(),
212            };
213            assert_eq!(derived, expected_okm, "vector {} derive_key OKM", i);
214        }
215    }
216
217    #[test]
218    fn hkdf_sha512_vectors() {
219        for (i, vector) in SHA512_VECTORS.iter().enumerate() {
220            let ikm = decode_hex(vector.ikm);
221            let salt = vector.salt.map(decode_hex);
222            let info = decode_hex(vector.info);
223            let expected_prk = decode_hex(vector.expected_prk);
224            let expected_okm = decode_hex(vector.expected_okm);
225
226            let prk = extract::<Sha512>(salt.as_deref(), &ikm);
227            assert_eq!(prk.as_ref(), expected_prk.as_slice(), "vector {} PRK", i);
228
229            let okm = match expected_okm.len() {
230                42 => expand::<Sha512, 42>(prk.as_ref(), &info).unwrap().to_vec(),
231                82 => expand::<Sha512, 82>(prk.as_ref(), &info).unwrap().to_vec(),
232                _ => unreachable!(),
233            };
234            assert_eq!(okm, expected_okm, "vector {} OKM", i);
235
236            let derived = match expected_okm.len() {
237                42 => derive_key::<Sha512, 42>(&ikm, &info, salt.as_deref()).unwrap().to_vec(),
238                82 => derive_key::<Sha512, 82>(&ikm, &info, salt.as_deref()).unwrap().to_vec(),
239                _ => unreachable!(),
240            };
241            assert_eq!(derived, expected_okm, "vector {} derive_key OKM", i);
242        }
243    }
244
245    #[test]
246    fn hkdf_zero_length_output() {
247        let prk = [0u8; 32];
248        assert_eq!(expand::<Sha256, 0>(&prk, b"").unwrap(), [] as [u8; 0]);
249        assert_eq!(derive_key::<Sha256, 0>(b"ikm", b"info", None).unwrap(), [] as [u8; 0]);
250    }
251
252    #[test]
253    fn hkdf_expand_panics_when_output_is_too_large() {
254        let prk = [0u8; 32];
255        const N: usize = Sha256::BLOCK_SIZE * 300;
256        assert_eq!(expand::<Sha256, N>(&prk, b""), Err(HkdfError::OutputIsTooLong));
257    }
258
259    #[test]
260    fn hkdf_expand_panics_when_prk_is_too_short() {
261        assert_eq!(
262            expand::<Sha256, 32>(&[0u8; 31], b""),
263            Err(HkdfError::PrkIsTooShort(Sha256::OUTPUT_SIZE))
264        );
265    }
266
267    // --- Wycheproof test vectors ---
268
269    #[test]
270    fn hkdf_sha256_wycheproof() {
271        // Maximum valid HKDF-SHA-256 output: 255 * 32 = 8160 bytes.
272        const MAX_OKM: usize = 8160;
273        const SIZE_TOO_LARGE: usize = 8161;
274
275        let data: serde_json::Value =
276            serde_json::from_str(include_str!("../testdata/wycheproof/testvectors_v1/hkdf_sha256_test.json")).unwrap();
277        let mut valid_tested = 0u64;
278        let mut invalid_tested = 0u64;
279        for group in data["testGroups"].as_array().unwrap() {
280            for test in group["tests"].as_array().unwrap() {
281                let ikm_hex = test["ikm"].as_str().unwrap();
282                let salt_hex = test["salt"].as_str().unwrap();
283                let info_hex = test["info"].as_str().unwrap();
284                let size = test["size"].as_u64().unwrap() as usize;
285                let expected_okm_hex = test["okm"].as_str().unwrap();
286                let result = test["result"].as_str().unwrap();
287
288                let ikm = hex::decode(ikm_hex).unwrap();
289                let info = hex::decode(info_hex).unwrap();
290                let salt: Option<Vec<u8>> = if salt_hex.is_empty() {
291                    None
292                } else {
293                    Some(hex::decode(salt_hex).unwrap())
294                };
295
296                if result == "valid" {
297                    let okm = derive_key::<Sha256, MAX_OKM>(&ikm, &info, salt.as_deref()).unwrap();
298                    let okm_hex = hex::encode(&okm[..size]);
299                    assert_eq!(
300                        okm_hex, expected_okm_hex,
301                        "wycheproof HKDF-SHA-256 tcId={} size={}",
302                        test["tcId"], size
303                    );
304                    valid_tested += 1;
305                } else {
306                    assert_eq!(
307                        derive_key::<Sha256, SIZE_TOO_LARGE>(&ikm, &info, salt.as_deref()),
308                        Err(HkdfError::OutputIsTooLong),
309                        "wycheproof HKDF-SHA-256 tcId={} size={} should reject",
310                        test["tcId"],
311                        size
312                    );
313                    invalid_tested += 1;
314                }
315            }
316        }
317        assert!(valid_tested > 0, "no valid HKDF-SHA-256 wycheproof tests were run");
318        assert!(invalid_tested > 0, "no invalid HKDF-SHA-256 wycheproof tests were run");
319    }
320
321    #[test]
322    fn hkdf_sha512_wycheproof() {
323        // Maximum valid HKDF-SHA-512 output: 255 * 64 = 16320 bytes.
324        const MAX_OKM: usize = 16320;
325        const SIZE_TOO_LARGE: usize = 16321;
326
327        let data: serde_json::Value =
328            serde_json::from_str(include_str!("../testdata/wycheproof/testvectors_v1/hkdf_sha512_test.json")).unwrap();
329        let mut valid_tested = 0u64;
330        let mut invalid_tested = 0u64;
331        for group in data["testGroups"].as_array().unwrap() {
332            for test in group["tests"].as_array().unwrap() {
333                let ikm_hex = test["ikm"].as_str().unwrap();
334                let salt_hex = test["salt"].as_str().unwrap();
335                let info_hex = test["info"].as_str().unwrap();
336                let size = test["size"].as_u64().unwrap() as usize;
337                let expected_okm_hex = test["okm"].as_str().unwrap();
338                let result = test["result"].as_str().unwrap();
339
340                let ikm = hex::decode(ikm_hex).unwrap();
341                let info = hex::decode(info_hex).unwrap();
342                let salt: Option<Vec<u8>> = if salt_hex.is_empty() {
343                    None
344                } else {
345                    Some(hex::decode(salt_hex).unwrap())
346                };
347
348                if result == "valid" {
349                    let okm = derive_key::<Sha512, MAX_OKM>(&ikm, &info, salt.as_deref()).unwrap();
350                    let okm_hex = hex::encode(&okm[..size]);
351                    assert_eq!(
352                        okm_hex, expected_okm_hex,
353                        "wycheproof HKDF-SHA-512 tcId={} size={}",
354                        test["tcId"], size
355                    );
356                    valid_tested += 1;
357                } else {
358                    assert_eq!(
359                        derive_key::<Sha512, SIZE_TOO_LARGE>(&ikm, &info, salt.as_deref()),
360                        Err(HkdfError::OutputIsTooLong),
361                        "wycheproof HKDF-SHA-512 tcId={} size={} should reject",
362                        test["tcId"],
363                        size
364                    );
365                    invalid_tested += 1;
366                }
367            }
368        }
369        assert!(valid_tested > 0, "no valid HKDF-SHA-512 wycheproof tests were run");
370        assert!(invalid_tested > 0, "no invalid HKDF-SHA-512 wycheproof tests were run");
371    }
372}