Skip to content

Commit e9a87a8

Browse files
committed
feat(mysql): encode OK packet info when session tracking
- Advertise CLIENT_SESSION_TRACK so the server side negotiates session state with modern clients. - Ensure write_ok_packet always length-encodes the info field when session tracking, CLIENT_DEPRECATE_EOF, or 0xFE OK headers are used, avoiding MySQL JDBC payload mismatches. - Add coverage that captures raw OK packets across capability combinations to assert the expected encoding.
1 parent 9737d2b commit e9a87a8

File tree

4 files changed

+213
-1
lines changed

4 files changed

+213
-1
lines changed

mysql/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ where
310310
| CapabilityFlags::CLIENT_PLUGIN_AUTH
311311
| CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
312312
| CapabilityFlags::CLIENT_CONNECT_WITH_DB
313+
| CapabilityFlags::CLIENT_SESSION_TRACK
313314
| CapabilityFlags::CLIENT_DEPRECATE_EOF;
314315

315316
#[cfg(feature = "tls")]

mysql/src/tests/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414

1515
mod commands;
1616
mod packet;
17+
mod writers;
1718
mod value;

mysql/src/tests/writers.rs

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
// Copyright 2021 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// Note to developers: you can find decent overviews of the protocol at
16+
//
17+
// https://github.com/cwarden/mysql-proxy/blob/master/doc/protocol.rst
18+
//
19+
// and
20+
//
21+
// https://mariadb.com/kb/en/library/clientserver-protocol/
22+
//
23+
// Wireshark also does a pretty good job at parsing the MySQL protocol.
24+
25+
use tokio::io::{duplex, AsyncReadExt};
26+
27+
use crate::packet_writer::PacketWriter;
28+
use crate::writers::write_ok_packet;
29+
use crate::{CapabilityFlags, OkResponse};
30+
31+
async fn capture_ok_payload(info: &str, capabilities: CapabilityFlags, header: u8) -> Vec<u8> {
32+
let (mut client, server) = duplex(1024);
33+
let mut writer = PacketWriter::new(server);
34+
35+
let ok_packet = OkResponse {
36+
header,
37+
info: info.to_string(),
38+
..Default::default()
39+
};
40+
41+
write_ok_packet(&mut writer, capabilities, ok_packet)
42+
.await
43+
.expect("write_ok_packet succeeds");
44+
45+
let mut header_buf = [0u8; 4];
46+
client
47+
.read_exact(&mut header_buf)
48+
.await
49+
.expect("payload header available");
50+
let payload_len =
51+
(header_buf[0] as usize) | ((header_buf[1] as usize) << 8) | ((header_buf[2] as usize) << 16);
52+
let mut payload = vec![0u8; payload_len];
53+
client
54+
.read_exact(&mut payload)
55+
.await
56+
.expect("payload body available");
57+
payload
58+
}
59+
60+
fn parse_lenenc_int(data: &[u8]) -> (u64, usize) {
61+
match data[0] {
62+
0xFC => {
63+
let len = u16::from_le_bytes([data[1], data[2]]) as u64;
64+
(len, 3)
65+
}
66+
0xFD => {
67+
let len = (data[1] as u64) | ((data[2] as u64) << 8) | ((data[3] as u64) << 16);
68+
(len, 4)
69+
}
70+
0xFE => {
71+
let mut buf = [0u8; 8];
72+
buf.copy_from_slice(&data[1..9]);
73+
(u64::from_le_bytes(buf), 9)
74+
}
75+
v => (v as u64, 1),
76+
}
77+
}
78+
79+
fn consume_ok_prefix(payload: &[u8]) -> (usize, u8, u16, u16) {
80+
let mut idx = 0;
81+
let header = payload[idx];
82+
idx += 1;
83+
84+
let (affected_rows, consumed) = parse_lenenc_int(&payload[idx..]);
85+
assert_eq!(affected_rows, 0);
86+
idx += consumed;
87+
88+
let (last_insert_id, consumed) = parse_lenenc_int(&payload[idx..]);
89+
assert_eq!(last_insert_id, 0);
90+
idx += consumed;
91+
92+
let status = u16::from_le_bytes([payload[idx], payload[idx + 1]]);
93+
idx += 2;
94+
95+
let warnings = u16::from_le_bytes([payload[idx], payload[idx + 1]]);
96+
idx += 2;
97+
98+
(idx, header, status, warnings)
99+
}
100+
101+
#[tokio::test]
102+
async fn ok_packet_info_lenenc_when_session_track() {
103+
let info = "Read 1 rows, 1.00 B in 0.007 sec.";
104+
let payload = capture_ok_payload(
105+
info,
106+
CapabilityFlags::CLIENT_PROTOCOL_41 | CapabilityFlags::CLIENT_SESSION_TRACK,
107+
0x00,
108+
)
109+
.await;
110+
111+
let (mut idx, header, status, warnings) = consume_ok_prefix(&payload);
112+
assert_eq!(header, 0x00);
113+
assert_eq!(status, 0);
114+
assert_eq!(warnings, 0);
115+
116+
let (info_len, consumed) = parse_lenenc_int(&payload[idx..]);
117+
assert_eq!(info_len as usize, info.len());
118+
idx += consumed;
119+
120+
let encoded = &payload[idx..idx + info.len()];
121+
assert_eq!(encoded, info.as_bytes());
122+
}
123+
124+
#[tokio::test]
125+
async fn ok_packet_info_lenenc_when_deprecate_eof() {
126+
let info = "Read 1 rows, 1.00 B in 0.007 sec.";
127+
let payload = capture_ok_payload(
128+
info,
129+
CapabilityFlags::CLIENT_PROTOCOL_41 | CapabilityFlags::CLIENT_DEPRECATE_EOF,
130+
0x00,
131+
)
132+
.await;
133+
134+
let (mut idx, header, status, warnings) = consume_ok_prefix(&payload);
135+
assert_eq!(header, 0x00);
136+
assert_eq!(status, 0);
137+
assert_eq!(warnings, 0);
138+
139+
let (info_len, consumed) = parse_lenenc_int(&payload[idx..]);
140+
assert_eq!(info_len as usize, info.len());
141+
idx += consumed;
142+
143+
let encoded = &payload[idx..idx + info.len()];
144+
assert_eq!(encoded, info.as_bytes());
145+
}
146+
147+
#[tokio::test]
148+
async fn ok_packet_info_lenenc_when_header_fe() {
149+
let info = "Read 1 rows, 1.00 B in 0.007 sec.";
150+
let payload = capture_ok_payload(
151+
info,
152+
CapabilityFlags::CLIENT_PROTOCOL_41,
153+
0xfe,
154+
)
155+
.await;
156+
157+
let (mut idx, header, status, warnings) = consume_ok_prefix(&payload);
158+
assert_eq!(header, 0xfe);
159+
assert_eq!(status, 0);
160+
assert_eq!(warnings, 0);
161+
162+
let (info_len, consumed) = parse_lenenc_int(&payload[idx..]);
163+
assert_eq!(info_len as usize, info.len());
164+
idx += consumed;
165+
166+
let encoded = &payload[idx..idx + info.len()];
167+
assert_eq!(encoded, info.as_bytes());
168+
}
169+
170+
#[tokio::test]
171+
async fn ok_packet_info_plain_when_no_flags() {
172+
let info = "Read 1 rows, 1.00 B in 0.007 sec.";
173+
let payload = capture_ok_payload(info, CapabilityFlags::CLIENT_PROTOCOL_41, 0x00).await;
174+
175+
let (idx, header, status, warnings) = consume_ok_prefix(&payload);
176+
assert_eq!(header, 0x00);
177+
assert_eq!(status, 0);
178+
assert_eq!(warnings, 0);
179+
180+
let encoded = &payload[idx..];
181+
assert_eq!(encoded, info.as_bytes());
182+
}
183+
184+
#[tokio::test]
185+
async fn ok_packet_info_extended_lenenc_with_flags() {
186+
let info = "x".repeat(300);
187+
let payload = capture_ok_payload(
188+
&info,
189+
CapabilityFlags::CLIENT_PROTOCOL_41 | CapabilityFlags::CLIENT_SESSION_TRACK,
190+
0x00,
191+
)
192+
.await;
193+
194+
let (mut idx, header, status, warnings) = consume_ok_prefix(&payload);
195+
assert_eq!(header, 0x00);
196+
assert_eq!(status, 0);
197+
assert_eq!(warnings, 0);
198+
199+
let (info_len, consumed) = parse_lenenc_int(&payload[idx..]);
200+
assert_eq!(consumed, 3); // expect 0xFC marker with two-byte length
201+
assert_eq!(payload[idx], 0xFC);
202+
assert_eq!(info_len as usize, info.len());
203+
idx += consumed;
204+
205+
let encoded = &payload[idx..idx + info.len()];
206+
assert_eq!(encoded, info.as_bytes());
207+
}

mysql/src/writers.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,12 @@ pub(crate) async fn write_ok_packet<W: AsyncWrite + Unpin>(
8282

8383
// Only session-tracking clients expect length-encoded info per protocol; otherwise emit raw text.
8484
let has_session_track = client_capabilities.contains(CapabilityFlags::CLIENT_SESSION_TRACK);
85+
let expect_lenenc_info = has_session_track
86+
|| ok_packet.header == 0xfe
87+
|| client_capabilities.contains(CapabilityFlags::CLIENT_DEPRECATE_EOF);
8588
let send_info = !ok_packet.info.is_empty() || has_session_track;
8689
if send_info {
87-
if has_session_track {
90+
if expect_lenenc_info {
8891
w.write_lenenc_str(ok_packet.info.as_bytes())?;
8992
} else {
9093
w.write_all(ok_packet.info.as_bytes())?;

0 commit comments

Comments
 (0)