Skip to main content

bel/
objects.rs

1#[cfg(feature = "time")]
2use std::sync::LazyLock;
3use std::{
4    cmp::Ordering,
5    collections::HashMap,
6    convert::{Infallible, TryInto},
7    fmt::{Display, Formatter},
8    ops,
9    ops::Deref,
10    sync::Arc,
11};
12
13#[cfg(feature = "time")]
14use chrono::TimeZone;
15
16use crate::{
17    ExecutionError, Expression,
18    common::{
19        ast::{EntryExpr, Expr, operators},
20        value::CelVal,
21    },
22    context::Context,
23    functions::FunctionContext,
24};
25
26/// Timestamp values are limited to the range of values which can be serialized as a string:
27/// `["0001-01-01T00:00:00Z", "9999-12-31T23:59:59.999999999Z"]`. Since the max is a smaller
28/// and the min is a larger timestamp than what is possible to represent with [`DateTime`],
29/// we need to perform our own spec-compliant overflow checks.
30///
31/// https://github.com/google/cel-spec/blob/master/doc/langdef.md#overflow
32#[cfg(feature = "time")]
33static MAX_TIMESTAMP: LazyLock<chrono::DateTime<chrono::FixedOffset>> = LazyLock::new(|| {
34    let naive = chrono::NaiveDate::from_ymd_opt(9999, 12, 31)
35        .unwrap()
36        .and_hms_nano_opt(23, 59, 59, 999_999_999)
37        .unwrap();
38    chrono::FixedOffset::east_opt(0).unwrap().from_utc_datetime(&naive)
39});
40
41#[cfg(feature = "time")]
42static MIN_TIMESTAMP: LazyLock<chrono::DateTime<chrono::FixedOffset>> = LazyLock::new(|| {
43    let naive = chrono::NaiveDate::from_ymd_opt(1, 1, 1)
44        .unwrap()
45        .and_hms_opt(0, 0, 0)
46        .unwrap();
47    chrono::FixedOffset::east_opt(0).unwrap().from_utc_datetime(&naive)
48});
49
50#[derive(Debug, PartialEq, Clone)]
51// #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
52pub struct Map {
53    pub map: Arc<HashMap<Key, Value>>,
54}
55
56impl PartialOrd for Map {
57    fn partial_cmp(&self, _: &Self) -> Option<Ordering> {
58        None
59    }
60}
61
62impl Map {
63    /// Returns a reference to the value corresponding to the key. Implicitly converts between int
64    /// and uint keys.
65    pub fn get(&self, key: &Key) -> Option<&Value> {
66        self.map.get(key)
67
68        // .or_else(|| {
69        //     // Also check keys that are cross type comparable.
70        //     let converted = match key {
71        //         Key::Int(k) => Key::Uint(u64::try_from(*k).ok()?),
72        //         // Key::Uint(k) => Key::Int(i64::try_from(*k).ok()?),
73        //         _ => return None,
74        //     };
75        //     self.map.get(&converted)
76        // })
77    }
78}
79
80#[derive(Debug, Eq, PartialEq, Hash, Ord, Clone, PartialOrd)]
81#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
82pub enum Key {
83    Int(i64),
84    // Uint(u64),
85    Bool(bool),
86    String(Arc<String>),
87}
88
89/// Implement conversions from primitive types to [`Key`]
90impl From<String> for Key {
91    fn from(v: String) -> Self {
92        Key::String(v.into())
93    }
94}
95
96impl From<Arc<String>> for Key {
97    fn from(v: Arc<String>) -> Self {
98        Key::String(v)
99    }
100}
101
102impl<'a> From<&'a str> for Key {
103    fn from(v: &'a str) -> Self {
104        Key::String(Arc::new(v.into()))
105    }
106}
107
108impl From<bool> for Key {
109    fn from(v: bool) -> Self {
110        Key::Bool(v)
111    }
112}
113
114impl From<i64> for Key {
115    fn from(v: i64) -> Self {
116        Key::Int(v)
117    }
118}
119
120// impl From<u64> for Key {
121//     fn from(v: u64) -> Self {
122//         Key::Uint(v)
123//     }
124// }
125
126impl serde::Serialize for Key {
127    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
128    where
129        S: serde::Serializer,
130    {
131        match self {
132            Key::Int(v) => v.serialize(serializer),
133            // Key::Uint(v) => v.serialize(serializer),
134            Key::Bool(v) => v.serialize(serializer),
135            Key::String(v) => v.serialize(serializer),
136        }
137    }
138}
139
140impl Display for Key {
141    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
142        match self {
143            Key::Int(v) => write!(f, "{v}"),
144            // Key::Uint(v) => write!(f, "{v}"),
145            Key::Bool(v) => write!(f, "{v}"),
146            Key::String(v) => write!(f, "{v}"),
147        }
148    }
149}
150
151/// Implement conversions from [`Key`] into [`Value`]
152impl TryInto<Key> for Value {
153    type Error = Value;
154
155    #[inline(always)]
156    fn try_into(self) -> Result<Key, Self::Error> {
157        match self {
158            Value::Int(v) => Ok(Key::Int(v)),
159            // Value::UInt(v) => Ok(Key::Uint(v)),
160            Value::String(v) => Ok(Key::String(v)),
161            Value::Bool(v) => Ok(Key::Bool(v)),
162            _ => Err(self),
163        }
164    }
165}
166
167// Implement conversion from HashMap<K, V> into CelMap
168impl<K: Into<Key>, V: Into<Value>> From<HashMap<K, V>> for Map {
169    fn from(map: HashMap<K, V>) -> Self {
170        let mut new_map = HashMap::with_capacity(map.len());
171        for (k, v) in map {
172            new_map.insert(k.into(), v.into());
173        }
174        Map {
175            map: Arc::new(new_map),
176        }
177    }
178}
179
180pub trait TryIntoValue {
181    type Error: std::error::Error + 'static + Send + Sync;
182    fn try_into_value(self) -> Result<Value, Self::Error>;
183}
184
185impl<T: serde::Serialize> TryIntoValue for T {
186    type Error = crate::ser::SerializationError;
187    fn try_into_value(self) -> Result<Value, Self::Error> {
188        crate::ser::to_value(self)
189    }
190}
191impl TryIntoValue for Value {
192    type Error = Infallible;
193    fn try_into_value(self) -> Result<Value, Self::Error> {
194        Ok(self)
195    }
196}
197
198#[derive(Debug, Clone)]
199// #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
200pub enum Value {
201    List(Arc<Vec<Value>>),
202    Map(Map),
203
204    Function(Arc<String>, Option<Box<Value>>),
205
206    // Atoms
207    Int(i64),
208    // UInt(u64),
209    Float(f64),
210    String(Arc<String>),
211    Bytes(Arc<Vec<u8>>),
212    Bool(bool),
213    #[cfg(feature = "time")]
214    Duration(chrono::Duration),
215    #[cfg(feature = "time")]
216    Timestamp(chrono::DateTime<chrono::FixedOffset>),
217    #[cfg(feature = "regex")]
218    Regex(regex::Regex),
219    #[cfg(feature = "ip")]
220    Ip(ipnetwork::IpNetwork),
221    Null,
222}
223
224impl From<CelVal> for Value {
225    fn from(val: CelVal) -> Self {
226        match val {
227            CelVal::String(s) => Value::String(Arc::new(s)),
228            CelVal::Boolean(b) => Value::Bool(b),
229            CelVal::Int(i) => Value::Int(i),
230            // CelVal::UInt(u) => Value::UInt(u),
231            CelVal::Float(d) => Value::Float(d),
232            CelVal::Bytes(bytes) => Value::Bytes(Arc::new(bytes)),
233            CelVal::Null => Value::Null,
234            v => unimplemented!("{v:?}"),
235        }
236    }
237}
238
239#[derive(Clone, Copy, Debug)]
240pub enum ValueType {
241    List,
242    Map,
243    Function,
244    Int,
245    // UInt,
246    Float,
247    String,
248    Bytes,
249    Bool,
250    Duration,
251    Timestamp,
252    Regex,
253    Ip,
254    Null,
255}
256
257impl Display for ValueType {
258    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
259        match self {
260            ValueType::List => write!(f, "list"),
261            ValueType::Map => write!(f, "map"),
262            ValueType::Function => write!(f, "function"),
263            ValueType::Int => write!(f, "int"),
264            // ValueType::UInt => write!(f, "uint"),
265            ValueType::Float => write!(f, "float"),
266            ValueType::String => write!(f, "string"),
267            ValueType::Bytes => write!(f, "bytes"),
268            ValueType::Bool => write!(f, "bool"),
269            ValueType::Duration => write!(f, "duration"),
270            ValueType::Timestamp => write!(f, "timestamp"),
271            ValueType::Regex => write!(f, "regex"),
272            ValueType::Ip => write!(f, "ip"),
273            ValueType::Null => write!(f, "null"),
274        }
275    }
276}
277
278impl Value {
279    pub fn type_of(&self) -> ValueType {
280        match self {
281            Value::List(_) => ValueType::List,
282            Value::Map(_) => ValueType::Map,
283            Value::Function(_, _) => ValueType::Function,
284            Value::Int(_) => ValueType::Int,
285            // Value::UInt(_) => ValueType::UInt,
286            Value::Float(_) => ValueType::Float,
287            Value::String(_) => ValueType::String,
288            Value::Bytes(_) => ValueType::Bytes,
289            Value::Bool(_) => ValueType::Bool,
290            #[cfg(feature = "time")]
291            Value::Duration(_) => ValueType::Duration,
292            #[cfg(feature = "time")]
293            Value::Timestamp(_) => ValueType::Timestamp,
294            #[cfg(feature = "regex")]
295            Value::Regex(_) => ValueType::Regex,
296            #[cfg(feature = "ip")]
297            Value::Ip(_) => ValueType::Ip,
298            Value::Null => ValueType::Null,
299        }
300    }
301
302    pub fn error_expected_type(&self, expected: ValueType) -> ExecutionError {
303        ExecutionError::UnexpectedType {
304            got: self.type_of().to_string(),
305            want: expected.to_string(),
306        }
307    }
308}
309
310impl From<&Value> for Value {
311    fn from(value: &Value) -> Self {
312        value.clone()
313    }
314}
315
316impl PartialEq for Value {
317    fn eq(&self, other: &Self) -> bool {
318        match (self, other) {
319            (Value::Map(a), Value::Map(b)) => a == b,
320            (Value::List(a), Value::List(b)) => a == b,
321            (Value::Function(a1, a2), Value::Function(b1, b2)) => a1 == b1 && a2 == b2,
322            (Value::Int(a), Value::Int(b)) => a == b,
323            // (Value::UInt(a), Value::UInt(b)) => a == b,
324            (Value::Float(a), Value::Float(b)) => a == b,
325            (Value::String(a), Value::String(b)) => a == b,
326            (Value::Bytes(a), Value::Bytes(b)) => a == b,
327            (Value::Bool(a), Value::Bool(b)) => a == b,
328            (Value::Null, Value::Null) => true,
329            #[cfg(feature = "time")]
330            (Value::Duration(a), Value::Duration(b)) => a == b,
331            #[cfg(feature = "time")]
332            (Value::Timestamp(a), Value::Timestamp(b)) => a == b,
333            // Allow different numeric types to be compared without explicit casting.
334            // (Value::Int(a), Value::UInt(b)) => a
335            //     .to_owned()
336            //     .try_into()
337            //     .map(|a: u64| a == *b)
338            //     .unwrap_or(false),
339            (Value::Int(a), Value::Float(b)) => (*a as f64) == *b,
340            // (Value::UInt(a), Value::Int(b)) => a
341            //     .to_owned()
342            //     .try_into()
343            //     .map(|a: i64| a == *b)
344            //     .unwrap_or(false),
345            // (Value::UInt(a), Value::Float(b)) => (*a as f64) == *b,
346            (Value::Float(a), Value::Int(b)) => *a == (*b as f64),
347            // (Value::Float(a), Value::UInt(b)) => *a == (*b as f64),
348            #[cfg(feature = "ip")]
349            (Value::Ip(a), Value::Ip(b)) => a == b,
350            (_, _) => false,
351        }
352    }
353}
354
355impl Eq for Value {}
356
357impl PartialOrd for Value {
358    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
359        match (self, other) {
360            (Value::Int(a), Value::Int(b)) => Some(a.cmp(b)),
361            // (Value::UInt(a), Value::UInt(b)) => Some(a.cmp(b)),
362            (Value::Float(a), Value::Float(b)) => a.partial_cmp(b),
363            (Value::String(a), Value::String(b)) => Some(a.cmp(b)),
364            (Value::Bool(a), Value::Bool(b)) => Some(a.cmp(b)),
365            (Value::Null, Value::Null) => Some(Ordering::Equal),
366            #[cfg(feature = "time")]
367            (Value::Duration(a), Value::Duration(b)) => Some(a.cmp(b)),
368            #[cfg(feature = "time")]
369            (Value::Timestamp(a), Value::Timestamp(b)) => Some(a.cmp(b)),
370            // Allow different numeric types to be compared without explicit casting.
371            // (Value::Int(a), Value::UInt(b)) => Some(
372            //     a.to_owned()
373            //         .try_into()
374            //         .map(|a: u64| a.cmp(b))
375            //         // If the i64 doesn't fit into a u64 it must be less than 0.
376            //         .unwrap_or(Ordering::Less),
377            // ),
378            (Value::Int(a), Value::Float(b)) => (*a as f64).partial_cmp(b),
379            // (Value::UInt(a), Value::Int(b)) => Some(
380            //     a.to_owned()
381            //         .try_into()
382            //         .map(|a: i64| a.cmp(b))
383            //         // If the u64 doesn't fit into a i64 it must be greater than i64::MAX.
384            //         .unwrap_or(Ordering::Greater),
385            // ),
386            // (Value::UInt(a), Value::Float(b)) => (*a as f64).partial_cmp(b),
387            (Value::Float(a), Value::Int(b)) => a.partial_cmp(&(*b as f64)),
388            // (Value::Float(a), Value::UInt(b)) => a.partial_cmp(&(*b as f64)),
389            #[cfg(feature = "ip")]
390            (Value::Ip(a), Value::Ip(b)) => Some(a.cmp(b)),
391            _ => None,
392        }
393    }
394}
395
396impl From<&Key> for Value {
397    fn from(value: &Key) -> Self {
398        match value {
399            Key::Int(v) => Value::Int(*v),
400            // Key::Uint(v) => Value::UInt(*v),
401            Key::Bool(v) => Value::Bool(*v),
402            Key::String(v) => Value::String(v.clone()),
403        }
404    }
405}
406
407impl From<Key> for Value {
408    fn from(value: Key) -> Self {
409        match value {
410            Key::Int(v) => Value::Int(v),
411            // Key::Uint(v) => Value::UInt(v),
412            Key::Bool(v) => Value::Bool(v),
413            Key::String(v) => Value::String(v),
414        }
415    }
416}
417
418impl From<&Key> for Key {
419    fn from(key: &Key) -> Self {
420        key.clone()
421    }
422}
423
424// Convert Vec<T> to Value
425impl<T: Into<Value>> From<Vec<T>> for Value {
426    fn from(v: Vec<T>) -> Self {
427        Value::List(v.into_iter().map(|v| v.into()).collect::<Vec<_>>().into())
428    }
429}
430
431// Convert Vec<u8> to Value
432impl From<Vec<u8>> for Value {
433    fn from(v: Vec<u8>) -> Self {
434        Value::Bytes(v.into())
435    }
436}
437
438// Convert String to Value
439impl From<String> for Value {
440    fn from(v: String) -> Self {
441        Value::String(v.into())
442    }
443}
444
445impl From<&str> for Value {
446    fn from(v: &str) -> Self {
447        Value::String(v.to_string().into())
448    }
449}
450
451// Convert Option<T> to Value
452impl<T: Into<Value>> From<Option<T>> for Value {
453    fn from(v: Option<T>) -> Self {
454        match v {
455            Some(v) => v.into(),
456            None => Value::Null,
457        }
458    }
459}
460
461// Convert HashMap<K, V> to Value
462impl<K: Into<Key>, V: Into<Value>> From<HashMap<K, V>> for Value {
463    fn from(v: HashMap<K, V>) -> Self {
464        Value::Map(v.into())
465    }
466}
467
468impl From<ExecutionError> for ResolveResult {
469    fn from(value: ExecutionError) -> Self {
470        Err(value)
471    }
472}
473
474pub type ResolveResult = Result<Value, ExecutionError>;
475
476impl From<Value> for ResolveResult {
477    fn from(value: Value) -> Self {
478        Ok(value)
479    }
480}
481
482impl Value {
483    pub fn resolve_all(expr: &[Expression], ctx: &Context) -> ResolveResult {
484        let mut res = Vec::with_capacity(expr.len());
485        for expr in expr {
486            res.push(Value::resolve(expr, ctx)?);
487        }
488        Ok(Value::List(res.into()))
489    }
490
491    #[inline(always)]
492    pub fn resolve(expr: &Expression, ctx: &Context) -> ResolveResult {
493        match &expr.expr {
494            Expr::Literal(val) => Ok(val.clone().into()),
495            Expr::Call(call) => {
496                if call.args.len() == 3 && call.func_name == operators::CONDITIONAL {
497                    let cond = Value::resolve(&call.args[0], ctx)?;
498                    return if cond.to_bool()? {
499                        Value::resolve(&call.args[1], ctx)
500                    } else {
501                        Value::resolve(&call.args[2], ctx)
502                    };
503                }
504                if call.args.len() == 2 {
505                    match call.func_name.as_str() {
506                        operators::ADD => {
507                            return Value::resolve(&call.args[0], ctx)? + Value::resolve(&call.args[1], ctx)?;
508                        }
509                        operators::SUBSTRACT => {
510                            return Value::resolve(&call.args[0], ctx)? - Value::resolve(&call.args[1], ctx)?;
511                        }
512                        operators::DIVIDE => {
513                            return Value::resolve(&call.args[0], ctx)? / Value::resolve(&call.args[1], ctx)?;
514                        }
515                        operators::MULTIPLY => {
516                            return Value::resolve(&call.args[0], ctx)? * Value::resolve(&call.args[1], ctx)?;
517                        }
518                        operators::MODULO => {
519                            return Value::resolve(&call.args[0], ctx)? % Value::resolve(&call.args[1], ctx)?;
520                        }
521                        operators::EQUALS => {
522                            return Value::Bool(
523                                Value::resolve(&call.args[0], ctx)?.eq(&Value::resolve(&call.args[1], ctx)?),
524                            )
525                            .into();
526                        }
527                        operators::NOT_EQUALS => {
528                            return Value::Bool(
529                                Value::resolve(&call.args[0], ctx)?.ne(&Value::resolve(&call.args[1], ctx)?),
530                            )
531                            .into();
532                        }
533                        operators::LESS => {
534                            let left = Value::resolve(&call.args[0], ctx)?;
535                            let right = Value::resolve(&call.args[1], ctx)?;
536                            return Value::Bool(
537                                left.partial_cmp(&right)
538                                    .ok_or(ExecutionError::ValuesNotComparable(left, right))?
539                                    == Ordering::Less,
540                            )
541                            .into();
542                        }
543                        operators::LESS_EQUALS => {
544                            let left = Value::resolve(&call.args[0], ctx)?;
545                            let right = Value::resolve(&call.args[1], ctx)?;
546                            return Value::Bool(
547                                left.partial_cmp(&right)
548                                    .ok_or(ExecutionError::ValuesNotComparable(left, right))?
549                                    != Ordering::Greater,
550                            )
551                            .into();
552                        }
553                        operators::GREATER => {
554                            let left = Value::resolve(&call.args[0], ctx)?;
555                            let right = Value::resolve(&call.args[1], ctx)?;
556                            return Value::Bool(
557                                left.partial_cmp(&right)
558                                    .ok_or(ExecutionError::ValuesNotComparable(left, right))?
559                                    == Ordering::Greater,
560                            )
561                            .into();
562                        }
563                        operators::GREATER_EQUALS => {
564                            let left = Value::resolve(&call.args[0], ctx)?;
565                            let right = Value::resolve(&call.args[1], ctx)?;
566                            return Value::Bool(
567                                left.partial_cmp(&right)
568                                    .ok_or(ExecutionError::ValuesNotComparable(left, right))?
569                                    != Ordering::Less,
570                            )
571                            .into();
572                        }
573                        // operators::IN => {
574                        //     let left = Value::resolve(&call.args[0], ctx)?;
575                        //     let right = Value::resolve(&call.args[1], ctx)?;
576                        //     match (left, right) {
577                        //         (Value::String(l), Value::String(r)) => {
578                        //             return Value::Bool(r.contains(&*l)).into();
579                        //         }
580                        //         (any, Value::List(v)) => {
581                        //             return Value::Bool(v.contains(&any)).into();
582                        //         }
583                        //         (any, Value::Map(m)) => match any.try_into() {
584                        //             Ok(key) => return Value::Bool(m.map.contains_key(&key)).into(),
585                        //             Err(_) => return Value::Bool(false).into(),
586                        //         },
587                        //         (left, right) => {
588                        //             Err(ExecutionError::ValuesNotComparable(left, right))?
589                        //         }
590                        //     }
591                        // }
592                        operators::LOGICAL_OR => {
593                            let left = Value::resolve(&call.args[0], ctx)?;
594                            return if left.to_bool()? {
595                                left.into()
596                            } else {
597                                Value::resolve(&call.args[1], ctx)
598                            };
599                        }
600                        operators::LOGICAL_AND => {
601                            let left = Value::resolve(&call.args[0], ctx)?;
602                            return if !left.to_bool()? {
603                                Value::Bool(false)
604                            } else {
605                                let right = Value::resolve(&call.args[1], ctx)?;
606                                Value::Bool(right.to_bool()?)
607                            }
608                            .into();
609                        }
610                        operators::INDEX => {
611                            let value = Value::resolve(&call.args[0], ctx)?;
612                            let idx = Value::resolve(&call.args[1], ctx)?;
613                            return match (value, idx) {
614                                (Value::List(items), Value::Int(idx)) => {
615                                    items.get(idx as usize).cloned().unwrap_or(Value::Null).into()
616                                }
617                                (Value::String(str), Value::Int(idx)) => {
618                                    match str.get(idx as usize..(idx + 1) as usize) {
619                                        None => Ok(Value::Null),
620                                        Some(str) => Ok(Value::String(str.to_string().into())),
621                                    }
622                                }
623                                (Value::Map(map), Value::String(property)) => {
624                                    map.get(&property.into()).cloned().unwrap_or(Value::Null).into()
625                                }
626                                (Value::Map(map), Value::Bool(property)) => {
627                                    map.get(&property.into()).cloned().unwrap_or(Value::Null).into()
628                                }
629                                (Value::Map(map), Value::Int(property)) => {
630                                    map.get(&property.into()).cloned().unwrap_or(Value::Null).into()
631                                }
632                                // (Value::Map(map), Value::UInt(property)) => map
633                                //     .get(&property.into())
634                                //     .cloned()
635                                //     .unwrap_or(Value::Null)
636                                //     .into(),
637                                (Value::Map(_), index) => Err(ExecutionError::UnsupportedMapIndex(index)),
638                                (Value::List(_), index) => Err(ExecutionError::UnsupportedListIndex(index)),
639                                (value, index) => Err(ExecutionError::UnsupportedIndex(value, index)),
640                            };
641                        }
642                        _ => (),
643                    }
644                }
645                if call.args.len() == 1 {
646                    let expr = Value::resolve(&call.args[0], ctx)?;
647                    match call.func_name.as_str() {
648                        operators::LOGICAL_NOT => return Ok(Value::Bool(!expr.to_bool()?)),
649                        operators::NEGATE => {
650                            return match expr {
651                                Value::Int(i) => Ok(Value::Int(-i)),
652                                Value::Float(f) => Ok(Value::Float(-f)),
653                                value => Err(ExecutionError::UnsupportedUnaryOperator("minus", value)),
654                            };
655                        }
656                        operators::NOT_STRICTLY_FALSE => {
657                            return match expr {
658                                Value::Bool(b) => Ok(Value::Bool(b)),
659                                _ => Ok(Value::Bool(true)),
660                            };
661                        }
662                        _ => (),
663                    }
664                }
665                let func = ctx
666                    .get_function(call.func_name.as_str())
667                    .ok_or_else(|| ExecutionError::UndeclaredReference(call.func_name.clone().into()))?;
668                match &call.target {
669                    None => {
670                        let mut ctx = FunctionContext::new(call.func_name.clone().into(), None, ctx, call.args.clone());
671                        (func)(&mut ctx)
672                    }
673                    Some(target) => {
674                        let mut ctx = FunctionContext::new(
675                            call.func_name.clone().into(),
676                            Some(Value::resolve(target, ctx)?),
677                            ctx,
678                            call.args.clone(),
679                        );
680                        (func)(&mut ctx)
681                    }
682                }
683            }
684            Expr::Ident(name) => ctx.get_variable(name),
685            Expr::Select(select) => {
686                let left = Value::resolve(select.operand.deref(), ctx)?;
687                if select.test {
688                    match &left {
689                        Value::Map(map) => {
690                            for key in map.map.deref().keys() {
691                                if key.to_string().eq(&select.field) {
692                                    return Ok(Value::Bool(true));
693                                }
694                            }
695                            Ok(Value::Bool(false))
696                        }
697                        _ => Ok(Value::Bool(false)),
698                    }
699                } else {
700                    left.member(&select.field)
701                }
702            }
703            Expr::List(list_expr) => {
704                let list = list_expr
705                    .elements
706                    .iter()
707                    .map(|i| Value::resolve(i, ctx))
708                    .collect::<Result<Vec<_>, _>>()?;
709                Value::List(list.into()).into()
710            }
711            Expr::Map(map_expr) => {
712                let mut map = HashMap::with_capacity(map_expr.entries.len());
713                for entry in map_expr.entries.iter() {
714                    let (k, v) = match &entry.expr {
715                        EntryExpr::StructField(_) => panic!("WAT?"),
716                        EntryExpr::MapEntry(e) => (&e.key, &e.value),
717                    };
718                    let key = Value::resolve(k, ctx)?
719                        .try_into()
720                        .map_err(ExecutionError::UnsupportedKeyType)?;
721                    let value = Value::resolve(v, ctx)?;
722                    map.insert(key, value);
723                }
724                Ok(Value::Map(Map {
725                    map: Arc::from(map),
726                }))
727            }
728            Expr::Comprehension(comprehension) => {
729                let accu_init = Value::resolve(&comprehension.accu_init, ctx)?;
730                let iter = Value::resolve(&comprehension.iter_range, ctx)?;
731                let mut ctx = ctx.new_inner_scope();
732                ctx.add_variable(&comprehension.accu_var, accu_init)
733                    .expect("Failed to add accu variable");
734
735                match iter {
736                    Value::List(items) => {
737                        for item in items.deref() {
738                            if !Value::resolve(&comprehension.loop_cond, &ctx)?.to_bool()? {
739                                break;
740                            }
741                            ctx.add_variable_from_value(&comprehension.iter_var, item.clone());
742                            let accu = Value::resolve(&comprehension.loop_step, &ctx)?;
743                            ctx.add_variable_from_value(&comprehension.accu_var, accu);
744                        }
745                    }
746                    Value::Map(map) => {
747                        for key in map.map.deref().keys() {
748                            if !Value::resolve(&comprehension.loop_cond, &ctx)?.to_bool()? {
749                                break;
750                            }
751                            ctx.add_variable_from_value(&comprehension.iter_var, key.clone());
752                            let accu = Value::resolve(&comprehension.loop_step, &ctx)?;
753                            ctx.add_variable_from_value(&comprehension.accu_var, accu);
754                        }
755                    }
756                    t => todo!("Support {t:?}"),
757                }
758                Value::resolve(&comprehension.result, &ctx)
759            }
760            Expr::Struct(_) => todo!("Support structs!"),
761            Expr::Unspecified => panic!("Can't evaluate Unspecified Expr"),
762        }
763    }
764
765    // >> a(b)
766    // Member(Ident("a"),
767    //        FunctionCall([Ident("b")]))
768    // >> a.b(c)
769    // Member(Member(Ident("a"),
770    //               Attribute("b")),
771    //        FunctionCall([Ident("c")]))
772
773    fn member(self, name: &str) -> ResolveResult {
774        // todo! Ideally we would avoid creating a String just to create a Key for lookup in the
775        // map, but this would require something like the `hashbrown` crate's `Equivalent` trait.
776        let name: Arc<String> = name.to_owned().into();
777
778        // This will always either be because we're trying to access
779        // a property on self, or a method on self.
780        let child = match self {
781            Value::Map(ref m) => m.map.get(&name.clone().into()).cloned(),
782            _ => None,
783        };
784
785        // If the property is both an attribute and a method, then we
786        // give priority to the property. Maybe we can implement lookahead
787        // to see if the next token is a function call?
788        if let Some(child) = child {
789            child.into()
790        } else {
791            ExecutionError::NoSuchKey(name.clone()).into()
792        }
793    }
794
795    #[inline(always)]
796    fn to_bool(&self) -> Result<bool, ExecutionError> {
797        match self {
798            Value::Bool(v) => Ok(*v),
799            _ => Err(ExecutionError::NoSuchOverload),
800        }
801    }
802}
803
804impl ops::Add<Value> for Value {
805    type Output = ResolveResult;
806
807    #[inline(always)]
808    fn add(self, rhs: Value) -> Self::Output {
809        match (self, rhs) {
810            (Value::Int(l), Value::Int(r)) => l
811                .checked_add(r)
812                .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
813                .map(Value::Int),
814
815            // (Value::UInt(l), Value::UInt(r)) => l
816            //     .checked_add(r)
817            //     .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
818            //     .map(Value::UInt),
819            (Value::Float(l), Value::Float(r)) => Value::Float(l + r).into(),
820
821            (Value::List(mut l), Value::List(mut r)) => {
822                {
823                    // If this is the only reference to `l`, we can append to it in place.
824                    // `l` is replaced with a clone otherwise.
825                    let l = Arc::make_mut(&mut l);
826
827                    // Likewise, if this is the only reference to `r`, we can move its values
828                    // instead of cloning them.
829                    match Arc::get_mut(&mut r) {
830                        Some(r) => l.append(r),
831                        None => l.extend(r.iter().cloned()),
832                    }
833                }
834
835                Ok(Value::List(l))
836            }
837            (Value::String(mut l), Value::String(r)) => {
838                // If this is the only reference to `l`, we can append to it in place.
839                // `l` is replaced with a clone otherwise.
840                Arc::make_mut(&mut l).push_str(&r);
841                Ok(Value::String(l))
842            }
843            #[cfg(feature = "time")]
844            (Value::Duration(l), Value::Duration(r)) => l
845                .checked_add(&r)
846                .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
847                .map(Value::Duration),
848            #[cfg(feature = "time")]
849            (Value::Timestamp(l), Value::Duration(r)) => checked_op(TsOp::Add, &l, &r),
850            #[cfg(feature = "time")]
851            (Value::Duration(l), Value::Timestamp(r)) => r
852                .checked_add_signed(l)
853                .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
854                .map(Value::Timestamp),
855            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator("add", left, right)),
856        }
857    }
858}
859
860impl ops::Sub<Value> for Value {
861    type Output = ResolveResult;
862
863    #[inline(always)]
864    fn sub(self, rhs: Value) -> Self::Output {
865        match (self, rhs) {
866            (Value::Int(l), Value::Int(r)) => l
867                .checked_sub(r)
868                .ok_or(ExecutionError::Overflow("sub", l.into(), r.into()))
869                .map(Value::Int),
870
871            // (Value::UInt(l), Value::UInt(r)) => l
872            //     .checked_sub(r)
873            //     .ok_or(ExecutionError::Overflow("sub", l.into(), r.into()))
874            //     .map(Value::UInt),
875            (Value::Float(l), Value::Float(r)) => Value::Float(l - r).into(),
876
877            #[cfg(feature = "time")]
878            (Value::Duration(l), Value::Duration(r)) => l
879                .checked_sub(&r)
880                .ok_or(ExecutionError::Overflow("sub", l.into(), r.into()))
881                .map(Value::Duration),
882            #[cfg(feature = "time")]
883            (Value::Timestamp(l), Value::Duration(r)) => checked_op(TsOp::Sub, &l, &r),
884            #[cfg(feature = "time")]
885            (Value::Timestamp(l), Value::Timestamp(r)) => Value::Duration(l.signed_duration_since(r)).into(),
886            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator("sub", left, right)),
887        }
888    }
889}
890
891impl ops::Div<Value> for Value {
892    type Output = ResolveResult;
893
894    #[inline(always)]
895    fn div(self, rhs: Value) -> Self::Output {
896        match (self, rhs) {
897            (Value::Int(l), Value::Int(r)) => {
898                if r == 0 {
899                    Err(ExecutionError::DivisionByZero(l.into()))
900                } else {
901                    l.checked_div(r)
902                        .ok_or(ExecutionError::Overflow("div", l.into(), r.into()))
903                        .map(Value::Int)
904                }
905            }
906
907            // (Value::UInt(l), Value::UInt(r)) => l
908            //     .checked_div(r)
909            //     .ok_or(ExecutionError::DivisionByZero(l.into()))
910            //     .map(Value::UInt),
911            (Value::Float(l), Value::Float(r)) => Value::Float(l / r).into(),
912
913            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator("div", left, right)),
914        }
915    }
916}
917
918impl ops::Mul<Value> for Value {
919    type Output = ResolveResult;
920
921    #[inline(always)]
922    fn mul(self, rhs: Value) -> Self::Output {
923        match (self, rhs) {
924            (Value::Int(l), Value::Int(r)) => l
925                .checked_mul(r)
926                .ok_or(ExecutionError::Overflow("mul", l.into(), r.into()))
927                .map(Value::Int),
928
929            // (Value::UInt(l), Value::UInt(r)) => l
930            //     .checked_mul(r)
931            //     .ok_or(ExecutionError::Overflow("mul", l.into(), r.into()))
932            //     .map(Value::UInt),
933            (Value::Float(l), Value::Float(r)) => Value::Float(l * r).into(),
934
935            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator("mul", left, right)),
936        }
937    }
938}
939
940impl ops::Rem<Value> for Value {
941    type Output = ResolveResult;
942
943    #[inline(always)]
944    fn rem(self, rhs: Value) -> Self::Output {
945        match (self, rhs) {
946            (Value::Int(l), Value::Int(r)) => {
947                if r == 0 {
948                    Err(ExecutionError::RemainderByZero(l.into()))
949                } else {
950                    l.checked_rem(r)
951                        .ok_or(ExecutionError::Overflow("rem", l.into(), r.into()))
952                        .map(Value::Int)
953                }
954            }
955
956            // (Value::UInt(l), Value::UInt(r)) => l
957            //     .checked_rem(r)
958            //     .ok_or(ExecutionError::RemainderByZero(l.into()))
959            //     .map(Value::UInt),
960            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator("rem", left, right)),
961        }
962    }
963}
964
965/// Op represents a binary arithmetic operation supported on a timestamp
966///
967#[cfg(feature = "time")]
968enum TsOp {
969    Add,
970    Sub,
971}
972
973#[cfg(feature = "time")]
974impl TsOp {
975    fn str(&self) -> &'static str {
976        match self {
977            TsOp::Add => "add",
978            TsOp::Sub => "sub",
979        }
980    }
981}
982
983/// Performs a checked arithmetic operation [`TsOp`] on a timestamp and a duration and ensures that
984/// the resulting timestamp does not overflow the data type internal limits, as well as the timestamp
985/// limits defined in the cel-spec. See [`MAX_TIMESTAMP`] and [`MIN_TIMESTAMP`] for more details.
986#[cfg(feature = "time")]
987fn checked_op(op: TsOp, lhs: &chrono::DateTime<chrono::FixedOffset>, rhs: &chrono::Duration) -> ResolveResult {
988    // Add lhs and rhs together, checking for data type overflow
989    let result = match op {
990        TsOp::Add => lhs.checked_add_signed(*rhs),
991        TsOp::Sub => lhs.checked_sub_signed(*rhs),
992    }
993    .ok_or(ExecutionError::Overflow(op.str(), (*lhs).into(), (*rhs).into()))?;
994
995    // Check for cel-spec limits
996    if result > *MAX_TIMESTAMP || result < *MIN_TIMESTAMP {
997        Err(ExecutionError::Overflow(op.str(), (*lhs).into(), (*rhs).into()))
998    } else {
999        Value::Timestamp(result).into()
1000    }
1001}
1002
1003#[cfg(test)]
1004mod tests {
1005    use std::{collections::HashMap, sync::Arc};
1006
1007    use crate::{Context, ExecutionError, Program, Value, objects::Key};
1008
1009    #[test]
1010    fn test_indexed_map_access() {
1011        let mut context = Context::default();
1012        let mut headers = HashMap::new();
1013        headers.insert("Content-Type", "application/json".to_string());
1014        context.add_variable_from_value("headers", headers);
1015
1016        let program = Program::compile("headers[\"Content-Type\"]").unwrap();
1017        let value = program.execute(&context).unwrap();
1018        assert_eq!(value, "application/json".into());
1019    }
1020
1021    #[test]
1022    fn test_numeric_map_access() {
1023        let mut context = Context::default();
1024        let mut numbers = HashMap::new();
1025        numbers.insert(Key::Int(1), "one".to_string());
1026        context.add_variable_from_value("numbers", numbers);
1027
1028        let program = Program::compile("numbers[1]").unwrap();
1029        let value = program.execute(&context).unwrap();
1030        assert_eq!(value, "one".into());
1031    }
1032
1033    #[test]
1034    fn test_heterogeneous_compare() {
1035        let context = Context::default();
1036
1037        // let program = Program::compile("1 < Uint(2)").unwrap();
1038        // let value = program.execute(&context).unwrap();
1039        // assert_eq!(value, true.into());
1040
1041        let program = Program::compile("1 < 1.1").unwrap();
1042        let value = program.execute(&context).unwrap();
1043        assert_eq!(value, true.into());
1044
1045        // let program = Program::compile("Uint(0) > -10").unwrap();
1046        // let value = program.execute(&context).unwrap();
1047        // assert_eq!(
1048        //     value,
1049        //     true.into(),
1050        //     "negative signed ints should be less than uints"
1051        // );
1052    }
1053
1054    #[test]
1055    fn test_float_compare() {
1056        let context = Context::default();
1057
1058        let program = Program::compile("1.0 > 0.0").unwrap();
1059        let value = program.execute(&context).unwrap();
1060        assert_eq!(value, true.into());
1061
1062        let program = Program::compile(r#"Float("NaN") == Float("NaN")"#).unwrap();
1063        let value = program.execute(&context).unwrap();
1064        assert_eq!(value, false.into(), "NaN should not equal itself");
1065
1066        let program = Program::compile(r#"1.0 > Float("NaN")"#).unwrap();
1067        let result = program.execute(&context);
1068        assert!(result.is_err(), "NaN should not be comparable with inequality operators");
1069    }
1070
1071    #[test]
1072    fn test_invalid_compare() {
1073        let context = Context::default();
1074
1075        let program = Program::compile("{} == []").unwrap();
1076        let value = program.execute(&context).unwrap();
1077        assert_eq!(value, false.into());
1078    }
1079
1080    #[test]
1081    fn test_size_fn_var() {
1082        let program = Program::compile("length(requests) + size == 5").unwrap();
1083        let mut context = Context::default();
1084        let requests = vec![Value::Int(42), Value::Int(42)];
1085        context
1086            .add_variable("requests", Value::List(Arc::new(requests)))
1087            .unwrap();
1088        context.add_variable("size", Value::Int(3)).unwrap();
1089        assert_eq!(program.execute(&context).unwrap(), Value::Bool(true));
1090    }
1091
1092    fn test_execution_error(program: &str, expected: ExecutionError) {
1093        let program = Program::compile(program).unwrap();
1094        let result = program.execute(&Context::default());
1095        assert_eq!(result.unwrap_err(), expected);
1096    }
1097
1098    #[test]
1099    fn test_invalid_sub() {
1100        test_execution_error(
1101            r#""foo" - 10"#,
1102            ExecutionError::UnsupportedBinaryOperator("sub", "foo".into(), Value::Int(10)),
1103        );
1104    }
1105
1106    #[test]
1107    fn test_invalid_add() {
1108        test_execution_error(
1109            r#""foo" + 10"#,
1110            ExecutionError::UnsupportedBinaryOperator("add", "foo".into(), Value::Int(10)),
1111        );
1112    }
1113
1114    #[test]
1115    fn test_invalid_div() {
1116        test_execution_error(
1117            r#""foo" / 10"#,
1118            ExecutionError::UnsupportedBinaryOperator("div", "foo".into(), Value::Int(10)),
1119        );
1120    }
1121
1122    #[test]
1123    fn test_invalid_rem() {
1124        test_execution_error(
1125            r#""foo" % 10"#,
1126            ExecutionError::UnsupportedBinaryOperator("rem", "foo".into(), Value::Int(10)),
1127        );
1128    }
1129
1130    #[test]
1131    fn out_of_bound_list_access() {
1132        let program = Program::compile("list[10]").unwrap();
1133        let mut context = Context::default();
1134        context.add_variable("list", Value::List(Arc::new(vec![]))).unwrap();
1135        let result = program.execute(&context);
1136        assert_eq!(result.unwrap(), Value::Null);
1137    }
1138
1139    #[test]
1140    fn reference_to_value() {
1141        let test = "example".to_string();
1142        let direct: Value = test.as_str().into();
1143        assert_eq!(direct, Value::String(Arc::new(String::from("example"))));
1144
1145        let vec = vec![test.as_str()];
1146        let indirect: Value = vec.into();
1147        assert_eq!(
1148            indirect,
1149            Value::List(Arc::new(vec![Value::String(Arc::new(String::from("example")))]))
1150        );
1151    }
1152
1153    #[test]
1154    fn test_short_circuit_and() {
1155        let mut context = Context::default();
1156        let data: HashMap<String, String> = HashMap::new();
1157        context.add_variable_from_value("data", data);
1158
1159        let program = Program::compile("has(data.x) && data.x.starts_with(\"foo\")").unwrap();
1160        let value = program.execute(&context);
1161        println!("{value:?}");
1162        assert!(value.is_ok(), "The AND expression should support short-circuit evaluation.");
1163    }
1164
1165    #[test]
1166    fn invalid_int_math() {
1167        use ExecutionError::*;
1168
1169        let cases = [
1170            ("1 / 0", DivisionByZero(1.into())),
1171            ("1 % 0", RemainderByZero(1.into())),
1172            (&format!("{} + 1", i64::MAX), Overflow("add", i64::MAX.into(), 1.into())),
1173            (&format!("{} - 1", i64::MIN), Overflow("sub", i64::MIN.into(), 1.into())),
1174            (&format!("{} * 2", i64::MAX), Overflow("mul", i64::MAX.into(), 2.into())),
1175            (&format!("{} / -1", i64::MIN), Overflow("div", i64::MIN.into(), (-1).into())),
1176            (&format!("{} % -1", i64::MIN), Overflow("rem", i64::MIN.into(), (-1).into())),
1177        ];
1178
1179        for (expr, err) in cases {
1180            test_execution_error(expr, err);
1181        }
1182    }
1183
1184    // #[test]
1185    // fn invalid_uint_math() {
1186    //     use ExecutionError::*;
1187
1188    //     let cases = [
1189    //         ("1u / 0u", DivisionByZero(1u64.into())),
1190    //         ("1u % 0u", RemainderByZero(1u64.into())),
1191    //         (
1192    //             &format!("{}u + 1u", u64::MAX),
1193    //             Overflow("add", u64::MAX.into(), 1u64.into()),
1194    //         ),
1195    //         ("0u - 1u", Overflow("sub", 0u64.into(), 1u64.into())),
1196    //         (
1197    //             &format!("{}u * 2u", u64::MAX),
1198    //             Overflow("mul", u64::MAX.into(), 2u64.into()),
1199    //         ),
1200    //     ];
1201
1202    //     for (expr, err) in cases {
1203    //         test_execution_error(expr, err);
1204    //     }
1205    // }
1206
1207    #[test]
1208    fn test_function_identifier() {
1209        fn with(
1210            ftx: &crate::FunctionContext,
1211            crate::extractors::This(this): crate::extractors::This<Value>,
1212            ident: crate::extractors::Identifier,
1213            expr: crate::parser::Expression,
1214        ) -> crate::ResolveResult {
1215            let mut ptx = ftx.ptx.new_inner_scope();
1216            ptx.add_variable_from_value(&ident, this);
1217            ptx.resolve(&expr)
1218        }
1219        let mut context = Context::default();
1220        context.add_function("with", with);
1221
1222        let program = Program::compile("[1,2].with(a, a + a)").unwrap();
1223        let value = program.execute(&context);
1224        assert_eq!(
1225            value,
1226            Ok(Value::List(Arc::new(vec![
1227                Value::Int(1),
1228                Value::Int(2),
1229                Value::Int(1),
1230                Value::Int(2)
1231            ])))
1232        );
1233    }
1234}