@@ -838,11 +838,14 @@ impl<'de> MyDeserialize<'de> for ErrPacket<'de> {
838
838
sbuf. parse_unchecked :: < ErrPacketHeader > ( ( ) ) ?;
839
839
let code: RawInt < LeU16 > = sbuf. parse_unchecked ( ( ) ) ?;
840
840
841
- // We assume that CLIENT_PROTOCOL_41 was set
842
841
if * code == 0xFFFF && capabilities. contains ( CapabilityFlags :: CLIENT_PROGRESS_OBSOLETE ) {
843
842
buf. parse ( ( ) ) . map ( ErrPacket :: Progress )
844
843
} else {
845
- buf. parse ( * code) . map ( ErrPacket :: Error )
844
+ buf. parse ( (
845
+ * code,
846
+ capabilities. contains ( CapabilityFlags :: CLIENT_PROTOCOL_41 ) ,
847
+ ) )
848
+ . map ( ErrPacket :: Error )
846
849
}
847
850
}
848
851
}
@@ -872,18 +875,70 @@ impl<'a> fmt::Display for ErrPacket<'a> {
872
875
}
873
876
}
874
877
878
+ define_header ! (
879
+ SqlStateMarker ,
880
+ InvalidSqlStateMarker ( "Invalid SqlStateMarker value" ) ,
881
+ b'#'
882
+ ) ;
883
+
884
+ /// MySql error state.
885
+ #[ derive( Debug , Clone , Copy , Eq , PartialEq , Hash ) ]
886
+ pub struct SqlState {
887
+ __state_marker : SqlStateMarker ,
888
+ state : [ u8 ; 5 ] ,
889
+ }
890
+
891
+ impl SqlState {
892
+ /// Creates new sql state.
893
+ pub fn new ( state : [ u8 ; 5 ] ) -> Self {
894
+ Self {
895
+ __state_marker : SqlStateMarker :: new ( ) ,
896
+ state,
897
+ }
898
+ }
899
+
900
+ /// Returns an sql state as bytes.
901
+ pub fn as_bytes ( & self ) -> [ u8 ; 5 ] {
902
+ self . state
903
+ }
904
+
905
+ /// Returns an sql state as a string (lossy converted).
906
+ pub fn as_str ( & self ) -> Cow < ' _ , str > {
907
+ String :: from_utf8_lossy ( & self . state )
908
+ }
909
+ }
910
+
911
+ impl < ' de > MyDeserialize < ' de > for SqlState {
912
+ const SIZE : Option < usize > = Some ( 6 ) ;
913
+ type Ctx = ( ) ;
914
+
915
+ fn deserialize ( ( ) : Self :: Ctx , buf : & mut ParseBuf < ' de > ) -> io:: Result < Self > {
916
+ Ok ( Self {
917
+ __state_marker : buf. parse ( ( ) ) ?,
918
+ state : buf. parse ( ( ) ) ?,
919
+ } )
920
+ }
921
+ }
922
+
923
+ impl MySerialize for SqlState {
924
+ fn serialize ( & self , buf : & mut Vec < u8 > ) {
925
+ self . __state_marker . serialize ( buf) ;
926
+ self . state . serialize ( buf) ;
927
+ }
928
+ }
929
+
875
930
/// MySql error packet.
876
931
///
877
932
/// May hold an error or a progress report.
878
933
#[ derive( Debug , Clone , PartialEq ) ]
879
934
pub struct ServerError < ' a > {
880
935
code : RawInt < LeU16 > ,
881
- state : [ u8 ; 5 ] ,
936
+ state : Option < SqlState > ,
882
937
message : RawBytes < ' a , EofBytes > ,
883
938
}
884
939
885
940
impl < ' a > ServerError < ' a > {
886
- pub fn new ( code : u16 , state : [ u8 ; 5 ] , msg : impl Into < Cow < ' a , [ u8 ] > > ) -> Self {
941
+ pub fn new ( code : u16 , state : Option < SqlState > , msg : impl Into < Cow < ' a , [ u8 ] > > ) -> Self {
887
942
Self {
888
943
code : RawInt :: new ( code) ,
889
944
state,
@@ -897,13 +952,8 @@ impl<'a> ServerError<'a> {
897
952
}
898
953
899
954
/// Returns an sql state.
900
- pub fn sql_state_ref ( & self ) -> [ u8 ; 5 ] {
901
- self . state
902
- }
903
-
904
- /// Returns an sql state as a string (lossy converted).
905
- pub fn sql_state_str ( & self ) -> Cow < ' _ , str > {
906
- String :: from_utf8_lossy ( & self . state [ ..] )
955
+ pub fn sql_state_ref ( & self ) -> Option < & SqlState > {
956
+ self . state . as_ref ( )
907
957
}
908
958
909
959
/// Returns an error message.
@@ -927,43 +977,48 @@ impl<'a> ServerError<'a> {
927
977
928
978
impl < ' de > MyDeserialize < ' de > for ServerError < ' de > {
929
979
const SIZE : Option < usize > = None ;
930
- /// An error packet error code.
931
- type Ctx = u16 ;
932
-
933
- fn deserialize ( code : Self :: Ctx , buf : & mut ParseBuf < ' de > ) -> io:: Result < Self > {
934
- match buf. 0 [ 0 ] {
935
- b'#' => {
936
- buf. skip ( 1 ) ;
937
- Ok ( ServerError {
938
- code : RawInt :: new ( code) ,
939
- state : buf. parse ( ( ) ) ?,
940
- message : buf. parse ( ( ) ) ?,
941
- } )
980
+ /// An error packet error code + whether CLIENT_PROTOCOL_41 capability was negotiated.
981
+ type Ctx = ( u16 , bool ) ;
982
+
983
+ fn deserialize ( ( code, protocol_41) : Self :: Ctx , buf : & mut ParseBuf < ' de > ) -> io:: Result < Self > {
984
+ let server_error = if protocol_41 {
985
+ ServerError {
986
+ code : RawInt :: new ( code) ,
987
+ state : Some ( buf. parse ( ( ) ) ?) ,
988
+ message : buf. parse ( ( ) ) ?,
942
989
}
943
- _ => Ok ( ServerError {
990
+ } else {
991
+ ServerError {
944
992
code : RawInt :: new ( code) ,
945
- state : * b"HY000" ,
993
+ state : None ,
946
994
message : buf. parse ( ( ) ) ?,
947
- } ) ,
948
- }
995
+ }
996
+ } ;
997
+ Ok ( server_error)
949
998
}
950
999
}
951
1000
952
1001
impl MySerialize for ServerError < ' _ > {
953
1002
fn serialize ( & self , buf : & mut Vec < u8 > ) {
954
- buf. put_u8 ( b'#' ) ;
955
- buf. put_slice ( & self . state [ ..] ) ;
1003
+ if let Some ( state) = & self . state {
1004
+ state. serialize ( buf) ;
1005
+ }
956
1006
self . message . serialize ( buf) ;
957
1007
}
958
1008
}
959
1009
960
1010
impl fmt:: Display for ServerError < ' _ > {
961
1011
fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
1012
+ let sql_state_str = self
1013
+ . sql_state_ref ( )
1014
+ . map ( |s| format ! ( " ({})" , s. as_str( ) ) )
1015
+ . unwrap_or_default ( ) ;
1016
+
962
1017
write ! (
963
1018
f,
964
- "ERROR {} ({}) : {}" ,
1019
+ "ERROR {}{} : {}" ,
965
1020
self . error_code( ) ,
966
- self . sql_state_str( ) ,
1021
+ sql_state_str,
967
1022
self . message_str( )
968
1023
)
969
1024
}
@@ -3570,21 +3625,22 @@ mod test {
3570
3625
\x6f \x6e \x6e \x65 \x63 \x74 \x69 \x6f \x6e \x73 ";
3571
3626
const PROGRESS_PACKET : & [ u8 ] = b"\xff \xff \xff \x01 \x01 \x0a \xcc \x5b \x00 \x0a stage name" ;
3572
3627
3573
- let err_packet =
3574
- ErrPacket :: deserialize ( CapabilityFlags :: empty ( ) , & mut ParseBuf ( ERR_PACKET ) ) . unwrap ( ) ;
3575
- let err_packet = err_packet. server_error ( ) ;
3576
- assert_eq ! ( err_packet. error_code( ) , 1096 ) ;
3577
- assert_eq ! ( err_packet. sql_state_str( ) , "HY000" ) ;
3578
- assert_eq ! ( err_packet. message_str( ) , "No tables used" ) ;
3579
-
3580
3628
let err_packet = ErrPacket :: deserialize (
3581
3629
CapabilityFlags :: CLIENT_PROTOCOL_41 ,
3582
- & mut ParseBuf ( ERR_PACKET_NO_STATE ) ,
3630
+ & mut ParseBuf ( ERR_PACKET ) ,
3583
3631
)
3584
3632
. unwrap ( ) ;
3633
+ let err_packet = err_packet. server_error ( ) ;
3634
+ assert_eq ! ( err_packet. error_code( ) , 1096 ) ;
3635
+ assert_eq ! ( err_packet. sql_state_ref( ) . unwrap( ) . as_str( ) , "HY000" ) ;
3636
+ assert_eq ! ( err_packet. message_str( ) , "No tables used" ) ;
3637
+
3638
+ let err_packet =
3639
+ ErrPacket :: deserialize ( CapabilityFlags :: empty ( ) , & mut ParseBuf ( ERR_PACKET_NO_STATE ) )
3640
+ . unwrap ( ) ;
3585
3641
let server_error = err_packet. server_error ( ) ;
3586
3642
assert_eq ! ( server_error. error_code( ) , 1040 ) ;
3587
- assert_eq ! ( server_error. sql_state_str ( ) , "HY000" ) ;
3643
+ assert_eq ! ( server_error. sql_state_ref ( ) , None ) ;
3588
3644
assert_eq ! ( server_error. message_str( ) , "Too many connections" ) ;
3589
3645
3590
3646
let err_packet = ErrPacket :: deserialize (
0 commit comments