1use core::convert::TryInto;
2
3use serde::{
4 de::{self, DeserializeSeed, MapAccess, SeqAccess, Visitor},
5 forward_to_deserialize_any,
6};
7
8use super::{
9 MaxMindDBError,
10 MaxMindDBError::{DecodingError, InvalidDatabaseError},
11};
12
13fn to_usize(base: u8, bytes: &[u8]) -> usize {
14 bytes.iter().fold(base as usize, |acc, &b| (acc << 8) | b as usize)
15}
16
17enum Value<'a, 'de> {
18 Any { prev_ptr: usize },
19 Bytes(&'de [u8]),
20 String(&'de str),
21 Bool(bool),
22 I32(i32),
23 U16(u16),
24 U32(u32),
25 U64(u64),
26 U128(u128),
27 F64(f64),
28 F32(f32),
29 Map(MapAccessor<'a, 'de>),
30 Array(ArrayAccess<'a, 'de>),
31}
32
33#[derive(Debug)]
34pub struct Decoder<'de> {
35 buf: &'de [u8],
36 current_ptr: usize,
37}
38
39impl<'de> Decoder<'de> {
40 pub fn new(buf: &'de [u8], start_ptr: usize) -> Decoder<'de> {
41 Decoder {
42 buf,
43 current_ptr: start_ptr,
44 }
45 }
46
47 fn ensure_available(&self, n: usize) -> DecodeResult<()> {
48 if self.current_ptr + n > self.buf.len() {
49 return Err(InvalidDatabaseError(format!(
50 "truncated data: expected {} bytes at offset {} but buffer length is {}",
51 n,
52 self.current_ptr,
53 self.buf.len()
54 )));
55 }
56 Ok(())
57 }
58
59 fn eat_byte(&mut self) -> DecodeResult<u8> {
60 self.ensure_available(1)?;
61 let b = self.buf[self.current_ptr];
62 self.current_ptr += 1;
63 Ok(b)
64 }
65
66 fn size_from_ctrl_byte(&mut self, ctrl_byte: u8) -> DecodeResult<usize> {
67 let size = (ctrl_byte & 0x1f) as usize;
68
69 let bytes_to_read = if size > 28 { size - 28 } else { 0 };
70
71 let new_offset = self.current_ptr + bytes_to_read;
72 self.ensure_available(bytes_to_read)?;
73 let size_bytes = &self.buf[self.current_ptr..new_offset];
74 self.current_ptr = new_offset;
75
76 Ok(match size {
77 s if s < 29 => s,
78 29 => 29_usize + size_bytes[0] as usize,
79 30 => 285_usize + to_usize(0, size_bytes),
80 _ => 65_821_usize + to_usize(0, size_bytes),
81 })
82 }
83
84 fn size_and_type(&mut self) -> DecodeResult<(usize, u8)> {
85 let ctrl_byte = self.eat_byte()?;
86 let mut type_num = ctrl_byte >> 5;
87 if type_num == 0 {
88 type_num = self.eat_byte()? + 7;
89 }
90 let size = self.size_from_ctrl_byte(ctrl_byte)?;
91 Ok((size, type_num))
92 }
93
94 fn decode_any<V: Visitor<'de>>(&mut self, visitor: V) -> DecodeResult<V::Value> {
95 match self.decode_any_value()? {
96 Value::Any {
97 prev_ptr,
98 } => {
99 let res = self.decode_any(visitor);
100 self.current_ptr = prev_ptr;
101 res
102 }
103 Value::Bool(x) => visitor.visit_bool(x),
104 Value::Bytes(x) => visitor.visit_borrowed_bytes(x),
105 Value::String(x) => visitor.visit_borrowed_str(x),
106 Value::I32(x) => visitor.visit_i32(x),
107 Value::U16(x) => visitor.visit_u16(x),
108 Value::U32(x) => visitor.visit_u32(x),
109 Value::U64(x) => visitor.visit_u64(x),
110 Value::U128(x) => visitor.visit_u128(x),
111 Value::F64(x) => visitor.visit_f64(x),
112 Value::F32(x) => visitor.visit_f32(x),
113 Value::Map(x) => visitor.visit_map(x),
114 Value::Array(x) => visitor.visit_seq(x),
115 }
116 }
117
118 fn decode_any_value(&mut self) -> DecodeResult<Value<'_, 'de>> {
119 let (size, type_num) = self.size_and_type()?;
120
121 Ok(match type_num {
122 1 => {
123 let new_ptr = self.decode_pointer(size)?;
124 if new_ptr > self.buf.len() {
125 return Err(InvalidDatabaseError(format!(
126 "pointer at offset {} points beyond end of buffer ({})",
127 self.current_ptr - 1,
128 new_ptr
129 )));
130 }
131 let prev_ptr = self.current_ptr;
132 self.current_ptr = new_ptr;
133
134 Value::Any {
135 prev_ptr,
136 }
137 }
138 2 => Value::String(self.decode_string(size)?),
139 3 => Value::F64(self.decode_double(size)?),
140 4 => Value::Bytes(self.decode_bytes(size)?),
141 5 => Value::U16(self.decode_uint16(size)?),
142 6 => Value::U32(self.decode_uint32(size)?),
143 7 => self.decode_map(size),
144 8 => Value::I32(self.decode_int(size)?),
145 9 => Value::U64(self.decode_uint64(size)?),
146 10 => Value::U128(self.decode_uint128(size)?),
147 11 => self.decode_array(size),
148 14 => Value::Bool(self.decode_bool(size)?),
149 15 => Value::F32(self.decode_float(size)?),
150 u => return Err(InvalidDatabaseError(format!("Unknown data type: {u:?}"))),
151 })
152 }
153
154 fn decode_array(&mut self, size: usize) -> Value<'_, 'de> {
155 Value::Array(ArrayAccess {
156 de: self,
157 count: size,
158 })
159 }
160
161 fn decode_bool(&mut self, size: usize) -> DecodeResult<bool> {
162 match size {
163 0 | 1 => Ok(size != 0),
164 s => Err(InvalidDatabaseError(format!("bool of size {s:?}"))),
165 }
166 }
167
168 fn decode_bytes(&mut self, size: usize) -> DecodeResult<&'de [u8]> {
169 let new_offset = self.current_ptr + size;
170 self.ensure_available(size)?;
171 let u8_slice = &self.buf[self.current_ptr..new_offset];
172 self.current_ptr = new_offset;
173
174 Ok(u8_slice)
175 }
176
177 fn decode_float(&mut self, size: usize) -> DecodeResult<f32> {
178 let new_offset = self.current_ptr + size;
179 self.ensure_available(size)?;
180 let value: [u8; 4] = self.buf[self.current_ptr..new_offset]
181 .try_into()
182 .map_err(|_| InvalidDatabaseError(format!("float of size {:?}", new_offset - self.current_ptr)))?;
183 self.current_ptr = new_offset;
184 let float_value = f32::from_be_bytes(value);
185 Ok(float_value)
186 }
187
188 fn decode_double(&mut self, size: usize) -> DecodeResult<f64> {
189 let new_offset = self.current_ptr + size;
190 self.ensure_available(size)?;
191 let value: [u8; 8] = self.buf[self.current_ptr..new_offset]
192 .try_into()
193 .map_err(|_| InvalidDatabaseError(format!("double of size {:?}", new_offset - self.current_ptr)))?;
194 self.current_ptr = new_offset;
195 let float_value = f64::from_be_bytes(value);
196 Ok(float_value)
197 }
198
199 fn decode_uint64(&mut self, size: usize) -> DecodeResult<u64> {
200 match size {
201 s if s <= 8 => {
202 self.ensure_available(size)?;
203 let new_offset = self.current_ptr + size;
204
205 let value = self.buf[self.current_ptr..new_offset]
206 .iter()
207 .fold(0_u64, |acc, &b| (acc << 8) | u64::from(b));
208 self.current_ptr = new_offset;
209 Ok(value)
210 }
211 s => Err(InvalidDatabaseError(format!("u64 of size {s:?}"))),
212 }
213 }
214
215 fn decode_uint128(&mut self, size: usize) -> DecodeResult<u128> {
216 match size {
217 s if s <= 16 => {
218 self.ensure_available(size)?;
219 let new_offset = self.current_ptr + size;
220
221 let value = self.buf[self.current_ptr..new_offset]
222 .iter()
223 .fold(0_u128, |acc, &b| (acc << 8) | u128::from(b));
224 self.current_ptr = new_offset;
225 Ok(value)
226 }
227 s => Err(InvalidDatabaseError(format!("u128 of size {s:?}"))),
228 }
229 }
230
231 fn decode_uint32(&mut self, size: usize) -> DecodeResult<u32> {
232 match size {
233 s if s <= 4 => {
234 self.ensure_available(size)?;
235 let new_offset = self.current_ptr + size;
236
237 let value = self.buf[self.current_ptr..new_offset]
238 .iter()
239 .fold(0_u32, |acc, &b| (acc << 8) | u32::from(b));
240 self.current_ptr = new_offset;
241 Ok(value)
242 }
243 s => Err(InvalidDatabaseError(format!("u32 of size {s:?}"))),
244 }
245 }
246
247 fn decode_uint16(&mut self, size: usize) -> DecodeResult<u16> {
248 match size {
249 s if s <= 2 => {
250 self.ensure_available(size)?;
251 let new_offset = self.current_ptr + size;
252
253 let value = self.buf[self.current_ptr..new_offset]
254 .iter()
255 .fold(0_u16, |acc, &b| (acc << 8) | u16::from(b));
256 self.current_ptr = new_offset;
257 Ok(value)
258 }
259 s => Err(InvalidDatabaseError(format!("u16 of size {s:?}"))),
260 }
261 }
262
263 fn decode_int(&mut self, size: usize) -> DecodeResult<i32> {
264 match size {
265 s if s <= 4 => {
266 self.ensure_available(size)?;
267 let new_offset = self.current_ptr + size;
268
269 let value = self.buf[self.current_ptr..new_offset]
270 .iter()
271 .fold(0_i32, |acc, &b| (acc << 8) | i32::from(b));
272 self.current_ptr = new_offset;
273 Ok(value)
274 }
275 s => Err(InvalidDatabaseError(format!("int32 of size {s:?}"))),
276 }
277 }
278
279 fn decode_map(&mut self, size: usize) -> Value<'_, 'de> {
280 Value::Map(MapAccessor {
281 de: self,
282 count: size * 2,
283 })
284 }
285
286 fn decode_pointer(&mut self, size: usize) -> DecodeResult<usize> {
287 let pointer_value_offset = [0, 0, 2048, 526_336, 0];
288 let pointer_size = ((size >> 3) & 0x3) + 1;
289 let new_offset = self.current_ptr + pointer_size;
290 self.ensure_available(pointer_size)?;
291 let pointer_bytes = &self.buf[self.current_ptr..new_offset];
292 self.current_ptr = new_offset;
293
294 let base = if pointer_size == 4 { 0 } else { (size & 0x7) as u8 };
295 let unpacked = to_usize(base, pointer_bytes);
296
297 Ok(unpacked + pointer_value_offset[pointer_size])
298 }
299
300 fn decode_string(&mut self, size: usize) -> DecodeResult<&'de str> {
301 use std::str::from_utf8;
302
303 let new_offset: usize = self.current_ptr + size;
304 self.ensure_available(size)?;
305 let bytes = &self.buf[self.current_ptr..new_offset];
306 self.current_ptr = new_offset;
307 match from_utf8(bytes) {
308 Ok(v) => Ok(v),
309 Err(_) => Err(InvalidDatabaseError("error decoding string".to_owned())),
310 }
311 }
312}
313
314pub type DecodeResult<T> = Result<T, MaxMindDBError>;
315
316impl<'de: 'a, 'a> de::Deserializer<'de> for &'a mut Decoder<'de> {
317 type Error = MaxMindDBError;
318
319 fn deserialize_any<V>(self, visitor: V) -> DecodeResult<V::Value>
320 where
321 V: Visitor<'de>,
322 {
323 tracing::debug!("deserialize_any");
324
325 self.decode_any(visitor)
326 }
327
328 fn deserialize_option<V>(self, visitor: V) -> DecodeResult<V::Value>
329 where
330 V: Visitor<'de>,
331 {
332 tracing::debug!("deserialize_option");
333
334 visitor.visit_some(self)
335 }
336
337 forward_to_deserialize_any! {
338 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
339 bytes byte_buf unit unit_struct newtype_struct seq tuple
340 tuple_struct map struct enum identifier ignored_any
341 }
342}
343
344struct ArrayAccess<'a, 'de: 'a> {
345 de: &'a mut Decoder<'de>,
346 count: usize,
347}
348
349impl<'de, 'a> SeqAccess<'de> for ArrayAccess<'a, 'de> {
350 type Error = MaxMindDBError;
351
352 fn next_element_seed<T>(&mut self, seed: T) -> DecodeResult<Option<T::Value>>
353 where
354 T: DeserializeSeed<'de>,
355 {
356 if self.count == 0 {
357 return Ok(None);
358 }
359 self.count -= 1;
360
361 seed.deserialize(&mut *self.de).map(Some)
362 }
363}
364
365struct MapAccessor<'a, 'de: 'a> {
366 de: &'a mut Decoder<'de>,
367 count: usize,
368}
369
370impl<'de, 'a> MapAccess<'de> for MapAccessor<'a, 'de> {
371 type Error = MaxMindDBError;
372
373 fn next_key_seed<K>(&mut self, seed: K) -> DecodeResult<Option<K::Value>>
374 where
375 K: DeserializeSeed<'de>,
376 {
377 if self.count == 0 {
378 return Ok(None);
379 }
380 self.count -= 1;
381
382 seed.deserialize(&mut *self.de).map(Some)
383 }
384
385 fn next_value_seed<V>(&mut self, seed: V) -> DecodeResult<V::Value>
386 where
387 V: DeserializeSeed<'de>,
388 {
389 if self.count == 0 {
390 return Err(DecodingError("no more entries".to_owned()));
391 }
392 self.count -= 1;
393
394 seed.deserialize(&mut *self.de)
395 }
396}