1use std::{
5 future::Future,
6 pin::Pin,
7 task::{Context, Poll},
8 time::{Duration, Instant},
9};
10
11use hyper::rt::{Executor, Sleep, Timer};
12use pin_project_lite::pin_project;
13
14#[non_exhaustive]
16#[derive(Default, Debug, Clone)]
17pub struct TokioExecutor {}
18
19pin_project! {
20 #[derive(Debug)]
24 pub struct TokioIo<T> {
25 #[pin]
26 inner: T,
27 }
28}
29
30#[non_exhaustive]
32#[derive(Default, Clone, Debug)]
33pub struct TokioTimer;
34
35pin_project! {
38 #[derive(Debug)]
39 struct TokioSleep {
40 #[pin]
41 inner: tokio::time::Sleep,
42 }
43}
44
45impl<Fut> Executor<Fut> for TokioExecutor
48where
49 Fut: Future + Send + 'static,
50 Fut::Output: Send + 'static,
51{
52 fn execute(&self, fut: Fut) {
53 tokio::spawn(fut);
54 }
55}
56
57impl TokioExecutor {
58 pub fn new() -> Self {
60 Self {}
61 }
62}
63
64impl<T> TokioIo<T> {
67 pub fn new(inner: T) -> Self {
69 Self {
70 inner,
71 }
72 }
73
74 pub fn inner(&self) -> &T {
76 &self.inner
77 }
78
79 pub fn inner_mut(&mut self) -> &mut T {
81 &mut self.inner
82 }
83
84 pub fn into_inner(self) -> T {
86 self.inner
87 }
88}
89
90impl<T> hyper::rt::Read for TokioIo<T>
91where
92 T: tokio::io::AsyncRead,
93{
94 fn poll_read(
95 self: Pin<&mut Self>,
96 cx: &mut Context<'_>,
97 mut buf: hyper::rt::ReadBufCursor<'_>,
98 ) -> Poll<Result<(), std::io::Error>> {
99 let n = unsafe {
100 let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
101 match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
102 Poll::Ready(Ok(())) => tbuf.filled().len(),
103 other => return other,
104 }
105 };
106
107 unsafe {
108 buf.advance(n);
109 }
110 Poll::Ready(Ok(()))
111 }
112}
113
114impl<T> hyper::rt::Write for TokioIo<T>
115where
116 T: tokio::io::AsyncWrite,
117{
118 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
119 tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
120 }
121
122 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
123 tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
124 }
125
126 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
127 tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
128 }
129
130 fn is_write_vectored(&self) -> bool {
131 tokio::io::AsyncWrite::is_write_vectored(&self.inner)
132 }
133
134 fn poll_write_vectored(
135 self: Pin<&mut Self>,
136 cx: &mut Context<'_>,
137 bufs: &[std::io::IoSlice<'_>],
138 ) -> Poll<Result<usize, std::io::Error>> {
139 tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
140 }
141}
142
143impl<T> tokio::io::AsyncRead for TokioIo<T>
144where
145 T: hyper::rt::Read,
146{
147 fn poll_read(
148 self: Pin<&mut Self>,
149 cx: &mut Context<'_>,
150 tbuf: &mut tokio::io::ReadBuf<'_>,
151 ) -> Poll<Result<(), std::io::Error>> {
152 let filled = tbuf.filled().len();
154 let sub_filled = unsafe {
155 let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut());
156
157 match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) {
158 Poll::Ready(Ok(())) => buf.filled().len(),
159 other => return other,
160 }
161 };
162
163 let n_filled = filled + sub_filled;
164 let n_init = sub_filled;
166 unsafe {
167 tbuf.assume_init(n_init);
168 tbuf.set_filled(n_filled);
169 }
170
171 Poll::Ready(Ok(()))
172 }
173}
174
175impl<T> tokio::io::AsyncWrite for TokioIo<T>
176where
177 T: hyper::rt::Write,
178{
179 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
180 hyper::rt::Write::poll_write(self.project().inner, cx, buf)
181 }
182
183 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
184 hyper::rt::Write::poll_flush(self.project().inner, cx)
185 }
186
187 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
188 hyper::rt::Write::poll_shutdown(self.project().inner, cx)
189 }
190
191 fn is_write_vectored(&self) -> bool {
192 hyper::rt::Write::is_write_vectored(&self.inner)
193 }
194
195 fn poll_write_vectored(
196 self: Pin<&mut Self>,
197 cx: &mut Context<'_>,
198 bufs: &[std::io::IoSlice<'_>],
199 ) -> Poll<Result<usize, std::io::Error>> {
200 hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
201 }
202}
203
204impl Timer for TokioTimer {
207 fn sleep(&self, duration: Duration) -> Pin<Box<dyn Sleep>> {
208 Box::pin(TokioSleep {
209 inner: tokio::time::sleep(duration),
210 })
211 }
212
213 fn sleep_until(&self, deadline: Instant) -> Pin<Box<dyn Sleep>> {
214 Box::pin(TokioSleep {
215 inner: tokio::time::sleep_until(deadline.into()),
216 })
217 }
218
219 fn reset(&self, sleep: &mut Pin<Box<dyn Sleep>>, new_deadline: Instant) {
220 if let Some(sleep) = sleep.as_mut().downcast_mut_pin::<TokioSleep>() {
221 sleep.reset(new_deadline)
222 }
223 }
224}
225
226impl TokioTimer {
227 pub fn new() -> Self {
229 Self {}
230 }
231}
232
233impl Future for TokioSleep {
234 type Output = ();
235
236 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
237 self.project().inner.poll(cx)
238 }
239}
240
241impl Sleep for TokioSleep {}
242
243impl TokioSleep {
244 fn reset(self: Pin<&mut Self>, deadline: Instant) {
245 self.project().inner.as_mut().reset(deadline.into());
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use hyper::rt::Executor;
252 use tokio::sync::oneshot;
253
254 use crate::rt::TokioExecutor;
255
256 #[cfg(not(miri))]
257 #[tokio::test]
258 async fn simple_execute() -> Result<(), Box<dyn std::error::Error>> {
259 let (tx, rx) = oneshot::channel();
260 let executor = TokioExecutor::new();
261 executor.execute(async move {
262 tx.send(()).unwrap();
263 });
264 rx.await.map_err(Into::into)
265 }
266}