Skip to content

Commit 9737d2b

Browse files
committed
Fix InitCountingShim handshake query handling
1 parent 9085760 commit 9737d2b

File tree

2 files changed

+172
-145
lines changed

2 files changed

+172
-145
lines changed

mysql/src/writers.rs

Lines changed: 143 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414

1515
use std::io::{self, Write};
1616

17-
use byteorder::{LittleEndian, WriteBytesExt};
18-
1917
use crate::myc::constants::{CapabilityFlags, StatusFlags};
2018
use crate::myc::io::WriteMysqlExt;
2119
use crate::packet_writer::PacketWriter;
2220
use crate::{Column, ColumnFlags, ColumnType, ErrorKind, OkResponse};
21+
use byteorder::{LittleEndian, WriteBytesExt};
22+
use tokio::io::AsyncWrite;
2323

2424
const BIN_GENERAL_CI: u16 = 0x3f;
2525

@@ -65,6 +65,147 @@ pub(crate) async fn write_eof_packet<W: AsyncWrite + Unpin>(
6565
w.end_packet().await
6666
}
6767

68+
pub(crate) async fn write_ok_packet<W: AsyncWrite + Unpin>(
69+
w: &mut PacketWriter<W>,
70+
client_capabilities: CapabilityFlags,
71+
ok_packet: OkResponse,
72+
) -> io::Result<()> {
73+
w.write_u8(ok_packet.header)?; // OK packet type
74+
w.write_lenenc_int(ok_packet.affected_rows)?;
75+
w.write_lenenc_int(ok_packet.last_insert_id)?;
76+
if client_capabilities.contains(CapabilityFlags::CLIENT_PROTOCOL_41) {
77+
w.write_u16::<LittleEndian>(ok_packet.status_flags.bits())?;
78+
w.write_u16::<LittleEndian>(ok_packet.warnings)?;
79+
} else if client_capabilities.contains(CapabilityFlags::CLIENT_TRANSACTIONS) {
80+
w.write_u16::<LittleEndian>(ok_packet.status_flags.bits())?;
81+
}
82+
83+
// Only session-tracking clients expect length-encoded info per protocol; otherwise emit raw text.
84+
let has_session_track = client_capabilities.contains(CapabilityFlags::CLIENT_SESSION_TRACK);
85+
let send_info = !ok_packet.info.is_empty() || has_session_track;
86+
if send_info {
87+
if has_session_track {
88+
w.write_lenenc_str(ok_packet.info.as_bytes())?;
89+
} else {
90+
w.write_all(ok_packet.info.as_bytes())?;
91+
}
92+
}
93+
94+
// Session state info is optional and only sent if flag is set
95+
if has_session_track
96+
&& ok_packet
97+
.status_flags
98+
.contains(StatusFlags::SERVER_SESSION_STATE_CHANGED)
99+
{
100+
w.write_lenenc_str(ok_packet.session_state_info.as_bytes())?;
101+
}
102+
w.end_packet().await
103+
}
104+
105+
pub async fn write_err<W: AsyncWrite + Unpin>(
106+
err: ErrorKind,
107+
msg: &[u8],
108+
w: &mut PacketWriter<W>,
109+
) -> io::Result<()> {
110+
w.write_u8(0xFF)?;
111+
w.write_u16::<LittleEndian>(err as u16)?;
112+
w.write_u8(b'#')?;
113+
w.write_all(err.sqlstate())?;
114+
w.write_all(msg)?;
115+
w.end_packet().await
116+
}
117+
118+
pub(crate) async fn write_prepare_ok<'a, PI, CI, W>(
119+
id: u32,
120+
params: PI,
121+
columns: CI,
122+
w: &mut PacketWriter<W>,
123+
client_capabilities: CapabilityFlags,
124+
) -> io::Result<()>
125+
where
126+
PI: IntoIterator<Item = &'a Column>,
127+
CI: IntoIterator<Item = &'a Column>,
128+
<PI as IntoIterator>::IntoIter: ExactSizeIterator,
129+
<CI as IntoIterator>::IntoIter: ExactSizeIterator,
130+
W: AsyncWrite + Unpin,
131+
{
132+
let pi = params.into_iter();
133+
let ci = columns.into_iter();
134+
135+
// first, write out COM_STMT_PREPARE_OK
136+
w.write_u8(0x00)?;
137+
w.write_u32::<LittleEndian>(id)?;
138+
w.write_u16::<LittleEndian>(ci.len() as u16)?;
139+
w.write_u16::<LittleEndian>(pi.len() as u16)?;
140+
w.write_u8(0x00)?;
141+
w.write_u16::<LittleEndian>(0)?; // number of warnings
142+
w.end_packet().await?;
143+
144+
if pi.len() > 0 {
145+
write_column_definitions_41(pi, w, client_capabilities, false).await?;
146+
}
147+
if ci.len() > 0 {
148+
write_column_definitions_41(ci, w, client_capabilities, false).await?;
149+
}
150+
Ok(())
151+
}
152+
153+
/// works for Protocol::ColumnDefinition41 is set
154+
/// see: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset_column_definition.html
155+
pub(crate) async fn write_column_definitions_41<'a, I, W>(
156+
i: I,
157+
w: &mut PacketWriter<W>,
158+
client_capabilities: CapabilityFlags,
159+
is_com_field_list: bool,
160+
) -> io::Result<()>
161+
where
162+
I: IntoIterator<Item = &'a Column>,
163+
W: AsyncWrite + Unpin,
164+
{
165+
for c in i {
166+
w.write_lenenc_str(b"def")?;
167+
w.write_lenenc_str(b"")?;
168+
w.write_lenenc_str(c.table.as_bytes())?;
169+
w.write_lenenc_str(b"")?;
170+
w.write_lenenc_str(c.column.as_bytes())?;
171+
w.write_lenenc_str(b"")?;
172+
w.write_lenenc_int(0xC)?;
173+
w.write_u16::<LittleEndian>(column_charset(c))?;
174+
w.write_u32::<LittleEndian>(1024)?;
175+
w.write_u8(c.coltype as u8)?;
176+
w.write_u16::<LittleEndian>(c.colflags.bits())?;
177+
w.write_all(&[0x00])?; // decimals
178+
w.write_all(&[0x00, 0x00])?; // unused
179+
180+
if is_com_field_list {
181+
w.write_all(&[0xfb])?;
182+
}
183+
w.end_packet().await?;
184+
}
185+
186+
if !client_capabilities.contains(CapabilityFlags::CLIENT_DEPRECATE_EOF) {
187+
write_eof_packet(w, StatusFlags::empty()).await
188+
} else {
189+
Ok(())
190+
}
191+
}
192+
193+
pub(crate) async fn column_definitions<'a, I, W>(
194+
i: I,
195+
w: &mut PacketWriter<W>,
196+
client_capabilities: CapabilityFlags,
197+
) -> io::Result<()>
198+
where
199+
I: IntoIterator<Item = &'a Column>,
200+
<I as IntoIterator>::IntoIter: ExactSizeIterator,
201+
W: AsyncWrite + Unpin,
202+
{
203+
let i = i.into_iter();
204+
w.write_lenenc_int(i.len() as u64)?;
205+
w.end_packet().await?;
206+
write_column_definitions_41(i, w, client_capabilities, false).await
207+
}
208+
68209
#[cfg(test)]
69210
mod tests {
70211
use super::*;
@@ -301,146 +442,3 @@ mod tests {
301442
assert_eq!(ok_packet.status_flags(), StatusFlags::empty());
302443
}
303444
}
304-
305-
pub(crate) async fn write_ok_packet<W: AsyncWrite + Unpin>(
306-
w: &mut PacketWriter<W>,
307-
client_capabilities: CapabilityFlags,
308-
ok_packet: OkResponse,
309-
) -> io::Result<()> {
310-
w.write_u8(ok_packet.header)?; // OK packet type
311-
w.write_lenenc_int(ok_packet.affected_rows)?;
312-
w.write_lenenc_int(ok_packet.last_insert_id)?;
313-
if client_capabilities.contains(CapabilityFlags::CLIENT_PROTOCOL_41) {
314-
w.write_u16::<LittleEndian>(ok_packet.status_flags.bits())?;
315-
w.write_u16::<LittleEndian>(ok_packet.warnings)?;
316-
} else if client_capabilities.contains(CapabilityFlags::CLIENT_TRANSACTIONS) {
317-
w.write_u16::<LittleEndian>(ok_packet.status_flags.bits())?;
318-
}
319-
320-
// Only session-tracking clients expect length-encoded info per protocol; otherwise emit raw text.
321-
let has_session_track = client_capabilities.contains(CapabilityFlags::CLIENT_SESSION_TRACK);
322-
let send_info = !ok_packet.info.is_empty() || has_session_track;
323-
if send_info {
324-
if has_session_track {
325-
w.write_lenenc_str(ok_packet.info.as_bytes())?;
326-
} else {
327-
w.write_all(ok_packet.info.as_bytes())?;
328-
}
329-
}
330-
331-
// Session state info is optional and only sent if flag is set
332-
if has_session_track
333-
&& ok_packet
334-
.status_flags
335-
.contains(StatusFlags::SERVER_SESSION_STATE_CHANGED)
336-
{
337-
w.write_lenenc_str(ok_packet.session_state_info.as_bytes())?;
338-
}
339-
w.end_packet().await
340-
}
341-
342-
pub async fn write_err<W: AsyncWrite + Unpin>(
343-
err: ErrorKind,
344-
msg: &[u8],
345-
w: &mut PacketWriter<W>,
346-
) -> io::Result<()> {
347-
w.write_u8(0xFF)?;
348-
w.write_u16::<LittleEndian>(err as u16)?;
349-
w.write_u8(b'#')?;
350-
w.write_all(err.sqlstate())?;
351-
w.write_all(msg)?;
352-
w.end_packet().await
353-
}
354-
355-
use tokio::io::AsyncWrite;
356-
357-
pub(crate) async fn write_prepare_ok<'a, PI, CI, W>(
358-
id: u32,
359-
params: PI,
360-
columns: CI,
361-
w: &mut PacketWriter<W>,
362-
client_capabilities: CapabilityFlags,
363-
) -> io::Result<()>
364-
where
365-
PI: IntoIterator<Item = &'a Column>,
366-
CI: IntoIterator<Item = &'a Column>,
367-
<PI as IntoIterator>::IntoIter: ExactSizeIterator,
368-
<CI as IntoIterator>::IntoIter: ExactSizeIterator,
369-
W: AsyncWrite + Unpin,
370-
{
371-
let pi = params.into_iter();
372-
let ci = columns.into_iter();
373-
374-
// first, write out COM_STMT_PREPARE_OK
375-
w.write_u8(0x00)?;
376-
w.write_u32::<LittleEndian>(id)?;
377-
w.write_u16::<LittleEndian>(ci.len() as u16)?;
378-
w.write_u16::<LittleEndian>(pi.len() as u16)?;
379-
w.write_u8(0x00)?;
380-
w.write_u16::<LittleEndian>(0)?; // number of warnings
381-
w.end_packet().await?;
382-
383-
if pi.len() > 0 {
384-
write_column_definitions_41(pi, w, client_capabilities, false).await?;
385-
}
386-
if ci.len() > 0 {
387-
write_column_definitions_41(ci, w, client_capabilities, false).await?;
388-
}
389-
Ok(())
390-
}
391-
392-
/// works for Protocol::ColumnDefinition41 is set
393-
/// see: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset_column_definition.html
394-
pub(crate) async fn write_column_definitions_41<'a, I, W>(
395-
i: I,
396-
w: &mut PacketWriter<W>,
397-
client_capabilities: CapabilityFlags,
398-
is_com_field_list: bool,
399-
) -> io::Result<()>
400-
where
401-
I: IntoIterator<Item = &'a Column>,
402-
W: AsyncWrite + Unpin,
403-
{
404-
for c in i {
405-
w.write_lenenc_str(b"def")?;
406-
w.write_lenenc_str(b"")?;
407-
w.write_lenenc_str(c.table.as_bytes())?;
408-
w.write_lenenc_str(b"")?;
409-
w.write_lenenc_str(c.column.as_bytes())?;
410-
w.write_lenenc_str(b"")?;
411-
w.write_lenenc_int(0xC)?;
412-
w.write_u16::<LittleEndian>(column_charset(c))?;
413-
w.write_u32::<LittleEndian>(1024)?;
414-
w.write_u8(c.coltype as u8)?;
415-
w.write_u16::<LittleEndian>(c.colflags.bits())?;
416-
w.write_all(&[0x00])?; // decimals
417-
w.write_all(&[0x00, 0x00])?; // unused
418-
419-
if is_com_field_list {
420-
w.write_all(&[0xfb])?;
421-
}
422-
w.end_packet().await?;
423-
}
424-
425-
if !client_capabilities.contains(CapabilityFlags::CLIENT_DEPRECATE_EOF) {
426-
write_eof_packet(w, StatusFlags::empty()).await
427-
} else {
428-
Ok(())
429-
}
430-
}
431-
432-
pub(crate) async fn column_definitions<'a, I, W>(
433-
i: I,
434-
w: &mut PacketWriter<W>,
435-
client_capabilities: CapabilityFlags,
436-
) -> io::Result<()>
437-
where
438-
I: IntoIterator<Item = &'a Column>,
439-
<I as IntoIterator>::IntoIter: ExactSizeIterator,
440-
W: AsyncWrite + Unpin,
441-
{
442-
let i = i.into_iter();
443-
w.write_lenenc_int(i.len() as u64)?;
444-
w.end_packet().await?;
445-
write_column_definitions_41(i, w, client_capabilities, false).await
446-
}

