Skip to content

Commit d9d512f

Browse files
authored
Merge pull request blackbeam#114 from Will-Low/master
Make SQLSTATE optional in ServerError
2 parents 76bd713 + 8990ff2 commit d9d512f

File tree

1 file changed

+96
-40
lines changed

1 file changed

+96
-40
lines changed

src/packets/mod.rs

Lines changed: 96 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -838,11 +838,14 @@ impl<'de> MyDeserialize<'de> for ErrPacket<'de> {
838838
sbuf.parse_unchecked::<ErrPacketHeader>(())?;
839839
let code: RawInt<LeU16> = sbuf.parse_unchecked(())?;
840840

841-
// We assume that CLIENT_PROTOCOL_41 was set
842841
if *code == 0xFFFF && capabilities.contains(CapabilityFlags::CLIENT_PROGRESS_OBSOLETE) {
843842
buf.parse(()).map(ErrPacket::Progress)
844843
} else {
845-
buf.parse(*code).map(ErrPacket::Error)
844+
buf.parse((
845+
*code,
846+
capabilities.contains(CapabilityFlags::CLIENT_PROTOCOL_41),
847+
))
848+
.map(ErrPacket::Error)
846849
}
847850
}
848851
}
@@ -872,18 +875,70 @@ impl<'a> fmt::Display for ErrPacket<'a> {
872875
}
873876
}
874877

878+
define_header!(
879+
SqlStateMarker,
880+
InvalidSqlStateMarker("Invalid SqlStateMarker value"),
881+
b'#'
882+
);
883+
884+
/// MySql error state.
885+
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
886+
pub struct SqlState {
887+
__state_marker: SqlStateMarker,
888+
state: [u8; 5],
889+
}
890+
891+
impl SqlState {
892+
/// Creates new sql state.
893+
pub fn new(state: [u8; 5]) -> Self {
894+
Self {
895+
__state_marker: SqlStateMarker::new(),
896+
state,
897+
}
898+
}
899+
900+
/// Returns an sql state as bytes.
901+
pub fn as_bytes(&self) -> [u8; 5] {
902+
self.state
903+
}
904+
905+
/// Returns an sql state as a string (lossy converted).
906+
pub fn as_str(&self) -> Cow<'_, str> {
907+
String::from_utf8_lossy(&self.state)
908+
}
909+
}
910+
911+
impl<'de> MyDeserialize<'de> for SqlState {
912+
const SIZE: Option<usize> = Some(6);
913+
type Ctx = ();
914+
915+
fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
916+
Ok(Self {
917+
__state_marker: buf.parse(())?,
918+
state: buf.parse(())?,
919+
})
920+
}
921+
}
922+
923+
impl MySerialize for SqlState {
924+
fn serialize(&self, buf: &mut Vec<u8>) {
925+
self.__state_marker.serialize(buf);
926+
self.state.serialize(buf);
927+
}
928+
}
929+
875930
/// MySql error packet.
876931
///
877932
/// May hold an error or a progress report.
878933
#[derive(Debug, Clone, PartialEq)]
879934
pub struct ServerError<'a> {
880935
code: RawInt<LeU16>,
881-
state: [u8; 5],
936+
state: Option<SqlState>,
882937
message: RawBytes<'a, EofBytes>,
883938
}
884939

