1use std::{
2 borrow::Cow,
3 collections::BTreeMap,
4 fmt,
5 pin::Pin,
6 time::{SystemTime, UNIX_EPOCH},
7};
8
9use bytes::Bytes;
10use crypto::{Hasher, hmac::Hmac, sha2::Sha256};
11use futures_util::{Stream, StreamExt};
12use url::Url;
13
14pub(crate) const EMPTY_PAYLOAD_SHA256: &str = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
15
16#[derive(Debug, Clone)]
17pub struct StaticCredentials<'a> {
18 pub access_key_id: &'a str,
19 pub secret_access_key: &'a str,
20 pub session_token: &'a str,
21}
22
23#[derive(Debug, Clone)]
24pub struct ClientConfig<'a> {
25 pub endpoint: &'a str,
26 pub credentials: StaticCredentials<'a>,
27 pub region: &'a str,
28 pub virtual_hosted: bool,
29}
30
31#[derive(Debug, Clone)]
32pub(crate) struct OwnedCredentials {
33 pub(crate) access_key_id: String,
34 pub(crate) secret_access_key: String,
35 pub(crate) session_token: Option<String>,
36}
37
38#[derive(Debug)]
39pub struct Client<H: HttpClient> {
40 pub(crate) endpoint: Url,
41 pub(crate) region: String,
42 pub(crate) credentials: OwnedCredentials,
43 pub(crate) virtual_hosted: bool,
44 pub(crate) http: H,
45}
46
47#[derive(Debug)]
48pub enum Error {
49 InvalidConfig(&'static str),
50 Http(Box<dyn std::error::Error + Send + Sync>),
51 Time(std::time::SystemTimeError),
52 Xml(quick_xml::DeError),
53 Api { status: u16, body: String },
54}
55
56impl fmt::Display for Error {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 match self {
59 Error::InvalidConfig(msg) => write!(f, "invalid config: {msg}"),
60 Error::Http(err) => write!(f, "http error: {err}"),
61 Error::Time(err) => write!(f, "time error: {err}"),
62 Error::Xml(err) => write!(f, "xml error: {err}"),
63 Error::Api {
64 status,
65 body,
66 } => write!(f, "s3 api error (status {status}): {body}"),
67 }
68 }
69}
70
71impl std::error::Error for Error {
72 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
73 match self {
74 Error::Http(err) => Some(&**err),
75 _ => None,
76 }
77 }
78}
79
80impl From<std::time::SystemTimeError> for Error {
81 fn from(value: std::time::SystemTimeError) -> Self {
82 Self::Time(value)
83 }
84}
85
86impl From<quick_xml::DeError> for Error {
87 fn from(value: quick_xml::DeError) -> Self {
88 Self::Xml(value)
89 }
90}
91
92pub type HttpError = Box<dyn std::error::Error + Send + Sync>;
93
94pub type ByteStream = Pin<Box<dyn Stream<Item = Result<Bytes, HttpError>> + Send>>;
96
97#[derive(Debug, Clone, Copy)]
98pub enum HttpMethod {
99 Get,
100 Put,
101 Post,
102 Delete,
103 Head,
104}
105
106impl HttpMethod {
107 pub(crate) fn as_str(self) -> &'static str {
108 match self {
109 HttpMethod::Get => "GET",
110 HttpMethod::Put => "PUT",
111 HttpMethod::Post => "POST",
112 HttpMethod::Delete => "DELETE",
113 HttpMethod::Head => "HEAD",
114 }
115 }
116}
117
118#[derive(Debug, Clone)]
119pub struct HttpRequest {
120 pub method: HttpMethod,
121 pub url: String,
122 pub headers: Vec<(String, String)>,
123 pub body: Vec<u8>,
124}
125
126pub struct HttpResponseData {
127 pub status_code: u16,
128 pub headers: Vec<(String, String)>,
129 pub body: ByteStream,
130}
131
132impl HttpResponseData {
133 pub fn header(&self, name: &str) -> Option<&str> {
134 self.headers
135 .iter()
136 .find(|(k, _)| k.eq_ignore_ascii_case(name))
137 .map(|(_, v)| v.as_str())
138 }
139}
140
141pub trait HttpClient: Send + Sync {
142 fn send(
143 &self,
144 request: HttpRequest,
145 ) -> impl std::future::Future<Output = Result<HttpResponseData, HttpError>> + Send;
146}
147
148#[cfg(feature = "reqwest")]
149#[derive(Debug, Default, Clone)]
150pub struct ReqwestHttpClient {
151 inner: reqwest::Client,
152}
153
154#[cfg(feature = "reqwest")]
155impl ReqwestHttpClient {
156 pub fn new() -> Self {
157 Self {
158 inner: reqwest::Client::new(),
159 }
160 }
161}
162
163#[cfg(feature = "reqwest")]
164impl HttpClient for ReqwestHttpClient {
165 async fn send(&self, request: HttpRequest) -> Result<HttpResponseData, HttpError> {
166 let method = reqwest::Method::from_bytes(request.method.as_str().as_bytes())?;
167 let mut req = self.inner.request(method, &request.url);
168 for (name, value) in request.headers {
169 req = req.header(&name, &value);
170 }
171 if !request.body.is_empty() {
172 req = req.body(request.body);
173 }
174 let response = req.send().await?;
175 let status_code = response.status().as_u16();
176 let headers = response
177 .headers()
178 .iter()
179 .filter_map(|(name, value)| value.to_str().ok().map(|v| (name.as_str().to_string(), v.to_string())))
180 .collect();
181 let body: ByteStream = Box::pin(
182 response
183 .bytes_stream()
184 .map(|r| r.map_err(|e| -> HttpError { Box::new(e) })),
185 );
186 Ok(HttpResponseData {
187 status_code,
188 headers,
189 body,
190 })
191 }
192}
193
194impl<H: HttpClient> Client<H> {
195 pub fn with_http_client(config: &ClientConfig<'_>, http: H) -> Result<Self, Error> {
196 if config.endpoint.trim().is_empty() {
197 return Err(Error::InvalidConfig("endpoint must not be empty"));
198 }
199 if config.region.trim().is_empty() {
200 return Err(Error::InvalidConfig("region must not be empty"));
201 }
202 if config.credentials.access_key_id.trim().is_empty() {
203 return Err(Error::InvalidConfig("access key id must not be empty"));
204 }
205 if config.credentials.secret_access_key.trim().is_empty() {
206 return Err(Error::InvalidConfig("secret access key must not be empty"));
207 }
208
209 let endpoint = Url::parse(config.endpoint).map_err(|_| Error::InvalidConfig("invalid endpoint URL"))?;
210
211 Ok(Self {
212 endpoint,
213 region: config.region.to_string(),
214 credentials: OwnedCredentials {
215 access_key_id: config.credentials.access_key_id.to_string(),
216 secret_access_key: config.credentials.secret_access_key.to_string(),
217 session_token: if config.credentials.session_token.is_empty() {
218 None
219 } else {
220 Some(config.credentials.session_token.to_string())
221 },
222 },
223 virtual_hosted: config.virtual_hosted,
224 http,
225 })
226 }
227
228 pub(crate) async fn execute(
229 &self,
230 method: HttpMethod,
231 canonical_uri: &str,
232 canonical_query: &str,
233 body: &[u8],
234 bucket: &str,
235 ) -> Result<HttpResponseData, Error> {
236 self.execute_with_headers(method, canonical_uri, canonical_query, body, &[], bucket)
237 .await
238 }
239
240 pub(crate) async fn execute_with_headers(
241 &self,
242 method: HttpMethod,
243 canonical_uri: &str,
244 canonical_query: &str,
245 body: &[u8],
246 extra_headers: &[(String, String)],
247 bucket: &str,
248 ) -> Result<HttpResponseData, Error> {
249 fn canonical_header_value(value: &str) -> String {
250 let trimmed = value.trim();
251 let mut out = String::with_capacity(trimmed.len());
252 let mut in_space = false;
253 for ch in trimmed.chars() {
254 if ch.is_whitespace() {
255 if !in_space {
256 out.push(' ');
257 in_space = true;
258 }
259 } else {
260 out.push(ch);
261 in_space = false;
262 }
263 }
264 out
265 }
266
267 let (date, amz_datetime) = amz_datetime(SystemTime::now())?;
268 let credential_scope = format!("{}/{}/s3/aws4_request", date, self.region);
269
270 let base_host = endpoint_host(&self.endpoint);
271 let host = if self.virtual_hosted && !bucket.is_empty() {
272 Cow::Owned(format!("{bucket}.{base_host}"))
273 } else {
274 Cow::Borrowed(&base_host)
275 };
276 let payload_hash = if body.is_empty() {
277 EMPTY_PAYLOAD_SHA256.to_string()
278 } else {
279 hex::encode(&sha256(body))
280 };
281
282 let mut headers = vec![
283 ("host".to_string(), host.to_string()),
284 ("x-amz-date".to_string(), amz_datetime.clone()),
285 ("x-amz-content-sha256".to_string(), payload_hash.clone()),
286 ];
287 headers.extend(extra_headers.iter().cloned());
288 if let Some(token) = self.credentials.session_token.as_deref() {
289 headers.push(("x-amz-security-token".to_string(), token.to_string()));
290 }
291
292 let mut canonical_headers = headers
293 .iter()
294 .map(|(name, value)| format!("{}:{}\n", name.to_ascii_lowercase(), canonical_header_value(value)))
295 .collect::<Vec<_>>();
296 let mut signed_headers = headers
297 .iter()
298 .map(|(name, _)| name.to_ascii_lowercase())
299 .collect::<Vec<_>>();
300
301 canonical_headers.sort();
302 signed_headers.sort();
303
304 let canonical_headers_joined = canonical_headers.concat();
305 let signed_headers_joined = signed_headers.join(";");
306
307 let canonical_request = format!(
308 "{}\n{}\n{}\n{}\n{}\n{}",
309 method.as_str(),
310 canonical_uri,
311 canonical_query,
312 canonical_headers_joined,
313 signed_headers_joined,
314 payload_hash
315 );
316
317 let string_to_sign = format!(
318 "AWS4-HMAC-SHA256\n{}\n{}\n{}",
319 amz_datetime,
320 credential_scope,
321 hex::encode(&sha256(canonical_request.as_bytes()))
322 );
323
324 let signature = hex::encode(&sign_v4(
325 &self.credentials.secret_access_key,
326 &date,
327 &self.region,
328 &string_to_sign,
329 ));
330
331 let authorization = format!(
332 "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
333 self.credentials.access_key_id, credential_scope, signed_headers_joined, signature
334 );
335
336 let url = if self.virtual_hosted && !bucket.is_empty() {
337 let scheme = self.endpoint.scheme();
338 let host_str = endpoint_host(&self.endpoint);
339 let path = canonical_uri;
340 if canonical_query.is_empty() {
341 format!("{scheme}://{bucket}.{host_str}{path}")
342 } else {
343 format!("{scheme}://{bucket}.{host_str}{path}?{canonical_query}")
344 }
345 } else if canonical_query.is_empty() {
346 format!("{}{}", self.endpoint.as_str().trim_end_matches('/'), canonical_uri)
347 } else {
348 format!(
349 "{}{}?{}",
350 self.endpoint.as_str().trim_end_matches('/'),
351 canonical_uri,
352 canonical_query
353 )
354 };
355
356 headers.push(("authorization".to_string(), authorization));
357
358 let request = HttpRequest {
359 method,
360 url,
361 headers,
362 body: body.to_vec(),
363 };
364 let response = self.http.send(request).await.map_err(Error::Http)?;
365
366 if (200..300).contains(&response.status_code) {
367 return Ok(response);
368 }
369
370 let status = response.status_code;
371 let body_bytes = collect_body(response.body).await.unwrap_or_default();
372 let body = String::from_utf8_lossy(&body_bytes).into_owned();
373 Err(Error::Api {
374 status,
375 body,
376 })
377 }
378}
379
380#[cfg(feature = "reqwest")]
381impl Client<ReqwestHttpClient> {
382 pub fn new(config: &ClientConfig<'_>) -> Result<Self, Error> {
383 Self::with_http_client(config, ReqwestHttpClient::new())
384 }
385}
386
387pub(crate) fn consume_empty(_response: HttpResponseData) -> Result<(), Error> {
388 Ok(())
389}
390
391pub(crate) fn canonical_bucket_uri(bucket: &str) -> String {
392 canonical_uri(&format!("/{bucket}"))
393}
394
395pub(crate) fn canonical_object_uri(bucket: &str, key: &str) -> String {
396 canonical_uri(&format!("/{bucket}/{key}"))
397}
398
399pub(crate) fn canonical_uri(path: &str) -> String {
400 path.split('/').map(percent_encode).collect::<Vec<_>>().join("/")
401}
402
403pub(crate) fn canonical_query_string(params: &BTreeMap<String, String>) -> String {
404 params
405 .iter()
406 .map(|(k, v)| format!("{}={}", percent_encode(k), percent_encode(v)))
407 .collect::<Vec<_>>()
408 .join("&")
409}
410
411fn endpoint_host(url: &Url) -> String {
412 match (url.host_str(), url.port()) {
413 (Some(host), Some(port)) => format!("{host}:{port}"),
414 (Some(host), None) => host.to_string(),
415 _ => String::new(),
416 }
417}
418
419pub(crate) fn percent_encode(input: &str) -> String {
420 let mut out = String::with_capacity(input.len());
421 for &b in input.as_bytes() {
422 if b.is_ascii_uppercase()
423 || b.is_ascii_lowercase()
424 || b.is_ascii_digit()
425 || matches!(b, b'-' | b'_' | b'.' | b'~')
426 {
427 out.push(b as char);
428 } else {
429 out.push('%');
430 out.push(hex_nibble_upper((b >> 4) & 0x0f));
431 out.push(hex_nibble_upper(b & 0x0f));
432 }
433 }
434 out
435}
436
437fn hex_nibble_upper(value: u8) -> char {
438 match value {
439 0..=9 => (b'0' + value) as char,
440 10..=15 => (b'A' + (value - 10)) as char,
441 _ => unreachable!(),
442 }
443}
444
445fn sign_v4(secret_access_key: &str, date: &str, region: &str, string_to_sign: &str) -> [u8; 32] {
446 let k_date = hmac_sha256(format!("AWS4{secret_access_key}").as_bytes(), date.as_bytes());
447 let k_region = hmac_sha256(&k_date, region.as_bytes());
448 let k_service = hmac_sha256(&k_region, b"s3");
449 let k_signing = hmac_sha256(&k_service, b"aws4_request");
450 hmac_sha256(&k_signing, string_to_sign.as_bytes())
451}
452
453pub(crate) fn hmac_sha256(key: &[u8], data: &[u8]) -> [u8; Sha256::OUTPUT_SIZE] {
454 let mut mac = Hmac::<Sha256>::new(key);
455 mac.update(data);
456 *mac.finalize().as_ref().as_array().unwrap()
457}
458
459pub(crate) fn sha256(data: &[u8]) -> [u8; Sha256::OUTPUT_SIZE] {
460 let mut hasher = Sha256::new();
461 hasher.update(data);
462 hasher.sum().as_ref().try_into().unwrap()
463}
464
465fn amz_datetime(now: SystemTime) -> Result<(String, String), Error> {
466 let elapsed = now.duration_since(UNIX_EPOCH)?;
467 let total_seconds = elapsed.as_secs() as i64;
468
469 let days = total_seconds.div_euclid(86_400);
470 let seconds_of_day = total_seconds.rem_euclid(86_400);
471
472 let (year, month, day) = civil_from_days(days);
473 let hour = seconds_of_day / 3_600;
474 let minute = (seconds_of_day % 3_600) / 60;
475 let second = seconds_of_day % 60;
476
477 let date = format!("{year:04}{month:02}{day:02}");
478 let datetime = format!("{date}T{hour:02}{minute:02}{second:02}Z");
479
480 Ok((date, datetime))
481}
482
483fn civil_from_days(days_since_unix_epoch: i64) -> (i32, i64, i64) {
484 let z = days_since_unix_epoch + 719_468;
485 let era = if z >= 0 { z } else { z - 146_096 } / 146_097;
486 let doe = z - era * 146_097;
487 let yoe = (doe - doe / 1_460 + doe / 36_524 - doe / 146_096) / 365;
488 let y = yoe + era * 400;
489 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
490 let mp = (5 * doy + 2) / 153;
491 let day = doy - (153 * mp + 2) / 5 + 1;
492 let month = mp + if mp < 10 { 3 } else { -9 };
493 let year = y + if month <= 2 { 1 } else { 0 };
494
495 (year as i32, month, day)
496}
497
498pub(crate) async fn collect_body(mut stream: ByteStream) -> Result<Vec<u8>, Error> {
499 let mut out = Vec::new();
500 while let Some(chunk) = stream.next().await {
501 let bytes = chunk.map_err(Error::Http)?;
502 out.extend_from_slice(&bytes);
503 }
504 Ok(out)
505}
506
507pub(crate) fn bytes_to_string(bytes: Vec<u8>) -> Result<String, Error> {
508 String::from_utf8(bytes).map_err(|e| Error::Http(Box::new(e)))
509}
510
511pub(crate) fn header_to_string(response: &HttpResponseData, name: &str) -> Option<String> {
512 response.header(name).map(ToString::to_string)
513}
514
515pub(crate) fn header_to_u64(response: &HttpResponseData, name: &str) -> Option<u64> {
516 response.header(name).and_then(|s| s.parse::<u64>().ok())
517}
518
519pub(crate) fn xml_escape(s: &str) -> String {
520 s.replace('&', "&").replace('<', "<").replace('>', ">")
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526
527 #[derive(Default)]
528 struct NoopHttpClient;
529
530 impl HttpClient for NoopHttpClient {
531 async fn send(&self, _request: HttpRequest) -> Result<HttpResponseData, HttpError> {
532 Ok(HttpResponseData {
533 status_code: 200,
534 headers: Vec::new(),
535 body: Box::pin(futures_util::stream::empty()),
536 })
537 }
538 }
539
540 #[test]
541 fn sha256_known_vectors() {
542 assert_eq!(
543 hex::encode(&sha256(b"")),
544 "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
545 );
546 assert_eq!(
547 hex::encode(&sha256(b"abc")),
548 "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
549 );
550 }
551
552 #[test]
553 fn hmac_sha256_known_vector() {
554 let key = [0x0b; 20];
555 let sig = hmac_sha256(&key, b"Hi There");
556 assert_eq!(
557 hex::encode(&sig),
558 "b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7"
559 );
560 }
561
562 #[test]
563 fn percent_encoding_works() {
564 assert_eq!(percent_encode("abcXYZ-_.~"), "abcXYZ-_.~");
565 assert_eq!(percent_encode("hello world/é"), "hello%20world%2F%C3%A9");
566 }
567
568 #[test]
569 fn canonical_uri_preserves_slashes() {
570 assert_eq!(
571 canonical_object_uri("my-bucket", "folder/my file.txt"),
572 "/my-bucket/folder/my%20file.txt"
573 );
574 }
575
576 #[test]
577 fn amz_datetime_format_works() {
578 let ts = UNIX_EPOCH + std::time::Duration::from_secs(1_700_000_000);
579 let (date, datetime) = amz_datetime(ts).unwrap();
580 assert_eq!(date.len(), 8);
581 assert_eq!(datetime.len(), 16);
582 assert!(datetime.ends_with('Z'));
583 }
584
585 #[test]
586 fn client_config_validation_works() {
587 let cfg = ClientConfig {
588 endpoint: "",
589 credentials: StaticCredentials {
590 access_key_id: "a",
591 secret_access_key: "b",
592 session_token: "",
593 },
594 region: "auto",
595 virtual_hosted: false,
596 };
597 assert!(matches!(
598 Client::with_http_client(&cfg, NoopHttpClient),
599 Err(Error::InvalidConfig(_))
600 ));
601 }
602
603 #[test]
604 fn sign_v4_known_output_length() {
605 let sig = sign_v4(
606 "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
607 "20130524",
608 "auto",
609 "AWS4-HMAC-SHA256\n20130524T000000Z\n20130524/auto/s3/aws4_request\nabc",
610 );
611 assert_eq!(sig.len(), 32);
612 }
613
614 #[test]
615 fn empty_payload_sha256_constant_is_correct() {
616 assert_eq!(EMPTY_PAYLOAD_SHA256, hex::encode(&sha256(b"")));
617 }
618}