Skip to main content

pg/protocol/
backend.rs

1use bytes::Buf;
2
3use crate::error::{DbError, PgError};
4
5#[derive(Debug)]
6pub enum BackendMessage {
7    AuthenticationOk,
8    AuthenticationCleartextPassword,
9    AuthenticationMD5Password([u8; 4]),
10    AuthenticationSasl(Vec<String>),
11    AuthenticationSaslContinue(Vec<u8>),
12    AuthenticationSaslFinal(Vec<u8>),
13    AuthenticationKerberosV5,
14    AuthenticationScmCredential,
15    AuthenticationGss,
16    AuthenticationSspi,
17    AuthenticationGssContinue(Vec<u8>),
18    BackendKeyData(i32, i32),
19    ParameterStatus(String, String),
20    ReadyForQuery(u8),
21    RowDescription(Vec<FieldDescription>),
22    DataRow(Vec<Option<Vec<u8>>>),
23    CommandComplete(String),
24    ParseComplete,
25    BindComplete,
26    CloseComplete,
27    PortalSuspended,
28    ErrorResponse(DbError),
29    NoticeResponse(DbError),
30    NotificationResponse(i32, String, String),
31    EmptyQueryResponse,
32    NoData,
33}
34
35#[derive(Debug, Clone)]
36pub struct FieldDescription {
37    pub name: String,
38    pub table_oid: u32,
39    pub column_attr: i16,
40    pub type_oid: u32,
41    pub type_size: i16,
42    pub type_mod: i32,
43    pub format: i16,
44}
45
46pub struct BackendDecoder;
47
48impl BackendDecoder {
49    pub fn decode(buf: &mut bytes::BytesMut) -> Result<Option<BackendMessage>, PgError> {
50        if buf.len() < 5 {
51            return Ok(None);
52        }
53
54        let tag = buf[0];
55        let len = (&buf[1..5]).get_i32() as usize;
56
57        if buf.len() < 5 + len - 4 {
58            return Ok(None);
59        }
60
61        buf.advance(5);
62        let mut payload = buf.split_to(len - 4);
63
64        let msg = match tag {
65            b'R' => Self::decode_authentication(&mut payload)?,
66            b'K' => Self::decode_backend_key_data(&mut payload),
67            b'S' => Self::decode_parameter_status(&mut payload),
68            b'Z' => Self::decode_ready_for_query(&mut payload),
69            b'T' => Self::decode_row_description(&mut payload)?,
70            b'D' => Self::decode_data_row(&mut payload),
71            b'C' => Self::decode_command_complete(&mut payload),
72            b'1' => BackendMessage::ParseComplete,
73            b'2' => BackendMessage::BindComplete,
74            b'3' => BackendMessage::CloseComplete,
75            b's' => BackendMessage::PortalSuspended,
76            b'E' => {
77                let err = Self::decode_error(&mut payload);
78                return Err(PgError::Server(err));
79            }
80            b'N' => {
81                let err = Self::decode_error(&mut payload);
82                BackendMessage::NoticeResponse(err)
83            }
84            b'A' => Self::decode_notification_response(&mut payload),
85            b'I' => BackendMessage::EmptyQueryResponse,
86            b'n' => BackendMessage::NoData,
87            other => return Err(PgError::Protocol(format!("unknown message tag: {:?}", other as char))),
88        };
89
90        Ok(Some(msg))
91    }
92
93    fn decode_authentication(buf: &mut bytes::BytesMut) -> Result<BackendMessage, PgError> {
94        let kind = buf.get_i32();
95
96        match kind {
97            0 => Ok(BackendMessage::AuthenticationOk),
98            2 => Ok(BackendMessage::AuthenticationKerberosV5),
99            3 => Ok(BackendMessage::AuthenticationCleartextPassword),
100            5 => {
101                let mut salt = [0u8; 4];
102                buf.copy_to_slice(&mut salt);
103                Ok(BackendMessage::AuthenticationMD5Password(salt))
104            }
105            6 => Ok(BackendMessage::AuthenticationScmCredential),
106            7 => Ok(BackendMessage::AuthenticationGss),
107            8 => Ok(BackendMessage::AuthenticationGssContinue(buf.to_vec())),
108            9 => Ok(BackendMessage::AuthenticationSspi),
109            10 => {
110                let mut mechanisms = Vec::new();
111                while buf.has_remaining() {
112                    let b = buf.get_u8();
113                    if b == 0 {
114                        break;
115                    }
116                    let mut s = vec![b];
117                    while buf.has_remaining() {
118                        let b = buf.get_u8();
119                        if b == 0 {
120                            break;
121                        }
122                        s.push(b);
123                    }
124                    mechanisms.push(String::from_utf8_lossy(&s).to_string());
125                }
126                Ok(BackendMessage::AuthenticationSasl(mechanisms))
127            }
128            11 => Ok(BackendMessage::AuthenticationSaslContinue(buf.to_vec())),
129            12 => Ok(BackendMessage::AuthenticationSaslFinal(buf.to_vec())),
130            _ => Err(PgError::Auth(format!("unknown auth method: {}", kind))),
131        }
132    }
133
134    fn decode_backend_key_data(buf: &mut bytes::BytesMut) -> BackendMessage {
135        let pid = buf.get_i32();
136        let key = buf.get_i32();
137        BackendMessage::BackendKeyData(pid, key)
138    }
139
140    fn decode_parameter_status(buf: &mut bytes::BytesMut) -> BackendMessage {
141        let key = read_cstring(buf);
142        let value = read_cstring(buf);
143        BackendMessage::ParameterStatus(key, value)
144    }
145
146    fn decode_ready_for_query(buf: &mut bytes::BytesMut) -> BackendMessage {
147        let status = buf.get_u8();
148        BackendMessage::ReadyForQuery(status)
149    }
150
151    fn decode_row_description(buf: &mut bytes::BytesMut) -> Result<BackendMessage, PgError> {
152        let count = buf.get_i16();
153        let mut fields = Vec::with_capacity(count as usize);
154        for _ in 0..count {
155            let name = read_cstring(buf);
156            let table_oid = buf.get_u32();
157            let column_attr = buf.get_i16();
158            let type_oid = buf.get_u32();
159            let type_size = buf.get_i16();
160            let type_mod = buf.get_i32();
161            let format = buf.get_i16();
162            fields.push(FieldDescription {
163                name,
164                table_oid,
165                column_attr,
166                type_oid,
167                type_size,
168                type_mod,
169                format,
170            });
171        }
172        Ok(BackendMessage::RowDescription(fields))
173    }
174
175    fn decode_data_row(buf: &mut bytes::BytesMut) -> BackendMessage {
176        let count = buf.get_i16();
177        let mut columns = Vec::with_capacity(count as usize);
178        for _ in 0..count {
179            let len = buf.get_i32();
180            if len == -1 {
181                columns.push(None);
182            } else {
183                let mut data = vec![0u8; len as usize];
184                buf.copy_to_slice(&mut data);
185                columns.push(Some(data));
186            }
187        }
188        BackendMessage::DataRow(columns)
189    }
190
191    fn decode_command_complete(buf: &mut bytes::BytesMut) -> BackendMessage {
192        let tag = read_cstring(buf);
193        BackendMessage::CommandComplete(tag)
194    }
195
196    fn decode_error(buf: &mut bytes::BytesMut) -> DbError {
197        let mut severity = String::new();
198        let mut code = String::new();
199        let mut message = String::new();
200        let mut detail = None;
201        let mut hint = None;
202        let mut position = None;
203
204        while buf.has_remaining() {
205            let field_type = buf.get_u8();
206            if field_type == 0 {
207                break;
208            }
209            let value = read_cstring(buf);
210            match field_type {
211                b'S' => severity = value,
212                b'C' => code = value,
213                b'M' => message = value,
214                b'D' => detail = Some(value),
215                b'H' => hint = Some(value),
216                b'P' => position = value.parse().ok(),
217                _ => {}
218            }
219        }
220
221        DbError {
222            severity,
223            code,
224            message,
225            detail,
226            hint,
227            position,
228        }
229    }
230
231    fn decode_notification_response(buf: &mut bytes::BytesMut) -> BackendMessage {
232        let pid = buf.get_i32();
233        let channel = read_cstring(buf);
234        let payload = read_cstring(buf);
235        BackendMessage::NotificationResponse(pid, channel, payload)
236    }
237}
238
239fn read_cstring(buf: &mut bytes::BytesMut) -> String {
240    let mut s = Vec::new();
241    loop {
242        let b = buf.get_u8();
243        if b == 0 {
244            break;
245        }
246        s.push(b);
247    }
248    String::from_utf8_lossy(&s).to_string()
249}