Skip to content

Commit

Permalink
Sync with upstream (postgresml#115)
Browse files Browse the repository at this point in the history
Contains
chore(deps): bump tokio from 1.24.1 to 1.24.2 (postgresml#286)
Log error messages for network failures (postgresml#289)
Removes message cloning operation required for query router (postgresml#285)
Add more metrics to prometheus endpoint (postgresml#263)
  • Loading branch information
drdrsh authored Jan 19, 2023
1 parent 7bc866b commit d571376
Show file tree
Hide file tree
Showing 11 changed files with 367 additions and 182 deletions.
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 9 additions & 9 deletions src/admin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, res).await
write_all_half(stream, &res).await
}

/// Show PgCat version.
Expand All @@ -189,7 +189,7 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, res).await
write_all_half(stream, &res).await
}

/// Show utilization of connection pools for each shard and replicas.
Expand Down Expand Up @@ -250,7 +250,7 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, res).await
write_all_half(stream, &res).await
}

/// Show shards and replicas.
Expand Down Expand Up @@ -317,7 +317,7 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, res).await
write_all_half(stream, &res).await
}

/// Ignore any SET commands the client sends.
Expand Down Expand Up @@ -349,7 +349,7 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, res).await
write_all_half(stream, &res).await
}

/// Shows current configuration.
Expand Down Expand Up @@ -395,7 +395,7 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, res).await
write_all_half(stream, &res).await
}

/// Show shard and replicas statistics.
Expand Down Expand Up @@ -455,7 +455,7 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, res).await
write_all_half(stream, &res).await
}

/// Show currently connected clients
Expand Down Expand Up @@ -505,7 +505,7 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, res).await
write_all_half(stream, &res).await
}

/// Show currently connected servers
Expand Down Expand Up @@ -559,5 +559,5 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, res).await
write_all_half(stream, &res).await
}
50 changes: 30 additions & 20 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -693,11 +693,11 @@ where
let current_shard = query_router.shard();

