Skip to main content

maxminddb/
decoder.rs

1use core::convert::TryInto;
2
3use serde::{
4    de::{self, DeserializeSeed, MapAccess, SeqAccess, Visitor},
5    forward_to_deserialize_any,
6};
7
8use super::{
9    MaxMindDBError,
10    MaxMindDBError::{DecodingError, InvalidDatabaseError},
11};
12
13fn to_usize(base: u8, bytes: &[u8]) -> usize {
14    bytes.iter().fold(base as usize, |acc, &b| (acc << 8) | b as usize)
15}
16
17enum Value<'a, 'de> {
18    Any { prev_ptr: usize },
19    Bytes(&'de [u8]),
20    String(&'de str),
21    Bool(bool),
22    I32(i32),
23    U16(u16),
24    U32(u32),
25    U64(u64),
26    U128(u128),
27    F64(f64),
28    F32(f32),
29    Map(MapAccessor<'a, 'de>),
30    Array(ArrayAccess<'a, 'de>),
31}
32
33#[derive(Debug)]
34pub struct Decoder<'de> {
35    buf: &'de [u8],
36    current_ptr: usize,
37}
38
39impl<'de> Decoder<'de> {
40    pub fn new(buf: &'de [u8], start_ptr: usize) -> Decoder<'de> {
41        Decoder {
42            buf,
43            current_ptr: start_ptr,
44        }
45    }
46
47    fn ensure_available(&self, n: usize) -> DecodeResult<()> {
48        if self.current_ptr + n > self.buf.len() {
49            return Err(InvalidDatabaseError(format!(
50                "truncated data: expected {} bytes at offset {} but buffer length is {}",
51                n,
52                self.current_ptr,
53                self.buf.len()
54            )));
55        }
56        Ok(())
57    }
58
59    fn eat_byte(&mut self) -> DecodeResult<u8> {
60        self.ensure_available(1)?;
61        let b = self.buf[self.current_ptr];
62        self.current_ptr += 1;
63        Ok(b)
64    }
65
66    fn size_from_ctrl_byte(&mut self, ctrl_byte: u8) -> DecodeResult<usize> {
67        let size = (ctrl_byte & 0x1f) as usize;
68
69        let bytes_to_read = if size > 28 { size - 28 } else { 0 };
70
71        let new_offset = self.current_ptr + bytes_to_read;
72        self.ensure_available(bytes_to_read)?;
73        let size_bytes = &self.buf[self.current_ptr..new_offset];
74        self.current_ptr = new_offset;
75
76        Ok(match size {
77            s if s < 29 => s,
78            29 => 29_usize + size_bytes[0] as usize,
79            30 => 285_usize + to_usize(0, size_bytes),
80            _ => 65_821_usize + to_usize(0, size_bytes),
81        })
82    }
83
84    fn size_and_type(&mut self) -> DecodeResult<(usize, u8)> {
85        let ctrl_byte = self.eat_byte()?;
86        let mut type_num = ctrl_byte >> 5;
87        if type_num == 0 {
88            type_num = self.eat_byte()? + 7;
89        }
90        let size = self.size_from_ctrl_byte(ctrl_byte)?;
91        Ok((size, type_num))
92    }
93
94    fn decode_any<V: Visitor<'de>>(&mut self, visitor: V) -> DecodeResult<V::Value> {
95        match self.decode_any_value()? {
96            Value::Any {
97                prev_ptr,
98            } => {
99                let res = self.decode_any(visitor);
100                self.current_ptr = prev_ptr;
101                res
102            }
103            Value::Bool(x) => visitor.visit_bool(x),
104            Value::Bytes(x) => visitor.visit_borrowed_bytes(x),
105            Value::String(x) => visitor.visit_borrowed_str(x),
106            Value::I32(x) => visitor.visit_i32(x),
107            Value::U16(x) => visitor.visit_u16(x),
108            Value::U32(x) => visitor.visit_u32(x),
109            Value::U64(x) => visitor.visit_u64(x),
110            Value::U128(x) => visitor.visit_u128(x),
111            Value::F64(x) => visitor.visit_f64(x),
112            Value::F32(x) => visitor.visit_f32(x),
113            Value::Map(x) => visitor.visit_map(x),
114            Value::Array(x) => visitor.visit_seq(x),
115        }
116    }
117
118    fn decode_any_value(&mut self) -> DecodeResult<Value<'_, 'de>> {
119        let (size, type_num) = self.size_and_type()?;
120
121        Ok(match type_num {
122            1 => {
123                let new_ptr = self.decode_pointer(size)?;
124                if new_ptr > self.buf.len() {
125                    return Err(InvalidDatabaseError(format!(
126                        "pointer at offset {} points beyond end of buffer ({})",
127                        self.current_ptr - 1,
128                        new_ptr
129                    )));
130                }
131                let prev_ptr = self.current_ptr;
132                self.current_ptr = new_ptr;
133
134                Value::Any {
135                    prev_ptr,
136                }
137            }
138            2 => Value::String(self.decode_string(size)?),
139            3 => Value::F64(self.decode_double(size)?),
140            4 => Value::Bytes(self.decode_bytes(size)?),
141            5 => Value::U16(self.decode_uint16(size)?),
142            6 => Value::U32(self.decode_uint32(size)?),
143            7 => self.decode_map(size),
144            8 => Value::I32(self.decode_int(size)?),
145            9 => Value::U64(self.decode_uint64(size)?),
146            10 => Value::U128(self.decode_uint128(size)?),
147            11 => self.decode_array(size),
148            14 => Value::Bool(self.decode_bool(size)?),
149            15 => Value::F32(self.decode_float(size)?),
150            u => return Err(InvalidDatabaseError(format!("Unknown data type: {u:?}"))),
151        })
152    }
153
154    fn decode_array(&mut self, size: usize) -> Value<'_, 'de> {
155        Value::Array(ArrayAccess {
156            de: self,
157            count: size,
158        })
159    }
160
161    fn decode_bool(&mut self, size: usize) -> DecodeResult<bool> {
162        match size {
163            0 | 1 => Ok(size != 0),
164            s => Err(InvalidDatabaseError(format!("bool of size {s:?}"))),
165        }
166    }
167
168    fn decode_bytes(&mut self, size: usize) -> DecodeResult<&'de [u8]> {
169        let new_offset = self.current_ptr + size;
170        self.ensure_available(size)?;
171        let u8_slice = &self.buf[self.current_ptr..new_offset];
172        self.current_ptr = new_offset;
173
174        Ok(u8_slice)
175    }
176
177    fn decode_float(&mut self, size: usize) -> DecodeResult<f32> {
178        let new_offset = self.current_ptr + size;
179        self.ensure_available(size)?;
180        let value: [u8; 4] = self.buf[self.current_ptr..new_offset]
181            .try_into()
182            .map_err(|_| InvalidDatabaseError(format!("float of size {:?}", new_offset - self.current_ptr)))?;
183        self.current_ptr = new_offset;
184        let float_value = f32::from_be_bytes(value);
185        Ok(float_value)
186    }
187
188    fn decode_double(&mut self, size: usize) -> DecodeResult<f64> {
189        let new_offset = self.current_ptr + size;
190        self.ensure_available(size)?;
191        let value: [u8; 8] = self.buf[self.current_ptr..new_offset]
192            .try_into()
193            .map_err(|_| InvalidDatabaseError(format!("double of size {:?}", new_offset - self.current_ptr)))?;
194        self.current_ptr = new_offset;
195        let float_value = f64::from_be_bytes(value);
196        Ok(float_value)
197    }
198
199    fn decode_uint64(&mut self, size: usize) -> DecodeResult<u64> {
200        match size {
201            s if s <= 8 => {
202                self.ensure_available(size)?;
203                let new_offset = self.current_ptr + size;
204
205                let value = self.buf[self.current_ptr..new_offset]
206                    .iter()
207                    .fold(0_u64, |acc, &b| (acc << 8) | u64::from(b));
208                self.current_ptr = new_offset;
209                Ok(value)
210            }
211            s => Err(InvalidDatabaseError(format!("u64 of size {s:?}"))),
212        }
213    }
214
215    fn decode_uint128(&mut self, size: usize) -> DecodeResult<u128> {
216        match size {
217            s if s <= 16 => {
218                self.ensure_available(size)?;
219                let new_offset = self.current_ptr + size;
220
221                let value = self.buf[self.current_ptr..new_offset]
222                    .iter()
223                    .fold(0_u128, |acc, &b| (acc << 8) | u128::from(b));
224                self.current_ptr = new_offset;
225                Ok(value)
226            }
227            s => Err(InvalidDatabaseError(format!("u128 of size {s:?}"))),
228        }
229    }
230
231    fn decode_uint32(&mut self, size: usize) -> DecodeResult<u32> {
232        match size {
233            s if s <= 4 => {
234                self.ensure_available(size)?;
235                let new_offset = self.current_ptr + size;
236
237                let value = self.buf[self.current_ptr..new_offset]
238                    .iter()
239                    .fold(0_u32, |acc, &b| (acc << 8) | u32::from(b));
240                self.current_ptr = new_offset;
241                Ok(value)
242            }
243            s => Err(InvalidDatabaseError(format!("u32 of size {s:?}"))),
244        }
245    }
246
247    fn decode_uint16(&mut self, size: usize) -> DecodeResult<u16> {
248        match size {
249            s if s <= 2 => {
250                self.ensure_available(size)?;
251                let new_offset = self.current_ptr + size;
252
253                let value = self.buf[self.current_ptr..new_offset]
254                    .iter()
255                    .fold(0_u16, |acc, &b| (acc << 8) | u16::from(b));
256                self.current_ptr = new_offset;
257                Ok(value)
258            }
259            s => Err(InvalidDatabaseError(format!("u16 of size {s:?}"))),
260        }
261    }
262
263    fn decode_int(&mut self, size: usize) -> DecodeResult<i32> {
264        match size {
265            s if s <= 4 => {
266                self.ensure_available(size)?;
267                let new_offset = self.current_ptr + size;
268
269                let value = self.buf[self.current_ptr..new_offset]
270                    .iter()
271                    .fold(0_i32, |acc, &b| (acc << 8) | i32::from(b));
272                self.current_ptr = new_offset;
273                Ok(value)
274            }
275            s => Err(InvalidDatabaseError(format!("int32 of size {s:?}"))),
276        }
277    }
278
279    fn decode_map(&mut self, size: usize) -> Value<'_, 'de> {
280        Value::Map(MapAccessor {
281            de: self,
282            count: size * 2,
283        })
284    }
285
286    fn decode_pointer(&mut self, size: usize) -> DecodeResult<usize> {
287        let pointer_value_offset = [0, 0, 2048, 526_336, 0];
288        let pointer_size = ((size >> 3) & 0x3) + 1;
289        let new_offset = self.current_ptr + pointer_size;
290        self.ensure_available(pointer_size)?;
291        let pointer_bytes = &self.buf[self.current_ptr..new_offset];
292        self.current_ptr = new_offset;
293
294        let base = if pointer_size == 4 { 0 } else { (size & 0x7) as u8 };
295        let unpacked = to_usize(base, pointer_bytes);
296
297        Ok(unpacked + pointer_value_offset[pointer_size])
298    }
299
300    fn decode_string(&mut self, size: usize) -> DecodeResult<&'de str> {
301        use std::str::from_utf8;
302
303        let new_offset: usize = self.current_ptr + size;
304        self.ensure_available(size)?;
305        let bytes = &self.buf[self.current_ptr..new_offset];
306        self.current_ptr = new_offset;
307        match from_utf8(bytes) {
308            Ok(v) => Ok(v),
309            Err(_) => Err(InvalidDatabaseError("error decoding string".to_owned())),
310        }
311    }
312}
313
314pub type DecodeResult<T> = Result<T, MaxMindDBError>;
315
316impl<'de: 'a, 'a> de::Deserializer<'de> for &'a mut Decoder<'de> {
317    type Error = MaxMindDBError;
318
319    fn deserialize_any<V>(self, visitor: V) -> DecodeResult<V::Value>
320    where
321        V: Visitor<'de>,
322    {
323        tracing::debug!("deserialize_any");
324
325        self.decode_any(visitor)
326    }
327
328    fn deserialize_option<V>(self, visitor: V) -> DecodeResult<V::Value>
329    where
330        V: Visitor<'de>,
331    {
332        tracing::debug!("deserialize_option");
333
334        visitor.visit_some(self)
335    }
336
337    forward_to_deserialize_any! {
338        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
339        bytes byte_buf unit unit_struct newtype_struct seq tuple
340        tuple_struct map struct enum identifier ignored_any
341    }
342}
343
344struct ArrayAccess<'a, 'de: 'a> {
345    de: &'a mut Decoder<'de>,
346    count: usize,
347}
348
349impl<'de, 'a> SeqAccess<'de> for ArrayAccess<'a, 'de> {
350    type Error = MaxMindDBError;
351
352    fn next_element_seed<T>(&mut self, seed: T) -> DecodeResult<Option<T::Value>>
353    where
354        T: DeserializeSeed<'de>,
355    {
356        if self.count == 0 {
357            return Ok(None);
358        }
359        self.count -= 1;
360
361        seed.deserialize(&mut *self.de).map(Some)
362    }
363}
364
365struct MapAccessor<'a, 'de: 'a> {
366    de: &'a mut Decoder<'de>,
367    count: usize,
368}
369
370impl<'de, 'a> MapAccess<'de> for MapAccessor<'a, 'de> {
371    type Error = MaxMindDBError;
372
373    fn next_key_seed<K>(&mut self, seed: K) -> DecodeResult<Option<K::Value>>
374    where
375        K: DeserializeSeed<'de>,
376    {
377        if self.count == 0 {
378            return Ok(None);
379        }
380        self.count -= 1;
381
382        seed.deserialize(&mut *self.de).map(Some)
383    }
384
385    fn next_value_seed<V>(&mut self, seed: V) -> DecodeResult<V::Value>
386    where
387        V: DeserializeSeed<'de>,
388    {
389        if self.count == 0 {
390            return Err(DecodingError("no more entries".to_owned()));
391        }
392        self.count -= 1;
393
394        seed.deserialize(&mut *self.de)
395    }
396}