1use std::fmt;
2
3use serde::{
4 Deserialize, Deserializer, Serialize, Serializer,
5 de::{self, DeserializeOwned, MapAccess, SeqAccess, Visitor, value::MapAccessDeserializer},
6};
7use serde_json::value::RawValue;
8
9use crate::Id;
10
11#[derive(Clone, Debug, PartialEq, Eq)]
18pub struct RequestId(pub Option<Id>);
19
20impl RequestId {
21 pub fn is_notification(&self) -> bool {
23 self.0.is_none()
24 }
25
26 pub fn as_ref(&self) -> Option<&Id> {
28 self.0.as_ref()
29 }
30
31 pub fn into_id(self) -> Option<Id> {
33 self.0
34 }
35}
36
37impl Default for RequestId {
38 fn default() -> Self {
39 Self(None)
40 }
41}
42
43impl Serialize for RequestId {
44 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
45 match &self.0 {
46 Some(id) => id.serialize(serializer),
47 None => serializer.serialize_none(),
48 }
49 }
50}
51
52impl<'de> Deserialize<'de> for RequestId {
53 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
54 struct RequestIdVisitor;
55
56 impl<'de> Visitor<'de> for RequestIdVisitor {
57 type Value = RequestId;
58
59 fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.write_str("a JSON-RPC id: a string, an integer, or null")
61 }
62
63 fn visit_none<E: de::Error>(self) -> Result<Self::Value, E> {
64 Ok(RequestId(Some(Id::Null)))
65 }
66
67 fn visit_unit<E: de::Error>(self) -> Result<Self::Value, E> {
68 Ok(RequestId(Some(Id::Null)))
69 }
70
71 fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
72 Ok(RequestId(Some(Id::Number(v))))
73 }
74
75 fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
76 Ok(RequestId(Some(Id::Number(v as i64))))
77 }
78
79 fn visit_f64<E: de::Error>(self, v: f64) -> Result<Self::Value, E> {
80 if v.fract() == 0.0 {
81 Ok(RequestId(Some(Id::Number(v as i64))))
82 } else {
83 Err(de::Error::invalid_value(de::Unexpected::Float(v), &self))
84 }
85 }
86
87 fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
88 Ok(RequestId(Some(Id::String(v.to_owned()))))
89 }
90
91 fn visit_string<E: de::Error>(self, v: String) -> Result<Self::Value, E> {
92 Ok(RequestId(Some(Id::String(v))))
93 }
94 }
95
96 deserializer.deserialize_any(RequestIdVisitor)
97 }
98}
99
100#[derive(Clone, Debug, Serialize, Deserialize)]
105pub struct Request {
106 pub jsonrpc: String,
108 pub method: String,
110 #[serde(default)]
112 #[serde(skip_serializing_if = "Option::is_none")]
113 pub params: Option<Box<RawValue>>,
114 #[serde(default)]
116 #[serde(skip_serializing_if = "RequestId::is_notification")]
117 pub id: RequestId,
118}
119
120impl Request {
121 pub fn is_notification(&self) -> bool {
123 self.id.is_notification()
124 }
125
126 pub fn parse_params<P: DeserializeOwned>(&self) -> Result<P, serde_json::Error> {
130 match &self.params {
131 Some(raw) => serde_json::from_str(raw.get()),
132 None => serde_json::from_str("{}"),
133 }
134 }
135}
136
137#[derive(Clone, Debug)]
144pub enum RequestMessage {
145 Single(Request),
147 Batch(Vec<Request>),
150}
151
152impl RequestMessage {
153 pub fn is_batch(&self) -> bool {
155 matches!(self, Self::Batch(_))
156 }
157
158 pub fn len(&self) -> usize {
160 match self {
161 Self::Single(_) => 1,
162 Self::Batch(entries) => entries.len(),
163 }
164 }
165
166 pub fn is_empty(&self) -> bool {
168 match self {
169 Self::Batch(entries) => entries.is_empty(),
170 _ => false,
171 }
172 }
173}
174
175impl From<Request> for RequestMessage {
176 fn from(req: Request) -> Self {
177 Self::Single(req)
178 }
179}
180
181impl Serialize for RequestMessage {
182 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
183 match self {
184 Self::Single(req) => req.serialize(serializer),
185 Self::Batch(entries) => entries.serialize(serializer),
186 }
187 }
188}
189
190impl<'de> Deserialize<'de> for RequestMessage {
191 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
192 struct RequestMessageVisitor;
193
194 impl<'de> Visitor<'de> for RequestMessageVisitor {
195 type Value = RequestMessage;
196
197 fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 f.write_str("a JSON-RPC request object or array")
199 }
200
201 fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
202 let mut requests = Vec::with_capacity(seq.size_hint().unwrap_or(1));
203 while let Some(req) = seq.next_element::<Request>()? {
204 requests.push(req);
205 }
206 Ok(RequestMessage::Batch(requests))
207 }
208
209 fn visit_map<M: MapAccess<'de>>(self, map: M) -> Result<Self::Value, M::Error> {
210 let req = Request::deserialize(MapAccessDeserializer::new(map))?;
211 Ok(RequestMessage::Single(req))
212 }
213 }
214
215 deserializer.deserialize_any(RequestMessageVisitor)
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use serde_json::json;
222
223 use super::*;
224
225 #[test]
226 fn test_deserialize_request() {
227 let json = r#"{"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":1}"#;
228 let req: Request = serde_json::from_str(json).unwrap();
229 assert_eq!(req.jsonrpc, "2.0");
230 assert_eq!(req.method, "subtract");
231 assert!(req.params.is_some());
232 assert_eq!(req.id, RequestId(Some(Id::Number(1))));
233 assert!(!req.is_notification());
234 }
235
236 #[test]
237 fn test_deserialize_notification() {
238 let json = r#"{"jsonrpc":"2.0","method":"update","params":[1,2,3]}"#;
239 let req: Request = serde_json::from_str(json).unwrap();
240 assert!(req.is_notification());
241 assert_eq!(req.id, RequestId(None));
242 }
243
244 #[test]
245 fn test_deserialize_request_no_params() {
246 let json = r#"{"jsonrpc":"2.0","method":"foobar"}"#;
247 let req: Request = serde_json::from_str(json).unwrap();
248 assert_eq!(req.method, "foobar");
249 assert!(req.params.is_none());
250 assert!(req.is_notification());
251 }
252
253 #[test]
254 fn test_deserialize_request_string_id() {
255 let json = r#"{"jsonrpc":"2.0","method":"get_data","id":"abc"}"#;
256 let req: Request = serde_json::from_str(json).unwrap();
257 assert_eq!(req.id, RequestId(Some(Id::String("abc".into()))));
258 }
259
260 #[test]
261 fn test_deserialize_request_null_id() {
262 let json = r#"{"jsonrpc":"2.0","method":"foo","id":null}"#;
263 let req: Request = serde_json::from_str(json).unwrap();
264 assert_eq!(req.id, RequestId(Some(Id::Null)));
265 }
266
267 #[test]
268 fn test_request_parse_params() {
269 let req: Request = serde_json::from_str(r#"{"jsonrpc":"2.0","method":"add","params":[1,2],"id":1}"#).unwrap();
270 let (a, b): (i64, i64) = req.parse_params().unwrap();
271 assert_eq!((a, b), (1, 2));
272 }
273
274 #[test]
275 fn test_request_parse_params_absent() {
276 #[derive(Deserialize)]
277 struct PingParams {
278 _extra: Option<String>,
279 }
280 let req: Request = serde_json::from_str(r#"{"jsonrpc":"2.0","method":"ping","id":1}"#).unwrap();
281 let p: PingParams = req.parse_params().unwrap();
282 assert!(p._extra.is_none());
283 }
284
285 #[test]
286 fn test_request_id_into_id() {
287 let id = RequestId(Some(Id::Number(42)));
288 assert_eq!(id.into_id(), Some(Id::Number(42)));
289 let id = RequestId(None);
290 assert_eq!(id.into_id(), None);
291 }
292
293 #[test]
294 fn test_deserialize_single_request_message() {
295 let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
296 let message: RequestMessage = serde_json::from_str(json).unwrap();
297 assert!(!message.is_batch());
298 assert_eq!(message.len(), 1);
299 assert!(!message.is_empty());
300 }
301
302 #[test]
303 fn test_deserialize_batch_request_message() {
304 let json = r#"[
305 {"jsonrpc":"2.0","method":"a","id":1},
306 {"jsonrpc":"2.0","method":"b","id":2}
307 ]"#;
308 let message: RequestMessage = serde_json::from_str(json).unwrap();
309 assert!(message.is_batch());
310 assert_eq!(message.len(), 2);
311 assert!(!message.is_empty());
312 }
313
314 #[test]
315 fn test_deserialize_batch_with_invalid_entries() {
316 let json = r#"[
317 {"jsonrpc":"2.0","method":"a","id":1},
318 42,
319 {"jsonrpc":"2.0","method":"b","id":2}
320 ]"#;
321 let message: RequestMessage = serde_json::from_str(json).unwrap();
322 assert!(message.is_batch());
323 assert_eq!(message.len(), 2);
324 }
325
326 #[test]
327 fn test_deserialize_empty_array() {
328 let json = "[]";
329 let message: RequestMessage = serde_json::from_str(json).unwrap();
330 assert!(message.is_batch());
331 assert!(message.is_empty());
332 assert_eq!(message.len(), 0);
333 }
334
335 #[test]
336 fn test_request_message_from_request() {
337 let req: Request = serde_json::from_str(r#"{"jsonrpc":"2.0","method":"x","id":1}"#).unwrap();
338 let message: RequestMessage = req.into();
339 assert!(!message.is_batch());
340 }
341
342 #[test]
343 fn test_request_serialize() {
344 let req = Request {
345 jsonrpc: "2.0".into(),
346 method: "subtract".into(),
347 params: Some(RawValue::from_string("[42,23]".into()).unwrap()),
348 id: RequestId(Some(Id::Number(1))),
349 };
350 let json = serde_json::to_string(&req).unwrap();
351 let expected = json!({"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":1});
352 let actual: serde_json::Value = serde_json::from_str(&json).unwrap();
353 assert_eq!(actual, expected);
354 }
355
356 #[test]
357 fn test_request_message_batch_serialize() {
358 let reqs = vec![
359 Request {
360 jsonrpc: "2.0".into(),
361 method: "a".into(),
362 params: None,
363 id: RequestId(Some(Id::Number(1))),
364 },
365 Request {
366 jsonrpc: "2.0".into(),
367 method: "b".into(),
368 params: None,
369 id: RequestId(Some(Id::Number(2))),
370 },
371 ];
372 let message = RequestMessage::Batch(reqs);
373 assert!(message.is_batch());
374 let json = serde_json::to_string(&message).unwrap();
375 let actual: serde_json::Value = serde_json::from_str(&json).unwrap();
376 assert!(actual.is_array());
377 assert_eq!(actual.as_array().unwrap().len(), 2);
378 }
379}