Skip to main content

pg/
connection.rs

1use std::{
2    collections::HashMap,
3    hash::{Hash, Hasher},
4    num::NonZeroUsize,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8};
9
10use bytes::BytesMut;
11use lru::LruCache;
12use rustls::pki_types::ServerName;
13use tokio::{
14    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
15    sync::Mutex,
16};
17use tokio_rustls::TlsConnector;
18
19use crate::{
20    config::ConnectParams,
21    encode::ToSql,
22    error::{PgError, Result},
23    protocol::{BackendDecoder, BackendMessage, FieldDescription, FrontendMessage, ScramClient},
24    row::Row,
25};
26
27enum PgStream {
28    Plain(tokio::net::TcpStream),
29    Tls(tokio_rustls::client::TlsStream<tokio::net::TcpStream>),
30}
31
32impl AsyncRead for PgStream {
33    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
34        match self.get_mut() {
35            PgStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
36            PgStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
37        }
38    }
39}
40
41impl AsyncWrite for PgStream {
42    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
43        match self.get_mut() {
44            PgStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
45            PgStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
46        }
47    }
48
49    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
50        match self.get_mut() {
51            PgStream::Plain(s) => Pin::new(s).poll_flush(cx),
52            PgStream::Tls(s) => Pin::new(s).poll_flush(cx),
53        }
54    }
55
56    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
57        match self.get_mut() {
58            PgStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
59            PgStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
60        }
61    }
62}
63
64struct Inner {
65    stream: PgStream,
66    buf: BytesMut,
67    _pid: i32,
68    _secret_key: i32,
69    _parameter_status: HashMap<String, String>,
70    statement_cache: LruCache<String, (String, Vec<FieldDescription>)>,
71}
72
73#[derive(Clone)]
74pub struct Connection {
75    inner: Arc<Mutex<Inner>>,
76}
77
78impl std::fmt::Debug for Connection {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        f.debug_struct("Connection").finish_non_exhaustive()
81    }
82}
83
84fn sql_statement_name(sql: &str) -> String {
85    let mut hasher = std::collections::hash_map::DefaultHasher::new();
86    sql.hash(&mut hasher);
87    format!("s{:x}", hasher.finish())
88}
89
90impl Connection {
91    pub async fn connect(params: &ConnectParams) -> Result<Self> {
92        let addr = format!("{}:{}", params.host, params.port);
93        let tcp = tokio::time::timeout(params.connect_timeout, tokio::net::TcpStream::connect(&addr))
94            .await
95            .map_err(|_| PgError::Protocol(format!("connection timeout to {}", addr)))?
96            .map_err(PgError::Io)?;
97
98        tcp.set_nodelay(true).ok();
99
100        let stream = tls_handshake(tcp, &params.host).await?;
101
102        let mut inner = Inner {
103            stream,
104            buf: BytesMut::with_capacity(8192),
105            _pid: 0,
106            _secret_key: 0,
107            _parameter_status: HashMap::new(),
108            statement_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
109        };
110
111        do_startup(&mut inner, params).await?;
112
113        Ok(Connection {
114            inner: Arc::new(Mutex::new(inner)),
115        })
116    }
117
118    pub async fn query_raw(&self, sql: &str, params: &[&dyn ToSql]) -> Result<Vec<Row>> {
119        let mut guard = self.inner.lock().await;
120
121        if params.is_empty() {
122            return simple_query(&mut guard, sql).await;
123        }
124
125        let stmt_name = sql_statement_name(sql);
126        let param_oids: Vec<u32> = params.iter().map(|p| p.pg_type().oid).collect();
127
128        let cached_name = guard.statement_cache.get(sql).map(|(name, _)| name.clone());
129        if let Some(ref old_name) = cached_name {
130            if *old_name != stmt_name {
131                send_close_statement(&mut guard, old_name).await?;
132                guard.statement_cache.pop(sql);
133            }
134        }
135
136        if !guard.statement_cache.contains(sql) {
137            send_parse(&mut guard, &stmt_name, sql, &param_oids).await?;
138            read_until_parse_complete(&mut guard).await?;
139
140            send_describe_statement(&mut guard, &stmt_name).await?;
141            let fields = match read_describe_response(&mut guard).await? {
142                Some(fields) => fields,
143                None => Vec::new(),
144            };
145
146            guard.statement_cache.put(sql.to_string(), (stmt_name.clone(), fields));
147        }
148
149        let (cached_stmt_name, fields) = guard.statement_cache.get(sql).expect("just cached").clone();
150
151        let param_binary: Vec<Vec<u8>> = params.iter().map(|p| p.to_sql()).collect::<Result<Vec<Vec<u8>>>>()?;
152        let param_formats: Vec<crate::types::Format> = params.iter().map(|_| crate::types::Format::Binary).collect();
153
154        send_bind(
155            &mut guard,
156            "",
157            &cached_stmt_name,
158            &param_formats,
159            &param_binary,
160            crate::types::Format::Binary,
161        )
162        .await?;
163        read_until_bind_complete(&mut guard).await?;
164
165        send_describe_portal(&mut guard, "").await?;
166
167        send_execute(&mut guard, "", 0).await?;
168        let rows = read_rows_until_complete(&mut guard, &fields).await?;
169
170        send_sync(&mut guard).await?;
171        read_until_ready(&mut guard).await?;
172
173        Ok(rows)
174    }
175
176    pub async fn execute_raw(&self, sql: &str, params: &[&dyn ToSql]) -> Result<u64> {
177        let mut guard = self.inner.lock().await;
178
179        if params.is_empty() {
180            return simple_execute(&mut guard, sql).await;
181        }
182
183        let stmt_name = sql_statement_name(sql);
184        let param_oids: Vec<u32> = params.iter().map(|p| p.pg_type().oid).collect();
185
186        if let Some((old_name, _)) = guard.statement_cache.get(sql) {
187            if *old_name != stmt_name {
188                let old = old_name.clone();
189                guard.statement_cache.pop(sql);
190                send_close_statement(&mut guard, &old).await?;
191            }
192        }
193
194        if !guard.statement_cache.contains(sql) {
195            send_parse(&mut guard, &stmt_name, sql, &param_oids).await?;
196            read_until_parse_complete(&mut guard).await?;
197            guard
198                .statement_cache
199                .put(sql.to_string(), (stmt_name.clone(), Vec::new()));
200        }
201
202        let cached_stmt_name = guard
203            .statement_cache
204            .get(sql)
205            .map(|(n, _)| n.clone())
206            .expect("just cached");
207
208        let param_binary: Vec<Vec<u8>> = params.iter().map(|p| p.to_sql()).collect::<Result<Vec<Vec<u8>>>>()?;
209        let param_formats: Vec<crate::types::Format> = params.iter().map(|_| crate::types::Format::Binary).collect();
210
211        send_bind(
212            &mut guard,
213            "",
214            &cached_stmt_name,
215            &param_formats,
216            &param_binary,
217            crate::types::Format::Binary,
218        )
219        .await?;
220        read_until_bind_complete(&mut guard).await?;
221
222        send_execute(&mut guard, "", 0).await?;
223        let rows_affected = read_command_complete(&mut guard).await?;
224
225        send_sync(&mut guard).await?;
226        read_until_ready(&mut guard).await?;
227
228        Ok(rows_affected)
229    }
230
231    pub(crate) async fn ping(&self) -> Result<()> {
232        self.query_raw("SELECT 1", &[]).await?;
233        Ok(())
234    }
235}
236
237async fn tls_handshake(tcp: tokio::net::TcpStream, host: &str) -> Result<PgStream> {
238    let msg = FrontendMessage::SslRequest;
239    let encoded = msg.encode();
240    let (mut reader, mut writer) = tokio::io::split(tcp);
241
242    writer.write_all(&encoded).await?;
243
244    let mut response = [0u8; 1];
245    reader.read_exact(&mut response).await?;
246
247    if response[0] == b'S' {
248        let config = rustls::ClientConfig::builder()
249            .with_root_certificates(rustls::RootCertStore::empty())
250            .with_no_client_auth();
251
252        let connector = TlsConnector::from(Arc::new(config));
253        let server_name = ServerName::try_from(host.to_string())
254            .map_err(|_| PgError::Config(format!("invalid hostname: {}", host)))?;
255
256        let tls_stream = connector
257            .connect(server_name, reader.unsplit(writer))
258            .await
259            .map_err(|e| PgError::Tls(Box::new(e)))?;
260        Ok(PgStream::Tls(tls_stream))
261    } else {
262        let tcp = reader.unsplit(writer);
263        Ok(PgStream::Plain(tcp))
264    }
265}
266
267async fn do_startup(inner: &mut Inner, params: &ConnectParams) -> Result<()> {
268    let mut kv = vec![
269        ("client_encoding".to_string(), "UTF8".to_string()),
270        ("user".to_string(), params.user.clone()),
271    ];
272    if let Some(ref db) = params.dbname {
273        kv.push(("database".to_string(), db.clone()));
274    }
275
276    let msg = FrontendMessage::Startup(kv);
277    inner.write_all(&msg.encode()).await?;
278
279    loop {
280        let backend = read_message(inner).await?;
281        match backend {
282            BackendMessage::AuthenticationOk => {
283                read_until_ready(inner).await?;
284                return Ok(());
285            }
286            BackendMessage::AuthenticationCleartextPassword => {
287                let password = params.password.as_deref().unwrap_or("");
288                let msg = FrontendMessage::Password(password.to_string());
289                inner.write_all(&msg.encode()).await?;
290            }
291            BackendMessage::AuthenticationSasl(mechanisms) => {
292                if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
293                    return Err(PgError::Auth("server does not support SCRAM-SHA-256".into()));
294                }
295                let mut scram = ScramClient::new(&params.user, params.password.as_deref().unwrap_or(""));
296
297                let initial = scram.client_first_message().as_bytes().to_vec();
298                let msg = FrontendMessage::SaslInitialResponse("SCRAM-SHA-256".to_string(), initial);
299                inner.write_all(&msg.encode()).await?;
300
301                let continue_msg = read_message(inner).await?;
302                match continue_msg {
303                    BackendMessage::AuthenticationSaslContinue(data) => {
304                        scram.parse_server_first_message(&data)?;
305                        let final_msg = scram.build_client_final_message();
306                        let msg = FrontendMessage::SaslResponse(final_msg);
307                        inner.write_all(&msg.encode()).await?;
308                    }
309                    _ => return Err(PgError::Protocol("expected SASL continue".into())),
310                }
311
312                let final_msg = read_message(inner).await?;
313                match final_msg {
314                    BackendMessage::AuthenticationSaslFinal(data) => {
315                        scram.parse_server_final_message(&data)?;
316                    }
317                    _ => return Err(PgError::Protocol("expected SASL final".into())),
318                }
319            }
320            BackendMessage::ParameterStatus(_, _) => {}
321            BackendMessage::BackendKeyData(pid, key) => {
322                inner._pid = pid;
323                inner._secret_key = key;
324            }
325            BackendMessage::ReadyForQuery(_) => {
326                return Ok(());
327            }
328            _ => return Err(PgError::Protocol("unexpected message during startup".into())),
329        }
330    }
331}
332
333async fn simple_query(inner: &mut Inner, sql: &str) -> Result<Vec<Row>> {
334    let msg = FrontendMessage::Query(sql.to_string());
335    inner.write_all(&msg.encode()).await?;
336
337    let mut rows = Vec::new();
338    let mut fields = Vec::new();
339
340    loop {
341        let backend = read_message(inner).await?;
342        match backend {
343            BackendMessage::RowDescription(fds) => {
344                fields = fds;
345            }
346            BackendMessage::DataRow(cols) => {
347                let row = Row::new(&fields, &cols);
348                rows.push(row);
349            }
350            BackendMessage::CommandComplete(_) => {}
351            BackendMessage::ReadyForQuery(_) => {
352                return Ok(rows);
353            }
354            BackendMessage::EmptyQueryResponse => {}
355            BackendMessage::NoticeResponse(_) => {}
356            _ => {
357                return Err(PgError::Protocol("unexpected message in simple query".into()));
358            }
359        }
360    }
361}
362
363async fn simple_execute(inner: &mut Inner, sql: &str) -> Result<u64> {
364    let msg = FrontendMessage::Query(sql.to_string());
365    inner.write_all(&msg.encode()).await?;
366
367    let mut rows_affected = 0u64;
368
369    loop {
370        let backend = read_message(inner).await?;
371        match backend {
372            BackendMessage::RowDescription(_) => {}
373            BackendMessage::DataRow(_) => {}
374            BackendMessage::CommandComplete(tag) => {
375                parse_command_tag(&tag, &mut rows_affected);
376            }
377            BackendMessage::ReadyForQuery(_) => {
378                return Ok(rows_affected);
379            }
380            BackendMessage::EmptyQueryResponse => {}
381            BackendMessage::NoticeResponse(_) => {}
382            _ => {
383                return Err(PgError::Protocol("unexpected message in simple execute".into()));
384            }
385        }
386    }
387}
388
389fn parse_command_tag(tag: &str, affected: &mut u64) -> u64 {
390    if let Some(n) = tag.rsplit(' ').next().and_then(|s| s.parse::<u64>().ok()) {
391        *affected = n;
392        n
393    } else {
394        0
395    }
396}
397
398async fn send_parse(inner: &mut Inner, stmt_name: &str, sql: &str, param_oids: &[u32]) -> Result<()> {
399    let msg = FrontendMessage::Parse(stmt_name.to_string(), sql.to_string(), param_oids.to_vec());
400    inner.write_all(&msg.encode()).await?;
401    Ok(())
402}
403
404async fn read_until_parse_complete(inner: &mut Inner) -> Result<()> {
405    loop {
406        let backend = read_message(inner).await?;
407        match backend {
408            BackendMessage::ParseComplete => return Ok(()),
409            BackendMessage::NoticeResponse(_) => {}
410            _ => return Err(PgError::Protocol("expected ParseComplete".into())),
411        }
412    }
413}
414
415async fn send_bind(
416    inner: &mut Inner,
417    portal: &str,
418    stmt: &str,
419    formats: &[crate::types::Format],
420    params: &[Vec<u8>],
421    result_format: crate::types::Format,
422) -> Result<()> {
423    let msg = FrontendMessage::Bind(
424        portal.to_string(),
425        stmt.to_string(),
426        formats.to_vec(),
427        params.to_vec(),
428        result_format,
429    );
430    inner.write_all(&msg.encode()).await?;
431    Ok(())
432}
433
434async fn read_until_bind_complete(inner: &mut Inner) -> Result<()> {
435    loop {
436        let backend = read_message(inner).await?;
437        match backend {
438            BackendMessage::BindComplete => return Ok(()),
439            BackendMessage::NoticeResponse(_) => {}
440            _ => return Err(PgError::Protocol("expected BindComplete".into())),
441        }
442    }
443}
444
445async fn send_describe_statement(inner: &mut Inner, stmt: &str) -> Result<()> {
446    let msg = FrontendMessage::Describe(b'S', stmt.to_string());
447    inner.write_all(&msg.encode()).await?;
448    Ok(())
449}
450
451async fn send_describe_portal(inner: &mut Inner, portal: &str) -> Result<()> {
452    let msg = FrontendMessage::Describe(b'P', portal.to_string());
453    inner.write_all(&msg.encode()).await?;
454    Ok(())
455}
456
457async fn read_describe_response(inner: &mut Inner) -> Result<Option<Vec<FieldDescription>>> {
458    loop {
459        let backend = read_message(inner).await?;
460        match backend {
461            BackendMessage::RowDescription(fields) => return Ok(Some(fields)),
462            BackendMessage::NoData => return Ok(None),
463            BackendMessage::NoticeResponse(_) => {}
464            _ => return Err(PgError::Protocol("expected RowDescription or NoData".into())),
465        }
466    }
467}
468
469async fn send_execute(inner: &mut Inner, portal: &str, max_rows: i32) -> Result<()> {
470    let msg = FrontendMessage::Execute(portal.to_string(), max_rows);
471    inner.write_all(&msg.encode()).await?;
472    Ok(())
473}
474
475async fn read_rows_until_complete(inner: &mut Inner, fields: &[FieldDescription]) -> Result<Vec<Row>> {
476    let mut rows = Vec::new();
477
478    loop {
479        let backend = read_message(inner).await?;
480        match backend {
481            BackendMessage::DataRow(cols) => {
482                let row = Row::new(fields, &cols);
483                rows.push(row);
484            }
485            BackendMessage::CommandComplete(_) => {
486                return Ok(rows);
487            }
488            BackendMessage::PortalSuspended => {
489                return Ok(rows);
490            }
491            BackendMessage::NoticeResponse(_) => {}
492            _ => return Err(PgError::Protocol("expected DataRow or CommandComplete".into())),
493        }
494    }
495}
496
497async fn read_command_complete(inner: &mut Inner) -> Result<u64> {
498    let mut affected = 0u64;
499    loop {
500        let backend = read_message(inner).await?;
501        match backend {
502            BackendMessage::CommandComplete(tag) => {
503                parse_command_tag(&tag, &mut affected);
504                return Ok(affected);
505            }
506            BackendMessage::NoticeResponse(_) => {}
507            _ => return Err(PgError::Protocol("expected CommandComplete".into())),
508        }
509    }
510}
511
512async fn send_sync(inner: &mut Inner) -> Result<()> {
513    let msg = FrontendMessage::Sync;
514    inner.write_all(&msg.encode()).await?;
515    Ok(())
516}
517
518async fn send_close_statement(inner: &mut Inner, stmt: &str) -> Result<()> {
519    let msg = FrontendMessage::Close(b'S', stmt.to_string());
520    inner.write_all(&msg.encode()).await?;
521    read_until_close_complete(inner).await
522}
523
524async fn read_until_close_complete(inner: &mut Inner) -> Result<()> {
525    loop {
526        let backend = read_message(inner).await?;
527        match backend {
528            BackendMessage::CloseComplete => return Ok(()),
529            BackendMessage::NoticeResponse(_) => {}
530            _ => return Err(PgError::Protocol("expected CloseComplete".into())),
531        }
532    }
533}
534
535async fn read_until_ready(inner: &mut Inner) -> Result<u8> {
536    loop {
537        let backend = read_message(inner).await?;
538        match backend {
539            BackendMessage::ReadyForQuery(status) => return Ok(status),
540            BackendMessage::NoticeResponse(_) => {}
541            BackendMessage::ParameterStatus(_, _) => {}
542            BackendMessage::CommandComplete(_) => {}
543            _ => {
544                return Err(PgError::Protocol("expected ReadyForQuery".into()));
545            }
546        }
547    }
548}
549
550async fn read_message(inner: &mut Inner) -> Result<BackendMessage> {
551    loop {
552        if let Some(msg) = BackendDecoder::decode(&mut inner.buf)? {
553            return Ok(msg);
554        }
555        inner.buf.reserve(4096);
556        let n = inner.stream.read_buf(&mut inner.buf).await.map_err(PgError::Io)?;
557        if n == 0 {
558            return Err(PgError::Protocol("connection closed by server".into()));
559        }
560    }
561}
562
563impl Inner {
564    async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
565        AsyncWriteExt::write_all(&mut self.stream, buf)
566            .await
567            .map_err(PgError::Io)
568    }
569}