Skip to main content

s3/
client.rs

1use std::{
2    borrow::Cow,
3    collections::BTreeMap,
4    fmt,
5    pin::Pin,
6    time::{SystemTime, UNIX_EPOCH},
7};
8
9use bytes::Bytes;
10use crypto::{Hasher, hmac::Hmac, sha2::Sha256};
11use futures_util::{Stream, StreamExt};
12use url::Url;
13
14pub(crate) const EMPTY_PAYLOAD_SHA256: &str = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
15
16#[derive(Debug, Clone)]
17pub struct StaticCredentials<'a> {
18    pub access_key_id: &'a str,
19    pub secret_access_key: &'a str,
20    pub session_token: &'a str,
21}
22
23#[derive(Debug, Clone)]
24pub struct ClientConfig<'a> {
25    pub endpoint: &'a str,
26    pub credentials: StaticCredentials<'a>,
27    pub region: &'a str,
28    pub virtual_hosted: bool,
29}
30
31#[derive(Debug, Clone)]
32pub(crate) struct OwnedCredentials {
33    pub(crate) access_key_id: String,
34    pub(crate) secret_access_key: String,
35    pub(crate) session_token: Option<String>,
36}
37
38#[derive(Debug)]
39pub struct Client<H: HttpClient> {
40    pub(crate) endpoint: Url,
41    pub(crate) region: String,
42    pub(crate) credentials: OwnedCredentials,
43    pub(crate) virtual_hosted: bool,
44    pub(crate) http: H,
45}
46
47#[derive(Debug)]
48pub enum Error {
49    InvalidConfig(&'static str),
50    Http(Box<dyn std::error::Error + Send + Sync>),
51    Time(std::time::SystemTimeError),
52    Xml(quick_xml::DeError),
53    Api { status: u16, body: String },
54}
55
56impl fmt::Display for Error {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        match self {
59            Error::InvalidConfig(msg) => write!(f, "invalid config: {msg}"),
60            Error::Http(err) => write!(f, "http error: {err}"),
61            Error::Time(err) => write!(f, "time error: {err}"),
62            Error::Xml(err) => write!(f, "xml error: {err}"),
63            Error::Api {
64                status,
65                body,
66            } => write!(f, "s3 api error (status {status}): {body}"),
67        }
68    }
69}
70
71impl std::error::Error for Error {
72    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
73        match self {
74            Error::Http(err) => Some(&**err),
75            _ => None,
76        }
77    }
78}
79
80impl From<std::time::SystemTimeError> for Error {
81    fn from(value: std::time::SystemTimeError) -> Self {
82        Self::Time(value)
83    }
84}
85
86impl From<quick_xml::DeError> for Error {
87    fn from(value: quick_xml::DeError) -> Self {
88        Self::Xml(value)
89    }
90}
91
92pub type HttpError = Box<dyn std::error::Error + Send + Sync>;
93
94/// A pinned, boxed async stream of byte chunks.
95pub type ByteStream = Pin<Box<dyn Stream<Item = Result<Bytes, HttpError>> + Send>>;
96
97#[derive(Debug, Clone, Copy)]
98pub enum HttpMethod {
99    Get,
100    Put,
101    Post,
102    Delete,
103    Head,
104}
105
106impl HttpMethod {
107    pub(crate) fn as_str(self) -> &'static str {
108        match self {
109            HttpMethod::Get => "GET",
110            HttpMethod::Put => "PUT",
111            HttpMethod::Post => "POST",
112            HttpMethod::Delete => "DELETE",
113            HttpMethod::Head => "HEAD",
114        }
115    }
116}
117
118#[derive(Debug, Clone)]
119pub struct HttpRequest {
120    pub method: HttpMethod,
121    pub url: String,
122    pub headers: Vec<(String, String)>,
123    pub body: Vec<u8>,
124}
125
126pub struct HttpResponseData {
127    pub status_code: u16,
128    pub headers: Vec<(String, String)>,
129    pub body: ByteStream,
130}
131
132impl HttpResponseData {
133    pub fn header(&self, name: &str) -> Option<&str> {
134        self.headers
135            .iter()
136            .find(|(k, _)| k.eq_ignore_ascii_case(name))
137            .map(|(_, v)| v.as_str())
138    }
139}
140
141pub trait HttpClient: Send + Sync {
142    fn send(
143        &self,
144        request: HttpRequest,
145    ) -> impl std::future::Future<Output = Result<HttpResponseData, HttpError>> + Send;
146}
147
148#[cfg(feature = "reqwest")]
149#[derive(Debug, Default, Clone)]
150pub struct ReqwestHttpClient {
151    inner: reqwest::Client,
152}
153
154#[cfg(feature = "reqwest")]
155impl ReqwestHttpClient {
156    pub fn new() -> Self {
157        Self {
158            inner: reqwest::Client::new(),
159        }
160    }
161}
162
163#[cfg(feature = "reqwest")]
164impl HttpClient for ReqwestHttpClient {
165    async fn send(&self, request: HttpRequest) -> Result<HttpResponseData, HttpError> {
166        let method = reqwest::Method::from_bytes(request.method.as_str().as_bytes())?;
167        let mut req = self.inner.request(method, &request.url);
168        for (name, value) in request.headers {
169            req = req.header(&name, &value);
170        }
171        if !request.body.is_empty() {
172            req = req.body(request.body);
173        }
174        let response = req.send().await?;
175        let status_code = response.status().as_u16();
176        let headers = response
177            .headers()
178            .iter()
179            .filter_map(|(name, value)| value.to_str().ok().map(|v| (name.as_str().to_string(), v.to_string())))
180            .collect();
181        let body: ByteStream = Box::pin(
182            response
183                .bytes_stream()
184                .map(|r| r.map_err(|e| -> HttpError { Box::new(e) })),
185        );
186        Ok(HttpResponseData {
187            status_code,
188            headers,
189            body,
190        })
191    }
192}
193
194impl<H: HttpClient> Client<H> {
195    pub fn with_http_client(config: &ClientConfig<'_>, http: H) -> Result<Self, Error> {
196        if config.endpoint.trim().is_empty() {
197            return Err(Error::InvalidConfig("endpoint must not be empty"));
198        }
199        if config.region.trim().is_empty() {
200            return Err(Error::InvalidConfig("region must not be empty"));
201        }
202        if config.credentials.access_key_id.trim().is_empty() {
203            return Err(Error::InvalidConfig("access key id must not be empty"));
204        }
205        if config.credentials.secret_access_key.trim().is_empty() {
206            return Err(Error::InvalidConfig("secret access key must not be empty"));
207        }
208
209        let endpoint = Url::parse(config.endpoint).map_err(|_| Error::InvalidConfig("invalid endpoint URL"))?;
210
211        Ok(Self {
212            endpoint,
213            region: config.region.to_string(),
214            credentials: OwnedCredentials {
215                access_key_id: config.credentials.access_key_id.to_string(),
216                secret_access_key: config.credentials.secret_access_key.to_string(),
217                session_token: if config.credentials.session_token.is_empty() {
218                    None
219                } else {
220                    Some(config.credentials.session_token.to_string())
221                },
222            },
223            virtual_hosted: config.virtual_hosted,
224            http,
225        })
226    }
227
228    pub(crate) async fn execute(
229        &self,
230        method: HttpMethod,
231        canonical_uri: &str,
232        canonical_query: &str,
233        body: &[u8],
234        bucket: &str,
235    ) -> Result<HttpResponseData, Error> {
236        self.execute_with_headers(method, canonical_uri, canonical_query, body, &[], bucket)
237            .await
238    }
239
240    pub(crate) async fn execute_with_headers(
241        &self,
242        method: HttpMethod,
243        canonical_uri: &str,
244        canonical_query: &str,
245        body: &[u8],
246        extra_headers: &[(String, String)],
247        bucket: &str,
248    ) -> Result<HttpResponseData, Error> {
249        fn canonical_header_value(value: &str) -> String {
250            let trimmed = value.trim();
251            let mut out = String::with_capacity(trimmed.len());
252            let mut in_space = false;
253            for ch in trimmed.chars() {
254                if ch.is_whitespace() {
255                    if !in_space {
256                        out.push(' ');
257                        in_space = true;
258                    }
259                } else {
260                    out.push(ch);
261                    in_space = false;
262                }
263            }
264            out
265        }
266
267        let (date, amz_datetime) = amz_datetime(SystemTime::now())?;
268        let credential_scope = format!("{}/{}/s3/aws4_request", date, self.region);
269
270        let base_host = endpoint_host(&self.endpoint);
271        let host = if self.virtual_hosted && !bucket.is_empty() {
272            Cow::Owned(format!("{bucket}.{base_host}"))
273        } else {
274            Cow::Borrowed(&base_host)
275        };
276        let payload_hash = if body.is_empty() {
277            EMPTY_PAYLOAD_SHA256.to_string()
278        } else {
279            hex::encode(&sha256(body))
280        };
281
282        let mut headers = vec![
283            ("host".to_string(), host.to_string()),
284            ("x-amz-date".to_string(), amz_datetime.clone()),
285            ("x-amz-content-sha256".to_string(), payload_hash.clone()),
286        ];
287        headers.extend(extra_headers.iter().cloned());
288        if let Some(token) = self.credentials.session_token.as_deref() {
289            headers.push(("x-amz-security-token".to_string(), token.to_string()));
290        }
291
292        let mut canonical_headers = headers
293            .iter()
294            .map(|(name, value)| format!("{}:{}\n", name.to_ascii_lowercase(), canonical_header_value(value)))
295            .collect::<Vec<_>>();
296        let mut signed_headers = headers
297            .iter()
298            .map(|(name, _)| name.to_ascii_lowercase())
299            .collect::<Vec<_>>();
300
301        canonical_headers.sort();
302        signed_headers.sort();
303
304        let canonical_headers_joined = canonical_headers.concat();
305        let signed_headers_joined = signed_headers.join(";");
306
307        let canonical_request = format!(
308            "{}\n{}\n{}\n{}\n{}\n{}",
309            method.as_str(),
310            canonical_uri,
311            canonical_query,
312            canonical_headers_joined,
313            signed_headers_joined,
314            payload_hash
315        );
316
317        let string_to_sign = format!(
318            "AWS4-HMAC-SHA256\n{}\n{}\n{}",
319            amz_datetime,
320            credential_scope,
321            hex::encode(&sha256(canonical_request.as_bytes()))
322        );
323
324        let signature = hex::encode(&sign_v4(
325            &self.credentials.secret_access_key,
326            &date,
327            &self.region,
328            &string_to_sign,
329        ));
330
331        let authorization = format!(
332            "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
333            self.credentials.access_key_id, credential_scope, signed_headers_joined, signature
334        );
335
336        let url = if self.virtual_hosted && !bucket.is_empty() {
337            let scheme = self.endpoint.scheme();
338            let host_str = endpoint_host(&self.endpoint);
339            let path = canonical_uri;
340            if canonical_query.is_empty() {
341                format!("{scheme}://{bucket}.{host_str}{path}")
342            } else {
343                format!("{scheme}://{bucket}.{host_str}{path}?{canonical_query}")
344            }
345        } else if canonical_query.is_empty() {
346            format!("{}{}", self.endpoint.as_str().trim_end_matches('/'), canonical_uri)
347        } else {
348            format!(
349                "{}{}?{}",
350                self.endpoint.as_str().trim_end_matches('/'),
351                canonical_uri,
352                canonical_query
353            )
354        };
355
356        headers.push(("authorization".to_string(), authorization));
357
358        let request = HttpRequest {
359            method,
360            url,
361            headers,
362            body: body.to_vec(),
363        };
364        let response = self.http.send(request).await.map_err(Error::Http)?;
365
366        if (200..300).contains(&response.status_code) {
367            return Ok(response);
368        }
369
370        let status = response.status_code;
371        let body_bytes = collect_body(response.body).await.unwrap_or_default();
372        let body = String::from_utf8_lossy(&body_bytes).into_owned();
373        Err(Error::Api {
374            status,
375            body,
376        })
377    }
378}
379
380#[cfg(feature = "reqwest")]
381impl Client<ReqwestHttpClient> {
382    pub fn new(config: &ClientConfig<'_>) -> Result<Self, Error> {
383        Self::with_http_client(config, ReqwestHttpClient::new())
384    }
385}
386
387pub(crate) fn consume_empty(_response: HttpResponseData) -> Result<(), Error> {
388    Ok(())
389}
390
391pub(crate) fn canonical_bucket_uri(bucket: &str) -> String {
392    canonical_uri(&format!("/{bucket}"))
393}
394
395pub(crate) fn canonical_object_uri(bucket: &str, key: &str) -> String {
396    canonical_uri(&format!("/{bucket}/{key}"))
397}
398
399pub(crate) fn canonical_uri(path: &str) -> String {
400    path.split('/').map(percent_encode).collect::<Vec<_>>().join("/")
401}
402
403pub(crate) fn canonical_query_string(params: &BTreeMap<String, String>) -> String {
404    params
405        .iter()
406        .map(|(k, v)| format!("{}={}", percent_encode(k), percent_encode(v)))
407        .collect::<Vec<_>>()
408        .join("&")
409}
410
411fn endpoint_host(url: &Url) -> String {
412    match (url.host_str(), url.port()) {
413        (Some(host), Some(port)) => format!("{host}:{port}"),
414        (Some(host), None) => host.to_string(),
415        _ => String::new(),
416    }
417}
418
419pub(crate) fn percent_encode(input: &str) -> String {
420    let mut out = String::with_capacity(input.len());
421    for &b in input.as_bytes() {
422        if b.is_ascii_uppercase()
423            || b.is_ascii_lowercase()
424            || b.is_ascii_digit()
425            || matches!(b, b'-' | b'_' | b'.' | b'~')
426        {
427            out.push(b as char);
428        } else {
429            out.push('%');
430            out.push(hex_nibble_upper((b >> 4) & 0x0f));
431            out.push(hex_nibble_upper(b & 0x0f));
432        }
433    }
434    out
435}
436
437fn hex_nibble_upper(value: u8) -> char {
438    match value {
439        0..=9 => (b'0' + value) as char,
440        10..=15 => (b'A' + (value - 10)) as char,
441        _ => unreachable!(),
442    }
443}
444
445fn sign_v4(secret_access_key: &str, date: &str, region: &str, string_to_sign: &str) -> [u8; 32] {
446    let k_date = hmac_sha256(format!("AWS4{secret_access_key}").as_bytes(), date.as_bytes());
447    let k_region = hmac_sha256(&k_date, region.as_bytes());
448    let k_service = hmac_sha256(&k_region, b"s3");
449    let k_signing = hmac_sha256(&k_service, b"aws4_request");
450    hmac_sha256(&k_signing, string_to_sign.as_bytes())
451}
452
453pub(crate) fn hmac_sha256(key: &[u8], data: &[u8]) -> [u8; Sha256::OUTPUT_SIZE] {
454    let mut mac = Hmac::<Sha256>::new(key);
455    mac.update(data);
456    *mac.finalize().as_ref().as_array().unwrap()
457}
458
459pub(crate) fn sha256(data: &[u8]) -> [u8; Sha256::OUTPUT_SIZE] {
460    let mut hasher = Sha256::new();
461    hasher.update(data);
462    hasher.sum().as_ref().try_into().unwrap()
463}
464
465fn amz_datetime(now: SystemTime) -> Result<(String, String), Error> {
466    let elapsed = now.duration_since(UNIX_EPOCH)?;
467    let total_seconds = elapsed.as_secs() as i64;
468
469    let days = total_seconds.div_euclid(86_400);
470    let seconds_of_day = total_seconds.rem_euclid(86_400);
471
472    let (year, month, day) = civil_from_days(days);
473    let hour = seconds_of_day / 3_600;
474    let minute = (seconds_of_day % 3_600) / 60;
475    let second = seconds_of_day % 60;
476
477    let date = format!("{year:04}{month:02}{day:02}");
478    let datetime = format!("{date}T{hour:02}{minute:02}{second:02}Z");
479
480    Ok((date, datetime))
481}
482
483fn civil_from_days(days_since_unix_epoch: i64) -> (i32, i64, i64) {
484    let z = days_since_unix_epoch + 719_468;
485    let era = if z >= 0 { z } else { z - 146_096 } / 146_097;
486    let doe = z - era * 146_097;
487    let yoe = (doe - doe / 1_460 + doe / 36_524 - doe / 146_096) / 365;
488    let y = yoe + era * 400;
489    let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
490    let mp = (5 * doy + 2) / 153;
491    let day = doy - (153 * mp + 2) / 5 + 1;
492    let month = mp + if mp < 10 { 3 } else { -9 };
493    let year = y + if month <= 2 { 1 } else { 0 };
494
495    (year as i32, month, day)
496}
497
498pub(crate) async fn collect_body(mut stream: ByteStream) -> Result<Vec<u8>, Error> {
499    let mut out = Vec::new();
500    while let Some(chunk) = stream.next().await {
501        let bytes = chunk.map_err(Error::Http)?;
502        out.extend_from_slice(&bytes);
503    }
504    Ok(out)
505}
506
507pub(crate) fn bytes_to_string(bytes: Vec<u8>) -> Result<String, Error> {
508    String::from_utf8(bytes).map_err(|e| Error::Http(Box::new(e)))
509}
510
511pub(crate) fn header_to_string(response: &HttpResponseData, name: &str) -> Option<String> {
512    response.header(name).map(ToString::to_string)
513}
514
515pub(crate) fn header_to_u64(response: &HttpResponseData, name: &str) -> Option<u64> {
516    response.header(name).and_then(|s| s.parse::<u64>().ok())
517}
518
519pub(crate) fn xml_escape(s: &str) -> String {
520    s.replace('&', "&amp;").replace('<', "&lt;").replace('>', "&gt;")
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526
527    #[derive(Default)]
528    struct NoopHttpClient;
529
530    impl HttpClient for NoopHttpClient {
531        async fn send(&self, _request: HttpRequest) -> Result<HttpResponseData, HttpError> {
532            Ok(HttpResponseData {
533                status_code: 200,
534                headers: Vec::new(),
535                body: Box::pin(futures_util::stream::empty()),
536            })
537        }
538    }
539
540    #[test]
541    fn sha256_known_vectors() {
542        assert_eq!(
543            hex::encode(&sha256(b"")),
544            "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
545        );
546        assert_eq!(
547            hex::encode(&sha256(b"abc")),
548            "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
549        );
550    }
551
552    #[test]
553    fn hmac_sha256_known_vector() {
554        let key = [0x0b; 20];
555        let sig = hmac_sha256(&key, b"Hi There");
556        assert_eq!(
557            hex::encode(&sig),
558            "b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7"
559        );
560    }
561
562    #[test]
563    fn percent_encoding_works() {
564        assert_eq!(percent_encode("abcXYZ-_.~"), "abcXYZ-_.~");
565        assert_eq!(percent_encode("hello world/é"), "hello%20world%2F%C3%A9");
566    }
567
568    #[test]
569    fn canonical_uri_preserves_slashes() {
570        assert_eq!(
571            canonical_object_uri("my-bucket", "folder/my file.txt"),
572            "/my-bucket/folder/my%20file.txt"
573        );
574    }
575
576    #[test]
577    fn amz_datetime_format_works() {
578        let ts = UNIX_EPOCH + std::time::Duration::from_secs(1_700_000_000);
579        let (date, datetime) = amz_datetime(ts).unwrap();
580        assert_eq!(date.len(), 8);
581        assert_eq!(datetime.len(), 16);
582        assert!(datetime.ends_with('Z'));
583    }
584
585    #[test]
586    fn client_config_validation_works() {
587        let cfg = ClientConfig {
588            endpoint: "",
589            credentials: StaticCredentials {
590                access_key_id: "a",
591                secret_access_key: "b",
592                session_token: "",
593            },
594            region: "auto",
595            virtual_hosted: false,
596        };
597        assert!(matches!(
598            Client::with_http_client(&cfg, NoopHttpClient),
599            Err(Error::InvalidConfig(_))
600        ));
601    }
602
603    #[test]
604    fn sign_v4_known_output_length() {
605        let sig = sign_v4(
606            "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
607            "20130524",
608            "auto",
609            "AWS4-HMAC-SHA256\n20130524T000000Z\n20130524/auto/s3/aws4_request\nabc",
610        );
611        assert_eq!(sig.len(), 32);
612    }
613
614    #[test]
615    fn empty_payload_sha256_constant_is_correct() {
616        assert_eq!(EMPTY_PAYLOAD_SHA256, hex::encode(&sha256(b"")));
617    }
618}