Skip to main content

json_rpc/
server.rs

1use std::{collections::HashMap, future::Future, marker::PhantomData, pin::Pin, sync::Arc};
2
3use serde::{Serialize, de::DeserializeOwned};
4use serde_json::value::RawValue;
5
6use crate::{Error, ErrorCode, Id, Request, RequestMessage, Response};
7
8trait MethodHandler<C>: Send + Sync {
9    fn call(&self, ctx: C, params: &RawValue) -> Pin<Box<dyn Future<Output = Result<Box<RawValue>, Error>> + Send>>;
10}
11
12struct MethodHandlerImpl<C, H, P, R, E, F> {
13    handler: H,
14    _phantom: PhantomData<fn(C, P) -> (R, E, F)>,
15}
16
17impl<C, P, R, E, F, H> MethodHandler<C> for MethodHandlerImpl<C, H, P, R, E, F>
18where
19    C: Send + 'static,
20    P: DeserializeOwned + Send,
21    R: Serialize + Send,
22    E: Into<Error> + Send,
23    F: Future<Output = Result<R, E>> + Send + 'static,
24    H: Fn(C, P) -> F + Send + Sync,
25{
26    fn call(
27        &self,
28        ctx: C,
29        raw_params: &RawValue,
30    ) -> Pin<Box<dyn Future<Output = Result<Box<RawValue>, Error>> + Send>> {
31        let params: P = match serde_json::from_str(raw_params.get()) {
32            Ok(p) => p,
33            Err(e) => {
34                return Box::pin(async move { Err(Error::invalid_params(e.to_string())) });
35            }
36        };
37        let fut = (self.handler)(ctx, params);
38        Box::pin(async move {
39            match fut.await {
40                Ok(result) => serde_json::value::to_raw_value(&result)
41                    .map_err(|e| Error::new(ErrorCode::INTERNAL_ERROR, e.to_string())),
42                Err(e) => Err(e.into()),
43            }
44        })
45    }
46}
47
48/// The output of [`Server::handle`].
49///
50/// An `Empty` variant means nothing should be sent back (e.g., all-notification batch).
51#[derive(Clone, Debug)]
52pub enum ResponseMessage {
53    /// A single response.
54    Single(Response),
55    /// A batch of responses.
56    Batch(Vec<Response>),
57    /// No response to send (notification, or all-notification batch).
58    Empty,
59}
60
61impl ResponseMessage {
62    /// Serializes this message into a JSON string suitable for writing to a transport.
63    ///
64    /// For `Empty` variants, returns `None`.
65    /// For `Single`, returns the serialized `Response`.
66    /// For `Batch`, returns the serialized array.
67    pub fn to_json(&self) -> serde_json::Result<Option<String>> {
68        match self {
69            Self::Empty => Ok(None),
70            Self::Single(resp) => serde_json::to_string(resp).map(Some),
71            Self::Batch(resps) => {
72                if resps.is_empty() {
73                    Ok(None)
74                } else {
75                    serde_json::to_string(resps).map(Some)
76                }
77            }
78        }
79    }
80}
81
82/// A JSON-RPC 2.0 server.
83///
84/// Generic over a context type `C` that is cloned once per handler invocation.
85///
86/// # Example
87///
88/// ```rust
89/// use jsonrpc::{Server, Error};
90///
91/// let mut server = Server::new();
92/// server.register("add", |_: (), (a, b): (i64, i64)| async move {
93///     Ok::<_, Error>(a + b)
94/// });
95/// ```
96pub struct Server<C> {
97    methods: HashMap<String, Arc<dyn MethodHandler<C>>>,
98    empty_params: Box<RawValue>,
99}
100
101impl<C: Send + Sync + 'static> Server<C> {
102    /// Creates a new server with no registered methods.
103    pub fn new() -> Self {
104        Self {
105            methods: HashMap::new(),
106            empty_params: RawValue::from_string("{}".to_owned()).expect("{} is valid JSON"),
107        }
108    }
109
110    /// Registers an async handler for the given method name.
111    ///
112    /// The handler receives an owned clone of the context and deserialized
113    /// method parameters, and returns a future.
114    pub fn register<P, R, E, F>(
115        &mut self,
116        method: impl Into<String>,
117        handler: impl Fn(C, P) -> F + Send + Sync + 'static,
118    ) where
119        P: DeserializeOwned + Send + 'static,
120        R: Serialize + Send + 'static,
121        E: Into<Error> + Send + 'static,
122        F: Future<Output = Result<R, E>> + Send + 'static,
123    {
124        let entry = MethodHandlerImpl::<C, _, P, R, E, F> {
125            handler,
126            _phantom: PhantomData,
127        };
128        self.methods.insert(method.into(), Arc::new(entry));
129    }
130
131    /// Handles a request message and returns the corresponding response message.
132    ///
133    /// The context `ctx` is consumed and, for batches, cloned once per handler invocation.
134    pub async fn handle(&self, ctx: C, message: RequestMessage) -> ResponseMessage
135    where
136        C: Clone,
137    {
138        match message {
139            RequestMessage::Single(req) => self.handle_single(ctx, req).await,
140            RequestMessage::Batch(entries) => self.handle_batch(ctx, entries).await,
141        }
142    }
143
144    async fn handle_single(&self, ctx: C, req: Request) -> ResponseMessage {
145        let Some(id) = req.id.into_id() else {
146            let _ = self
147                .dispatch(ctx, &req.method, req.params.as_deref().unwrap_or(&self.empty_params))
148                .await;
149            return ResponseMessage::Empty;
150        };
151
152        let params = req.params.as_deref().unwrap_or(&self.empty_params);
153        match self.dispatch(ctx, &req.method, params).await {
154            Ok(result) => ResponseMessage::Single(Response::Success {
155                result,
156                id,
157            }),
158            Err(error) => ResponseMessage::Single(Response::Error {
159                error,
160                id,
161            }),
162        }
163    }
164
165    async fn handle_batch(&self, ctx: C, entries: Vec<Request>) -> ResponseMessage
166    where
167        C: Clone,
168    {
169        if entries.is_empty() {
170            return ResponseMessage::Single(Response::Error {
171                error: Error::invalid_request("empty batch"),
172                id: Id::Null,
173            });
174        }
175
176        let mut responses: Vec<Response> = Vec::with_capacity(entries.len());
177
178        for req in entries {
179            let Some(id) = req.id.into_id() else {
180                let _ = self
181                    .dispatch(ctx.clone(), &req.method, req.params.as_deref().unwrap_or(&self.empty_params))
182                    .await;
183                continue;
184            };
185
186            let params = req.params.as_deref().unwrap_or(&self.empty_params);
187            match self.dispatch(ctx.clone(), &req.method, params).await {
188                Ok(result) => responses.push(Response::Success {
189                    result,
190                    id,
191                }),
192                Err(error) => responses.push(Response::Error {
193                    error,
194                    id,
195                }),
196            }
197        }
198
199        if responses.is_empty() {
200            ResponseMessage::Empty
201        } else {
202            ResponseMessage::Batch(responses)
203        }
204    }
205
206    async fn dispatch(&self, ctx: C, method: &str, params: &RawValue) -> Result<Box<RawValue>, Error> {
207        let callback = self
208            .methods
209            .get(method)
210            .ok_or_else(|| Error::method_not_found(method))?;
211        callback.call(ctx, params).await
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use crate::{ErrorCode, RequestId};
219
220    fn make_request(method: &str, params: Option<&str>, id: Option<i64>) -> Request {
221        Request {
222            jsonrpc: "2.0".into(),
223            method: method.into(),
224            params: params.map(|s| RawValue::from_string(s.to_owned()).unwrap()),
225            id: RequestId(id.map(Id::Number)),
226        }
227    }
228
229    #[tokio::test]
230    async fn test_simple_handler() {
231        let mut server: Server<()> = Server::new();
232        server.register("add", |_: (), (a, b): (i64, i64)| async move { Ok::<_, Error>(a + b) });
233
234        let req = make_request("add", Some("[3, 4]"), Some(1));
235        let message = server.handle((), RequestMessage::Single(req)).await;
236
237        match message {
238            ResponseMessage::Single(Response::Success {
239                result,
240                id,
241            }) => {
242                assert_eq!(id, Id::Number(1));
243                let v: i64 = serde_json::from_str(result.get()).unwrap();
244                assert_eq!(v, 7);
245            }
246            other => panic!("expected success response, got {other:?}"),
247        }
248    }
249
250    #[tokio::test]
251    async fn test_handler_with_error() {
252        let mut server: Server<()> = Server::new();
253        server.register("div", |_: (), (a, b): (i64, i64)| async move {
254            if b == 0 {
255                Err(Error::new(-32000, "division by zero"))
256            } else {
257                Ok(a / b)
258            }
259        });
260
261        let req = make_request("div", Some("[4, 0]"), Some(1));
262        let message = server.handle((), RequestMessage::Single(req)).await;
263
264        match message {
265            ResponseMessage::Single(Response::Error {
266                error,
267                id,
268            }) => {
269                assert_eq!(id, Id::Number(1));
270                assert_eq!(error.code, -32000);
271                assert_eq!(error.message, "division by zero");
272            }
273            other => panic!("expected error response, got {other:?}"),
274        }
275    }
276
277    #[tokio::test]
278    async fn test_method_not_found() {
279        let server: Server<()> = Server::new();
280        let req = make_request("unknown", None, Some(1));
281        let message = server.handle((), RequestMessage::Single(req)).await;
282
283        match message {
284            ResponseMessage::Single(Response::Error {
285                error,
286                id,
287            }) => {
288                assert_eq!(id, Id::Number(1));
289                assert_eq!(error.code, ErrorCode::METHOD_NOT_FOUND);
290            }
291            other => panic!("expected error response, got {other:?}"),
292        }
293    }
294
295    #[tokio::test]
296    async fn test_invalid_params() {
297        let mut server: Server<()> = Server::new();
298        server.register("add", |_: (), (a, b): (i64, i64)| async move { Ok::<_, Error>(a + b) });
299
300        let req = make_request("add", Some(r#""not_an_array""#), Some(1));
301        let message = server.handle((), RequestMessage::Single(req)).await;
302
303        match message {
304            ResponseMessage::Single(Response::Error {
305                error,
306                id,
307            }) => {
308                assert_eq!(id, Id::Number(1));
309                assert_eq!(error.code, ErrorCode::INVALID_PARAMS);
310            }
311            other => panic!("expected error response, got {other:?}"),
312        }
313    }
314
315    #[tokio::test]
316    async fn test_notification_is_silent() {
317        let mut server: Server<()> = Server::new();
318        server.register("log", |_: (), _message: (String,)| async move { Ok::<_, Error>(()) });
319
320        let req = make_request("log", Some(r#"["hello"]"#), None);
321        let message = server.handle((), RequestMessage::Single(req)).await;
322
323        assert!(matches!(message, ResponseMessage::Empty));
324    }
325
326    #[tokio::test]
327    async fn test_empty_batch() {
328        let server: Server<()> = Server::new();
329        let message = server.handle((), RequestMessage::Batch(vec![])).await;
330
331        match message {
332            ResponseMessage::Single(Response::Error {
333                error,
334                id,
335            }) => {
336                assert_eq!(id, Id::Null);
337                assert_eq!(error.code, ErrorCode::INVALID_REQUEST);
338            }
339            other => panic!("expected single error for empty batch, got {other:?}"),
340        }
341    }
342
343    #[tokio::test]
344    async fn test_batch_mixed() {
345        let mut server: Server<()> = Server::new();
346        server.register("add", |_: (), (a, b): (i64, i64)| async move { Ok::<_, Error>(a + b) });
347
348        let entries = vec![
349            make_request("add", Some("[1, 2]"), Some(1)),
350            make_request("add", Some("[3, 4]"), None),
351            make_request("add", Some("[5, 6]"), Some(2)),
352        ];
353
354        let message = server.handle((), RequestMessage::Batch(entries)).await;
355
356        match message {
357            ResponseMessage::Batch(responses) => {
358                assert_eq!(responses.len(), 2);
359            }
360            other => panic!("expected batch response, got {other:?}"),
361        }
362    }
363
364    #[tokio::test]
365    async fn test_batch_with_invalid_entry() {
366        let mut server: Server<()> = Server::new();
367        server.register("add", |_: (), (a, b): (i64, i64)| async move { Ok::<_, Error>(a + b) });
368
369        let json = r#"[
370            {"jsonrpc":"2.0","method":"add","params":[1,2],"id":1},
371            42,
372            {"jsonrpc":"2.0","method":"add","params":[3,4],"id":2}
373        ]"#;
374        let message: RequestMessage = serde_json::from_str(json).unwrap();
375        let message = server.handle((), message).await;
376
377        match message {
378            ResponseMessage::Batch(responses) => {
379                assert_eq!(responses.len(), 2);
380                assert!(responses[0].is_success());
381                assert!(responses[1].is_success());
382            }
383            other => panic!("expected batch response, got {other:?}"),
384        }
385    }
386
387    #[tokio::test]
388    async fn test_all_notification_batch_is_empty() {
389        let mut server: Server<()> = Server::new();
390        server.register("notify", |_: (), _message: (String,)| async move { Ok::<_, Error>(()) });
391
392        let entries = vec![
393            make_request("notify", Some(r#"["a"]"#), None),
394            make_request("notify", Some(r#"["b"]"#), None),
395        ];
396
397        let message = server.handle((), RequestMessage::Batch(entries)).await;
398        assert!(matches!(message, ResponseMessage::Empty));
399    }
400
401    #[test]
402    fn test_response_message_to_json_single() {
403        let resp = Response::success(Id::Number(1), 42).unwrap();
404        let message = ResponseMessage::Single(resp);
405        let json = message.to_json().unwrap().unwrap();
406        let v: serde_json::Value = serde_json::from_str(&json).unwrap();
407        assert_eq!(v["result"], serde_json::json!(42));
408    }
409
410    #[test]
411    fn test_response_message_to_json_empty() {
412        let message: ResponseMessage = ResponseMessage::Empty;
413        assert!(message.to_json().unwrap().is_none());
414    }
415
416    #[test]
417    fn test_response_message_to_json_batch() {
418        let resps = vec![
419            Response::success(Id::Number(1), 10).unwrap(),
420            Response::success(Id::Number(2), 20).unwrap(),
421        ];
422        let message = ResponseMessage::Batch(resps);
423        let json = message.to_json().unwrap().unwrap();
424        let v: serde_json::Value = serde_json::from_str(&json).unwrap();
425        assert!(v.is_array());
426        assert_eq!(v.as_array().unwrap().len(), 2);
427    }
428
429    #[tokio::test]
430    async fn test_handler_with_context() {
431        #[derive(Clone)]
432        struct State {
433            base: i64,
434        }
435
436        let mut server: Server<State> = Server::new();
437        server.register("add", |ctx: State, (x,): (i64,)| async move { Ok::<_, Error>(ctx.base + x) });
438
439        let state = State {
440            base: 100,
441        };
442        let req = make_request("add", Some("[5]"), Some(1));
443        let message = server.handle(state, RequestMessage::Single(req)).await;
444
445        match message {
446            ResponseMessage::Single(Response::Success {
447                result, ..
448            }) => {
449                let v: i64 = serde_json::from_str(result.get()).unwrap();
450                assert_eq!(v, 105);
451            }
452            other => panic!("expected success, got {other:?}"),
453        }
454    }
455}