1use crypto::{Hasher, hmac::Hmac, sha2::Sha256};
2use rand::TryRngCore;
3
4use crate::{
5 error::{PgError, Result},
6 protocol::{base64_decode, base64_encode},
7};
8
9fn hi(password: &str, salt: &[u8], iterations: u32) -> [u8; 32] {
10 let pw = password.as_bytes();
11 let mut u = hmac_sha256(pw, &[salt, &[0, 0, 0, 1]].concat());
12 let mut result = [0u8; 32];
13 let mut result = [0u8; 32];
14 result.copy_from_slice(u.as_ref());
15
16 for _ in 1..iterations {
17 u = hmac_sha256(pw, u.as_ref());
18 for (a, b) in result.iter_mut().zip(u.as_ref()) {
19 *a ^= b;
20 }
21 }
22
23 result
24}
25
26fn hmac_sha256(key: &[u8], data: &[u8]) -> crypto::Hash {
27 let mut mac = Hmac::<Sha256>::new(key);
28 mac.update(data);
29 mac.finalize()
30}
31
32pub(crate) struct ScramClient {
33 client_first_message_bare: String,
34 client_nonce: String,
35 password: String,
36 server_first_message: Option<String>,
37 client_final_without_proof: Option<String>,
38 salted_password: Option<[u8; 32]>,
39}
40
41impl ScramClient {
42 pub fn new(username: &str, password: &str) -> Self {
43 let mut rng = rand::rngs::OsRng;
44 let mut raw = [0u8; 24];
45 rng.try_fill_bytes(&mut raw).unwrap();
46 let client_nonce = hex::encode(&raw);
47
48 ScramClient {
49 client_first_message_bare: format!("n={},r={}", username, client_nonce),
50 client_nonce,
51 password: password.to_string(),
52 server_first_message: None,
53 client_final_without_proof: None,
54 salted_password: None,
55 }
56 }
57
58 pub fn client_first_message(&self) -> &str {
59 &self.client_first_message_bare
60 }
61
62 pub fn parse_server_first_message(&mut self, data: &[u8]) -> Result<()> {
63 let msg = std::str::from_utf8(data).map_err(|_| PgError::Auth("invalid utf-8 in server-first".into()))?;
64 self.server_first_message = Some(msg.to_string());
65
66 let mut combined_nonce = None;
67 let mut salt_b64 = None;
68 let mut iterations = None;
69
70 for part in msg.split(',') {
71 if let Some(val) = part.strip_prefix("r=") {
72 combined_nonce = Some(val.to_string());
73 } else if let Some(val) = part.strip_prefix("s=") {
74 salt_b64 = Some(val.to_string());
75 } else if let Some(val) = part.strip_prefix("i=") {
76 iterations = Some(
77 val.parse::<u32>()
78 .map_err(|_| PgError::Auth("invalid iteration count".into()))?,
79 );
80 }
81 }
82
83 let combined_nonce = combined_nonce.ok_or_else(|| PgError::Auth("missing nonce in server-first".into()))?;
84 let salt_b64 = salt_b64.ok_or_else(|| PgError::Auth("missing salt in server-first".into()))?;
85 let iterations = iterations.ok_or_else(|| PgError::Auth("missing iterations in server-first".into()))?;
86
87 if !combined_nonce.starts_with(&self.client_nonce) {
88 return Err(PgError::Auth("server nonce doesn't start with client nonce".into()));
89 }
90
91 let salt = base64_decode(&salt_b64).map_err(|e| PgError::Auth(format!("invalid base64 salt: {}", e)))?;
92
93 let salted_password = hi(&self.password, &salt, iterations);
94 self.salted_password = Some(salted_password);
95 self.client_final_without_proof = Some(format!("c=biws,r={}", combined_nonce));
96
97 Ok(())
98 }
99
100 pub fn build_client_final_message(&self) -> Vec<u8> {
101 let sp = self.salted_password.as_ref().expect("salted password not computed");
102
103 let client_key = hmac_sha256(sp, b"Client Key");
104 let client_key_bytes: &[u8] = client_key.as_ref();
105
106 let mut hasher = crypto::sha2::Sha256::new();
107 hasher.update(client_key_bytes);
108 let stored_key = hasher.sum();
109 let stored_key_bytes: &[u8] = stored_key.as_ref();
110
111 let server_first = self.server_first_message.as_ref().expect("no server-first message");
112 let cfnop = self
113 .client_final_without_proof
114 .as_ref()
115 .expect("no client-final-without-proof");
116
117 let auth_message = format!("{},{},{}", self.client_first_message_bare, server_first, cfnop);
118
119 let client_signature = hmac_sha256(stored_key_bytes, auth_message.as_bytes());
120
121 let mut client_proof = [0u8; 32];
122 client_proof.copy_from_slice(client_key_bytes);
123 for (a, b) in client_proof.iter_mut().zip(client_signature.as_ref()) {
124 *a ^= b;
125 }
126
127 let client_proof_b64 = base64_encode(&client_proof);
128 let client_final = format!("{},p={}", cfnop, client_proof_b64);
129 client_final.into_bytes()
130 }
131
132 pub fn parse_server_final_message(&self, data: &[u8]) -> Result<()> {
133 let msg = std::str::from_utf8(data).map_err(|_| PgError::Auth("invalid utf-8 in server-final".into()))?;
134
135 let mut server_sig_b64 = None;
136 for part in msg.split(',') {
137 if let Some(val) = part.strip_prefix("v=") {
138 server_sig_b64 = Some(val.to_string());
139 } else if let Some(val) = part.strip_prefix("e=") {
140 return Err(PgError::Auth(format!("server returned auth error: {}", val)));
141 }
142 }
143
144 let server_sig_b64 = server_sig_b64.ok_or_else(|| PgError::Auth("missing server signature".into()))?;
145
146 let sp = self.salted_password.as_ref().expect("salted password not computed");
147 let server_key = hmac_sha256(sp, b"Server Key");
148
149 let server_first = self.server_first_message.as_ref().expect("no server-first message");
150 let cfnop = self
151 .client_final_without_proof
152 .as_ref()
153 .expect("no client-final-without-proof");
154
155 let auth_message = format!("{},{},{}", self.client_first_message_bare, server_first, cfnop);
156
157 let expected_signature = hmac_sha256(server_key.as_ref(), auth_message.as_bytes());
158 let expected_b64 = base64_encode(expected_signature.as_ref());
159
160 if expected_b64 != server_sig_b64 {
161 return Err(PgError::Auth("server signature mismatch".into()));
162 }
163
164 Ok(())
165 }
166}