mysql/tests/it/async.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,37 @@ impl AsyncMysqlShim<BufWriter<OwnedWriteHalf>> for InitCountingShim {
281281
) -> Result<(), Self::Error> {
282282
if query.eq_ignore_ascii_case("SELECT @@socket")
283283
|| query.eq_ignore_ascii_case("SELECT @@wait_timeout")
284+
|| query.eq_ignore_ascii_case("SELECT @@max_allowed_packet")
284285
{
285286
results.completed(OkResponse::default()).await
287+
} else if query.eq_ignore_ascii_case("SELECT @@max_allowed_packet,@@wait_timeout,@@socket")
288+
{
289+
let columns = [
290+
Column {
291+
table: String::new(),
292+
column: "@@max_allowed_packet".to_owned(),
293+
coltype: myc::constants::ColumnType::MYSQL_TYPE_LONG,
294+
colflags: myc::constants::ColumnFlags::UNSIGNED_FLAG,
295+
},
296+
Column {
297+
table: String::new(),
298+
column: "@@wait_timeout".to_owned(),
299+
coltype: myc::constants::ColumnType::MYSQL_TYPE_LONG,
300+
colflags: myc::constants::ColumnFlags::UNSIGNED_FLAG,
301+
},
302+
Column {
303+
table: String::new(),
304+
column: "@@socket".to_owned(),
305+
coltype: myc::constants::ColumnType::MYSQL_TYPE_VAR_STRING,
306+
colflags: myc::constants::ColumnFlags::empty(),
307+
},
308+
];
309+
let mut row_writer = results.start(&columns).await?;
310+
row_writer.write_col(67108864u32)?;
311+
row_writer.write_col(28800u32)?;
312+
row_writer.write_col(None::<String>)?;
313+
row_writer.end_row().await?;
314+
row_writer.finish().await
286315
} else {
287316
Err(io::Error::new(
288317
io::ErrorKind::Unsupported,

0 commit comments

Comments
 (0)