Skip to content

Commit

Permalink
Add Manual host banning to PgCat (#340)
Browse files Browse the repository at this point in the history
Sometimes we want an admin to be able to ban a host for some time to route traffic away from that host for reasons like partial outages, replication lag, and scheduled maintenance.

We can achieve this today using a configuration update but a quicker approach is to send a control command to PgCat that bans the replica for some specified duration.

This command does not change the current banning rules like

Primaries cannot be banned
When all replicas are banned, all replicas are unbanned
  • Loading branch information
drdrsh authored and levkk committed Mar 9, 2023
1 parent e01c4fb commit c03f01a
Show file tree
Hide file tree
Showing 5 changed files with 300 additions and 13 deletions.
172 changes: 172 additions & 0 deletions src/admin.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use crate::config::Role;
use crate::pool::BanReason;
/// Admin database.
use bytes::{Buf, BufMut, BytesMut};
use log::{error, info, trace};
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::time::Instant;

use crate::config::{get_config, reload_config, VERSION};
Expand Down Expand Up @@ -53,6 +56,14 @@ where
let query_parts: Vec<&str> = query.trim_end_matches(';').split_whitespace().collect();

match query_parts[0].to_ascii_uppercase().as_str() {
"BAN" => {
trace!("BAN");
ban(stream, query_parts).await
}
"UNBAN" => {
trace!("UNBAN");
unban(stream, query_parts).await
}
"RELOAD" => {
trace!("RELOAD");
reload(stream, client_server_map).await
Expand All @@ -74,6 +85,10 @@ where
shutdown(stream).await
}
"SHOW" => match query_parts[1].to_ascii_uppercase().as_str() {
"BANS" => {
trace!("SHOW BANS");
show_bans(stream).await
}
"CONFIG" => {
trace!("SHOW CONFIG");
show_config(stream).await
Expand Down Expand Up @@ -350,6 +365,163 @@ where
custom_protocol_response_ok(stream, "SET").await
}

/// Bans a host from being used
async fn ban<T>(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let host = match tokens.get(1) {
Some(host) => host,
None => return error_response(stream, "usage: BAN hostname duration_seconds").await,
};

let duration_seconds = match tokens.get(2) {
Some(duration_seconds) => match duration_seconds.parse::<i64>() {
Ok(duration_seconds) => duration_seconds,
Err(_) => {
return error_response(stream, "duration_seconds must be an integer").await;
}
},
None => return error_response(stream, "usage: BAN hostname duration_seconds").await,
};

if duration_seconds <= 0 {
return error_response(stream, "duration_seconds must be >= 0").await;
}

let columns = vec![
("db", DataType::Text),
("user", DataType::Text),
("role", DataType::Text),
("host", DataType::Text),
];
let mut res = BytesMut::new();
res.put(row_description(&columns));

for (id, pool) in get_all_pools().iter() {
for address in pool.get_addresses_from_host(host) {
if !pool.is_banned(&address) {
pool.ban(&address, BanReason::AdminBan(duration_seconds), -1);
res.put(data_row(&vec![
id.db.clone(),
id.user.clone(),
address.role.to_string(),
address.host,
]));
}
}
}

res.put(command_complete("BAN"));

// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
}

/// Clear a host for use
async fn unban<T>(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let host = match tokens.get(1) {
Some(host) => host,
None => return error_response(stream, "UNBAN command requires a hostname to unban").await,
};

let columns = vec![
("db", DataType::Text),
("user", DataType::Text),
("role", DataType::Text),
("host", DataType::Text),
];
let mut res = BytesMut::new();
res.put(row_description(&columns));

for (id, pool) in get_all_pools().iter() {
for address in pool.get_addresses_from_host(host) {
if pool.is_banned(&address) {
pool.unban(&address);
res.put(data_row(&vec![
id.db.clone(),
id.user.clone(),
address.role.to_string(),
address.host,
]));
}
}
}

res.put(command_complete("UNBAN"));

// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
}

/// Shows all the bans
async fn show_bans<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let columns = vec![
("db", DataType::Text),
("user", DataType::Text),
("role", DataType::Text),
("host", DataType::Text),
("reason", DataType::Text),
("ban_time", DataType::Text),
("ban_duration_seconds", DataType::Text),
("ban_remaining_seconds", DataType::Text),
];
let mut res = BytesMut::new();
res.put(row_description(&columns));

// The block should be pretty quick so we cache the time outside
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs() as i64;

for (id, pool) in get_all_pools().iter() {
for (address, (ban_reason, ban_time)) in pool.get_bans().iter() {
let ban_duration = match ban_reason {
BanReason::AdminBan(duration) => *duration,
_ => pool.settings.ban_time,
};
let remaining = ban_duration - (now - ban_time.timestamp());
if remaining <= 0 {
continue;
}
res.put(data_row(&vec![
id.db.clone(),
id.user.clone(),
address.role.to_string(),
address.host.clone(),
format!("{:?}", ban_reason),
ban_time.to_string(),
ban_duration.to_string(),
remaining.to_string(),
]));
}
}

res.put(command_complete("SHOW BANS"));

// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
}

/// Reload the configuration file without restarting the process.
async fn reload<T>(stream: &mut T, client_server_map: ClientServerMap) -> Result<(), Error>
where
Expand Down
13 changes: 8 additions & 5 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::errors::Error;
use crate::pool::BanReason;
/// Handle clients by pretending to be a PostgreSQL server.
use bytes::{Buf, BufMut, BytesMut};
use log::{debug, error, info, trace, warn};

use std::collections::HashMap;
use std::time::Instant;
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
Expand All @@ -11,7 +14,7 @@ use tokio::sync::mpsc::Sender;
use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::config::{get_config, Address, PoolMode};
use crate::constants::*;
use crate::errors::Error;

use crate::messages::*;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::query_router::{Command, QueryRouter};
Expand Down Expand Up @@ -1135,7 +1138,7 @@ where
match server.send(message).await {
Ok(_) => Ok(()),
Err(err) => {
pool.ban(address, self.process_id);
pool.ban(address, BanReason::MessageSendFailed, self.process_id);
Err(err)
}
}
Expand All @@ -1157,7 +1160,7 @@ where
Ok(result) => match result {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, self.process_id);
pool.ban(address, BanReason::MessageReceiveFailed, self.process_id);
error_response_terminal(
&mut self.write,
&format!("error receiving data from server: {:?}", err),
Expand All @@ -1172,7 +1175,7 @@ where
address, pool.settings.user.username
);
server.mark_bad();
pool.ban(address, self.process_id);
pool.ban(address, BanReason::StatementTimeout, self.process_id);
error_response_terminal(&mut self.write, "pool statement timeout").await?;
Err(Error::StatementTimeout)
}
Expand All @@ -1181,7 +1184,7 @@ where
match server.recv().await {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, self.process_id);
pool.ban(address, BanReason::MessageReceiveFailed, self.process_id);
error_response_terminal(
&mut self.write,
&format!("error receiving data from server: {:?}", err),
Expand Down
1 change: 1 addition & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub enum Error {
SocketError(String),
ClientBadStartup,
ProtocolSyncError(String),
BadQuery(String),
ServerError,
BadConfig,
AllServersDown,
Expand Down
57 changes: 49 additions & 8 deletions src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub type SecretKey = i32;
pub type ServerHost = String;
pub type ServerPort = u16;

pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>;
pub type BanList = Arc<RwLock<Vec<HashMap<Address, (BanReason, NaiveDateTime)>>>>;
pub type ClientServerMap =
Arc<Mutex<HashMap<(ProcessId, SecretKey), (ProcessId, SecretKey, ServerHost, ServerPort)>>>;
pub type PoolMap = HashMap<PoolIdentifier, ConnectionPool>;
Expand All @@ -38,6 +38,17 @@ pub type PoolMap = HashMap<PoolIdentifier, ConnectionPool>;
/// The pool is recreated dynamically when the config is reloaded.
pub static POOLS: Lazy<ArcSwap<PoolMap>> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default()));

// Reasons for banning a server.
#[derive(Debug, PartialEq, Clone)]
pub enum BanReason {
FailedHealthCheck,
MessageSendFailed,
MessageReceiveFailed,
FailedCheckout,
StatementTimeout,
AdminBan(i64),
}

/// An identifier for a PgCat pool,
/// a database visible to clients.
#[derive(Hash, Debug, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -489,7 +500,7 @@ impl ConnectionPool {
Ok(conn) => conn,
Err(err) => {
error!("Banning instance {:?}, error: {:?}", address, err);
self.ban(address, client_process_id);
self.ban(address, BanReason::FailedCheckout, client_process_id);
self.stats
.client_checkout_error(client_process_id, address.id);
continue;
Expand Down Expand Up @@ -582,14 +593,14 @@ impl ConnectionPool {
// Don't leave a bad connection in the pool.
server.mark_bad();

self.ban(&address, client_process_id);
self.ban(&address, BanReason::FailedHealthCheck, client_process_id);
return false;
}

/// Ban an address (i.e. replica). It no longer will serve
/// traffic for any new transactions. Existing transactions on that replica
/// will finish successfully or error out to the clients.
pub fn ban(&self, address: &Address, client_id: i32) {
pub fn ban(&self, address: &Address, reason: BanReason, client_id: i32) {
// Primary can never be banned
if address.role == Role::Primary {
return;
Expand All @@ -599,12 +610,12 @@ impl ConnectionPool {
let mut guard = self.banlist.write();
error!("Banning {:?}", address);
self.stats.client_ban_error(client_id, address.id);
guard[address.shard].insert(address.clone(), now);
guard[address.shard].insert(address.clone(), (reason, now));
}

/// Clear the replica to receive traffic again. Takes effect immediately
/// for all new transactions.
pub fn _unban(&self, address: &Address) {
pub fn unban(&self, address: &Address) {
let mut guard = self.banlist.write();
guard[address.shard].remove(address);
}
Expand Down Expand Up @@ -653,9 +664,14 @@ impl ConnectionPool {
// Check if ban time is expired
let read_guard = self.banlist.read();
let exceeded_ban_time = match read_guard[address.shard].get(address) {
Some(timestamp) => {
Some((ban_reason, timestamp)) => {
let now = chrono::offset::Utc::now().naive_utc();
now.timestamp() - timestamp.timestamp() > self.settings.ban_time
match ban_reason {
BanReason::AdminBan(duration) => {
now.timestamp() - timestamp.timestamp() > *duration
}
_ => now.timestamp() - timestamp.timestamp() > self.settings.ban_time,
}
}
None => return true,
};
Expand All @@ -679,6 +695,31 @@ impl ConnectionPool {
self.databases.len()
}

pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> {
let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new();
let guard = self.banlist.read();
for banlist in guard.iter() {
for (address, (reason, timestamp)) in banlist.iter() {
bans.push((address.clone(), (reason.clone(), timestamp.clone())));
}
}
return bans;
}

/// Get the address from the host url
pub fn get_addresses_from_host(&self, host: &str) -> Vec<Address> {
let mut addresses = Vec::new();
for shard in 0..self.shards() {
for server in 0..self.servers(shard) {
let address = self.address(shard, server);
if address.host == host {
addresses.push(address.clone());
}
}
}
addresses
}

/// Get the number of servers (primary and replicas)
/// configured for a shard.
pub fn servers(&self, shard: usize) -> usize {
Expand Down
Loading

0 comments on commit c03f01a

Please sign in to comment.