1use std::{
2 collections::HashMap,
3 hash::{Hash, Hasher},
4 num::NonZeroUsize,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8};
9
10use bytes::BytesMut;
11use lru::LruCache;
12use rustls::pki_types::ServerName;
13use tokio::{
14 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
15 sync::Mutex,
16};
17use tokio_rustls::TlsConnector;
18
19use crate::{
20 config::ConnectParams,
21 encode::ToSql,
22 error::{PgError, Result},
23 protocol::{BackendDecoder, BackendMessage, FieldDescription, FrontendMessage, ScramClient},
24 row::Row,
25};
26
27enum PgStream {
28 Plain(tokio::net::TcpStream),
29 Tls(tokio_rustls::client::TlsStream<tokio::net::TcpStream>),
30}
31
32impl AsyncRead for PgStream {
33 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
34 match self.get_mut() {
35 PgStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
36 PgStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
37 }
38 }
39}
40
41impl AsyncWrite for PgStream {
42 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
43 match self.get_mut() {
44 PgStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
45 PgStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
46 }
47 }
48
49 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
50 match self.get_mut() {
51 PgStream::Plain(s) => Pin::new(s).poll_flush(cx),
52 PgStream::Tls(s) => Pin::new(s).poll_flush(cx),
53 }
54 }
55
56 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
57 match self.get_mut() {
58 PgStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
59 PgStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
60 }
61 }
62}
63
64struct Inner {
65 stream: PgStream,
66 buf: BytesMut,
67 _pid: i32,
68 _secret_key: i32,
69 _parameter_status: HashMap<String, String>,
70 statement_cache: LruCache<String, (String, Vec<FieldDescription>)>,
71}
72
73#[derive(Clone)]
74pub struct Connection {
75 inner: Arc<Mutex<Inner>>,
76}
77
78impl std::fmt::Debug for Connection {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 f.debug_struct("Connection").finish_non_exhaustive()
81 }
82}
83
84fn sql_statement_name(sql: &str) -> String {
85 let mut hasher = std::collections::hash_map::DefaultHasher::new();
86 sql.hash(&mut hasher);
87 format!("s{:x}", hasher.finish())
88}
89
90impl Connection {
91 pub async fn connect(params: &ConnectParams) -> Result<Self> {
92 let addr = format!("{}:{}", params.host, params.port);
93 let tcp = tokio::time::timeout(params.connect_timeout, tokio::net::TcpStream::connect(&addr))
94 .await
95 .map_err(|_| PgError::Protocol(format!("connection timeout to {}", addr)))?
96 .map_err(PgError::Io)?;
97
98 tcp.set_nodelay(true).ok();
99
100 let stream = tls_handshake(tcp, ¶ms.host).await?;
101
102 let mut inner = Inner {
103 stream,
104 buf: BytesMut::with_capacity(8192),
105 _pid: 0,
106 _secret_key: 0,
107 _parameter_status: HashMap::new(),
108 statement_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
109 };
110
111 do_startup(&mut inner, params).await?;
112
113 Ok(Connection {
114 inner: Arc::new(Mutex::new(inner)),
115 })
116 }
117
118 pub async fn query_raw(&self, sql: &str, params: &[&dyn ToSql]) -> Result<Vec<Row>> {
119 let mut guard = self.inner.lock().await;
120
121 if params.is_empty() {
122 return simple_query(&mut guard, sql).await;
123 }
124
125 let stmt_name = sql_statement_name(sql);
126 let param_oids: Vec<u32> = params.iter().map(|p| p.pg_type().oid).collect();
127
128 let cached_name = guard.statement_cache.get(sql).map(|(name, _)| name.clone());
129 if let Some(ref old_name) = cached_name {
130 if *old_name != stmt_name {
131 send_close_statement(&mut guard, old_name).await?;
132 guard.statement_cache.pop(sql);
133 }
134 }
135
136 if !guard.statement_cache.contains(sql) {
137 send_parse(&mut guard, &stmt_name, sql, ¶m_oids).await?;
138 read_until_parse_complete(&mut guard).await?;
139
140 send_describe_statement(&mut guard, &stmt_name).await?;
141 let fields = match read_describe_response(&mut guard).await? {
142 Some(fields) => fields,
143 None => Vec::new(),
144 };
145
146 guard.statement_cache.put(sql.to_string(), (stmt_name.clone(), fields));
147 }
148
149 let (cached_stmt_name, fields) = guard.statement_cache.get(sql).expect("just cached").clone();
150
151 let param_binary: Vec<Vec<u8>> = params.iter().map(|p| p.to_sql()).collect::<Result<Vec<Vec<u8>>>>()?;
152 let param_formats: Vec<crate::types::Format> = params.iter().map(|_| crate::types::Format::Binary).collect();
153
154 send_bind(
155 &mut guard,
156 "",
157 &cached_stmt_name,
158 ¶m_formats,
159 ¶m_binary,
160 crate::types::Format::Binary,
161 )
162 .await?;
163 read_until_bind_complete(&mut guard).await?;
164
165 send_describe_portal(&mut guard, "").await?;
166
167 send_execute(&mut guard, "", 0).await?;
168 let rows = read_rows_until_complete(&mut guard, &fields).await?;
169
170 send_sync(&mut guard).await?;
171 read_until_ready(&mut guard).await?;
172
173 Ok(rows)
174 }
175
176 pub async fn execute_raw(&self, sql: &str, params: &[&dyn ToSql]) -> Result<u64> {
177 let mut guard = self.inner.lock().await;
178
179 if params.is_empty() {
180 return simple_execute(&mut guard, sql).await;
181 }
182
183 let stmt_name = sql_statement_name(sql);
184 let param_oids: Vec<u32> = params.iter().map(|p| p.pg_type().oid).collect();
185
186 if let Some((old_name, _)) = guard.statement_cache.get(sql) {
187 if *old_name != stmt_name {
188 let old = old_name.clone();
189 guard.statement_cache.pop(sql);
190 send_close_statement(&mut guard, &old).await?;
191 }
192 }
193
194 if !guard.statement_cache.contains(sql) {
195 send_parse(&mut guard, &stmt_name, sql, ¶m_oids).await?;
196 read_until_parse_complete(&mut guard).await?;
197 guard
198 .statement_cache
199 .put(sql.to_string(), (stmt_name.clone(), Vec::new()));
200 }
201
202 let cached_stmt_name = guard
203 .statement_cache
204 .get(sql)
205 .map(|(n, _)| n.clone())
206 .expect("just cached");
207
208 let param_binary: Vec<Vec<u8>> = params.iter().map(|p| p.to_sql()).collect::<Result<Vec<Vec<u8>>>>()?;
209 let param_formats: Vec<crate::types::Format> = params.iter().map(|_| crate::types::Format::Binary).collect();
210
211 send_bind(
212 &mut guard,
213 "",
214 &cached_stmt_name,
215 ¶m_formats,
216 ¶m_binary,
217 crate::types::Format::Binary,
218 )
219 .await?;
220 read_until_bind_complete(&mut guard).await?;
221
222 send_execute(&mut guard, "", 0).await?;
223 let rows_affected = read_command_complete(&mut guard).await?;
224
225 send_sync(&mut guard).await?;
226 read_until_ready(&mut guard).await?;
227
228 Ok(rows_affected)
229 }
230
231 pub(crate) async fn ping(&self) -> Result<()> {
232 self.query_raw("SELECT 1", &[]).await?;
233 Ok(())
234 }
235}
236
237async fn tls_handshake(tcp: tokio::net::TcpStream, host: &str) -> Result<PgStream> {
238 let msg = FrontendMessage::SslRequest;
239 let encoded = msg.encode();
240 let (mut reader, mut writer) = tokio::io::split(tcp);
241
242 writer.write_all(&encoded).await?;
243
244 let mut response = [0u8; 1];
245 reader.read_exact(&mut response).await?;
246
247 if response[0] == b'S' {
248 let config = rustls::ClientConfig::builder()
249 .with_root_certificates(rustls::RootCertStore::empty())
250 .with_no_client_auth();
251
252 let connector = TlsConnector::from(Arc::new(config));
253 let server_name = ServerName::try_from(host.to_string())
254 .map_err(|_| PgError::Config(format!("invalid hostname: {}", host)))?;
255
256 let tls_stream = connector
257 .connect(server_name, reader.unsplit(writer))
258 .await
259 .map_err(|e| PgError::Tls(Box::new(e)))?;
260 Ok(PgStream::Tls(tls_stream))
261 } else {
262 let tcp = reader.unsplit(writer);
263 Ok(PgStream::Plain(tcp))
264 }
265}
266
267async fn do_startup(inner: &mut Inner, params: &ConnectParams) -> Result<()> {
268 let mut kv = vec![
269 ("client_encoding".to_string(), "UTF8".to_string()),
270 ("user".to_string(), params.user.clone()),
271 ];
272 if let Some(ref db) = params.dbname {
273 kv.push(("database".to_string(), db.clone()));
274 }
275
276 let msg = FrontendMessage::Startup(kv);
277 inner.write_all(&msg.encode()).await?;
278
279 loop {
280 let backend = read_message(inner).await?;
281 match backend {
282 BackendMessage::AuthenticationOk => {
283 read_until_ready(inner).await?;
284 return Ok(());
285 }
286 BackendMessage::AuthenticationCleartextPassword => {
287 let password = params.password.as_deref().unwrap_or("");
288 let msg = FrontendMessage::Password(password.to_string());
289 inner.write_all(&msg.encode()).await?;
290 }
291 BackendMessage::AuthenticationSasl(mechanisms) => {
292 if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
293 return Err(PgError::Auth("server does not support SCRAM-SHA-256".into()));
294 }
295 let mut scram = ScramClient::new(¶ms.user, params.password.as_deref().unwrap_or(""));
296
297 let initial = scram.client_first_message().as_bytes().to_vec();
298 let msg = FrontendMessage::SaslInitialResponse("SCRAM-SHA-256".to_string(), initial);
299 inner.write_all(&msg.encode()).await?;
300
301 let continue_msg = read_message(inner).await?;
302 match continue_msg {
303 BackendMessage::AuthenticationSaslContinue(data) => {
304 scram.parse_server_first_message(&data)?;
305 let final_msg = scram.build_client_final_message();
306 let msg = FrontendMessage::SaslResponse(final_msg);
307 inner.write_all(&msg.encode()).await?;
308 }
309 _ => return Err(PgError::Protocol("expected SASL continue".into())),
310 }
311
312 let final_msg = read_message(inner).await?;
313 match final_msg {
314 BackendMessage::AuthenticationSaslFinal(data) => {
315 scram.parse_server_final_message(&data)?;
316 }
317 _ => return Err(PgError::Protocol("expected SASL final".into())),
318 }
319 }
320 BackendMessage::ParameterStatus(_, _) => {}
321 BackendMessage::BackendKeyData(pid, key) => {
322 inner._pid = pid;
323 inner._secret_key = key;
324 }
325 BackendMessage::ReadyForQuery(_) => {
326 return Ok(());
327 }
328 _ => return Err(PgError::Protocol("unexpected message during startup".into())),
329 }
330 }
331}
332
333async fn simple_query(inner: &mut Inner, sql: &str) -> Result<Vec<Row>> {
334 let msg = FrontendMessage::Query(sql.to_string());
335 inner.write_all(&msg.encode()).await?;
336
337 let mut rows = Vec::new();
338 let mut fields = Vec::new();
339
340 loop {
341 let backend = read_message(inner).await?;
342 match backend {
343 BackendMessage::RowDescription(fds) => {
344 fields = fds;
345 }
346 BackendMessage::DataRow(cols) => {
347 let row = Row::new(&fields, &cols);
348 rows.push(row);
349 }
350 BackendMessage::CommandComplete(_) => {}
351 BackendMessage::ReadyForQuery(_) => {
352 return Ok(rows);
353 }
354 BackendMessage::EmptyQueryResponse => {}
355 BackendMessage::NoticeResponse(_) => {}
356 _ => {
357 return Err(PgError::Protocol("unexpected message in simple query".into()));
358 }
359 }
360 }
361}
362
363async fn simple_execute(inner: &mut Inner, sql: &str) -> Result<u64> {
364 let msg = FrontendMessage::Query(sql.to_string());
365 inner.write_all(&msg.encode()).await?;
366
367 let mut rows_affected = 0u64;
368
369 loop {
370 let backend = read_message(inner).await?;
371 match backend {
372 BackendMessage::RowDescription(_) => {}
373 BackendMessage::DataRow(_) => {}
374 BackendMessage::CommandComplete(tag) => {
375 parse_command_tag(&tag, &mut rows_affected);
376 }
377 BackendMessage::ReadyForQuery(_) => {
378 return Ok(rows_affected);
379 }
380 BackendMessage::EmptyQueryResponse => {}
381 BackendMessage::NoticeResponse(_) => {}
382 _ => {
383 return Err(PgError::Protocol("unexpected message in simple execute".into()));
384 }
385 }
386 }
387}
388
389fn parse_command_tag(tag: &str, affected: &mut u64) -> u64 {
390 if let Some(n) = tag.rsplit(' ').next().and_then(|s| s.parse::<u64>().ok()) {
391 *affected = n;
392 n
393 } else {
394 0
395 }
396}
397
398async fn send_parse(inner: &mut Inner, stmt_name: &str, sql: &str, param_oids: &[u32]) -> Result<()> {
399 let msg = FrontendMessage::Parse(stmt_name.to_string(), sql.to_string(), param_oids.to_vec());
400 inner.write_all(&msg.encode()).await?;
401 Ok(())
402}
403
404async fn read_until_parse_complete(inner: &mut Inner) -> Result<()> {
405 loop {
406 let backend = read_message(inner).await?;
407 match backend {
408 BackendMessage::ParseComplete => return Ok(()),
409 BackendMessage::NoticeResponse(_) => {}
410 _ => return Err(PgError::Protocol("expected ParseComplete".into())),
411 }
412 }
413}
414
415async fn send_bind(
416 inner: &mut Inner,
417 portal: &str,
418 stmt: &str,
419 formats: &[crate::types::Format],
420 params: &[Vec<u8>],
421 result_format: crate::types::Format,
422) -> Result<()> {
423 let msg = FrontendMessage::Bind(
424 portal.to_string(),
425 stmt.to_string(),
426 formats.to_vec(),
427 params.to_vec(),
428 result_format,
429 );
430 inner.write_all(&msg.encode()).await?;
431 Ok(())
432}
433
434async fn read_until_bind_complete(inner: &mut Inner) -> Result<()> {
435 loop {
436 let backend = read_message(inner).await?;
437 match backend {
438 BackendMessage::BindComplete => return Ok(()),
439 BackendMessage::NoticeResponse(_) => {}
440 _ => return Err(PgError::Protocol("expected BindComplete".into())),
441 }
442 }
443}
444
445async fn send_describe_statement(inner: &mut Inner, stmt: &str) -> Result<()> {
446 let msg = FrontendMessage::Describe(b'S', stmt.to_string());
447 inner.write_all(&msg.encode()).await?;
448 Ok(())
449}
450
451async fn send_describe_portal(inner: &mut Inner, portal: &str) -> Result<()> {
452 let msg = FrontendMessage::Describe(b'P', portal.to_string());
453 inner.write_all(&msg.encode()).await?;
454 Ok(())
455}
456
457async fn read_describe_response(inner: &mut Inner) -> Result<Option<Vec<FieldDescription>>> {
458 loop {
459 let backend = read_message(inner).await?;
460 match backend {
461 BackendMessage::RowDescription(fields) => return Ok(Some(fields)),
462 BackendMessage::NoData => return Ok(None),
463 BackendMessage::NoticeResponse(_) => {}
464 _ => return Err(PgError::Protocol("expected RowDescription or NoData".into())),
465 }
466 }
467}
468
469async fn send_execute(inner: &mut Inner, portal: &str, max_rows: i32) -> Result<()> {
470 let msg = FrontendMessage::Execute(portal.to_string(), max_rows);
471 inner.write_all(&msg.encode()).await?;
472 Ok(())
473}
474
475async fn read_rows_until_complete(inner: &mut Inner, fields: &[FieldDescription]) -> Result<Vec<Row>> {
476 let mut rows = Vec::new();
477
478 loop {
479 let backend = read_message(inner).await?;
480 match backend {
481 BackendMessage::DataRow(cols) => {
482 let row = Row::new(fields, &cols);
483 rows.push(row);
484 }
485 BackendMessage::CommandComplete(_) => {
486 return Ok(rows);
487 }
488 BackendMessage::PortalSuspended => {
489 return Ok(rows);
490 }
491 BackendMessage::NoticeResponse(_) => {}
492 _ => return Err(PgError::Protocol("expected DataRow or CommandComplete".into())),
493 }
494 }
495}
496
497async fn read_command_complete(inner: &mut Inner) -> Result<u64> {
498 let mut affected = 0u64;
499 loop {
500 let backend = read_message(inner).await?;
501 match backend {
502 BackendMessage::CommandComplete(tag) => {
503 parse_command_tag(&tag, &mut affected);
504 return Ok(affected);
505 }
506 BackendMessage::NoticeResponse(_) => {}
507 _ => return Err(PgError::Protocol("expected CommandComplete".into())),
508 }
509 }
510}
511
512async fn send_sync(inner: &mut Inner) -> Result<()> {
513 let msg = FrontendMessage::Sync;
514 inner.write_all(&msg.encode()).await?;
515 Ok(())
516}
517
518async fn send_close_statement(inner: &mut Inner, stmt: &str) -> Result<()> {
519 let msg = FrontendMessage::Close(b'S', stmt.to_string());
520 inner.write_all(&msg.encode()).await?;
521 read_until_close_complete(inner).await
522}
523
524async fn read_until_close_complete(inner: &mut Inner) -> Result<()> {
525 loop {
526 let backend = read_message(inner).await?;
527 match backend {
528 BackendMessage::CloseComplete => return Ok(()),
529 BackendMessage::NoticeResponse(_) => {}
530 _ => return Err(PgError::Protocol("expected CloseComplete".into())),
531 }
532 }
533}
534
535async fn read_until_ready(inner: &mut Inner) -> Result<u8> {
536 loop {
537 let backend = read_message(inner).await?;
538 match backend {
539 BackendMessage::ReadyForQuery(status) => return Ok(status),
540 BackendMessage::NoticeResponse(_) => {}
541 BackendMessage::ParameterStatus(_, _) => {}
542 BackendMessage::CommandComplete(_) => {}
543 _ => {
544 return Err(PgError::Protocol("expected ReadyForQuery".into()));
545 }
546 }
547 }
548}
549
550async fn read_message(inner: &mut Inner) -> Result<BackendMessage> {
551 loop {
552 if let Some(msg) = BackendDecoder::decode(&mut inner.buf)? {
553 return Ok(msg);
554 }
555 inner.buf.reserve(4096);
556 let n = inner.stream.read_buf(&mut inner.buf).await.map_err(PgError::Io)?;
557 if n == 0 {
558 return Err(PgError::Protocol("connection closed by server".into()));
559 }
560 }
561}
562
563impl Inner {
564 async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
565 AsyncWriteExt::write_all(&mut self.stream, buf)
566 .await
567 .map_err(PgError::Io)
568 }
569}