From d5713764961b3ffc70dbfe6470f749df073ac8d3 Mon Sep 17 00:00:00 2001 From: Mostafa Abdelraouf Date: Thu, 19 Jan 2023 12:44:20 -0600 Subject: [PATCH] Sync with upstream (#115) Contains chore(deps): bump tokio from 1.24.1 to 1.24.2 (#286) Log error messages for network failures (#289) Removes message cloning operation required for query router (#285) Add more metrics to prometheus endpoint (#263) --- Cargo.lock | 8 +- src/admin.rs | 18 +- src/client.rs | 50 +++--- src/config.rs | 13 -- src/errors.rs | 1 + src/messages.rs | 63 +++++-- src/prometheus.rs | 278 +++++++++++++++++++++++------- src/query_router.rs | 94 +++++----- src/server.rs | 4 +- tests/ruby/Gemfile.lock | 18 +- tests/ruby/load_balancing_spec.rb | 2 +- 11 files changed, 367 insertions(+), 182 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c7794cc43..0119e7e75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -633,9 +633,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.6" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba1ef8814b5c993410bb3adfad7a5ed269563e4a2f90c41f5d85be7fb47133bf" +checksum = "7ff9f3fef3968a3ec5945535ed654cb38ff72d7495a25619e2247fb15a2ed9ba" dependencies = [ "cfg-if", "libc", @@ -1043,9 +1043,9 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.24.1" +version = "1.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d9f76183f91ecfb55e1d7d5602bd1d979e38a3a522fe900241cf195624d67ae" +checksum = "597a12a59981d9e3c38d216785b0c37399f6e415e8d0712047620f189371b0bb" dependencies = [ "autocfg", "bytes", diff --git a/src/admin.rs b/src/admin.rs index 4460f9821..5879114ac 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -171,7 +171,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show PgCat version. @@ -189,7 +189,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show utilization of connection pools for each shard and replicas. @@ -250,7 +250,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show shards and replicas. @@ -317,7 +317,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Ignore any SET commands the client sends. @@ -349,7 +349,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Shows current configuration. @@ -395,7 +395,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show shard and replicas statistics. @@ -455,7 +455,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show currently connected clients @@ -505,7 +505,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show currently connected servers @@ -559,5 +559,5 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } diff --git a/src/client.rs b/src/client.rs index b55906b2b..15fe21d91 100644 --- a/src/client.rs +++ b/src/client.rs @@ -693,11 +693,11 @@ where let current_shard = query_router.shard(); // Handle all custom protocol commands, if any. - match query_router.try_execute_command(message.clone()) { + match query_router.try_execute_command(&message) { // Normal query, not a custom command. None => { if query_router.query_parser_enabled() { - query_router.infer(message.clone()); + query_router.infer(&message); } } @@ -861,7 +861,7 @@ where 'Q' => { debug!("Sending query to server"); - self.send_and_receive_loop(code, message, server, &address, &pool) + self.send_and_receive_loop(code, Some(&message), server, &address, &pool) .await?; if !server.in_transaction() { @@ -931,14 +931,8 @@ where } } - self.send_and_receive_loop( - code, - self.buffer.clone(), - server, - &address, - &pool, - ) - .await?; + self.send_and_receive_loop(code, None, server, &address, &pool) + .await?; self.buffer.clear(); @@ -955,21 +949,32 @@ where // CopyData 'd' => { - // Forward the data to the server, - // don't buffer it since it can be rather large. - self.send_server_message(server, message, &address, &pool) - .await?; + self.buffer.put(&message[..]); + + // Want to limit buffer size + if self.buffer.len() > 8196 { + // Forward the data to the server, + self.send_server_message(server, &self.buffer, &address, &pool) + .await?; + self.buffer.clear(); + } } // CopyDone or CopyFail // Copy is done, successfully or not. 'c' | 'f' => { - self.send_server_message(server, message, &address, &pool) + // We may already have some copy data in the buffer, add this message to buffer + self.buffer.put(&message[..]); + + self.send_server_message(server, &self.buffer, &address, &pool) .await?; + // Clear the buffer + self.buffer.clear(); + let response = self.receive_server_message(server, &address, &pool).await?; - match write_all_half(&mut self.write, response).await { + match write_all_half(&mut self.write, &response).await { Ok(_) => (), Err(err) => { server.mark_bad(); @@ -1016,13 +1021,18 @@ where async fn send_and_receive_loop( &mut self, code: char, - message: BytesMut, + message: Option<&BytesMut>, server: &mut Server, address: &Address, pool: &ConnectionPool, ) -> Result<(), Error> { debug!("Sending {} to server", code); + let message = match message { + Some(message) => message, + None => &self.buffer, + }; + self.send_server_message(server, message, address, pool) .await?; @@ -1032,7 +1042,7 @@ where loop { let response = self.receive_server_message(server, address, pool).await?; - match write_all_half(&mut self.write, response).await { + match write_all_half(&mut self.write, &response).await { Ok(_) => (), Err(err) => { server.mark_bad(); @@ -1058,7 +1068,7 @@ where async fn send_server_message( &self, server: &mut Server, - message: BytesMut, + message: &BytesMut, address: &Address, pool: &ConnectionPool, ) -> Result<(), Error> { diff --git a/src/config.rs b/src/config.rs index e79b93378..219f0debc 100644 --- a/src/config.rs +++ b/src/config.rs @@ -624,19 +624,6 @@ impl Config { "[pool: {}] Pool mode: {:?}", pool_name, pool_config.pool_mode ); - let connect_timeout = match pool_config.connect_timeout { - Some(connect_timeout) => connect_timeout, - None => self.general.connect_timeout, - }; - info!( - "[pool: {}] Connection timeout: {}ms", - pool_name, connect_timeout - ); - let idle_timeout = match pool_config.idle_timeout { - Some(idle_timeout) => idle_timeout, - None => self.general.idle_timeout, - }; - info!("[pool: {}] Idle timeout: {}ms", pool_name, idle_timeout); info!( "[pool: {}] Load Balancing mode: {:?}", pool_name, pool_config.load_balancing_mode diff --git a/src/errors.rs b/src/errors.rs index 7789a8a77..4ac23a855 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -13,4 +13,5 @@ pub enum Error { TlsError, StatementTimeout, ShuttingDown, + ParseBytesError(String), } diff --git a/src/messages.rs b/src/messages.rs index e83155036..e7c36747a 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -7,6 +7,7 @@ use tokio::net::TcpStream; use crate::errors::Error; use std::collections::HashMap; +use std::io::{BufRead, Cursor}; use std::mem; /// Postgres data type mappings @@ -136,9 +137,10 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu match stream.write_all(&startup).await { Ok(_) => Ok(()), - Err(_) => { + Err(err) => { return Err(Error::SocketError(format!( - "Error writing startup to server socket" + "Error writing startup to server socket - Error: {:?}", + err ))) } } @@ -258,7 +260,7 @@ where res.put_i32(len); res.put_slice(&set_complete[..]); - write_all_half(stream, res).await?; + write_all_half(stream, &res).await?; ready_for_query(stream).await } @@ -308,7 +310,7 @@ where res.put_i32(error.len() as i32 + 4); res.put(error); - write_all_half(stream, res).await + write_all_half(stream, &res).await } pub async fn wrong_password(stream: &mut S, user: &str) -> Result<(), Error> @@ -370,7 +372,7 @@ where // CommandComplete res.put(command_complete("SELECT 1")); - write_all_half(stream, res).await?; + write_all_half(stream, &res).await?; ready_for_query(stream).await } @@ -454,18 +456,28 @@ where { match stream.write_all(&buf).await { Ok(_) => Ok(()), - Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))), + Err(err) => { + return Err(Error::SocketError(format!( + "Error writing to socket - Error: {:?}", + err + ))) + } } } /// Write all the data in the buffer to the TcpStream, write owned half (see mpsc). -pub async fn write_all_half(stream: &mut S, buf: BytesMut) -> Result<(), Error> +pub async fn write_all_half(stream: &mut S, buf: &BytesMut) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { - match stream.write_all(&buf).await { + match stream.write_all(buf).await { Ok(_) => Ok(()), - Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))), + Err(err) => { + return Err(Error::SocketError(format!( + "Error writing to socket - Error: {:?}", + err + ))) + } } } @@ -476,19 +488,20 @@ where { let code = match stream.read_u8().await { Ok(code) => code, - Err(_) => { + Err(err) => { return Err(Error::SocketError(format!( - "Error reading message code from socket" + "Error reading message code from socket - Error {:?}", + err ))) } }; let len = match stream.read_i32().await { Ok(len) => len, - Err(_) => { + Err(err) => { return Err(Error::SocketError(format!( - "Error reading message len from socket, code: {:?}", - code + "Error reading message len from socket - Code: {:?}, Error: {:?}", + code, err ))) } }; @@ -509,10 +522,10 @@ where .await { Ok(_) => (), - Err(_) => { + Err(err) => { return Err(Error::SocketError(format!( - "Error reading message from socket, code: {:?}", - code + "Error reading message from socket - Code: {:?}, Error: {:?}", + code, err ))) } }; @@ -536,3 +549,19 @@ pub fn server_parameter_message(key: &str, value: &str) -> BytesMut { server_info } + +pub trait BytesMutReader { + fn read_string(&mut self) -> Result; +} + +impl BytesMutReader for Cursor<&BytesMut> { + /// Should only be used when reading strings from the message protocol. + /// Can be used to read multiple strings from the same message which are separated by the null byte + fn read_string(&mut self) -> Result { + let mut buf = vec![]; + match self.read_until(b'\0', &mut buf) { + Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()), + Err(err) => return Err(Error::ParseBytesError(err.to_string())), + } + } +} diff --git a/src/prometheus.rs b/src/prometheus.rs index ec011b396..e596f9f99 100644 --- a/src/prometheus.rs +++ b/src/prometheus.rs @@ -8,7 +8,7 @@ use std::net::SocketAddr; use crate::config::Address; use crate::pool::get_all_pools; -use crate::stats::get_address_stats; +use crate::stats::{get_address_stats, get_pool_stats, get_server_stats, ServerInformation}; struct MetricHelpType { help: &'static str, @@ -19,113 +19,141 @@ struct MetricHelpType { // counters only increase // gauges can arbitrarily increase or decrease static METRIC_HELP_AND_TYPES_LOOKUP: phf::Map<&'static str, MetricHelpType> = phf_map! { - "total_query_count" => MetricHelpType { + "stats_total_query_count" => MetricHelpType { help: "Number of queries sent by all clients", ty: "counter", }, - "total_query_time" => MetricHelpType { + "stats_total_query_time" => MetricHelpType { help: "Total amount of time for queries to execute", ty: "counter", }, - "total_received" => MetricHelpType { + "stats_total_received" => MetricHelpType { help: "Number of bytes received from the server", ty: "counter", }, - "total_sent" => MetricHelpType { + "stats_total_sent" => MetricHelpType { help: "Number of bytes sent to the server", ty: "counter", }, - "total_xact_count" => MetricHelpType { + "stats_total_xact_count" => MetricHelpType { help: "Total number of transactions started by the client", ty: "counter", }, - "total_xact_time" => MetricHelpType { + "stats_total_xact_time" => MetricHelpType { help: "Total amount of time for all transactions to execute", ty: "counter", }, - "total_wait_time" => MetricHelpType { + "stats_total_wait_time" => MetricHelpType { help: "Total time client waited for a server connection", ty: "counter", }, - "avg_query_count" => MetricHelpType { + "stats_avg_query_count" => MetricHelpType { help: "Average of total_query_count every 15 seconds", ty: "gauge", }, - "avg_query_time" => MetricHelpType { + "stats_avg_query_time" => MetricHelpType { help: "Average time taken for queries to execute every 15 seconds", ty: "gauge", }, - "avg_recv" => MetricHelpType { + "stats_avg_recv" => MetricHelpType { help: "Average of total_received bytes every 15 seconds", ty: "gauge", }, - "avg_sent" => MetricHelpType { + "stats_avg_sent" => MetricHelpType { help: "Average of total_sent bytes every 15 seconds", ty: "gauge", }, - "avg_errors" => MetricHelpType { + "stats_avg_errors" => MetricHelpType { help: "Average number of errors every 15 seconds", ty: "gauge", }, - "avg_xact_count" => MetricHelpType { + "stats_avg_xact_count" => MetricHelpType { help: "Average of total_xact_count every 15 seconds", ty: "gauge", }, - "avg_xact_time" => MetricHelpType { + "stats_avg_xact_time" => MetricHelpType { help: "Average of total_xact_time every 15 seconds", ty: "gauge", }, - "avg_wait_time" => MetricHelpType { + "stats_avg_wait_time" => MetricHelpType { help: "Average of total_wait_time every 15 seconds", ty: "gauge", }, - "maxwait_us" => MetricHelpType { + "pools_maxwait_us" => MetricHelpType { help: "The time a client waited for a server connection in microseconds", ty: "gauge", }, - "maxwait" => MetricHelpType { + "pools_maxwait" => MetricHelpType { help: "The time a client waited for a server connection in seconds", ty: "gauge", }, - "cl_waiting" => MetricHelpType { + "pools_cl_waiting" => MetricHelpType { help: "How many clients are waiting for a connection from the pool", ty: "gauge", }, - "cl_active" => MetricHelpType { + "pools_cl_active" => MetricHelpType { help: "How many clients are actively communicating with a server", ty: "gauge", }, - "cl_idle" => MetricHelpType { + "pools_cl_idle" => MetricHelpType { help: "How many clients are idle", ty: "gauge", }, - "sv_idle" => MetricHelpType { + "pools_sv_idle" => MetricHelpType { help: "How many server connections are idle", ty: "gauge", }, - "sv_active" => MetricHelpType { + "pools_sv_active" => MetricHelpType { help: "How many server connections are actively communicating with a client", ty: "gauge", }, - "sv_login" => MetricHelpType { + "pools_sv_login" => MetricHelpType { help: "How many server connections are currently being created", ty: "gauge", }, - "sv_tested" => MetricHelpType { + "pools_sv_tested" => MetricHelpType { help: "How many server connections are currently waiting on a health check to succeed", ty: "gauge", }, + "servers_bytes_received" => MetricHelpType { + help: "Volume in bytes of network traffic received by server", + ty: "gauge", + }, + "servers_bytes_sent" => MetricHelpType { + help: "Volume in bytes of network traffic sent by server", + ty: "gauge", + }, + "servers_transaction_count" => MetricHelpType { + help: "Number of transactions executed by server", + ty: "gauge", + }, + "servers_query_count" => MetricHelpType { + help: "Number of queries executed by server", + ty: "gauge", + }, + "servers_error_count" => MetricHelpType { + help: "Number of errors", + ty: "gauge", + }, + "databases_pool_size" => MetricHelpType { + help: "Maximum number of server connections", + ty: "gauge", + }, + "databases_current_connections" => MetricHelpType { + help: "Current number of connections for this database", + ty: "gauge", + }, }; -struct PrometheusMetric { +struct PrometheusMetric { name: String, help: String, ty: String, labels: HashMap<&'static str, String>, - value: i64, + value: Value, } -impl fmt::Display for PrometheusMetric { +impl fmt::Display for PrometheusMetric { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let formatted_labels = self .labels @@ -145,50 +173,81 @@ impl fmt::Display for PrometheusMetric { } } -impl PrometheusMetric { - fn new(address: &Address, name: &str, value: i64) -> Option { - let mut labels = HashMap::new(); - labels.insert("host", address.host.clone()); - labels.insert("shard", address.shard.to_string()); - labels.insert("role", address.role.to_string()); - labels.insert("database", address.database.to_string()); - +impl PrometheusMetric { + fn from_name( + name: &str, + value: V, + labels: HashMap<&'static str, String>, + ) -> Option> { METRIC_HELP_AND_TYPES_LOOKUP .get(name) - .map(|metric| PrometheusMetric { + .map(|metric| PrometheusMetric:: { name: name.to_owned(), help: metric.help.to_owned(), ty: metric.ty.to_owned(), - labels, value, + labels, }) } + + fn from_database_info( + address: &Address, + name: &str, + value: u32, + ) -> Option> { + let mut labels = HashMap::new(); + labels.insert("host", address.host.clone()); + labels.insert("shard", address.shard.to_string()); + labels.insert("role", address.role.to_string()); + labels.insert("pool", address.pool_name.clone()); + labels.insert("database", address.database.to_string()); + + Self::from_name(&format!("databases_{}", name), value, labels) + } + + fn from_server_info( + address: &Address, + name: &str, + value: u64, + ) -> Option> { + let mut labels = HashMap::new(); + labels.insert("host", address.host.clone()); + labels.insert("shard", address.shard.to_string()); + labels.insert("role", address.role.to_string()); + labels.insert("pool", address.pool_name.clone()); + labels.insert("database", address.database.to_string()); + + Self::from_name(&format!("servers_{}", name), value, labels) + } + + fn from_address(address: &Address, name: &str, value: i64) -> Option> { + let mut labels = HashMap::new(); + labels.insert("host", address.host.clone()); + labels.insert("shard", address.shard.to_string()); + labels.insert("pool", address.pool_name.clone()); + labels.insert("role", address.role.to_string()); + labels.insert("database", address.database.to_string()); + + Self::from_name(&format!("stats_{}", name), value, labels) + } + + fn from_pool(pool: &(String, String), name: &str, value: i64) -> Option> { + let mut labels = HashMap::new(); + labels.insert("pool", pool.0.clone()); + labels.insert("user", pool.1.clone()); + + Self::from_name(&format!("pools_{}", name), value, labels) + } } async fn prometheus_stats(request: Request) -> Result, hyper::http::Error> { match (request.method(), request.uri().path()) { (&Method::GET, "/metrics") => { - let stats: HashMap> = get_address_stats(); - let mut lines = Vec::new(); - for (_, pool) in get_all_pools() { - for shard in 0..pool.shards() { - for server in 0..pool.servers(shard) { - let address = pool.address(shard, server); - if let Some(address_stats) = stats.get(&address.id) { - for (key, value) in address_stats.iter() { - if let Some(prometheus_metric) = - PrometheusMetric::new(address, key, *value) - { - lines.push(prometheus_metric.to_string()); - } else { - warn!("Metric {} not implemented for {}", key, address.name()); - } - } - } - } - } - } + push_address_stats(&mut lines); + push_pool_stats(&mut lines); + push_server_stats(&mut lines); + push_database_stats(&mut lines); Response::builder() .header("content-type", "text/plain; version=0.0.4") @@ -200,6 +259,109 @@ async fn prometheus_stats(request: Request) -> Result, hype } } +// Adds metrics shown in a SHOW STATS admin command. +fn push_address_stats(lines: &mut Vec) { + let address_stats: HashMap> = get_address_stats(); + for (_, pool) in get_all_pools() { + for shard in 0..pool.shards() { + for server in 0..pool.servers(shard) { + let address = pool.address(shard, server); + if let Some(address_stats) = address_stats.get(&address.id) { + for (key, value) in address_stats.iter() { + if let Some(prometheus_metric) = + PrometheusMetric::::from_address(address, key, *value) + { + lines.push(prometheus_metric.to_string()); + } else { + warn!("Metric {} not implemented for {}", key, address.name()); + } + } + } + } + } + } +} + +// Adds relevant metrics shown in a SHOW POOLS admin command. +fn push_pool_stats(lines: &mut Vec) { + let pool_stats = get_pool_stats(); + for (pool, stats) in pool_stats.iter() { + for (name, value) in stats.iter() { + if let Some(prometheus_metric) = PrometheusMetric::::from_pool(pool, name, *value) + { + lines.push(prometheus_metric.to_string()); + } else { + warn!( + "Metric {} not implemented for ({},{})", + name, pool.0, pool.1 + ); + } + } + } +} + +// Adds relevant metrics shown in a SHOW DATABASES admin command. +fn push_database_stats(lines: &mut Vec) { + for (_, pool) in get_all_pools() { + let pool_config = pool.settings.clone(); + for shard in 0..pool.shards() { + for server in 0..pool.servers(shard) { + let address = pool.address(shard, server); + let pool_state = pool.pool_state(shard, server); + + let metrics = vec![ + ("pool_size", pool_config.user.pool_size), + ("current_connections", pool_state.connections), + ]; + for (key, value) in metrics { + if let Some(prometheus_metric) = + PrometheusMetric::::from_database_info(address, key, value) + { + lines.push(prometheus_metric.to_string()); + } else { + warn!("Metric {} not implemented for {}", key, address.name()); + } + } + } + } + } +} + +// Adds relevant metrics shown in a SHOW SERVERS admin command. +fn push_server_stats(lines: &mut Vec) { + let server_stats = get_server_stats(); + let mut server_stats_by_addresses = HashMap::::new(); + for (_, info) in server_stats { + server_stats_by_addresses.insert(info.address_name.clone(), info); + } + + for (_, pool) in get_all_pools() { + for shard in 0..pool.shards() { + for server in 0..pool.servers(shard) { + let address = pool.address(shard, server); + if let Some(server_info) = server_stats_by_addresses.get(&address.name()) { + let metrics = [ + ("bytes_received", server_info.bytes_received), + ("bytes_sent", server_info.bytes_sent), + ("transaction_count", server_info.transaction_count), + ("query_count", server_info.query_count), + ("error_count", server_info.error_count), + ]; + for (key, value) in metrics { + if let Some(prometheus_metric) = + PrometheusMetric::::from_server_info(address, key, value) + { + lines.push(prometheus_metric.to_string()); + } else { + warn!("Metric {} not implemented for {}", key, address.name()); + } + } + } + } + } + } +} + pub async fn start_metric_server(http_addr: SocketAddr) { let http_service_factory = make_service_fn(|_conn| async { Ok::<_, hyper::Error>(service_fn(prometheus_stats)) }); diff --git a/src/query_router.rs b/src/query_router.rs index 03f460191..9f9dcd76e 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -10,10 +10,12 @@ use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; use crate::config::Role; +use crate::messages::BytesMutReader; use crate::pool::PoolSettings; use crate::sharding::Sharder; use std::collections::BTreeSet; +use std::io::Cursor; /// Regexes used to parse custom commands. const CUSTOM_SQL_REGEXES: [&str; 7] = [ @@ -107,16 +109,18 @@ impl QueryRouter { } /// Try to parse a command and execute it. - pub fn try_execute_command(&mut self, mut buf: BytesMut) -> Option<(Command, String)> { - let code = buf.get_u8() as char; + pub fn try_execute_command(&mut self, message_buffer: &BytesMut) -> Option<(Command, String)> { + let mut message_cursor = Cursor::new(message_buffer); + + let code = message_cursor.get_u8() as char; // Only simple protocol supported for commands. if code != 'Q' { return None; } - let len = buf.get_i32() as usize; - let query = String::from_utf8_lossy(&buf[..len - 5]).to_string(); // Ignore the terminating NULL. + let _len = message_cursor.get_i32() as usize; + let query = message_cursor.read_string().unwrap(); let regex_set = match CUSTOM_SQL_REGEX_SET.get() { Some(regex_set) => regex_set, @@ -256,37 +260,29 @@ impl QueryRouter { } /// Try to infer which server to connect to based on the contents of the query. - pub fn infer(&mut self, mut buf: BytesMut) -> bool { + pub fn infer(&mut self, message_buffer: &BytesMut) -> bool { debug!("Inferring role"); - let code = buf.get_u8() as char; - let len = buf.get_i32() as usize; + let mut message_cursor = Cursor::new(message_buffer); + + let code = message_cursor.get_u8() as char; + let _len = message_cursor.get_i32() as usize; let query = match code { // Query 'Q' => { - let query = String::from_utf8_lossy(&buf[..len - 5]).to_string(); + let query = message_cursor.read_string().unwrap(); debug!("Query: '{}'", query); query } // Parse (prepared statement) 'P' => { - let mut start = 0; - - // Skip the name of the prepared statement. - while buf[start] != 0 && start < buf.len() { - start += 1; - } - start += 1; // Skip terminating null - - // Find the end of the prepared stmt (\0) - let mut end = start; - while buf[end] != 0 && end < buf.len() { - end += 1; - } + // Reads statement name + message_cursor.read_string().unwrap(); - let query = String::from_utf8_lossy(&buf[start..end]).to_string(); + // Reads query string + let query = message_cursor.read_string().unwrap(); debug!("Prepared statement: '{}'", query); @@ -519,10 +515,10 @@ mod test { fn test_infer_replica() { QueryRouter::setup(); let mut qr = QueryRouter::new(); - assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None); + assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None); assert!(qr.query_parser_enabled()); - assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); + assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); let queries = vec![ simple_query("SELECT * FROM items WHERE id = 5"), @@ -534,7 +530,7 @@ mod test { for query in queries { // It's a recognized query - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), Some(Role::Replica)); } } @@ -553,7 +549,7 @@ mod test { for query in queries { // It's a recognized query - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), Some(Role::Primary)); } } @@ -563,9 +559,9 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); let query = simple_query("SELECT * FROM items WHERE id = 5"); - assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None); + assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None); - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), None); } @@ -573,8 +569,8 @@ mod test { fn test_infer_parse_prepared() { QueryRouter::setup(); let mut qr = QueryRouter::new(); - qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")); - assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); + qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")); + assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); let prepared_stmt = BytesMut::from( &b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..], @@ -585,7 +581,7 @@ mod test { res.put(prepared_stmt); res.put_i16(0); - assert!(qr.infer(res)); + assert!(qr.infer(&res)); assert_eq!(qr.role(), Some(Role::Replica)); } @@ -668,7 +664,7 @@ mod test { // SetShardingKey let query = simple_query("SET SHARDING KEY TO 13"); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::SetShardingKey, String::from("0"))) ); assert_eq!(qr.shard(), 0); @@ -676,7 +672,7 @@ mod test { // SetShard let query = simple_query("SET SHARD TO '1'"); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::SetShard, String::from("1"))) ); assert_eq!(qr.shard(), 1); @@ -684,7 +680,7 @@ mod test { // ShowShard let query = simple_query("SHOW SHARD"); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::ShowShard, String::from("1"))) ); @@ -702,7 +698,7 @@ mod test { for (idx, role) in roles.iter().enumerate() { let query = simple_query(&format!("SET SERVER ROLE TO '{}'", role)); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::SetServerRole, String::from(*role))) ); assert_eq!(qr.role(), verify_roles[idx],); @@ -711,7 +707,7 @@ mod test { // ShowServerRole let query = simple_query("SHOW SERVER ROLE"); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::ShowServerRole, String::from(*role))) ); } @@ -721,14 +717,14 @@ mod test { for (idx, primary_reads) in primary_reads.iter().enumerate() { assert_eq!( - qr.try_execute_command(simple_query(&format!( + qr.try_execute_command(&simple_query(&format!( "SET PRIMARY READS TO {}", primary_reads ))), Some((Command::SetPrimaryReads, String::from(*primary_reads))) ); assert_eq!( - qr.try_execute_command(simple_query("SHOW PRIMARY READS")), + qr.try_execute_command(&simple_query("SHOW PRIMARY READS")), Some(( Command::ShowPrimaryReads, String::from(primary_reads_enabled[idx]) @@ -742,23 +738,23 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); let query = simple_query("SET SERVER ROLE TO 'auto'"); - assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); + assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); - assert!(qr.try_execute_command(query) != None); + assert!(qr.try_execute_command(&query) != None); assert!(qr.query_parser_enabled()); assert_eq!(qr.role(), None); let query = simple_query("INSERT INTO test_table VALUES (1)"); - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), Some(Role::Primary)); let query = simple_query("SELECT * FROM test_table"); - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), Some(Role::Replica)); assert!(qr.query_parser_enabled()); let query = simple_query("SET SERVER ROLE TO 'default'"); - assert!(qr.try_execute_command(query) != None); + assert!(qr.try_execute_command(&query) != None); assert!(!qr.query_parser_enabled()); } @@ -794,16 +790,16 @@ mod test { assert!(!qr.primary_reads_enabled()); let q1 = simple_query("SET SERVER ROLE TO 'primary'"); - assert!(qr.try_execute_command(q1) != None); + assert!(qr.try_execute_command(&q1) != None); assert_eq!(qr.active_role.unwrap(), Role::Primary); let q2 = simple_query("SET SERVER ROLE TO 'default'"); - assert!(qr.try_execute_command(q2) != None); + assert!(qr.try_execute_command(&q2) != None); assert_eq!(qr.active_role.unwrap(), pool_settings.default_role); // Here we go :) let q3 = simple_query("SELECT * FROM test WHERE id = 5 AND values IN (1, 2, 3)"); - assert!(qr.infer(q3)); + assert!(qr.infer(&q3)); assert_eq!(qr.shard(), 1); } @@ -812,13 +808,13 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); - assert!(qr.infer(simple_query("BEGIN; SELECT 1; COMMIT;"))); + assert!(qr.infer(&simple_query("BEGIN; SELECT 1; COMMIT;"))); assert_eq!(qr.role(), Role::Primary); - assert!(qr.infer(simple_query("SELECT 1; SELECT 2;"))); + assert!(qr.infer(&simple_query("SELECT 1; SELECT 2;"))); assert_eq!(qr.role(), Role::Replica); - assert!(qr.infer(simple_query( + assert!(qr.infer(&simple_query( "SELECT 123; INSERT INTO t VALUES (5); SELECT 1;" ))); assert_eq!(qr.role(), Role::Primary); diff --git a/src/server.rs b/src/server.rs index 05a3b770e..f2a6d387d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -381,7 +381,7 @@ impl Server { } /// Send messages to the server from the client. - pub async fn send(&mut self, messages: BytesMut) -> Result<(), Error> { + pub async fn send(&mut self, messages: &BytesMut) -> Result<(), Error> { self.stats.data_sent(messages.len(), self.server_id); match write_all_half(&mut self.write, messages).await { @@ -593,7 +593,7 @@ impl Server { pub async fn query(&mut self, query: &str) -> Result<(), Error> { let query = simple_query(query); - self.send(query).await?; + self.send(&query).await?; loop { let _ = self.recv().await?; diff --git a/tests/ruby/Gemfile.lock b/tests/ruby/Gemfile.lock index 65d8bce7a..f49468053 100644 --- a/tests/ruby/Gemfile.lock +++ b/tests/ruby/Gemfile.lock @@ -1,12 +1,12 @@ GEM remote: https://rubygems.org/ specs: - activemodel (7.0.3.1) - activesupport (= 7.0.3.1) - activerecord (7.0.3.1) - activemodel (= 7.0.3.1) - activesupport (= 7.0.3.1) - activesupport (7.0.3.1) + activemodel (7.0.4.1) + activesupport (= 7.0.4.1) + activerecord (7.0.4.1) + activemodel (= 7.0.4.1) + activesupport (= 7.0.4.1) + activesupport (7.0.4.1) concurrent-ruby (~> 1.0, >= 1.0.2) i18n (>= 1.6, < 2) minitest (>= 5.1) @@ -14,9 +14,9 @@ GEM ast (2.4.2) concurrent-ruby (1.1.10) diff-lcs (1.5.0) - i18n (1.11.0) + i18n (1.12.0) concurrent-ruby (~> 1.0) - minitest (5.16.2) + minitest (5.17.0) parallel (1.22.1) parser (3.1.2.0) ast (~> 2.4.1) @@ -53,7 +53,7 @@ GEM toml (0.3.0) parslet (>= 1.8.0, < 3.0.0) toxiproxy (2.0.1) - tzinfo (2.0.4) + tzinfo (2.0.5) concurrent-ruby (~> 1.0) unicode-display_width (2.1.0) diff --git a/tests/ruby/load_balancing_spec.rb b/tests/ruby/load_balancing_spec.rb index 8be066df4..fccf0a859 100644 --- a/tests/ruby/load_balancing_spec.rb +++ b/tests/ruby/load_balancing_spec.rb @@ -93,7 +93,7 @@ threads = Array.new(slow_query_count) do Thread.new do conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) - conn.async_exec("SELECT pg_sleep(1)") + conn.async_exec("BEGIN") end end