1use bytes::Buf;
2
3use crate::error::{DbError, PgError};
4
5#[derive(Debug)]
6pub enum BackendMessage {
7 AuthenticationOk,
8 AuthenticationCleartextPassword,
9 AuthenticationMD5Password([u8; 4]),
10 AuthenticationSasl(Vec<String>),
11 AuthenticationSaslContinue(Vec<u8>),
12 AuthenticationSaslFinal(Vec<u8>),
13 AuthenticationKerberosV5,
14 AuthenticationScmCredential,
15 AuthenticationGss,
16 AuthenticationSspi,
17 AuthenticationGssContinue(Vec<u8>),
18 BackendKeyData(i32, i32),
19 ParameterStatus(String, String),
20 ReadyForQuery(u8),
21 RowDescription(Vec<FieldDescription>),
22 DataRow(Vec<Option<Vec<u8>>>),
23 CommandComplete(String),
24 ParseComplete,
25 BindComplete,
26 CloseComplete,
27 PortalSuspended,
28 ErrorResponse(DbError),
29 NoticeResponse(DbError),
30 NotificationResponse(i32, String, String),
31 EmptyQueryResponse,
32 NoData,
33}
34
35#[derive(Debug, Clone)]
36pub struct FieldDescription {
37 pub name: String,
38 pub table_oid: u32,
39 pub column_attr: i16,
40 pub type_oid: u32,
41 pub type_size: i16,
42 pub type_mod: i32,
43 pub format: i16,
44}
45
46pub struct BackendDecoder;
47
48impl BackendDecoder {
49 pub fn decode(buf: &mut bytes::BytesMut) -> Result<Option<BackendMessage>, PgError> {
50 if buf.len() < 5 {
51 return Ok(None);
52 }
53
54 let tag = buf[0];
55 let len = (&buf[1..5]).get_i32() as usize;
56
57 if buf.len() < 5 + len - 4 {
58 return Ok(None);
59 }
60
61 buf.advance(5);
62 let mut payload = buf.split_to(len - 4);
63
64 let msg = match tag {
65 b'R' => Self::decode_authentication(&mut payload)?,
66 b'K' => Self::decode_backend_key_data(&mut payload),
67 b'S' => Self::decode_parameter_status(&mut payload),
68 b'Z' => Self::decode_ready_for_query(&mut payload),
69 b'T' => Self::decode_row_description(&mut payload)?,
70 b'D' => Self::decode_data_row(&mut payload),
71 b'C' => Self::decode_command_complete(&mut payload),
72 b'1' => BackendMessage::ParseComplete,
73 b'2' => BackendMessage::BindComplete,
74 b'3' => BackendMessage::CloseComplete,
75 b's' => BackendMessage::PortalSuspended,
76 b'E' => {
77 let err = Self::decode_error(&mut payload);
78 return Err(PgError::Server(err));
79 }
80 b'N' => {
81 let err = Self::decode_error(&mut payload);
82 BackendMessage::NoticeResponse(err)
83 }
84 b'A' => Self::decode_notification_response(&mut payload),
85 b'I' => BackendMessage::EmptyQueryResponse,
86 b'n' => BackendMessage::NoData,
87 other => return Err(PgError::Protocol(format!("unknown message tag: {:?}", other as char))),
88 };
89
90 Ok(Some(msg))
91 }
92
93 fn decode_authentication(buf: &mut bytes::BytesMut) -> Result<BackendMessage, PgError> {
94 let kind = buf.get_i32();
95
96 match kind {
97 0 => Ok(BackendMessage::AuthenticationOk),
98 2 => Ok(BackendMessage::AuthenticationKerberosV5),
99 3 => Ok(BackendMessage::AuthenticationCleartextPassword),
100 5 => {
101 let mut salt = [0u8; 4];
102 buf.copy_to_slice(&mut salt);
103 Ok(BackendMessage::AuthenticationMD5Password(salt))
104 }
105 6 => Ok(BackendMessage::AuthenticationScmCredential),
106 7 => Ok(BackendMessage::AuthenticationGss),
107 8 => Ok(BackendMessage::AuthenticationGssContinue(buf.to_vec())),
108 9 => Ok(BackendMessage::AuthenticationSspi),
109 10 => {
110 let mut mechanisms = Vec::new();
111 while buf.has_remaining() {
112 let b = buf.get_u8();
113 if b == 0 {
114 break;
115 }
116 let mut s = vec![b];
117 while buf.has_remaining() {
118 let b = buf.get_u8();
119 if b == 0 {
120 break;
121 }
122 s.push(b);
123 }
124 mechanisms.push(String::from_utf8_lossy(&s).to_string());
125 }
126 Ok(BackendMessage::AuthenticationSasl(mechanisms))
127 }
128 11 => Ok(BackendMessage::AuthenticationSaslContinue(buf.to_vec())),
129 12 => Ok(BackendMessage::AuthenticationSaslFinal(buf.to_vec())),
130 _ => Err(PgError::Auth(format!("unknown auth method: {}", kind))),
131 }
132 }
133
134 fn decode_backend_key_data(buf: &mut bytes::BytesMut) -> BackendMessage {
135 let pid = buf.get_i32();
136 let key = buf.get_i32();
137 BackendMessage::BackendKeyData(pid, key)
138 }
139
140 fn decode_parameter_status(buf: &mut bytes::BytesMut) -> BackendMessage {
141 let key = read_cstring(buf);
142 let value = read_cstring(buf);
143 BackendMessage::ParameterStatus(key, value)
144 }
145
146 fn decode_ready_for_query(buf: &mut bytes::BytesMut) -> BackendMessage {
147 let status = buf.get_u8();
148 BackendMessage::ReadyForQuery(status)
149 }
150
151 fn decode_row_description(buf: &mut bytes::BytesMut) -> Result<BackendMessage, PgError> {
152 let count = buf.get_i16();
153 let mut fields = Vec::with_capacity(count as usize);
154 for _ in 0..count {
155 let name = read_cstring(buf);
156 let table_oid = buf.get_u32();
157 let column_attr = buf.get_i16();
158 let type_oid = buf.get_u32();
159 let type_size = buf.get_i16();
160 let type_mod = buf.get_i32();
161 let format = buf.get_i16();
162 fields.push(FieldDescription {
163 name,
164 table_oid,
165 column_attr,
166 type_oid,
167 type_size,
168 type_mod,
169 format,
170 });
171 }
172 Ok(BackendMessage::RowDescription(fields))
173 }
174
175 fn decode_data_row(buf: &mut bytes::BytesMut) -> BackendMessage {
176 let count = buf.get_i16();
177 let mut columns = Vec::with_capacity(count as usize);
178 for _ in 0..count {
179 let len = buf.get_i32();
180 if len == -1 {
181 columns.push(None);
182 } else {
183 let mut data = vec![0u8; len as usize];
184 buf.copy_to_slice(&mut data);
185 columns.push(Some(data));
186 }
187 }
188 BackendMessage::DataRow(columns)
189 }
190
191 fn decode_command_complete(buf: &mut bytes::BytesMut) -> BackendMessage {
192 let tag = read_cstring(buf);
193 BackendMessage::CommandComplete(tag)
194 }
195
196 fn decode_error(buf: &mut bytes::BytesMut) -> DbError {
197 let mut severity = String::new();
198 let mut code = String::new();
199 let mut message = String::new();
200 let mut detail = None;
201 let mut hint = None;
202 let mut position = None;
203
204 while buf.has_remaining() {
205 let field_type = buf.get_u8();
206 if field_type == 0 {
207 break;
208 }
209 let value = read_cstring(buf);
210 match field_type {
211 b'S' => severity = value,
212 b'C' => code = value,
213 b'M' => message = value,
214 b'D' => detail = Some(value),
215 b'H' => hint = Some(value),
216 b'P' => position = value.parse().ok(),
217 _ => {}
218 }
219 }
220
221 DbError {
222 severity,
223 code,
224 message,
225 detail,
226 hint,
227 position,
228 }
229 }
230
231 fn decode_notification_response(buf: &mut bytes::BytesMut) -> BackendMessage {
232 let pid = buf.get_i32();
233 let channel = read_cstring(buf);
234 let payload = read_cstring(buf);
235 BackendMessage::NotificationResponse(pid, channel, payload)
236 }
237}
238
239fn read_cstring(buf: &mut bytes::BytesMut) -> String {
240 let mut s = Vec::new();
241 loop {
242 let b = buf.get_u8();
243 if b == 0 {
244 break;
245 }
246 s.push(b);
247 }
248 String::from_utf8_lossy(&s).to_string()
249}