Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removes message cloning operation required for query router #285

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -693,11 +693,11 @@ where
let current_shard = query_router.shard();

// Handle all custom protocol commands, if any.
match query_router.try_execute_command(message.clone()) {
match query_router.try_execute_command(&message) {
// Normal query, not a custom command.
None => {
if query_router.query_parser_enabled() {
query_router.infer(message.clone());
query_router.infer(&message);
}
}

Expand Down
1 change: 1 addition & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ pub enum Error {
TlsError,
StatementTimeout,
ShuttingDown,
ParseBytesError(String),
}
17 changes: 17 additions & 0 deletions src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use tokio::net::TcpStream;

use crate::errors::Error;
use std::collections::HashMap;
use std::io::{BufRead, Cursor};
use std::mem;

/// Postgres data type mappings
Expand Down Expand Up @@ -536,3 +537,19 @@ pub fn server_parameter_message(key: &str, value: &str) -> BytesMut {

server_info
}

pub trait BytesMutReader {
fn read_string(&mut self) -> Result<String, Error>;
}

impl BytesMutReader for Cursor<&BytesMut> {
/// Should only be used when reading strings from the message protocol.
/// Can be used to read multiple strings from the same message which are separated by the null byte
fn read_string(&mut self) -> Result<String, Error> {
let mut buf = vec![];
match self.read_until(b'\0', &mut buf) {
Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()),
Err(err) => return Err(Error::ParseBytesError(err.to_string())),
}
}
}
94 changes: 45 additions & 49 deletions src/query_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;

use crate::config::Role;
use crate::messages::BytesMutReader;
use crate::pool::PoolSettings;
use crate::sharding::Sharder;

use std::collections::BTreeSet;
use std::io::Cursor;

/// Regexes used to parse custom commands.
const CUSTOM_SQL_REGEXES: [&str; 7] = [
Expand Down Expand Up @@ -107,16 +109,18 @@ impl QueryRouter {
}

/// Try to parse a command and execute it.
pub fn try_execute_command(&mut self, mut buf: BytesMut) -> Option<(Command, String)> {
let code = buf.get_u8() as char;
pub fn try_execute_command(&mut self, message_buffer: &BytesMut) -> Option<(Command, String)> {
let mut message_cursor = Cursor::new(message_buffer);

let code = message_cursor.get_u8() as char;

// Only simple protocol supported for commands.
if code != 'Q' {
return None;
}

let len = buf.get_i32() as usize;
let query = String::from_utf8_lossy(&buf[..len - 5]).to_string(); // Ignore the terminating NULL.
let _len = message_cursor.get_i32() as usize;
let query = message_cursor.read_string().unwrap();

let regex_set = match CUSTOM_SQL_REGEX_SET.get() {
Some(regex_set) => regex_set,
Expand Down Expand Up @@ -256,37 +260,29 @@ impl QueryRouter {
}

/// Try to infer which server to connect to based on the contents of the query.
pub fn infer(&mut self, mut buf: BytesMut) -> bool {
pub fn infer(&mut self, message_buffer: &BytesMut) -> bool {
debug!("Inferring role");

let code = buf.get_u8() as char;
let len = buf.get_i32() as usize;
let mut message_cursor = Cursor::new(message_buffer);

let code = message_cursor.get_u8() as char;
let _len = message_cursor.get_i32() as usize;

let query = match code {
// Query
'Q' => {
let query = String::from_utf8_lossy(&buf[..len - 5]).to_string();
let query = message_cursor.read_string().unwrap();
debug!("Query: '{}'", query);
query
}

// Parse (prepared statement)
'P' => {
let mut start = 0;

// Skip the name of the prepared statement.
while buf[start] != 0 && start < buf.len() {
start += 1;
}
start += 1; // Skip terminating null

// Find the end of the prepared stmt (\0)
let mut end = start;
while buf[end] != 0 && end < buf.len() {
end += 1;
}
// Reads statement name
message_cursor.read_string().unwrap();

let query = String::from_utf8_lossy(&buf[start..end]).to_string();
// Reads query string
let query = message_cursor.read_string().unwrap();

debug!("Prepared statement: '{}'", query);

Expand Down Expand Up @@ -519,10 +515,10 @@ mod test {
fn test_infer_replica() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None);
assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None);
assert!(qr.query_parser_enabled());

assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);

let queries = vec![
simple_query("SELECT * FROM items WHERE id = 5"),
Expand All @@ -534,7 +530,7 @@ mod test {

for query in queries {
// It's a recognized query
assert!(qr.infer(query));
assert!(qr.infer(&query));
assert_eq!(qr.role(), Some(Role::Replica));
}
}
Expand All @@ -553,7 +549,7 @@ mod test {

for query in queries {
// It's a recognized query
assert!(qr.infer(query));
assert!(qr.infer(&query));
assert_eq!(qr.role(), Some(Role::Primary));
}
}
Expand All @@ -563,18 +559,18 @@ mod test {
QueryRouter::setup();
let mut qr = QueryRouter::new();
let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None);
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);

assert!(qr.infer(query));
assert!(qr.infer(&query));
assert_eq!(qr.role(), None);
}

#[test]
fn test_infer_parse_prepared() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'"));
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);
qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"));
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);

let prepared_stmt = BytesMut::from(
&b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..],
Expand All @@ -585,7 +581,7 @@ mod test {
res.put(prepared_stmt);
res.put_i16(0);

assert!(qr.infer(res));
assert!(qr.infer(&res));
assert_eq!(qr.role(), Some(Role::Replica));
}

Expand Down Expand Up @@ -668,23 +664,23 @@ mod test {
// SetShardingKey
let query = simple_query("SET SHARDING KEY TO 13");
assert_eq!(
qr.try_execute_command(query),
qr.try_execute_command(&query),
Some((Command::SetShardingKey, String::from("0")))
);
assert_eq!(qr.shard(), 0);

// SetShard
let query = simple_query("SET SHARD TO '1'");
assert_eq!(
qr.try_execute_command(query),
qr.try_execute_command(&query),
Some((Command::SetShard, String::from("1")))
);
assert_eq!(qr.shard(), 1);

// ShowShard
let query = simple_query("SHOW SHARD");
assert_eq!(
qr.try_execute_command(query),
qr.try_execute_command(&query),
Some((Command::ShowShard, String::from("1")))
);

Expand All @@ -702,7 +698,7 @@ mod test {
for (idx, role) in roles.iter().enumerate() {
let query = simple_query(&format!("SET SERVER ROLE TO '{}'", role));
assert_eq!(
qr.try_execute_command(query),
qr.try_execute_command(&query),
Some((Command::SetServerRole, String::from(*role)))
);
assert_eq!(qr.role(), verify_roles[idx],);
Expand All @@ -711,7 +707,7 @@ mod test {
// ShowServerRole
let query = simple_query("SHOW SERVER ROLE");
assert_eq!(
qr.try_execute_command(query),
qr.try_execute_command(&query),
Some((Command::ShowServerRole, String::from(*role)))
);
}
Expand All @@ -721,14 +717,14 @@ mod test {

for (idx, primary_reads) in primary_reads.iter().enumerate() {
assert_eq!(
qr.try_execute_command(simple_query(&format!(
qr.try_execute_command(&simple_query(&format!(
"SET PRIMARY READS TO {}",
primary_reads
))),
Some((Command::SetPrimaryReads, String::from(*primary_reads)))
);
assert_eq!(
qr.try_execute_command(simple_query("SHOW PRIMARY READS")),
qr.try_execute_command(&simple_query("SHOW PRIMARY READS")),
Some((
Command::ShowPrimaryReads,
String::from(primary_reads_enabled[idx])
Expand All @@ -742,23 +738,23 @@ mod test {
QueryRouter::setup();
let mut qr = QueryRouter::new();
let query = simple_query("SET SERVER ROLE TO 'auto'");
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);

assert!(qr.try_execute_command(query) != None);
assert!(qr.try_execute_command(&query) != None);
assert!(qr.query_parser_enabled());
assert_eq!(qr.role(), None);

let query = simple_query("INSERT INTO test_table VALUES (1)");
assert!(qr.infer(query));
assert!(qr.infer(&query));
assert_eq!(qr.role(), Some(Role::Primary));

let query = simple_query("SELECT * FROM test_table");
assert!(qr.infer(query));
assert!(qr.infer(&query));
assert_eq!(qr.role(), Some(Role::Replica));

assert!(qr.query_parser_enabled());
let query = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(query) != None);
assert!(qr.try_execute_command(&query) != None);
assert!(!qr.query_parser_enabled());
}

Expand Down Expand Up @@ -794,16 +790,16 @@ mod test {
assert!(!qr.primary_reads_enabled());

let q1 = simple_query("SET SERVER ROLE TO 'primary'");
assert!(qr.try_execute_command(q1) != None);
assert!(qr.try_execute_command(&q1) != None);
assert_eq!(qr.active_role.unwrap(), Role::Primary);

let q2 = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(q2) != None);
assert!(qr.try_execute_command(&q2) != None);
assert_eq!(qr.active_role.unwrap(), pool_settings.default_role);

// Here we go :)
let q3 = simple_query("SELECT * FROM test WHERE id = 5 AND values IN (1, 2, 3)");
assert!(qr.infer(q3));
assert!(qr.infer(&q3));
assert_eq!(qr.shard(), 1);
}

Expand All @@ -812,13 +808,13 @@ mod test {
QueryRouter::setup();

let mut qr = QueryRouter::new();
assert!(qr.infer(simple_query("BEGIN; SELECT 1; COMMIT;")));
assert!(qr.infer(&simple_query("BEGIN; SELECT 1; COMMIT;")));
assert_eq!(qr.role(), Role::Primary);

assert!(qr.infer(simple_query("SELECT 1; SELECT 2;")));
assert!(qr.infer(&simple_query("SELECT 1; SELECT 2;")));
assert_eq!(qr.role(), Role::Replica);

assert!(qr.infer(simple_query(
assert!(qr.infer(&simple_query(
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
)));
assert_eq!(qr.role(), Role::Primary);
Expand Down
2 changes: 1 addition & 1 deletion tests/ruby/load_balancing_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
threads = Array.new(slow_query_count) do
Thread.new do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("SELECT pg_sleep(1)")
conn.async_exec("BEGIN")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's faster than sleeping, that's for sure.

end
end

Expand Down