Skip to main content

json_rpc/
request.rs

1use std::fmt;
2
3use serde::{
4    Deserialize, Deserializer, Serialize, Serializer,
5    de::{self, DeserializeOwned, MapAccess, SeqAccess, Visitor, value::MapAccessDeserializer},
6};
7use serde_json::value::RawValue;
8
9use crate::Id;
10
11/// The `id` field of a JSON-RPC request.
12///
13/// This wrapper distinguishes three states:
14/// - **Absent**: the field was not present in the JSON → notification
15/// - **Null**: the field was `null` → request with a null id (discouraged)
16/// - **Present**: the field was a string or number → normal request
17#[derive(Clone, Debug, PartialEq, Eq)]
18pub struct RequestId(pub Option<Id>);
19
20impl RequestId {
21    /// Returns `true` if this is a notification (no `id` field).
22    pub fn is_notification(&self) -> bool {
23        self.0.is_none()
24    }
25
26    /// Returns a reference to the inner `Id`, if present.
27    pub fn as_ref(&self) -> Option<&Id> {
28        self.0.as_ref()
29    }
30
31    /// Consumes `self` and returns the inner `Option<Id>`.
32    pub fn into_id(self) -> Option<Id> {
33        self.0
34    }
35}
36
37impl Default for RequestId {
38    fn default() -> Self {
39        Self(None)
40    }
41}
42
43impl Serialize for RequestId {
44    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
45        match &self.0 {
46            Some(id) => id.serialize(serializer),
47            None => serializer.serialize_none(),
48        }
49    }
50}
51
52impl<'de> Deserialize<'de> for RequestId {
53    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
54        struct RequestIdVisitor;
55
56        impl<'de> Visitor<'de> for RequestIdVisitor {
57            type Value = RequestId;
58
59            fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60                f.write_str("a JSON-RPC id: a string, an integer, or null")
61            }
62
63            fn visit_none<E: de::Error>(self) -> Result<Self::Value, E> {
64                Ok(RequestId(Some(Id::Null)))
65            }
66
67            fn visit_unit<E: de::Error>(self) -> Result<Self::Value, E> {
68                Ok(RequestId(Some(Id::Null)))
69            }
70
71            fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
72                Ok(RequestId(Some(Id::Number(v))))
73            }
74
75            fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
76                Ok(RequestId(Some(Id::Number(v as i64))))
77            }
78
79            fn visit_f64<E: de::Error>(self, v: f64) -> Result<Self::Value, E> {
80                if v.fract() == 0.0 {
81                    Ok(RequestId(Some(Id::Number(v as i64))))
82                } else {
83                    Err(de::Error::invalid_value(de::Unexpected::Float(v), &self))
84                }
85            }
86
87            fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
88                Ok(RequestId(Some(Id::String(v.to_owned()))))
89            }
90
91            fn visit_string<E: de::Error>(self, v: String) -> Result<Self::Value, E> {
92                Ok(RequestId(Some(Id::String(v))))
93            }
94        }
95
96        deserializer.deserialize_any(RequestIdVisitor)
97    }
98}
99
100/// A JSON-RPC 2.0 request object.
101///
102/// Params are stored as `Box<RawValue>` to defer deserialization
103/// until the method handler is known.
104#[derive(Clone, Debug, Serialize, Deserialize)]
105pub struct Request {
106    /// The JSON-RPC version — must be `"2.0"`.
107    pub jsonrpc: String,
108    /// The name of the method to invoke.
109    pub method: String,
110    /// Structured parameters for the method (optional).
111    #[serde(default)]
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub params: Option<Box<RawValue>>,
114    /// Client-assigned identifier. Absent for notifications.
115    #[serde(default)]
116    #[serde(skip_serializing_if = "RequestId::is_notification")]
117    pub id: RequestId,
118}
119
120impl Request {
121    /// Returns `true` if this request is a notification (has no `id`).
122    pub fn is_notification(&self) -> bool {
123        self.id.is_notification()
124    }
125
126    /// Deserializes the params into a concrete type.
127    ///
128    /// Returns an error if the params are absent or fail to deserialize.
129    pub fn parse_params<P: DeserializeOwned>(&self) -> Result<P, serde_json::Error> {
130        match &self.params {
131            Some(raw) => serde_json::from_str(raw.get()),
132            None => serde_json::from_str("{}"),
133        }
134    }
135}
136
137/// A single JSON-RPC request message, or a batch of them.
138///
139/// Per the spec, a client may send either a single `Request` object
140/// or an `Array` of request objects (a batch). Batch entries that are
141/// not valid request objects are preserved as raw JSON so the server
142/// can respond with individual `Invalid Request` errors.
143#[derive(Clone, Debug)]
144pub enum RequestMessage {
145    /// A single request.
146    Single(Request),
147    /// A batch of raw request values. Each element is parsed individually
148    /// during dispatch, so invalid entries get individual error responses.
149    Batch(Vec<Request>),
150}
151
152impl RequestMessage {
153    /// Returns `true` if this message is a batch.
154    pub fn is_batch(&self) -> bool {
155        matches!(self, Self::Batch(_))
156    }
157
158    /// Returns the number of entries in this message.
159    pub fn len(&self) -> usize {
160        match self {
161            Self::Single(_) => 1,
162            Self::Batch(entries) => entries.len(),
163        }
164    }
165
166    /// Returns `true` if there are no entries in this message.
167    pub fn is_empty(&self) -> bool {
168        match self {
169            Self::Batch(entries) => entries.is_empty(),
170            _ => false,
171        }
172    }
173}
174
175impl From<Request> for RequestMessage {
176    fn from(req: Request) -> Self {
177        Self::Single(req)
178    }
179}
180
181impl Serialize for RequestMessage {
182    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
183        match self {
184            Self::Single(req) => req.serialize(serializer),
185            Self::Batch(entries) => entries.serialize(serializer),
186        }
187    }
188}
189
190impl<'de> Deserialize<'de> for RequestMessage {
191    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
192        struct RequestMessageVisitor;
193
194        impl<'de> Visitor<'de> for RequestMessageVisitor {
195            type Value = RequestMessage;
196
197            fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198                f.write_str("a JSON-RPC request object or array")
199            }
200
201            fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
202                let mut requests = Vec::with_capacity(seq.size_hint().unwrap_or(1));
203                while let Some(req) = seq.next_element::<Request>()? {
204                    requests.push(req);
205                }
206                Ok(RequestMessage::Batch(requests))
207            }
208
209            fn visit_map<M: MapAccess<'de>>(self, map: M) -> Result<Self::Value, M::Error> {
210                let req = Request::deserialize(MapAccessDeserializer::new(map))?;
211                Ok(RequestMessage::Single(req))
212            }
213        }
214
215        deserializer.deserialize_any(RequestMessageVisitor)
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use serde_json::json;
222
223    use super::*;
224
225    #[test]
226    fn test_deserialize_request() {
227        let json = r#"{"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":1}"#;
228        let req: Request = serde_json::from_str(json).unwrap();
229        assert_eq!(req.jsonrpc, "2.0");
230        assert_eq!(req.method, "subtract");
231        assert!(req.params.is_some());
232        assert_eq!(req.id, RequestId(Some(Id::Number(1))));
233        assert!(!req.is_notification());
234    }
235
236    #[test]
237    fn test_deserialize_notification() {
238        let json = r#"{"jsonrpc":"2.0","method":"update","params":[1,2,3]}"#;
239        let req: Request = serde_json::from_str(json).unwrap();
240        assert!(req.is_notification());
241        assert_eq!(req.id, RequestId(None));
242    }
243
244    #[test]
245    fn test_deserialize_request_no_params() {
246        let json = r#"{"jsonrpc":"2.0","method":"foobar"}"#;
247        let req: Request = serde_json::from_str(json).unwrap();
248        assert_eq!(req.method, "foobar");
249        assert!(req.params.is_none());
250        assert!(req.is_notification());
251    }
252
253    #[test]
254    fn test_deserialize_request_string_id() {
255        let json = r#"{"jsonrpc":"2.0","method":"get_data","id":"abc"}"#;
256        let req: Request = serde_json::from_str(json).unwrap();
257        assert_eq!(req.id, RequestId(Some(Id::String("abc".into()))));
258    }
259
260    #[test]
261    fn test_deserialize_request_null_id() {
262        let json = r#"{"jsonrpc":"2.0","method":"foo","id":null}"#;
263        let req: Request = serde_json::from_str(json).unwrap();
264        assert_eq!(req.id, RequestId(Some(Id::Null)));
265    }
266
267    #[test]
268    fn test_request_parse_params() {
269        let req: Request = serde_json::from_str(r#"{"jsonrpc":"2.0","method":"add","params":[1,2],"id":1}"#).unwrap();
270        let (a, b): (i64, i64) = req.parse_params().unwrap();
271        assert_eq!((a, b), (1, 2));
272    }
273
274    #[test]
275    fn test_request_parse_params_absent() {
276        #[derive(Deserialize)]
277        struct PingParams {
278            _extra: Option<String>,
279        }
280        let req: Request = serde_json::from_str(r#"{"jsonrpc":"2.0","method":"ping","id":1}"#).unwrap();
281        let p: PingParams = req.parse_params().unwrap();
282        assert!(p._extra.is_none());
283    }
284
285    #[test]
286    fn test_request_id_into_id() {
287        let id = RequestId(Some(Id::Number(42)));
288        assert_eq!(id.into_id(), Some(Id::Number(42)));
289        let id = RequestId(None);
290        assert_eq!(id.into_id(), None);
291    }
292
293    #[test]
294    fn test_deserialize_single_request_message() {
295        let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
296        let message: RequestMessage = serde_json::from_str(json).unwrap();
297        assert!(!message.is_batch());
298        assert_eq!(message.len(), 1);
299        assert!(!message.is_empty());
300    }
301
302    #[test]
303    fn test_deserialize_batch_request_message() {
304        let json = r#"[
305            {"jsonrpc":"2.0","method":"a","id":1},
306            {"jsonrpc":"2.0","method":"b","id":2}
307        ]"#;
308        let message: RequestMessage = serde_json::from_str(json).unwrap();
309        assert!(message.is_batch());
310        assert_eq!(message.len(), 2);
311        assert!(!message.is_empty());
312    }
313
314    #[test]
315    fn test_deserialize_batch_with_invalid_entries() {
316        let json = r#"[
317            {"jsonrpc":"2.0","method":"a","id":1},
318            42,
319            {"jsonrpc":"2.0","method":"b","id":2}
320        ]"#;
321        let message: RequestMessage = serde_json::from_str(json).unwrap();
322        assert!(message.is_batch());
323        assert_eq!(message.len(), 2);
324    }
325
326    #[test]
327    fn test_deserialize_empty_array() {
328        let json = "[]";
329        let message: RequestMessage = serde_json::from_str(json).unwrap();
330        assert!(message.is_batch());
331        assert!(message.is_empty());
332        assert_eq!(message.len(), 0);
333    }
334
335    #[test]
336    fn test_request_message_from_request() {
337        let req: Request = serde_json::from_str(r#"{"jsonrpc":"2.0","method":"x","id":1}"#).unwrap();
338        let message: RequestMessage = req.into();
339        assert!(!message.is_batch());
340    }
341
342    #[test]
343    fn test_request_serialize() {
344        let req = Request {
345            jsonrpc: "2.0".into(),
346            method: "subtract".into(),
347            params: Some(RawValue::from_string("[42,23]".into()).unwrap()),
348            id: RequestId(Some(Id::Number(1))),
349        };
350        let json = serde_json::to_string(&req).unwrap();
351        let expected = json!({"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":1});
352        let actual: serde_json::Value = serde_json::from_str(&json).unwrap();
353        assert_eq!(actual, expected);
354    }
355
356    #[test]
357    fn test_request_message_batch_serialize() {
358        let reqs = vec![
359            Request {
360                jsonrpc: "2.0".into(),
361                method: "a".into(),
362                params: None,
363                id: RequestId(Some(Id::Number(1))),
364            },
365            Request {
366                jsonrpc: "2.0".into(),
367                method: "b".into(),
368                params: None,
369                id: RequestId(Some(Id::Number(2))),
370            },
371        ];
372        let message = RequestMessage::Batch(reqs);
373        assert!(message.is_batch());
374        let json = serde_json::to_string(&message).unwrap();
375        let actual: serde_json::Value = serde_json::from_str(&json).unwrap();
376        assert!(actual.is_array());
377        assert_eq!(actual.as_array().unwrap().len(), 2);
378    }
379}