Skip to main content

constant_time_eq/
lib.rs

1#![no_std]
2
3#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
4#[inline]
5#[must_use]
6fn optimizer_hide(mut value: u8) -> u8 {
7    // SAFETY: the input value is passed unchanged to the output, the inline assembly does nothing.
8    unsafe {
9        core::arch::asm!("/* {0} */", inout(reg_byte) value, options(pure, nomem, nostack, preserves_flags));
10        value
11    }
12}
13
14#[cfg(any(
15    target_arch = "arm",
16    target_arch = "aarch64",
17    target_arch = "riscv32",
18    target_arch = "riscv64"
19))]
20#[inline]
21#[must_use]
22#[allow(asm_sub_register)]
23fn optimizer_hide(mut value: u8) -> u8 {
24    // SAFETY: the input value is passed unchanged to the output, the inline assembly does nothing.
25    unsafe {
26        core::arch::asm!("/* {0} */", inout(reg) value, options(pure, nomem, nostack, preserves_flags));
27        value
28    }
29}
30
31#[cfg(not(any(
32    target_arch = "x86",
33    target_arch = "x86_64",
34    target_arch = "arm",
35    target_arch = "aarch64",
36    target_arch = "riscv32",
37    target_arch = "riscv64"
38)))]
39#[inline(never)]
40#[must_use]
41fn optimizer_hide(value: u8) -> u8 {
42    // The current implementation of black_box in the main codegen backends is similar to
43    // {
44    //     let result = value;
45    //     asm!("", in(reg) &result);
46    //     result
47    // }
48    // which round-trips the value through the stack, instead of leaving it in a register.
49    // Experimental codegen backends might implement black_box as a pure identity function,
50    // without the expected optimization barrier, so it's less guaranteed than inline asm.
51    // For that reason, we also use the #[inline(never)] hint, which makes it harder for an
52    // optimizer to look inside this function.
53    core::hint::black_box(value)
54}
55
56#[inline]
57#[must_use]
58fn constant_time_ne(a: &[u8], b: &[u8]) -> u8 {
59    assert!(a.len() == b.len());
60
61    // These useless slices make the optimizer elide the bounds checks.
62    // See the comment in clone_from_slice() added on Rust commit 6a7bc47.
63    let len = a.len();
64    let a = &a[..len];
65    let b = &b[..len];
66
67    let mut tmp = 0;
68    for i in 0..len {
69        tmp |= a[i] ^ b[i];
70    }
71
72    // The compare with 0 must happen outside this function.
73    optimizer_hide(tmp)
74}
75
76/// Compares two equal-sized byte strings in constant time.
77///
78/// # Examples
79///
80/// ```
81/// use constant_time_eq::constant_time_eq;
82///
83/// assert!(constant_time_eq(b"foo", b"foo"));
84/// assert!(!constant_time_eq(b"foo", b"bar"));
85/// assert!(!constant_time_eq(b"bar", b"baz"));
86/// # assert!(constant_time_eq(b"", b""));
87///
88/// // Not equal-sized, so won't take constant time.
89/// assert!(!constant_time_eq(b"foo", b""));
90/// assert!(!constant_time_eq(b"foo", b"quux"));
91/// ```
92#[must_use]
93pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
94    a.len() == b.len() && constant_time_ne(a, b) == 0
95}
96
97// Fixed-size array variant.
98
99#[inline]
100#[must_use]
101fn constant_time_ne_n<const N: usize>(a: &[u8; N], b: &[u8; N]) -> u8 {
102    let mut tmp = 0;
103    for i in 0..N {
104        tmp |= a[i] ^ b[i];
105    }
106
107    // The compare with 0 must happen outside this function.
108    optimizer_hide(tmp)
109}
110
111/// Compares two fixed-size byte strings in constant time.
112///
113/// # Examples
114///
115/// ```
116/// use constant_time_eq::constant_time_eq_n;
117///
118/// assert!(constant_time_eq_n(&[3; 20], &[3; 20]));
119/// assert!(!constant_time_eq_n(&[3; 20], &[7; 20]));
120/// ```
121#[must_use]
122pub fn constant_time_eq_n<const N: usize>(a: &[u8; N], b: &[u8; N]) -> bool {
123    constant_time_ne_n(a, b) == 0
124}
125
126// Fixed-size variants for the most common sizes.
127
128/// Compares two 128-bit byte strings in constant time.
129///
130/// # Examples
131///
132/// ```
133/// use constant_time_eq::constant_time_eq_16;
134///
135/// assert!(constant_time_eq_16(&[3; 16], &[3; 16]));
136/// assert!(!constant_time_eq_16(&[3; 16], &[7; 16]));
137/// ```
138#[inline]
139#[must_use]
140pub fn constant_time_eq_16(a: &[u8; 16], b: &[u8; 16]) -> bool {
141    constant_time_eq_n(a, b)
142}
143
144/// Compares two 256-bit byte strings in constant time.
145///
146/// # Examples
147///
148/// ```
149/// use constant_time_eq::constant_time_eq_32;
150///
151/// assert!(constant_time_eq_32(&[3; 32], &[3; 32]));
152/// assert!(!constant_time_eq_32(&[3; 32], &[7; 32]));
153/// ```
154#[inline]
155#[must_use]
156pub fn constant_time_eq_32(a: &[u8; 32], b: &[u8; 32]) -> bool {
157    constant_time_eq_n(a, b)
158}
159
160/// Compares two 512-bit byte strings in constant time.
161///
162/// # Examples
163///
164/// ```
165/// use constant_time_eq::constant_time_eq_64;
166///
167/// assert!(constant_time_eq_64(&[3; 64], &[3; 64]));
168/// assert!(!constant_time_eq_64(&[3; 64], &[7; 64]));
169/// ```
170#[inline]
171#[must_use]
172pub fn constant_time_eq_64(a: &[u8; 64], b: &[u8; 64]) -> bool {
173    constant_time_eq_n(a, b)
174}
175
176#[cfg(test)]
177mod tests {
178    #[cfg(feature = "count_instructions_test")]
179    extern crate std;
180
181    #[cfg(feature = "count_instructions_test")]
182    #[test]
183    fn count_optimizer_hide_instructions() -> std::io::Result<()> {
184        use count_instructions::count_instructions;
185
186        use super::optimizer_hide;
187
188        fn count() -> std::io::Result<usize> {
189            // If optimizer_hide does not work, constant propagation and folding
190            // will make this identical to count_optimized() below.
191            let mut count = 0;
192            assert_eq!(
193                10u8,
194                count_instructions(
195                    || optimizer_hide(1) + optimizer_hide(2) + optimizer_hide(3) + optimizer_hide(4),
196                    |_| count += 1
197                )?
198            );
199            Ok(count)
200        }
201
202        fn count_optimized() -> std::io::Result<usize> {
203            #[inline]
204            fn inline_identity(value: u8) -> u8 {
205                value
206            }
207
208            let mut count = 0;
209            assert_eq!(
210                10u8,
211                count_instructions(
212                    || inline_identity(1) + inline_identity(2) + inline_identity(3) + inline_identity(4),
213                    |_| count += 1
214                )?
215            );
216            Ok(count)
217        }
218
219        assert!(count()? > count_optimized()?);
220        Ok(())
221    }
222}