diff --git a/src/client.rs b/src/client.rs index cfe12c0e..15fe21d9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -693,11 +693,11 @@ where let current_shard = query_router.shard(); // Handle all custom protocol commands, if any. - match query_router.try_execute_command(message.clone()) { + match query_router.try_execute_command(&message) { // Normal query, not a custom command. None => { if query_router.query_parser_enabled() { - query_router.infer(message.clone()); + query_router.infer(&message); } } diff --git a/src/errors.rs b/src/errors.rs index 7789a8a7..4ac23a85 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -13,4 +13,5 @@ pub enum Error { TlsError, StatementTimeout, ShuttingDown, + ParseBytesError(String), } diff --git a/src/messages.rs b/src/messages.rs index 45a827c8..0b7ad966 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -7,6 +7,7 @@ use tokio::net::TcpStream; use crate::errors::Error; use std::collections::HashMap; +use std::io::{BufRead, Cursor}; use std::mem; /// Postgres data type mappings @@ -536,3 +537,19 @@ pub fn server_parameter_message(key: &str, value: &str) -> BytesMut { server_info } + +pub trait BytesMutReader { + fn read_string(&mut self) -> Result; +} + +impl BytesMutReader for Cursor<&BytesMut> { + /// Should only be used when reading strings from the message protocol. + /// Can be used to read multiple strings from the same message which are separated by the null byte + fn read_string(&mut self) -> Result { + let mut buf = vec![]; + match self.read_until(b'\0', &mut buf) { + Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()), + Err(err) => return Err(Error::ParseBytesError(err.to_string())), + } + } +} diff --git a/src/query_router.rs b/src/query_router.rs index 03f46019..9f9dcd76 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -10,10 +10,12 @@ use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; use crate::config::Role; +use crate::messages::BytesMutReader; use crate::pool::PoolSettings; use crate::sharding::Sharder; use std::collections::BTreeSet; +use std::io::Cursor; /// Regexes used to parse custom commands. const CUSTOM_SQL_REGEXES: [&str; 7] = [ @@ -107,16 +109,18 @@ impl QueryRouter { } /// Try to parse a command and execute it. - pub fn try_execute_command(&mut self, mut buf: BytesMut) -> Option<(Command, String)> { - let code = buf.get_u8() as char; + pub fn try_execute_command(&mut self, message_buffer: &BytesMut) -> Option<(Command, String)> { + let mut message_cursor = Cursor::new(message_buffer); + + let code = message_cursor.get_u8() as char; // Only simple protocol supported for commands. if code != 'Q' { return None; } - let len = buf.get_i32() as usize; - let query = String::from_utf8_lossy(&buf[..len - 5]).to_string(); // Ignore the terminating NULL. + let _len = message_cursor.get_i32() as usize; + let query = message_cursor.read_string().unwrap(); let regex_set = match CUSTOM_SQL_REGEX_SET.get() { Some(regex_set) => regex_set, @@ -256,37 +260,29 @@ impl QueryRouter { } /// Try to infer which server to connect to based on the contents of the query. - pub fn infer(&mut self, mut buf: BytesMut) -> bool { + pub fn infer(&mut self, message_buffer: &BytesMut) -> bool { debug!("Inferring role"); - let code = buf.get_u8() as char; - let len = buf.get_i32() as usize; + let mut message_cursor = Cursor::new(message_buffer); + + let code = message_cursor.get_u8() as char; + let _len = message_cursor.get_i32() as usize; let query = match code { // Query 'Q' => { - let query = String::from_utf8_lossy(&buf[..len - 5]).to_string(); + let query = message_cursor.read_string().unwrap(); debug!("Query: '{}'", query); query } // Parse (prepared statement) 'P' => { - let mut start = 0; - - // Skip the name of the prepared statement. - while buf[start] != 0 && start < buf.len() { - start += 1; - } - start += 1; // Skip terminating null - - // Find the end of the prepared stmt (\0) - let mut end = start; - while buf[end] != 0 && end < buf.len() { - end += 1; - } + // Reads statement name + message_cursor.read_string().unwrap(); - let query = String::from_utf8_lossy(&buf[start..end]).to_string(); + // Reads query string + let query = message_cursor.read_string().unwrap(); debug!("Prepared statement: '{}'", query); @@ -519,10 +515,10 @@ mod test { fn test_infer_replica() { QueryRouter::setup(); let mut qr = QueryRouter::new(); - assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None); + assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None); assert!(qr.query_parser_enabled()); - assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); + assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); let queries = vec![ simple_query("SELECT * FROM items WHERE id = 5"), @@ -534,7 +530,7 @@ mod test { for query in queries { // It's a recognized query - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), Some(Role::Replica)); } } @@ -553,7 +549,7 @@ mod test { for query in queries { // It's a recognized query - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), Some(Role::Primary)); } } @@ -563,9 +559,9 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); let query = simple_query("SELECT * FROM items WHERE id = 5"); - assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None); + assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None); - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), None); } @@ -573,8 +569,8 @@ mod test { fn test_infer_parse_prepared() { QueryRouter::setup(); let mut qr = QueryRouter::new(); - qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")); - assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); + qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")); + assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); let prepared_stmt = BytesMut::from( &b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..], @@ -585,7 +581,7 @@ mod test { res.put(prepared_stmt); res.put_i16(0); - assert!(qr.infer(res)); + assert!(qr.infer(&res)); assert_eq!(qr.role(), Some(Role::Replica)); } @@ -668,7 +664,7 @@ mod test { // SetShardingKey let query = simple_query("SET SHARDING KEY TO 13"); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::SetShardingKey, String::from("0"))) ); assert_eq!(qr.shard(), 0); @@ -676,7 +672,7 @@ mod test { // SetShard let query = simple_query("SET SHARD TO '1'"); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::SetShard, String::from("1"))) ); assert_eq!(qr.shard(), 1); @@ -684,7 +680,7 @@ mod test { // ShowShard let query = simple_query("SHOW SHARD"); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::ShowShard, String::from("1"))) ); @@ -702,7 +698,7 @@ mod test { for (idx, role) in roles.iter().enumerate() { let query = simple_query(&format!("SET SERVER ROLE TO '{}'", role)); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::SetServerRole, String::from(*role))) ); assert_eq!(qr.role(), verify_roles[idx],); @@ -711,7 +707,7 @@ mod test { // ShowServerRole let query = simple_query("SHOW SERVER ROLE"); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::ShowServerRole, String::from(*role))) ); } @@ -721,14 +717,14 @@ mod test { for (idx, primary_reads) in primary_reads.iter().enumerate() { assert_eq!( - qr.try_execute_command(simple_query(&format!( + qr.try_execute_command(&simple_query(&format!( "SET PRIMARY READS TO {}", primary_reads ))), Some((Command::SetPrimaryReads, String::from(*primary_reads))) ); assert_eq!( - qr.try_execute_command(simple_query("SHOW PRIMARY READS")), + qr.try_execute_command(&simple_query("SHOW PRIMARY READS")), Some(( Command::ShowPrimaryReads, String::from(primary_reads_enabled[idx]) @@ -742,23 +738,23 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); let query = simple_query("SET SERVER ROLE TO 'auto'"); - assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); + assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); - assert!(qr.try_execute_command(query) != None); + assert!(qr.try_execute_command(&query) != None); assert!(qr.query_parser_enabled()); assert_eq!(qr.role(), None); let query = simple_query("INSERT INTO test_table VALUES (1)"); - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), Some(Role::Primary)); let query = simple_query("SELECT * FROM test_table"); - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), Some(Role::Replica)); assert!(qr.query_parser_enabled()); let query = simple_query("SET SERVER ROLE TO 'default'"); - assert!(qr.try_execute_command(query) != None); + assert!(qr.try_execute_command(&query) != None); assert!(!qr.query_parser_enabled()); } @@ -794,16 +790,16 @@ mod test { assert!(!qr.primary_reads_enabled()); let q1 = simple_query("SET SERVER ROLE TO 'primary'"); - assert!(qr.try_execute_command(q1) != None); + assert!(qr.try_execute_command(&q1) != None); assert_eq!(qr.active_role.unwrap(), Role::Primary); let q2 = simple_query("SET SERVER ROLE TO 'default'"); - assert!(qr.try_execute_command(q2) != None); + assert!(qr.try_execute_command(&q2) != None); assert_eq!(qr.active_role.unwrap(), pool_settings.default_role); // Here we go :) let q3 = simple_query("SELECT * FROM test WHERE id = 5 AND values IN (1, 2, 3)"); - assert!(qr.infer(q3)); + assert!(qr.infer(&q3)); assert_eq!(qr.shard(), 1); } @@ -812,13 +808,13 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); - assert!(qr.infer(simple_query("BEGIN; SELECT 1; COMMIT;"))); + assert!(qr.infer(&simple_query("BEGIN; SELECT 1; COMMIT;"))); assert_eq!(qr.role(), Role::Primary); - assert!(qr.infer(simple_query("SELECT 1; SELECT 2;"))); + assert!(qr.infer(&simple_query("SELECT 1; SELECT 2;"))); assert_eq!(qr.role(), Role::Replica); - assert!(qr.infer(simple_query( + assert!(qr.infer(&simple_query( "SELECT 123; INSERT INTO t VALUES (5); SELECT 1;" ))); assert_eq!(qr.role(), Role::Primary); diff --git a/tests/ruby/load_balancing_spec.rb b/tests/ruby/load_balancing_spec.rb index 8be066df..fccf0a85 100644 --- a/tests/ruby/load_balancing_spec.rb +++ b/tests/ruby/load_balancing_spec.rb @@ -93,7 +93,7 @@ threads = Array.new(slow_query_count) do Thread.new do conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) - conn.async_exec("SELECT pg_sleep(1)") + conn.async_exec("BEGIN") end end