diff --git a/src/config.rs b/src/config.rs index d13bee62..1cb37595 100644 --- a/src/config.rs +++ b/src/config.rs @@ -13,6 +13,7 @@ use toml; use crate::errors::Error; use crate::pool::{ClientServerMap, ConnectionPool}; +use crate::sharding::ShardingFunction; use crate::tls::{load_certs, load_keys}; pub const VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -179,31 +180,31 @@ pub struct General { } impl General { - fn default_host() -> String { + pub fn default_host() -> String { "0.0.0.0".into() } - fn default_port() -> i16 { + pub fn default_port() -> i16 { 5432 } - fn default_connect_timeout() -> u64 { + pub fn default_connect_timeout() -> u64 { 1000 } - fn default_shutdown_timeout() -> u64 { + pub fn default_shutdown_timeout() -> u64 { 60000 } - fn default_healthcheck_timeout() -> u64 { + pub fn default_healthcheck_timeout() -> u64 { 1000 } - fn default_healthcheck_delay() -> u64 { + pub fn default_healthcheck_delay() -> u64 { 30000 } - fn default_ban_time() -> i64 { + pub fn default_ban_time() -> i64 { 60 } } @@ -211,15 +212,15 @@ impl General { impl Default for General { fn default() -> General { General { - host: General::default_host(), - port: General::default_port(), + host: Self::default_host(), + port: Self::default_port(), enable_prometheus_exporter: Some(false), prometheus_exporter_port: 9930, connect_timeout: General::default_connect_timeout(), - shutdown_timeout: General::default_shutdown_timeout(), - healthcheck_timeout: General::default_healthcheck_timeout(), - healthcheck_delay: General::default_healthcheck_delay(), - ban_time: General::default_ban_time(), + shutdown_timeout: Self::default_shutdown_timeout(), + healthcheck_timeout: Self::default_healthcheck_timeout(), + healthcheck_delay: Self::default_healthcheck_delay(), + ban_time: Self::default_ban_time(), autoreload: false, tls_certificate: None, tls_private_key: None, @@ -263,31 +264,61 @@ pub struct Pool { #[serde(default)] // False pub primary_reads_enabled: bool, - #[serde(default = "General::default_connect_timeout")] - pub connect_timeout: u64, + pub connect_timeout: Option, - pub sharding_function: String, + pub sharding_function: ShardingFunction, pub shards: BTreeMap, pub users: BTreeMap, } impl Pool { - fn default_pool_mode() -> PoolMode { + pub fn default_pool_mode() -> PoolMode { PoolMode::Transaction } + + pub fn validate(&self) -> Result<(), Error> { + match self.default_role.as_ref() { + "any" => (), + "primary" => (), + "replica" => (), + other => { + error!( + "Query router default_role must be 'primary', 'replica', or 'any', got: '{}'", + other + ); + return Err(Error::BadConfig); + } + }; + + for (shard_idx, shard) in &self.shards { + match shard_idx.parse::() { + Ok(_) => (), + Err(_) => { + error!( + "Shard '{}' is not a valid number, shards must be numbered starting at 0", + shard_idx + ); + return Err(Error::BadConfig); + } + }; + shard.validate()?; + } + + Ok(()) + } } impl Default for Pool { fn default() -> Pool { Pool { - pool_mode: Pool::default_pool_mode(), + pool_mode: Self::default_pool_mode(), shards: BTreeMap::from([(String::from("1"), Shard::default())]), users: BTreeMap::default(), default_role: String::from("any"), query_parser_enabled: false, primary_reads_enabled: false, - sharding_function: "pg_bigint_hash".to_string(), - connect_timeout: General::default_connect_timeout(), + sharding_function: ShardingFunction::PgBigintHash, + connect_timeout: None, } } } @@ -306,6 +337,45 @@ pub struct Shard { pub servers: Vec, } +impl Shard { + pub fn validate(&self) -> Result<(), Error> { + // We use addresses as unique identifiers, + // let's make sure they are unique in the config as well. + let mut dup_check = HashSet::new(); + let mut primary_count = 0; + + if self.servers.len() == 0 { + error!("Shard {} has no servers configured", self.database); + return Err(Error::BadConfig); + } + + for server in &self.servers { + dup_check.insert(server); + + // Check that we define only zero or one primary. + match server.role { + Role::Primary => primary_count += 1, + _ => (), + }; + } + + if primary_count > 1 { + error!( + "Shard {} has more than on primary configured", + self.database + ); + return Err(Error::BadConfig); + } + + if dup_check.len() != self.servers.len() { + error!("Shard {} contains duplicate server configs", self.database); + return Err(Error::BadConfig); + } + + Ok(()) + } +} + impl Default for Shard { fn default() -> Shard { Shard { @@ -326,7 +396,7 @@ pub struct Config { // so we should always put simple fields before nested fields // in all serializable structs to avoid ValueAfterTable errors // These errors occur when the toml serializer is about to produce - // ambigous toml structure like the one below + // ambiguous toml structure like the one below // [main] // field1_under_main = 1 // field2_under_main = 2 @@ -341,7 +411,7 @@ pub struct Config { } impl Config { - fn default_path() -> String { + pub fn default_path() -> String { String::from("pgcat.toml") } } @@ -349,7 +419,7 @@ impl Config { impl Default for Config { fn default() -> Config { Config { - path: Config::default_path(), + path: Self::default_path(), general: General::default(), pools: HashMap::default(), } @@ -381,7 +451,7 @@ impl From<&Config> for std::collections::HashMap { ), ( format!("pools.{}.sharding_function", pool_name), - pool.sharding_function.clone(), + pool.sharding_function.to_string(), ), ( format!("pools.{:?}.shard_count", pool_name), @@ -477,9 +547,18 @@ impl Config { "[pool: {}] Pool mode: {:?}", pool_name, pool_config.pool_mode ); + let connect_timeout = match pool_config.connect_timeout { + Some(connect_timeout) => connect_timeout, + None => self.general.connect_timeout, + }; + info!( + "[pool: {}] Connection timeout: {}ms", + pool_name, connect_timeout + ); info!( "[pool: {}] Sharding function: {}", - pool_name, pool_config.sharding_function + pool_name, + pool_config.sharding_function.to_string() ); info!( "[pool: {}] Primary reads: {}", @@ -512,6 +591,50 @@ impl Config { } } } + + pub fn validate(&mut self) -> Result<(), Error> { + // Validate TLS! + match self.general.tls_certificate.clone() { + Some(tls_certificate) => { + match load_certs(&Path::new(&tls_certificate)) { + Ok(_) => { + // Cert is okay, but what about the private key? + match self.general.tls_private_key.clone() { + Some(tls_private_key) => { + match load_keys(&Path::new(&tls_private_key)) { + Ok(_) => (), + Err(err) => { + error!( + "tls_private_key is incorrectly configured: {:?}", + err + ); + return Err(Error::BadConfig); + } + } + } + + None => { + error!("tls_certificate is set, but the tls_private_key is not"); + return Err(Error::BadConfig); + } + }; + } + + Err(err) => { + error!("tls_certificate is incorrectly configured: {:?}", err); + return Err(Error::BadConfig); + } + } + } + None => (), + }; + + for (_, pool) in &mut self.pools { + pool.validate()?; + } + + Ok(()) + } } /// Get a read-only instance of the configuration @@ -548,110 +671,7 @@ pub async fn parse(path: &str) -> Result<(), Error> { } }; - // Validate TLS! - match config.general.tls_certificate.clone() { - Some(tls_certificate) => { - match load_certs(&Path::new(&tls_certificate)) { - Ok(_) => { - // Cert is okay, but what about the private key? - match config.general.tls_private_key.clone() { - Some(tls_private_key) => match load_keys(&Path::new(&tls_private_key)) { - Ok(_) => (), - Err(err) => { - error!("tls_private_key is incorrectly configured: {:?}", err); - return Err(Error::BadConfig); - } - }, - - None => { - error!("tls_certificate is set, but the tls_private_key is not"); - return Err(Error::BadConfig); - } - }; - } - - Err(err) => { - error!("tls_certificate is incorrectly configured: {:?}", err); - return Err(Error::BadConfig); - } - } - } - None => (), - }; - - for (pool_name, mut pool) in &mut config.pools { - // Copy the connect timeout over for hashing. - pool.connect_timeout = config.general.connect_timeout; - - match pool.sharding_function.as_ref() { - "pg_bigint_hash" => (), - "sha1" => (), - _ => { - error!( - "Supported sharding functions are: 'pg_bigint_hash', 'sha1', got: '{}' in pool {} settings", - pool.sharding_function, - pool_name - ); - return Err(Error::BadConfig); - } - }; - - match pool.default_role.as_ref() { - "any" => (), - "primary" => (), - "replica" => (), - other => { - error!( - "Query router default_role must be 'primary', 'replica', or 'any', got: '{}'", - other - ); - return Err(Error::BadConfig); - } - }; - - for shard in &pool.shards { - // We use addresses as unique identifiers, - // let's make sure they are unique in the config as well. - let mut dup_check = HashSet::new(); - let mut primary_count = 0; - - match shard.0.parse::() { - Ok(_) => (), - Err(_) => { - error!( - "Shard '{}' is not a valid number, shards must be numbered starting at 0", - shard.0 - ); - return Err(Error::BadConfig); - } - }; - - if shard.1.servers.len() == 0 { - error!("Shard {} has no servers configured", shard.0); - return Err(Error::BadConfig); - } - - for server in &shard.1.servers { - dup_check.insert(server); - - // Check that we define only zero or one primary. - match server.role { - Role::Primary => primary_count += 1, - _ => (), - }; - } - - if primary_count > 1 { - error!("Shard {} has more than on primary configured", &shard.0); - return Err(Error::BadConfig); - } - - if dup_check.len() != shard.1.servers.len() { - error!("Shard {} contains duplicate server configs", &shard.0); - return Err(Error::BadConfig); - } - } - } + config.validate()?; config.path = path.to_string(); diff --git a/src/pool.rs b/src/pool.rs index d9c9e7d6..815a2b8b 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -181,11 +181,14 @@ impl ConnectionPool { get_reporter(), ); + let connect_timeout = match pool_config.connect_timeout { + Some(connect_timeout) => connect_timeout, + None => config.general.connect_timeout, + }; + let pool = Pool::builder() .max_size(user.pool_size) - .connection_timeout(std::time::Duration::from_millis( - pool_config.connect_timeout, - )) + .connection_timeout(std::time::Duration::from_millis(connect_timeout)) .test_on_check_out(false) .build(manager) .await @@ -221,11 +224,7 @@ impl ConnectionPool { }, query_parser_enabled: pool_config.query_parser_enabled.clone(), primary_reads_enabled: pool_config.primary_reads_enabled, - sharding_function: match pool_config.sharding_function.as_str() { - "pg_bigint_hash" => ShardingFunction::PgBigintHash, - "sha1" => ShardingFunction::Sha1, - _ => unreachable!(), - }, + sharding_function: pool_config.sharding_function, }, }; diff --git a/src/sharding.rs b/src/sharding.rs index c332c601..c5ab45e7 100644 --- a/src/sharding.rs +++ b/src/sharding.rs @@ -1,3 +1,4 @@ +use serde_derive::{Deserialize, Serialize}; /// Implements various sharding functions. use sha1::{Digest, Sha1}; @@ -5,12 +6,23 @@ use sha1::{Digest, Sha1}; const PARTITION_HASH_SEED: u64 = 0x7A5B22367996DCFD; /// The sharding functions we support. -#[derive(Debug, PartialEq, Copy, Clone)] +#[derive(Debug, PartialEq, Copy, Clone, Serialize, Deserialize, Hash, std::cmp::Eq)] pub enum ShardingFunction { + #[serde(alias = "pg_bigint_hash", alias = "PgBigintHash")] PgBigintHash, + #[serde(alias = "sha1", alias = "Sha1")] Sha1, } +impl ToString for ShardingFunction { + fn to_string(&self) -> String { + match *self { + ShardingFunction::PgBigintHash => "pg_bigint_hash".to_string(), + ShardingFunction::Sha1 => "sha1".to_string(), + } + } +} + /// The sharder. pub struct Sharder { /// Number of shards in the cluster.