diff --git a/core/Cargo.toml b/core/Cargo.toml index accbee1cfb1..86458781520 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -178,7 +178,7 @@ services-pcloud = [] services-persy = ["dep:persy", "internal-tokio-rt"] services-postgresql = ["dep:sqlx", "sqlx?/postgres"] services-redb = ["dep:redb", "internal-tokio-rt"] -services-redis = ["dep:redis", "redis?/tokio-rustls-comp"] +services-redis = ["dep:redis","dep:bb8","redis?/tokio-rustls-comp"] services-redis-native-tls = ["services-redis", "redis?/tokio-native-tls-comp"] services-rocksdb = ["dep:rocksdb", "internal-tokio-rt"] services-s3 = [ diff --git a/core/src/services/redis/backend.rs b/core/src/services/redis/backend.rs index 8d08e1216a8..bf8cff20185 100644 --- a/core/src/services/redis/backend.rs +++ b/core/src/services/redis/backend.rs @@ -15,23 +15,19 @@ // specific language governing permissions and limitations // under the License. -use std::fmt::Debug; -use std::fmt::Formatter; -use std::path::PathBuf; -use std::time::Duration; - +use bb8::RunError; use http::Uri; -use redis::aio::ConnectionManager; use redis::cluster::ClusterClient; use redis::cluster::ClusterClientBuilder; -use redis::cluster_async::ClusterConnection; -use redis::AsyncCommands; use redis::Client; use redis::ConnectionAddr; use redis::ConnectionInfo; use redis::ProtocolVersion; use redis::RedisConnectionInfo; -use redis::RedisError; +use std::fmt::Debug; +use std::fmt::Formatter; +use std::path::PathBuf; +use std::time::Duration; use tokio::sync::OnceCell; use crate::raw::adapters::kv; @@ -39,6 +35,7 @@ use crate::raw::*; use crate::services::RedisConfig; use crate::*; +use super::core::*; const DEFAULT_REDIS_ENDPOINT: &str = "tcp://127.0.0.1:6379"; const DEFAULT_REDIS_PORT: u16 = 6379; @@ -272,18 +269,12 @@ impl RedisBuilder { /// Backend for redis services. pub type RedisBackend = kv::Backend; -#[derive(Clone)] -enum RedisConnection { - Normal(ConnectionManager), - Cluster(ClusterConnection), -} - #[derive(Clone)] pub struct Adapter { addr: String, client: Option, cluster_client: Option, - conn: OnceCell, + conn: OnceCell>, default_ttl: Option, } @@ -299,26 +290,39 @@ impl Debug for Adapter { } impl Adapter { - async fn conn(&self) -> Result { - Ok(self + async fn conn(&self) -> Result> { + let pool = self .conn .get_or_try_init(|| async { - if let Some(client) = self.client.clone() { - ConnectionManager::new(client.clone()) - .await - .map(RedisConnection::Normal) - } else { - self.cluster_client - .clone() - .unwrap() - .get_async_connection() - .await - .map(RedisConnection::Cluster) - } + bb8::Pool::builder() + .build(self.get_redis_connection_manager()) + .await + .map_err(|err| { + Error::new(ErrorKind::ConfigInvalid, "connect to redis failed") + .set_source(err) + }) }) - .await - .map_err(format_redis_error)? - .clone()) + .await?; + pool.get().await.map_err(|err| match err { + RunError::TimedOut => { + Error::new(ErrorKind::Unexpected, "get connection from pool failed").set_temporary() + } + RunError::User(err) => err, + }) + } + + fn get_redis_connection_manager(&self) -> RedisConnectionManager { + if let Some(_client) = self.client.clone() { + RedisConnectionManager { + client: self.client.clone(), + cluster_client: None, + } + } else { + RedisConnectionManager { + client: None, + cluster_client: self.cluster_client.clone(), + } + } } } @@ -337,69 +341,27 @@ impl kv::Adapter for Adapter { } async fn get(&self, key: &str) -> Result> { - let conn = self.conn().await?; - let result: Option = match conn { - RedisConnection::Normal(mut conn) => conn.get(key).await.map_err(format_redis_error), - RedisConnection::Cluster(mut conn) => conn.get(key).await.map_err(format_redis_error), - }?; - Ok(result.map(Buffer::from)) + let mut conn = self.conn().await?; + let result = conn.get(key).await?; + Ok(result) } async fn set(&self, key: &str, value: Buffer) -> Result<()> { - let conn = self.conn().await?; + let mut conn = self.conn().await?; let value = value.to_vec(); - match self.default_ttl { - Some(ttl) => match conn { - RedisConnection::Normal(mut conn) => conn - .set_ex(key, value, ttl.as_secs()) - .await - .map_err(format_redis_error)?, - RedisConnection::Cluster(mut conn) => conn - .set_ex(key, value, ttl.as_secs()) - .await - .map_err(format_redis_error)?, - }, - None => match conn { - RedisConnection::Normal(mut conn) => { - conn.set(key, value).await.map_err(format_redis_error)? - } - RedisConnection::Cluster(mut conn) => { - conn.set(key, value).await.map_err(format_redis_error)? - } - }, - } + conn.set(key, value, self.default_ttl).await?; Ok(()) } async fn delete(&self, key: &str) -> Result<()> { - let conn = self.conn().await?; - match conn { - RedisConnection::Normal(mut conn) => { - let _: () = conn.del(key).await.map_err(format_redis_error)?; - } - RedisConnection::Cluster(mut conn) => { - let _: () = conn.del(key).await.map_err(format_redis_error)?; - } - } + let mut conn = self.conn().await?; + conn.delete(key).await?; Ok(()) } async fn append(&self, key: &str, value: &[u8]) -> Result<()> { - let conn = self.conn().await?; - match conn { - RedisConnection::Normal(mut conn) => { - () = conn.append(key, value).await.map_err(format_redis_error)?; - } - RedisConnection::Cluster(mut conn) => { - () = conn.append(key, value).await.map_err(format_redis_error)?; - } - } + let mut conn = self.conn().await?; + conn.append(key, value).await?; Ok(()) } } - -pub fn format_redis_error(e: RedisError) -> Error { - Error::new(ErrorKind::Unexpected, e.category()) - .set_source(e) - .set_temporary() -} diff --git a/core/src/services/redis/core.rs b/core/src/services/redis/core.rs new file mode 100644 index 00000000000..041ed87169b --- /dev/null +++ b/core/src/services/redis/core.rs @@ -0,0 +1,168 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Buffer; +use crate::Error; +use crate::ErrorKind; + +use redis::aio::ConnectionLike; +use redis::aio::ConnectionManager; + +use redis::cluster::ClusterClient; +use redis::cluster_async::ClusterConnection; +use redis::from_redis_value; +use redis::AsyncCommands; +use redis::Client; +use redis::RedisError; + +use std::time::Duration; + +#[derive(Clone)] +pub enum RedisConnection { + Normal(ConnectionManager), + Cluster(ClusterConnection), +} +impl RedisConnection { + pub async fn get(&mut self, key: &str) -> crate::Result> { + let result: Option = match self { + RedisConnection::Normal(ref mut conn) => { + conn.get(key).await.map_err(format_redis_error) + } + RedisConnection::Cluster(ref mut conn) => { + conn.get(key).await.map_err(format_redis_error) + } + }?; + Ok(result.map(Buffer::from)) + } + + pub async fn set( + &mut self, + key: &str, + value: Vec, + ttl: Option, + ) -> crate::Result<()> { + let value = value.to_vec(); + if let Some(ttl) = ttl { + match self { + RedisConnection::Normal(ref mut conn) => conn + .set_ex(key, value, ttl.as_secs()) + .await + .map_err(format_redis_error)?, + RedisConnection::Cluster(ref mut conn) => conn + .set_ex(key, value, ttl.as_secs()) + .await + .map_err(format_redis_error)?, + } + } else { + match self { + RedisConnection::Normal(ref mut conn) => { + conn.set(key, value).await.map_err(format_redis_error)? + } + RedisConnection::Cluster(ref mut conn) => { + conn.set(key, value).await.map_err(format_redis_error)? + } + } + } + + Ok(()) + } + + pub async fn delete(&mut self, key: &str) -> crate::Result<()> { + match self { + RedisConnection::Normal(ref mut conn) => { + let _: () = conn.del(key).await.map_err(format_redis_error)?; + } + RedisConnection::Cluster(ref mut conn) => { + let _: () = conn.del(key).await.map_err(format_redis_error)?; + } + } + + Ok(()) + } + + pub async fn append(&mut self, key: &str, value: &[u8]) -> crate::Result<()> { + match self { + RedisConnection::Normal(ref mut conn) => { + () = conn.append(key, value).await.map_err(format_redis_error)?; + } + RedisConnection::Cluster(ref mut conn) => { + () = conn.append(key, value).await.map_err(format_redis_error)?; + } + } + Ok(()) + } +} + +#[derive(Clone)] +pub struct RedisConnectionManager { + pub client: Option, + pub cluster_client: Option, +} + +#[async_trait::async_trait] +impl bb8::ManageConnection for RedisConnectionManager { + type Connection = RedisConnection; + type Error = Error; + + async fn connect(&self) -> Result { + if let Some(client) = self.client.clone() { + ConnectionManager::new(client.clone()) + .await + .map_err(format_redis_error) + .map(RedisConnection::Normal) + } else { + self.cluster_client + .clone() + .unwrap() + .get_async_connection() + .await + .map_err(format_redis_error) + .map(RedisConnection::Cluster) + } + } + + async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + let pong_value = match conn { + RedisConnection::Normal(ref mut conn) => conn + .send_packed_command(&redis::cmd("PING")) + .await + .map_err(format_redis_error)?, + + RedisConnection::Cluster(ref mut conn) => conn + .req_packed_command(&redis::cmd("PING")) + .await + .map_err(format_redis_error)?, + }; + let pong: String = from_redis_value(&pong_value).map_err(format_redis_error)?; + + if pong == "PONG" { + Ok(()) + } else { + Err(Error::new(ErrorKind::Unexpected, "PING ERROR")) + } + } + + fn has_broken(&self, _: &mut Self::Connection) -> bool { + false + } +} + +pub fn format_redis_error(e: RedisError) -> Error { + Error::new(ErrorKind::Unexpected, e.category()) + .set_source(e) + .set_temporary() +} diff --git a/core/src/services/redis/mod.rs b/core/src/services/redis/mod.rs index a1dc12d620f..7ec39a5e832 100644 --- a/core/src/services/redis/mod.rs +++ b/core/src/services/redis/mod.rs @@ -19,6 +19,7 @@ mod backend; #[cfg(feature = "services-redis")] pub use backend::RedisBuilder as Redis; - mod config; +#[cfg(feature = "services-redis")] +mod core; pub use config::RedisConfig;