1use std::{collections::HashMap, future::Future, marker::PhantomData, pin::Pin, sync::Arc};
2
3use serde::{Serialize, de::DeserializeOwned};
4use serde_json::value::RawValue;
5
6use crate::{Error, ErrorCode, Id, Request, RequestMessage, Response};
7
8trait MethodHandler<C>: Send + Sync {
9 fn call(&self, ctx: C, params: &RawValue) -> Pin<Box<dyn Future<Output = Result<Box<RawValue>, Error>> + Send>>;
10}
11
12struct MethodHandlerImpl<C, H, P, R, E, F> {
13 handler: H,
14 _phantom: PhantomData<fn(C, P) -> (R, E, F)>,
15}
16
17impl<C, P, R, E, F, H> MethodHandler<C> for MethodHandlerImpl<C, H, P, R, E, F>
18where
19 C: Send + 'static,
20 P: DeserializeOwned + Send,
21 R: Serialize + Send,
22 E: Into<Error> + Send,
23 F: Future<Output = Result<R, E>> + Send + 'static,
24 H: Fn(C, P) -> F + Send + Sync,
25{
26 fn call(
27 &self,
28 ctx: C,
29 raw_params: &RawValue,
30 ) -> Pin<Box<dyn Future<Output = Result<Box<RawValue>, Error>> + Send>> {
31 let params: P = match serde_json::from_str(raw_params.get()) {
32 Ok(p) => p,
33 Err(e) => {
34 return Box::pin(async move { Err(Error::invalid_params(e.to_string())) });
35 }
36 };
37 let fut = (self.handler)(ctx, params);
38 Box::pin(async move {
39 match fut.await {
40 Ok(result) => serde_json::value::to_raw_value(&result)
41 .map_err(|e| Error::new(ErrorCode::INTERNAL_ERROR, e.to_string())),
42 Err(e) => Err(e.into()),
43 }
44 })
45 }
46}
47
48#[derive(Clone, Debug)]
52pub enum ResponseMessage {
53 Single(Response),
55 Batch(Vec<Response>),
57 Empty,
59}
60
61impl ResponseMessage {
62 pub fn to_json(&self) -> serde_json::Result<Option<String>> {
68 match self {
69 Self::Empty => Ok(None),
70 Self::Single(resp) => serde_json::to_string(resp).map(Some),
71 Self::Batch(resps) => {
72 if resps.is_empty() {
73 Ok(None)
74 } else {
75 serde_json::to_string(resps).map(Some)
76 }
77 }
78 }
79 }
80}
81
82pub struct Server<C> {
97 methods: HashMap<String, Arc<dyn MethodHandler<C>>>,
98 empty_params: Box<RawValue>,
99}
100
101impl<C: Send + Sync + 'static> Server<C> {
102 pub fn new() -> Self {
104 Self {
105 methods: HashMap::new(),
106 empty_params: RawValue::from_string("{}".to_owned()).expect("{} is valid JSON"),
107 }
108 }
109
110 pub fn register<P, R, E, F>(
115 &mut self,
116 method: impl Into<String>,
117 handler: impl Fn(C, P) -> F + Send + Sync + 'static,
118 ) where
119 P: DeserializeOwned + Send + 'static,
120 R: Serialize + Send + 'static,
121 E: Into<Error> + Send + 'static,
122 F: Future<Output = Result<R, E>> + Send + 'static,
123 {
124 let entry = MethodHandlerImpl::<C, _, P, R, E, F> {
125 handler,
126 _phantom: PhantomData,
127 };
128 self.methods.insert(method.into(), Arc::new(entry));
129 }
130
131 pub async fn handle(&self, ctx: C, message: RequestMessage) -> ResponseMessage
135 where
136 C: Clone,
137 {
138 match message {
139 RequestMessage::Single(req) => self.handle_single(ctx, req).await,
140 RequestMessage::Batch(entries) => self.handle_batch(ctx, entries).await,
141 }
142 }
143
144 async fn handle_single(&self, ctx: C, req: Request) -> ResponseMessage {
145 let Some(id) = req.id.into_id() else {
146 let _ = self
147 .dispatch(ctx, &req.method, req.params.as_deref().unwrap_or(&self.empty_params))
148 .await;
149 return ResponseMessage::Empty;
150 };
151
152 let params = req.params.as_deref().unwrap_or(&self.empty_params);
153 match self.dispatch(ctx, &req.method, params).await {
154 Ok(result) => ResponseMessage::Single(Response::Success {
155 result,
156 id,
157 }),
158 Err(error) => ResponseMessage::Single(Response::Error {
159 error,
160 id,
161 }),
162 }
163 }
164
165 async fn handle_batch(&self, ctx: C, entries: Vec<Request>) -> ResponseMessage
166 where
167 C: Clone,
168 {
169 if entries.is_empty() {
170 return ResponseMessage::Single(Response::Error {
171 error: Error::invalid_request("empty batch"),
172 id: Id::Null,
173 });
174 }
175
176 let mut responses: Vec<Response> = Vec::with_capacity(entries.len());
177
178 for req in entries {
179 let Some(id) = req.id.into_id() else {
180 let _ = self
181 .dispatch(ctx.clone(), &req.method, req.params.as_deref().unwrap_or(&self.empty_params))
182 .await;
183 continue;
184 };
185
186 let params = req.params.as_deref().unwrap_or(&self.empty_params);
187 match self.dispatch(ctx.clone(), &req.method, params).await {
188 Ok(result) => responses.push(Response::Success {
189 result,
190 id,
191 }),
192 Err(error) => responses.push(Response::Error {
193 error,
194 id,
195 }),
196 }
197 }
198
199 if responses.is_empty() {
200 ResponseMessage::Empty
201 } else {
202 ResponseMessage::Batch(responses)
203 }
204 }
205
206 async fn dispatch(&self, ctx: C, method: &str, params: &RawValue) -> Result<Box<RawValue>, Error> {
207 let callback = self
208 .methods
209 .get(method)
210 .ok_or_else(|| Error::method_not_found(method))?;
211 callback.call(ctx, params).await
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use crate::{ErrorCode, RequestId};
219
220 fn make_request(method: &str, params: Option<&str>, id: Option<i64>) -> Request {
221 Request {
222 jsonrpc: "2.0".into(),
223 method: method.into(),
224 params: params.map(|s| RawValue::from_string(s.to_owned()).unwrap()),
225 id: RequestId(id.map(Id::Number)),
226 }
227 }
228
229 #[tokio::test]
230 async fn test_simple_handler() {
231 let mut server: Server<()> = Server::new();
232 server.register("add", |_: (), (a, b): (i64, i64)| async move { Ok::<_, Error>(a + b) });
233
234 let req = make_request("add", Some("[3, 4]"), Some(1));
235 let message = server.handle((), RequestMessage::Single(req)).await;
236
237 match message {
238 ResponseMessage::Single(Response::Success {
239 result,
240 id,
241 }) => {
242 assert_eq!(id, Id::Number(1));
243 let v: i64 = serde_json::from_str(result.get()).unwrap();
244 assert_eq!(v, 7);
245 }
246 other => panic!("expected success response, got {other:?}"),
247 }
248 }
249
250 #[tokio::test]
251 async fn test_handler_with_error() {
252 let mut server: Server<()> = Server::new();
253 server.register("div", |_: (), (a, b): (i64, i64)| async move {
254 if b == 0 {
255 Err(Error::new(-32000, "division by zero"))
256 } else {
257 Ok(a / b)
258 }
259 });
260
261 let req = make_request("div", Some("[4, 0]"), Some(1));
262 let message = server.handle((), RequestMessage::Single(req)).await;
263
264 match message {
265 ResponseMessage::Single(Response::Error {
266 error,
267 id,
268 }) => {
269 assert_eq!(id, Id::Number(1));
270 assert_eq!(error.code, -32000);
271 assert_eq!(error.message, "division by zero");
272 }
273 other => panic!("expected error response, got {other:?}"),
274 }
275 }
276
277 #[tokio::test]
278 async fn test_method_not_found() {
279 let server: Server<()> = Server::new();
280 let req = make_request("unknown", None, Some(1));
281 let message = server.handle((), RequestMessage::Single(req)).await;
282
283 match message {
284 ResponseMessage::Single(Response::Error {
285 error,
286 id,
287 }) => {
288 assert_eq!(id, Id::Number(1));
289 assert_eq!(error.code, ErrorCode::METHOD_NOT_FOUND);
290 }
291 other => panic!("expected error response, got {other:?}"),
292 }
293 }
294
295 #[tokio::test]
296 async fn test_invalid_params() {
297 let mut server: Server<()> = Server::new();
298 server.register("add", |_: (), (a, b): (i64, i64)| async move { Ok::<_, Error>(a + b) });
299
300 let req = make_request("add", Some(r#""not_an_array""#), Some(1));
301 let message = server.handle((), RequestMessage::Single(req)).await;
302
303 match message {
304 ResponseMessage::Single(Response::Error {
305 error,
306 id,
307 }) => {
308 assert_eq!(id, Id::Number(1));
309 assert_eq!(error.code, ErrorCode::INVALID_PARAMS);
310 }
311 other => panic!("expected error response, got {other:?}"),
312 }
313 }
314
315 #[tokio::test]
316 async fn test_notification_is_silent() {
317 let mut server: Server<()> = Server::new();
318 server.register("log", |_: (), _message: (String,)| async move { Ok::<_, Error>(()) });
319
320 let req = make_request("log", Some(r#"["hello"]"#), None);
321 let message = server.handle((), RequestMessage::Single(req)).await;
322
323 assert!(matches!(message, ResponseMessage::Empty));
324 }
325
326 #[tokio::test]
327 async fn test_empty_batch() {
328 let server: Server<()> = Server::new();
329 let message = server.handle((), RequestMessage::Batch(vec![])).await;
330
331 match message {
332 ResponseMessage::Single(Response::Error {
333 error,
334 id,
335 }) => {
336 assert_eq!(id, Id::Null);
337 assert_eq!(error.code, ErrorCode::INVALID_REQUEST);
338 }
339 other => panic!("expected single error for empty batch, got {other:?}"),
340 }
341 }
342
343 #[tokio::test]
344 async fn test_batch_mixed() {
345 let mut server: Server<()> = Server::new();
346 server.register("add", |_: (), (a, b): (i64, i64)| async move { Ok::<_, Error>(a + b) });
347
348 let entries = vec![
349 make_request("add", Some("[1, 2]"), Some(1)),
350 make_request("add", Some("[3, 4]"), None),
351 make_request("add", Some("[5, 6]"), Some(2)),
352 ];
353
354 let message = server.handle((), RequestMessage::Batch(entries)).await;
355
356 match message {
357 ResponseMessage::Batch(responses) => {
358 assert_eq!(responses.len(), 2);
359 }
360 other => panic!("expected batch response, got {other:?}"),
361 }
362 }
363
364 #[tokio::test]
365 async fn test_batch_with_invalid_entry() {
366 let mut server: Server<()> = Server::new();
367 server.register("add", |_: (), (a, b): (i64, i64)| async move { Ok::<_, Error>(a + b) });
368
369 let json = r#"[
370 {"jsonrpc":"2.0","method":"add","params":[1,2],"id":1},
371 42,
372 {"jsonrpc":"2.0","method":"add","params":[3,4],"id":2}
373 ]"#;
374 let message: RequestMessage = serde_json::from_str(json).unwrap();
375 let message = server.handle((), message).await;
376
377 match message {
378 ResponseMessage::Batch(responses) => {
379 assert_eq!(responses.len(), 2);
380 assert!(responses[0].is_success());
381 assert!(responses[1].is_success());
382 }
383 other => panic!("expected batch response, got {other:?}"),
384 }
385 }
386
387 #[tokio::test]
388 async fn test_all_notification_batch_is_empty() {
389 let mut server: Server<()> = Server::new();
390 server.register("notify", |_: (), _message: (String,)| async move { Ok::<_, Error>(()) });
391
392 let entries = vec![
393 make_request("notify", Some(r#"["a"]"#), None),
394 make_request("notify", Some(r#"["b"]"#), None),
395 ];
396
397 let message = server.handle((), RequestMessage::Batch(entries)).await;
398 assert!(matches!(message, ResponseMessage::Empty));
399 }
400
401 #[test]
402 fn test_response_message_to_json_single() {
403 let resp = Response::success(Id::Number(1), 42).unwrap();
404 let message = ResponseMessage::Single(resp);
405 let json = message.to_json().unwrap().unwrap();
406 let v: serde_json::Value = serde_json::from_str(&json).unwrap();
407 assert_eq!(v["result"], serde_json::json!(42));
408 }
409
410 #[test]
411 fn test_response_message_to_json_empty() {
412 let message: ResponseMessage = ResponseMessage::Empty;
413 assert!(message.to_json().unwrap().is_none());
414 }
415
416 #[test]
417 fn test_response_message_to_json_batch() {
418 let resps = vec![
419 Response::success(Id::Number(1), 10).unwrap(),
420 Response::success(Id::Number(2), 20).unwrap(),
421 ];
422 let message = ResponseMessage::Batch(resps);
423 let json = message.to_json().unwrap().unwrap();
424 let v: serde_json::Value = serde_json::from_str(&json).unwrap();
425 assert!(v.is_array());
426 assert_eq!(v.as_array().unwrap().len(), 2);
427 }
428
429 #[tokio::test]
430 async fn test_handler_with_context() {
431 #[derive(Clone)]
432 struct State {
433 base: i64,
434 }
435
436 let mut server: Server<State> = Server::new();
437 server.register("add", |ctx: State, (x,): (i64,)| async move { Ok::<_, Error>(ctx.base + x) });
438
439 let state = State {
440 base: 100,
441 };
442 let req = make_request("add", Some("[5]"), Some(1));
443 let message = server.handle(state, RequestMessage::Single(req)).await;
444
445 match message {
446 ResponseMessage::Single(Response::Success {
447 result, ..
448 }) => {
449 let v: i64 = serde_json::from_str(result.get()).unwrap();
450 assert_eq!(v, 105);
451 }
452 other => panic!("expected success, got {other:?}"),
453 }
454 }
455}