Skip to main content

hyper_utils/http_body_util/
collected.rs

1use std::{
2    convert::Infallible,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use bytes::{Buf, Bytes};
8use hyper::{
9    body::{Body, Frame},
10    http::HeaderMap,
11};
12
13use super::buf_list::BufList;
14
15/// A collected body produced by [`BodyExt::collect`] which collects all the DATA frames
16/// and trailers.
17///
18/// [`BodyExt::collect`]: crate::BodyExt::collect
19#[derive(Debug)]
20pub struct Collected<B> {
21    bufs: BufList<B>,
22    trailers: Option<HeaderMap>,
23}
24
25impl<B: Buf> Collected<B> {
26    /// If there is a trailers frame buffered, returns a reference to it.
27    ///
28    /// Returns `None` if the body contained no trailers.
29    pub fn trailers(&self) -> Option<&HeaderMap> {
30        self.trailers.as_ref()
31    }
32
33    /// Aggregate this buffered into a [`Buf`].
34    pub fn aggregate(self) -> impl Buf {
35        self.bufs
36    }
37
38    /// Convert this body into a [`Bytes`].
39    pub fn to_bytes(mut self) -> Bytes {
40        self.bufs.copy_to_bytes(self.bufs.remaining())
41    }
42
43    pub(crate) fn push_frame(&mut self, frame: Frame<B>) {
44        let frame = match frame.into_data() {
45            Ok(data) => {
46                // Only push this frame if it has some data in it, to avoid crashing on
47                // `BufList::push`.
48                if data.has_remaining() {
49                    self.bufs.push(data);
50                }
51                return;
52            }
53            Err(frame) => frame,
54        };
55
56        if let Ok(trailers) = frame.into_trailers() {
57            if let Some(current) = &mut self.trailers {
58                current.extend(trailers);
59            } else {
60                self.trailers = Some(trailers);
61            }
62        };
63    }
64}
65
66impl<B: Buf> Body for Collected<B> {
67    type Data = B;
68    type Error = Infallible;
69
70    fn poll_frame(
71        mut self: Pin<&mut Self>,
72        _: &mut Context<'_>,
73    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
74        let frame = if let Some(data) = self.bufs.pop() {
75            Frame::data(data)
76        } else if let Some(trailers) = self.trailers.take() {
77            Frame::trailers(trailers)
78        } else {
79            return Poll::Ready(None);
80        };
81
82        Poll::Ready(Some(Ok(frame)))
83    }
84}
85
86impl<B> Default for Collected<B> {
87    fn default() -> Self {
88        Self {
89            bufs: BufList::default(),
90            trailers: None,
91        }
92    }
93}
94
95impl<B> Unpin for Collected<B> {}
96
97#[cfg(test)]
98mod tests {
99    use std::convert::TryInto;
100
101    use futures_util::stream;
102
103    use super::*;
104    use crate::http_body_util::{BodyExt, Full, StreamBody};
105
106    #[tokio::test]
107    async fn full_body() {
108        let body = Full::new(&b"hello"[..]);
109
110        let buffered = body.collect().await.unwrap();
111
112        let mut buf = buffered.to_bytes();
113
114        assert_eq!(&buf.copy_to_bytes(buf.remaining())[..], &b"hello"[..]);
115    }
116
117    #[tokio::test]
118    async fn segmented_body() {
119        let bufs = [&b"hello"[..], &b"world"[..], &b"!"[..]];
120        let body = StreamBody::new(stream::iter(bufs.map(Frame::data).map(Ok::<_, Infallible>)));
121
122        let buffered = body.collect().await.unwrap();
123
124        let mut buf = buffered.to_bytes();
125
126        assert_eq!(&buf.copy_to_bytes(buf.remaining())[..], b"helloworld!");
127    }
128
129    #[tokio::test]
130    async fn delayed_segments() {
131        let one = stream::once(async { Ok::<_, Infallible>(Frame::data(&b"hello "[..])) });
132        let two = stream::once(async {
133            // a yield just so its not ready immediately
134            tokio::task::yield_now().await;
135            Ok::<_, Infallible>(Frame::data(&b"world!"[..]))
136        });
137        let stream = futures_util::StreamExt::chain(one, two);
138
139        let body = StreamBody::new(stream);
140
141        let buffered = body.collect().await.unwrap();
142
143        let mut buf = buffered.to_bytes();
144
145        assert_eq!(&buf.copy_to_bytes(buf.remaining())[..], b"hello world!");
146    }
147
148    #[tokio::test]
149    async fn trailers() {
150        let mut trailers = HeaderMap::new();
151        trailers.insert("this", "a trailer".try_into().unwrap());
152        let bufs = [
153            Frame::data(&b"hello"[..]),
154            Frame::data(&b"world!"[..]),
155            Frame::trailers(trailers.clone()),
156        ];
157
158        let body = StreamBody::new(stream::iter(bufs.map(Ok::<_, Infallible>)));
159
160        let buffered = body.collect().await.unwrap();
161
162        assert_eq!(&trailers, buffered.trailers().unwrap());
163
164        let mut buf = buffered.to_bytes();
165
166        assert_eq!(&buf.copy_to_bytes(buf.remaining())[..], b"helloworld!");
167    }
168
169    /// Test for issue [#88](https://github.com/hyperium/http-body/issues/88).
170    #[tokio::test]
171    async fn empty_frame() {
172        let bufs: [&[u8]; 1] = [&[]];
173
174        let body = StreamBody::new(stream::iter(bufs.map(Frame::data).map(Ok::<_, Infallible>)));
175        let buffered = body.collect().await.unwrap();
176
177        assert_eq!(buffered.to_bytes().len(), 0);
178    }
179}