885940
impl<'a> ServerError<'a> {
886-
pub fn new(code: u16, state: [u8; 5], msg: impl Into<Cow<'a, [u8]>>) -> Self {
941+
pub fn new(code: u16, state: Option<SqlState>, msg: impl Into<Cow<'a, [u8]>>) -> Self {
887942
Self {
888943
code: RawInt::new(code),
889944
state,
@@ -897,13 +952,8 @@ impl<'a> ServerError<'a> {
897952
}
898953

899954
/// Returns an sql state.
900-
pub fn sql_state_ref(&self) -> [u8; 5] {
901-
self.state
902-
}
903-
904-
/// Returns an sql state as a string (lossy converted).
905-
pub fn sql_state_str(&self) -> Cow<'_, str> {
906-
String::from_utf8_lossy(&self.state[..])
955+
pub fn sql_state_ref(&self) -> Option<&SqlState> {
956+
self.state.as_ref()
907957
}
908958

909959
/// Returns an error message.
@@ -927,43 +977,48 @@ impl<'a> ServerError<'a> {
927977

928978
impl<'de> MyDeserialize<'de> for ServerError<'de> {
929979
const SIZE: Option<usize> = None;
930-
/// An error packet error code.
931-
type Ctx = u16;
932-
933-
fn deserialize(code: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
934-
match buf.0[0] {
935-
b'#' => {
936-
buf.skip(1);
937-
Ok(ServerError {
938-
code: RawInt::new(code),
939-
state: buf.parse(())?,
940-
message: buf.parse(())?,
941-
})
980+
/// An error packet error code + whether CLIENT_PROTOCOL_41 capability was negotiated.
981+
type Ctx = (u16, bool);
982+
983+
fn deserialize((code, protocol_41): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
984+
let server_error = if protocol_41 {
985+
ServerError {
986+
code: RawInt::new(code),
987+
state: Some(buf.parse(())?),
988+
message: buf.parse(())?,
942989
}
943-
_ => Ok(ServerError {
990+
} else {
991+
ServerError {
944992
code: RawInt::new(code),
945-
state: *b"HY000",
993+
state: None,
946994
message: buf.parse(())?,
947-
}),
948-
}
995+
}
996+
};
997+
Ok(server_error)
949998
}
950999
}
9511000

9521001
impl MySerialize for ServerError<'_> {
9531002
fn serialize(&self, buf: &mut Vec<u8>) {
954-
buf.put_u8(b'#');
955-
buf.put_slice(&self.state[..]);
1003+
if let Some(state) = &self.state {
1004+
state.serialize(buf);
1005+
}
9561006
self.message.serialize(buf);
9571007
}
9581008
}
9591009

9601010
impl fmt::Display for ServerError<'_> {
9611011
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1012+
let sql_state_str = self
1013+
.sql_state_ref()
1014+
.map(|s| format!(" ({})", s.as_str()))
1015+
.unwrap_or_default();
1016+
9621017
write!(
9631018
f,
964-
"ERROR {} ({}): {}",
1019+
"ERROR {}{}: {}",
9651020
self.error_code(),
966-
self.sql_state_str(),
1021+
sql_state_str,
9671022
self.message_str()
9681023
)
9691024
}
@@ -3570,21 +3625,22 @@ mod test {
35703625
\x6f\x6e\x6e\x65\x63\x74\x69\x6f\x6e\x73";
35713626
const PROGRESS_PACKET: &[u8] = b"\xff\xff\xff\x01\x01\x0a\xcc\x5b\x00\x0astage name";
35723627

3573-
let err_packet =
3574-
ErrPacket::deserialize(CapabilityFlags::empty(), &mut ParseBuf(ERR_PACKET)).unwrap();
3575-
let err_packet = err_packet.server_error();
3576-
assert_eq!(err_packet.error_code(), 1096);
3577-
assert_eq!(err_packet.sql_state_str(), "HY000");
3578-
assert_eq!(err_packet.message_str(), "No tables used");
3579-
35803628
let err_packet = ErrPacket::deserialize(
35813629
CapabilityFlags::CLIENT_PROTOCOL_41,
3582-
&mut ParseBuf(ERR_PACKET_NO_STATE),
3630+
&mut ParseBuf(ERR_PACKET),
35833631
)
35843632
.unwrap();
3633+
let err_packet = err_packet.server_error();
3634+
assert_eq!(err_packet.error_code(), 1096);
3635+
assert_eq!(err_packet.sql_state_ref().unwrap().as_str(), "HY000");
3636+
assert_eq!(err_packet.message_str(), "No tables used");
3637+
3638+
let err_packet =
3639+
ErrPacket::deserialize(CapabilityFlags::empty(), &mut ParseBuf(ERR_PACKET_NO_STATE))
3640+
.unwrap();
35853641
let server_error = err_packet.server_error();
35863642
assert_eq!(server_error.error_code(), 1040);
3587-
assert_eq!(server_error.sql_state_str(), "HY000");
3643+
assert_eq!(server_error.sql_state_ref(), None);
35883644
assert_eq!(server_error.message_str(), "Too many connections");
35893645

35903646
let err_packet = ErrPacket::deserialize(

0 commit comments

Comments
 (0)