// Handle all custom protocol commands, if any.
match query_router.try_execute_command(message.clone()) {
match query_router.try_execute_command(&message) {
// Normal query, not a custom command.
None => {
if query_router.query_parser_enabled() {
query_router.infer(message.clone());
query_router.infer(&message);
}
}

Expand Down Expand Up @@ -861,7 +861,7 @@ where
'Q' => {
debug!("Sending query to server");

self.send_and_receive_loop(code, message, server, &address, &pool)
self.send_and_receive_loop(code, Some(&message), server, &address, &pool)
.await?;

if !server.in_transaction() {
Expand Down Expand Up @@ -931,14 +931,8 @@ where
}
}

self.send_and_receive_loop(
code,
self.buffer.clone(),
server,
&address,
&pool,
)
.await?;
self.send_and_receive_loop(code, None, server, &address, &pool)
.await?;

self.buffer.clear();

Expand All @@ -955,21 +949,32 @@ where

// CopyData
'd' => {
// Forward the data to the server,
// don't buffer it since it can be rather large.
self.send_server_message(server, message, &address, &pool)
.await?;
self.buffer.put(&message[..]);

// Want to limit buffer size
if self.buffer.len() > 8196 {
// Forward the data to the server,
self.send_server_message(server, &self.buffer, &address, &pool)
.await?;
self.buffer.clear();
}
}

// CopyDone or CopyFail
// Copy is done, successfully or not.
'c' | 'f' => {
self.send_server_message(server, message, &address, &pool)
// We may already have some copy data in the buffer, add this message to buffer
self.buffer.put(&message[..]);

self.send_server_message(server, &self.buffer, &address, &pool)
.await?;

// Clear the buffer
self.buffer.clear();

let response = self.receive_server_message(server, &address, &pool).await?;

match write_all_half(&mut self.write, response).await {
match write_all_half(&mut self.write, &response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
Expand Down Expand Up @@ -1016,13 +1021,18 @@ where
async fn send_and_receive_loop(
&mut self,
code: char,
message: BytesMut,
message: Option<&BytesMut>,
server: &mut Server,
address: &Address,
pool: &ConnectionPool,
) -> Result<(), Error> {
debug!("Sending {} to server", code);

let message = match message {
Some(message) => message,
None => &self.buffer,
};

self.send_server_message(server, message, address, pool)
.await?;

Expand All @@ -1032,7 +1042,7 @@ where
loop {
let response = self.receive_server_message(server, address, pool).await?;

match write_all_half(&mut self.write, response).await {
match write_all_half(&mut self.write, &response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
Expand All @@ -1058,7 +1068,7 @@ where
async fn send_server_message(
&self,
server: &mut Server,
message: BytesMut,
message: &BytesMut,
address: &Address,
pool: &ConnectionPool,
) -> Result<(), Error> {
Expand Down
13 changes: 0 additions & 13 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -624,19 +624,6 @@ impl Config {
"[pool: {}] Pool mode: {:?}",
pool_name, pool_config.pool_mode
);
let connect_timeout = match pool_config.connect_timeout {
Some(connect_timeout) => connect_timeout,
None => self.general.connect_timeout,
};
info!(
"[pool: {}] Connection timeout: {}ms",
pool_name, connect_timeout
);
let idle_timeout = match pool_config.idle_timeout {
Some(idle_timeout) => idle_timeout,
None => self.general.idle_timeout,
};
info!("[pool: {}] Idle timeout: {}ms", pool_name, idle_timeout);
info!(
"[pool: {}] Load Balancing mode: {:?}",
pool_name, pool_config.load_balancing_mode
Expand Down
1 change: 1 addition & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ pub enum Error {
TlsError,
StatementTimeout,
ShuttingDown,
ParseBytesError(String),
}
63 changes: 46 additions & 17 deletions src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use tokio::net::TcpStream;

use crate::errors::Error;
use std::collections::HashMap;
use std::io::{BufRead, Cursor};
use std::mem;

/// Postgres data type mappings
Expand Down Expand Up @@ -136,9 +137,10 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu

match stream.write_all(&startup).await {
Ok(_) => Ok(()),
Err(_) => {
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing startup to server socket"
"Error writing startup to server socket - Error: {:?}",
err
)))
}
}
Expand Down Expand Up @@ -258,7 +260,7 @@ where
res.put_i32(len);
res.put_slice(&set_complete[..]);

write_all_half(stream, res).await?;
write_all_half(stream, &res).await?;
ready_for_query(stream).await
}

Expand Down Expand Up @@ -308,7 +310,7 @@ where
res.put_i32(error.len() as i32 + 4);
res.put(error);

write_all_half(stream, res).await
write_all_half(stream, &res).await
}

pub async fn wrong_password<S>(stream: &mut S, user: &str) -> Result<(), Error>
Expand Down Expand Up @@ -370,7 +372,7 @@ where
// CommandComplete
res.put(command_complete("SELECT 1"));

write_all_half(stream, res).await?;
write_all_half(stream, &res).await?;
ready_for_query(stream).await
}

Expand Down Expand Up @@ -454,18 +456,28 @@ where
{
match stream.write_all(&buf).await {
Ok(_) => Ok(()),
Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))),
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
)))
}
}
}

/// Write all the data in the buffer to the TcpStream, write owned half (see mpsc).
pub async fn write_all_half<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
pub async fn write_all_half<S>(stream: &mut S, buf: &BytesMut) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
match stream.write_all(&buf).await {
match stream.write_all(buf).await {
Ok(_) => Ok(()),
Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))),
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
)))
}
}
}

Expand All @@ -476,19 +488,20 @@ where
{
let code = match stream.read_u8().await {
Ok(code) => code,
Err(_) => {
Err(err) => {
return Err(Error::SocketError(format!(
"Error reading message code from socket"
"Error reading message code from socket - Error {:?}",
err
)))
}
};

let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => {
Err(err) => {
return Err(Error::SocketError(format!(
"Error reading message len from socket, code: {:?}",
code
"Error reading message len from socket - Code: {:?}, Error: {:?}",
code, err
)))
}
};
Expand All @@ -509,10 +522,10 @@ where
.await
{
Ok(_) => (),
Err(_) => {
Err(err) => {
return Err(Error::SocketError(format!(
"Error reading message from socket, code: {:?}",
code
"Error reading message from socket - Code: {:?}, Error: {:?}",
code, err
)))
}
};
Expand All @@ -536,3 +549,19 @@ pub fn server_parameter_message(key: &str, value: &str) -> BytesMut {

server_info
}

pub trait BytesMutReader {
fn read_string(&mut self) -> Result<String, Error>;
}

impl BytesMutReader for Cursor<&BytesMut> {
/// Should only be used when reading strings from the message protocol.
/// Can be used to read multiple strings from the same message which are separated by the null byte
fn read_string(&mut self) -> Result<String, Error> {
let mut buf = vec![];
match self.read_until(b'\0', &mut buf) {
Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()),
Err(err) => return Err(Error::ParseBytesError(err.to_string())),
}
}
}
Loading

0 comments on commit d571376

Please sign in to comment.