Skip to main content

hyper_utils/
rt.rs

1// Some code from https://github.com/hyperium/hyper-util
2
3//! Tokio IO integration for hyper
4use std::{
5    future::Future,
6    pin::Pin,
7    task::{Context, Poll},
8    time::{Duration, Instant},
9};
10
11use hyper::rt::{Executor, Sleep, Timer};
12use pin_project_lite::pin_project;
13
14/// Future executor that utilises `tokio` threads.
15#[non_exhaustive]
16#[derive(Default, Debug, Clone)]
17pub struct TokioExecutor {}
18
19pin_project! {
20    /// A wrapper that implements Tokio's IO traits for an inner type that
21    /// implements hyper's IO traits, or vice versa (implements hyper's IO
22    /// traits for a type that implements Tokio's IO traits).
23    #[derive(Debug)]
24    pub struct TokioIo<T> {
25        #[pin]
26        inner: T,
27    }
28}
29
30/// A Timer that uses the tokio runtime.
31#[non_exhaustive]
32#[derive(Default, Clone, Debug)]
33pub struct TokioTimer;
34
35// Use TokioSleep to get tokio::time::Sleep to implement Unpin.
36// see https://docs.rs/tokio/latest/tokio/time/struct.Sleep.html
37pin_project! {
38    #[derive(Debug)]
39    struct TokioSleep {
40        #[pin]
41        inner: tokio::time::Sleep,
42    }
43}
44
45// ===== impl TokioExecutor =====
46
47impl<Fut> Executor<Fut> for TokioExecutor
48where
49    Fut: Future + Send + 'static,
50    Fut::Output: Send + 'static,
51{
52    fn execute(&self, fut: Fut) {
53        tokio::spawn(fut);
54    }
55}
56
57impl TokioExecutor {
58    /// Create new executor that relies on [`tokio::spawn`] to execute futures.
59    pub fn new() -> Self {
60        Self {}
61    }
62}
63
64// ==== impl TokioIo =====
65
66impl<T> TokioIo<T> {
67    /// Wrap a type implementing Tokio's or hyper's IO traits.
68    pub fn new(inner: T) -> Self {
69        Self {
70            inner,
71        }
72    }
73
74    /// Borrow the inner type.
75    pub fn inner(&self) -> &T {
76        &self.inner
77    }
78
79    /// Mut borrow the inner type.
80    pub fn inner_mut(&mut self) -> &mut T {
81        &mut self.inner
82    }
83
84    /// Consume this wrapper and get the inner type.
85    pub fn into_inner(self) -> T {
86        self.inner
87    }
88}
89
90impl<T> hyper::rt::Read for TokioIo<T>
91where
92    T: tokio::io::AsyncRead,
93{
94    fn poll_read(
95        self: Pin<&mut Self>,
96        cx: &mut Context<'_>,
97        mut buf: hyper::rt::ReadBufCursor<'_>,
98    ) -> Poll<Result<(), std::io::Error>> {
99        let n = unsafe {
100            let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
101            match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
102                Poll::Ready(Ok(())) => tbuf.filled().len(),
103                other => return other,
104            }
105        };
106
107        unsafe {
108            buf.advance(n);
109        }
110        Poll::Ready(Ok(()))
111    }
112}
113
114impl<T> hyper::rt::Write for TokioIo<T>
115where
116    T: tokio::io::AsyncWrite,
117{
118    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
119        tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
120    }
121
122    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
123        tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
124    }
125
126    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
127        tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
128    }
129
130    fn is_write_vectored(&self) -> bool {
131        tokio::io::AsyncWrite::is_write_vectored(&self.inner)
132    }
133
134    fn poll_write_vectored(
135        self: Pin<&mut Self>,
136        cx: &mut Context<'_>,
137        bufs: &[std::io::IoSlice<'_>],
138    ) -> Poll<Result<usize, std::io::Error>> {
139        tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
140    }
141}
142
143impl<T> tokio::io::AsyncRead for TokioIo<T>
144where
145    T: hyper::rt::Read,
146{
147    fn poll_read(
148        self: Pin<&mut Self>,
149        cx: &mut Context<'_>,
150        tbuf: &mut tokio::io::ReadBuf<'_>,
151    ) -> Poll<Result<(), std::io::Error>> {
152        //let init = tbuf.initialized().len();
153        let filled = tbuf.filled().len();
154        let sub_filled = unsafe {
155            let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut());
156
157            match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) {
158                Poll::Ready(Ok(())) => buf.filled().len(),
159                other => return other,
160            }
161        };
162
163        let n_filled = filled + sub_filled;
164        // At least sub_filled bytes had to have been initialized.
165        let n_init = sub_filled;
166        unsafe {
167            tbuf.assume_init(n_init);
168            tbuf.set_filled(n_filled);
169        }
170
171        Poll::Ready(Ok(()))
172    }
173}
174
175impl<T> tokio::io::AsyncWrite for TokioIo<T>
176where
177    T: hyper::rt::Write,
178{
179    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
180        hyper::rt::Write::poll_write(self.project().inner, cx, buf)
181    }
182
183    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
184        hyper::rt::Write::poll_flush(self.project().inner, cx)
185    }
186
187    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
188        hyper::rt::Write::poll_shutdown(self.project().inner, cx)
189    }
190
191    fn is_write_vectored(&self) -> bool {
192        hyper::rt::Write::is_write_vectored(&self.inner)
193    }
194
195    fn poll_write_vectored(
196        self: Pin<&mut Self>,
197        cx: &mut Context<'_>,
198        bufs: &[std::io::IoSlice<'_>],
199    ) -> Poll<Result<usize, std::io::Error>> {
200        hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
201    }
202}
203
204// ==== impl TokioTimer =====
205
206impl Timer for TokioTimer {
207    fn sleep(&self, duration: Duration) -> Pin<Box<dyn Sleep>> {
208        Box::pin(TokioSleep {
209            inner: tokio::time::sleep(duration),
210        })
211    }
212
213    fn sleep_until(&self, deadline: Instant) -> Pin<Box<dyn Sleep>> {
214        Box::pin(TokioSleep {
215            inner: tokio::time::sleep_until(deadline.into()),
216        })
217    }
218
219    fn reset(&self, sleep: &mut Pin<Box<dyn Sleep>>, new_deadline: Instant) {
220        if let Some(sleep) = sleep.as_mut().downcast_mut_pin::<TokioSleep>() {
221            sleep.reset(new_deadline)
222        }
223    }
224}
225
226impl TokioTimer {
227    /// Create a new TokioTimer
228    pub fn new() -> Self {
229        Self {}
230    }
231}
232
233impl Future for TokioSleep {
234    type Output = ();
235
236    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
237        self.project().inner.poll(cx)
238    }
239}
240
241impl Sleep for TokioSleep {}
242
243impl TokioSleep {
244    fn reset(self: Pin<&mut Self>, deadline: Instant) {
245        self.project().inner.as_mut().reset(deadline.into());
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use hyper::rt::Executor;
252    use tokio::sync::oneshot;
253
254    use crate::rt::TokioExecutor;
255
256    #[cfg(not(miri))]
257    #[tokio::test]
258    async fn simple_execute() -> Result<(), Box<dyn std::error::Error>> {
259        let (tx, rx) = oneshot::channel();
260        let executor = TokioExecutor::new();
261        executor.execute(async move {
262            tx.send(()).unwrap();
263        });
264        rx.await.map_err(Into::into)
265    }
266}