Skip to content

Commit

Permalink
Fixes try_execute_command message parsing bug (#560)
Browse files Browse the repository at this point in the history
* Fixes try_execute_command message parsing bug

* Fix initial segment logic

* Add test
  • Loading branch information
zainkabani authored Aug 24, 2023
1 parent 4301ab0 commit be549f3
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions src/query_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ use crate::plugins::{Intercept, Plugin, PluginOutput, QueryLogger, TableAccess};
use crate::pool::PoolSettings;
use crate::sharding::Sharder;

use std::cmp;
use std::collections::BTreeSet;
use std::io::Cursor;
use std::{cmp, mem};

/// Regexes used to parse custom commands.
const CUSTOM_SQL_REGEXES: [&str; 7] = [
Expand Down Expand Up @@ -141,6 +141,7 @@ impl QueryRouter {
let mut message_cursor = Cursor::new(message_buffer);

let code = message_cursor.get_u8() as char;
let len = message_cursor.get_i32() as usize;

// Check for any sharding regex matches in any queries
match code as char {
Expand All @@ -150,9 +151,13 @@ impl QueryRouter {
|| self.pool_settings.sharding_key_regex.is_some()
{
// Check only the first block of bytes configured by the pool settings
let len = message_cursor.get_i32() as usize;
let seg = cmp::min(len - 5, self.pool_settings.regex_search_limit);
let initial_segment = String::from_utf8_lossy(&message_buffer[0..seg]);

let query_start_index = mem::size_of::<u8>() + mem::size_of::<i32>();

let initial_segment = String::from_utf8_lossy(
&message_buffer[query_start_index..query_start_index + seg],
);

// Check for a shard_id included in the query
if let Some(shard_id_regex) = &self.pool_settings.shard_id_regex {
Expand Down Expand Up @@ -192,7 +197,6 @@ impl QueryRouter {
return None;
}

let _len = message_cursor.get_i32() as usize;
let query = message_cursor.read_string().unwrap();

let regex_set = match CUSTOM_SQL_REGEX_SET.get() {
Expand Down Expand Up @@ -1291,6 +1295,11 @@ mod test {
// Shard should start out unset
assert_eq!(qr.active_shard, None);

// Don't panic when short query eg. ; is sent
let q0 = simple_query(";");
assert!(qr.try_execute_command(&q0) == None);
assert_eq!(qr.active_shard, None);

// Make sure setting it works
let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;");
assert!(qr.try_execute_command(&q1) == None);
Expand Down

0 comments on commit be549f3

Please sign in to comment.