Skip to main content

pg/protocol/
frontend.rs

1use bytes::{BufMut, Bytes, BytesMut};
2
3use crate::types::Format;
4
5fn startup_message(params: &[(&str, &str)]) -> Bytes {
6    let mut buf = BytesMut::new();
7    let body_len = 4 + 4 + params.iter().map(|(k, v)| k.len() + v.len() + 2).sum::<usize>() + 1;
8    buf.put_i32(body_len as i32);
9    buf.put_i32(196608);
10    for (k, v) in params {
11        buf.put_slice(k.as_bytes());
12        buf.put_u8(0);
13        buf.put_slice(v.as_bytes());
14        buf.put_u8(0);
15    }
16    buf.put_u8(0);
17    buf.freeze()
18}
19
20fn ssl_request() -> Bytes {
21    let mut buf = BytesMut::new();
22    buf.put_i32(8);
23    buf.put_i32(80877103);
24    buf.freeze()
25}
26
27fn password_message(password: &str) -> Bytes {
28    let mut buf = BytesMut::new();
29    buf.put_u8(b'p');
30    let body = password.as_bytes();
31    buf.put_i32((body.len() + 4 + 1) as i32);
32    buf.put_slice(body);
33    buf.put_u8(0);
34    buf.freeze()
35}
36
37fn sasl_initial_response(mechanism: &str, initial_data: &[u8]) -> Bytes {
38    let mechanism_c = mechanism.as_bytes();
39    let initial_len = mechanism_c.len() + 1 + 4 + initial_data.len();
40    let mut buf = BytesMut::new();
41    buf.put_u8(b'p');
42    buf.put_i32((initial_len + 4) as i32);
43    buf.put_slice(mechanism_c);
44    buf.put_u8(0);
45    buf.put_i32(initial_data.len() as i32);
46    buf.put_slice(initial_data);
47    buf.freeze()
48}
49
50fn sasl_response(data: &[u8]) -> Bytes {
51    let mut buf = BytesMut::new();
52    buf.put_u8(b'p');
53    buf.put_i32((data.len() + 4) as i32);
54    buf.put_slice(data);
55    buf.freeze()
56}
57
58fn query(sql: &str) -> Bytes {
59    let sql_bytes = sql.as_bytes();
60    let mut buf = BytesMut::new();
61    buf.put_u8(b'Q');
62    buf.put_i32((sql_bytes.len() + 4 + 1) as i32);
63    buf.put_slice(sql_bytes);
64    buf.put_u8(0);
65    buf.freeze()
66}
67
68fn parse(stmt_name: &str, sql: &str, param_oids: &[u32]) -> Bytes {
69    let sql_bytes = sql.as_bytes();
70    let name_bytes = stmt_name.as_bytes();
71    let body_len = name_bytes.len() + 1 + sql_bytes.len() + 1 + 2 + param_oids.len() * 4;
72    let mut buf = BytesMut::new();
73    buf.put_u8(b'P');
74    buf.put_i32((body_len + 4) as i32);
75    buf.put_slice(name_bytes);
76    buf.put_u8(0);
77    buf.put_slice(sql_bytes);
78    buf.put_u8(0);
79    buf.put_i16(param_oids.len() as i16);
80    for oid in param_oids {
81        buf.put_i32(*oid as i32);
82    }
83    buf.freeze()
84}
85
86fn bind(portal: &str, stmt: &str, param_formats: &[Format], params: &[Vec<u8>], result_format: Format) -> Bytes {
87    let portal_bytes = portal.as_bytes();
88    let stmt_bytes = stmt.as_bytes();
89    let params_len: usize = params.iter().map(|p| p.len() + 4).sum();
90    let body_len = portal_bytes.len()
91        + 1
92        + stmt_bytes.len()
93        + 1
94        + 2
95        + param_formats.len() * 2
96        + 2
97        + params.len() * 4
98        + params_len
99        + 2;
100    let mut buf = BytesMut::new();
101    buf.put_u8(b'B');
102    buf.put_i32((body_len + 4) as i32);
103    buf.put_slice(portal_bytes);
104    buf.put_u8(0);
105    buf.put_slice(stmt_bytes);
106    buf.put_u8(0);
107    buf.put_i16(param_formats.len() as i16);
108    for fmt in param_formats {
109        buf.put_i16(match fmt {
110            Format::Text => 0,
111            Format::Binary => 1,
112        });
113    }
114    buf.put_i16(params.len() as i16);
115    for p in params {
116        if p.is_empty() {
117            buf.put_i32(-1);
118        } else {
119            buf.put_i32(p.len() as i32);
120            buf.put_slice(p);
121        }
122    }
123    buf.put_i16(match result_format {
124        Format::Text => 0,
125        Format::Binary => 1,
126    });
127    buf.freeze()
128}
129
130fn execute(portal: &str, max_rows: i32) -> Bytes {
131    let portal_bytes = portal.as_bytes();
132    let mut buf = BytesMut::new();
133    buf.put_u8(b'E');
134    buf.put_i32((portal_bytes.len() + 1 + 4 + 4) as i32);
135    buf.put_slice(portal_bytes);
136    buf.put_u8(0);
137    buf.put_i32(max_rows);
138    buf.freeze()
139}
140
141fn describe(kind: u8, name: &str) -> Bytes {
142    let name_bytes = name.as_bytes();
143    let mut buf = BytesMut::new();
144    buf.put_u8(b'D');
145    buf.put_i32((4 + 1 + name_bytes.len() + 1) as i32);
146    buf.put_u8(kind);
147    buf.put_slice(name_bytes);
148    buf.put_u8(0);
149    buf.freeze()
150}
151
152fn sync() -> Bytes {
153    let mut buf = BytesMut::new();
154    buf.put_u8(b'S');
155    buf.put_i32(4 + 4);
156    buf.freeze()
157}
158
159fn close(kind: u8, name: &str) -> Bytes {
160    let name_bytes = name.as_bytes();
161    let mut buf = BytesMut::new();
162    buf.put_u8(b'C');
163    buf.put_i32((4 + 1 + name_bytes.len() + 1) as i32);
164    buf.put_u8(kind);
165    buf.put_slice(name_bytes);
166    buf.put_u8(0);
167    buf.freeze()
168}
169
170fn terminate() -> Bytes {
171    let mut buf = BytesMut::new();
172    buf.put_u8(b'X');
173    buf.put_i32(4 + 4);
174    buf.freeze()
175}
176
177pub enum FrontendMessage {
178    Startup(Vec<(String, String)>),
179    SslRequest,
180    Password(String),
181    SaslInitialResponse(String, Vec<u8>),
182    SaslResponse(Vec<u8>),
183    Query(String),
184    Parse(String, String, Vec<u32>),
185    Bind(String, String, Vec<Format>, Vec<Vec<u8>>, Format),
186    Execute(String, i32),
187    Describe(u8, String),
188    Sync,
189    Close(u8, String),
190    Terminate,
191}
192
193impl FrontendMessage {
194    pub fn encode(&self) -> Bytes {
195        match self {
196            FrontendMessage::Startup(params) => {
197                let pairs: Vec<(&str, &str)> = params.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
198                startup_message(&pairs)
199            }
200            FrontendMessage::SslRequest => ssl_request(),
201            FrontendMessage::Password(p) => password_message(p),
202            FrontendMessage::SaslInitialResponse(mech, data) => sasl_initial_response(mech, data),
203            FrontendMessage::SaslResponse(data) => sasl_response(data),
204            FrontendMessage::Query(sql) => query(sql),
205            FrontendMessage::Parse(name, sql, oids) => parse(name, sql, oids),
206            FrontendMessage::Bind(portal, stmt, formats, params, result_fmt) => {
207                bind(portal, stmt, formats, params, *result_fmt)
208            }
209            FrontendMessage::Execute(portal, rows) => execute(portal, *rows),
210            FrontendMessage::Describe(kind, name) => describe(*kind, name),
211            FrontendMessage::Sync => sync(),
212            FrontendMessage::Close(kind, name) => close(*kind, name),
213            FrontendMessage::Terminate => terminate(),
214        }
215    }
216}