diff --git a/pisa-proxy/protocol/mysql/src/column.rs b/pisa-proxy/protocol/mysql/src/column.rs index edb42297..f8b0ec06 100644 --- a/pisa-proxy/protocol/mysql/src/column.rs +++ b/pisa-proxy/protocol/mysql/src/column.rs @@ -16,7 +16,7 @@ use bytes::{Buf, BufMut, BytesMut}; use crate::{mysql_const::ColumnType, util::{ BufExt, BufMutExt, get_length }}; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ColumnInfo { pub schema: Option, pub table_name: Option, diff --git a/pisa-proxy/protocol/mysql/src/row.rs b/pisa-proxy/protocol/mysql/src/row.rs index a82ec6f4..9918b4d5 100644 --- a/pisa-proxy/protocol/mysql/src/row.rs +++ b/pisa-proxy/protocol/mysql/src/row.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; +use std::{sync::Arc, str::FromStr}; use crate::{ column::ColumnInfo, @@ -37,8 +37,8 @@ pub enum RowDataTyp> { pub struct RowPartData { pub data: Box<[u8]>, pub start_idx: usize, - pub start_part_idx: usize, - pub end_part_idx: usize, + pub part_encode_length: usize, + pub part_data_length: usize, } crate::gen_row_data!(RowDataTyp, Text(RowDataText), Binary(RowDataBinary)); @@ -83,8 +83,7 @@ impl> RowData for RowDataText { fn decode_with_name(&mut self, name: &str) -> value::Result { let row_data = self.get_row_data_with_name(name)?; match row_data { - Some(data) => Value::from(&data.data[data.start_part_idx..data.end_part_idx]), - + Some(data) => Value::from(&data.data), _ => Ok(None), } } @@ -104,10 +103,10 @@ impl> RowData for RowDataText { return Ok(Some( RowPartData { - data: self.buf.as_ref()[idx..idx + (pos + length) as usize].into(), + data: self.buf.as_ref()[idx + pos as usize .. idx + (pos + length) as usize].into(), start_idx: idx, - start_part_idx: pos as usize, - end_part_idx: (pos + length) as usize, + part_encode_length: pos as usize, + part_data_length: length as usize, } )); } @@ -239,12 +238,13 @@ impl> RowData for RowDataBinary { // Need to add packet header and null_map to returnd data let raw_data = &self.buf.as_ref()[start_pos + pos as usize..(start_pos + pos as usize + length as usize)]; + println!("eeeeeeeeeeeee {:?}", &raw_data[..]); return Ok(Some( RowPartData { data: raw_data.into(), start_idx: start_pos, - start_part_idx: pos as usize, - end_part_idx: (pos + length) as usize, + part_encode_length: pos as usize, + part_data_length: length as usize, } )) } @@ -254,6 +254,27 @@ impl> RowData for RowDataBinary { } } + +// Box has default 'static bound, use `'e` lifetime relax bound. +pub fn decode_with_name<'e, T: AsRef<[u8]>, V: Value + std::str::FromStr>(row_data: &mut RowDataTyp, name: &str, is_binary: bool) -> Result, Box > +where + T: AsRef<[u8]>, + V: Value + std::str::FromStr, + ::Err: std::error::Error + Sync + Send + 'e +{ + if is_binary { + row_data.decode_with_name::(name) + } else { + let new_value = row_data.decode_with_name::(name)?; + if let Some(new_value) = new_value { + let new_value = new_value.parse::()?; + Ok(Some(new_value)) + } else { + Ok(None) + } + } +} + #[cfg(test)] mod test { use std::sync::Arc; diff --git a/pisa-proxy/protocol/mysql/src/value.rs b/pisa-proxy/protocol/mysql/src/value.rs index b2c81579..592f8fe2 100644 --- a/pisa-proxy/protocol/mysql/src/value.rs +++ b/pisa-proxy/protocol/mysql/src/value.rs @@ -17,7 +17,8 @@ use chrono::{Duration, NaiveDateTime, NaiveDate, NaiveTime}; use crate::err::DecodeRowError; -pub type Result = std::result::Result, Box>; +type BoxError = Box; +pub type Result = std::result::Result, BoxError>; pub trait Value: Sized { type Item: Convert; diff --git a/pisa-proxy/proxy/strategy/Cargo.toml b/pisa-proxy/proxy/strategy/Cargo.toml index 47bcb56d..43a6f743 100644 --- a/pisa-proxy/proxy/strategy/Cargo.toml +++ b/pisa-proxy/proxy/strategy/Cargo.toml @@ -29,3 +29,4 @@ aho-corasick = "0.7.19" itertools = "0.10.4" thiserror = "1.0" crc32fast = "1.3.2" +paste = "1.0.9" diff --git a/pisa-proxy/proxy/strategy/src/sharding_rewrite/generic_meta.rs b/pisa-proxy/proxy/strategy/src/sharding_rewrite/generic_meta.rs index e0d9927b..8f04bfd7 100644 --- a/pisa-proxy/proxy/strategy/src/sharding_rewrite/generic_meta.rs +++ b/pisa-proxy/proxy/strategy/src/sharding_rewrite/generic_meta.rs @@ -1,38 +1,58 @@ // Copyright 2022 SphereEx Authors -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -use crate::config::{StrategyType, Sharding, ShardingAlgorithmName}; +use endpoint::endpoint::Endpoint; + +use crate::config::{Sharding, ShardingAlgorithmName, StrategyType}; + +#[derive(Debug)] +pub(crate) struct ShardingMetaBaseInfo<'a> { + pub column: (Option<&'a str>, Option<&'a str>), + pub count: (Option, Option), + pub algo: (Option<&'a ShardingAlgorithmName>, Option<&'a ShardingAlgorithmName>), +} pub trait ShardingMeta { fn get_sharding_column(&self) -> (Option<&str>, Option<&str>); fn get_algo(&self) -> (Option<&ShardingAlgorithmName>, Option<&ShardingAlgorithmName>); - fn get_sharding_count(&self) -> (Option, Option); + fn get_sharding_count(&self) -> (Option, Option); + fn get_actual_schema<'a>( + &'a self, + endpoints: &'a [Endpoint], + idx: Option, + ) -> Option<&'a str>; + fn get_endpoint<'a>( + &'a self, + endpoints: &'a [Endpoint], + idx: Option, + ) -> Option; + fn get_strategy_typ(&self) -> super::StrategyTyp; } /// Todo: use macro generate impl ShardingMeta for Sharding { fn get_sharding_column(&self) -> (Option<&str>, Option<&str>) { if let Some(strategy) = &self.database_strategy { - return strategy.get_sharding_column() + return strategy.get_sharding_column(); } if let Some(strategy) = &self.table_strategy { - return strategy.get_sharding_column() + return strategy.get_sharding_column(); } if let Some(strategy) = &self.database_table_strategy { - return strategy.get_sharding_column() + return strategy.get_sharding_column(); } (None, None) @@ -40,53 +60,77 @@ impl ShardingMeta for Sharding { fn get_algo(&self) -> (Option<&ShardingAlgorithmName>, Option<&ShardingAlgorithmName>) { if let Some(strategy) = &self.database_strategy { - return strategy.get_algo() + return strategy.get_algo(); } if let Some(strategy) = &self.table_strategy { - return strategy.get_algo() + return strategy.get_algo(); } if let Some(strategy) = &self.database_table_strategy { - return strategy.get_algo() + return strategy.get_algo(); } (None, None) } - fn get_sharding_count(&self) -> (Option, Option) { + fn get_sharding_count(&self) -> (Option, Option) { if let Some(_) = &self.database_strategy { - return (Some(self.actual_datanodes.len() as u64), None) + return (Some(self.actual_datanodes.len() as u32), None); } if let Some(strategy) = &self.table_strategy { - return (None, strategy.get_sharding_count().1) + return (None, strategy.get_sharding_count().1); } if let Some(strategy) = &self.database_table_strategy { - return (Some(self.actual_datanodes.len() as u64), strategy.get_sharding_count().1) + return (Some(self.actual_datanodes.len() as u32), strategy.get_sharding_count().1); } (None, None) } + + fn get_actual_schema<'a>( + &self, + endpoints: &'a [Endpoint], + idx: Option, + ) -> Option<&'a str> { + if self.database_strategy.is_some() || self.database_table_strategy.is_some() { + let ep = endpoints.iter().find(|ep| ep.name == self.actual_datanodes[idx.unwrap()]); + return ep.map(|x| x.db.as_str()); + } + + None + } + + fn get_endpoint(&self, endpoints: &[Endpoint], idx: Option) -> Option { + let idx = if self.table_strategy.is_some() { 0 } else { idx.unwrap() }; + endpoints.iter().find(|ep| ep.name == self.actual_datanodes[idx]).map(|x| x.clone()) + } + + fn get_strategy_typ(&self) -> super::StrategyTyp { + if self.database_strategy.is_some() { + super::StrategyTyp::Database + } else if self.table_strategy.is_some() { + super::StrategyTyp::Table + } else { + super::StrategyTyp::DatabaseTable + } + } } impl ShardingMeta for StrategyType { fn get_sharding_column(&self) -> (Option<&str>, Option<&str>) { match self { - Self::DatabaseStrategyConfig(config) => { - (Some(&config.database_sharding_column), None) - }, + Self::DatabaseStrategyConfig(config) => (Some(&config.database_sharding_column), None), Self::DatabaseTableStrategyConfig(config) => { (Some(&config.database_sharding_column), Some(&config.table_sharding_column)) - }, + } - Self::TableStrategyConfig(config) => { - (None, Some(&config.table_sharding_column)) - }, + Self::TableStrategyConfig(config) => (None, Some(&config.table_sharding_column)), - _ => (None, None) + _ => (None, None), } } @@ -94,35 +138,48 @@ impl ShardingMeta for StrategyType { match self { Self::DatabaseStrategyConfig(config) => { (Some(&config.database_sharding_algorithm_name), None) - }, + } - Self::DatabaseTableStrategyConfig(config) => { - (Some(&config.database_sharding_algorithm_name), Some(&config.table_sharding_algorithm_name)) - }, + Self::DatabaseTableStrategyConfig(config) => ( + Some(&config.database_sharding_algorithm_name), + Some(&config.table_sharding_algorithm_name), + ), Self::TableStrategyConfig(config) => { (None, Some(&config.table_sharding_algorithm_name)) - }, + } - _ => (None, None) + _ => (None, None), } } - fn get_sharding_count(&self) -> (Option, Option) { + fn get_sharding_count(&self) -> (Option, Option) { match self { Self::DatabaseStrategyConfig(_) => { unimplemented!() - }, + } - Self::DatabaseTableStrategyConfig(config) => { - (None, Some(config.shading_count.into())) - }, + Self::DatabaseTableStrategyConfig(config) => (None, Some(config.shading_count.into())), - Self::TableStrategyConfig(config) => { - (None, Some(config.sharding_count.into())) - }, + Self::TableStrategyConfig(config) => (None, Some(config.sharding_count.into())), - _ => (None, None) + _ => (None, None), } } -} \ No newline at end of file + + fn get_actual_schema<'a>( + &'a self, + _endpoints: &'a [Endpoint], + _idx: Option, + ) -> Option<&'a str> { + None + } + + fn get_endpoint(&self, _endpoints: &[Endpoint], _idx: Option) -> Option { + None + } + + fn get_strategy_typ(&self) -> super::StrategyTyp { + unimplemented!() + } +} diff --git a/pisa-proxy/proxy/strategy/src/sharding_rewrite/macros.rs b/pisa-proxy/proxy/strategy/src/sharding_rewrite/macros.rs new file mode 100644 index 00000000..fcd77a73 --- /dev/null +++ b/pisa-proxy/proxy/strategy/src/sharding_rewrite/macros.rs @@ -0,0 +1,23 @@ +// Copyright 2022 SphereEx Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[macro_export] +macro_rules! get_meta_detail { + ($meta:ident, $($meta_typ:ident),*) => { + paste! { + $(let $meta_typ = $meta.[]();)* + } + + } +} diff --git a/pisa-proxy/proxy/strategy/src/sharding_rewrite/meta.rs b/pisa-proxy/proxy/strategy/src/sharding_rewrite/meta.rs index a8e10811..5f710879 100644 --- a/pisa-proxy/proxy/strategy/src/sharding_rewrite/meta.rs +++ b/pisa-proxy/proxy/strategy/src/sharding_rewrite/meta.rs @@ -24,7 +24,6 @@ enum ScanState { Field(Option), Order(OrderDirection), Group, - Where(Vec), OnCond(Vec), // Option means whethe has `alias_name` Avg(mysql_parser::Span, Option, bool), @@ -35,21 +34,23 @@ enum ScanState { type TableMeta = TableIdent; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum FieldWrapFunc { Min, Max, - None, + Count, + None, } impl AsRef for FieldWrapFunc { - fn as_ref(&self) -> &str { + fn as_ref(&self) -> &str { match self { Self::Max => "max", Self::Min => "min", + Self::Count => "count", Self::None => "none", } - } + } } #[derive(Debug)] @@ -133,7 +134,7 @@ pub struct AvgMeta { pub span: mysql_parser::Span, // IS avg(t) pub avg_field_name: String, - // IS `t` of avg(t) + // IS `t` of avg(t) pub field_name: String, pub distinct: bool, } @@ -200,6 +201,17 @@ impl RewriteMetaData { input, } } + + fn get_field_name( + input: &str, + span: &mysql_parser::Span, + alias_name: &Option, + ) -> String { + match alias_name { + Some(name) => name.clone(), + None => input[span.start()..span.end()].to_string(), + } + } } macro_rules! gen_push_func { @@ -256,10 +268,6 @@ impl Transformer for RewriteMetaData { self.state = ScanState::TableName; } - Node::WhereClause(_) => { - self.state = ScanState::Where(vec![]); - } - Node::TableRef(..) => { self.state = ScanState::OnCond(vec![]); } @@ -273,6 +281,7 @@ impl Transformer for RewriteMetaData { Item::TableWild(val) => { self.push_field(FieldMeta::TableWild(val.clone())) } + _ => {} } } @@ -286,39 +295,60 @@ impl Transformer for RewriteMetaData { match &item.expr { Expr::SimpleIdentExpr(Value::Ident { .. }) => { let name = match &item.alias_name { - Some(name) => name.clone(), + Some(name) => name.clone(), None => self.input[item.span.start()..item.span.end()].to_string(), }; - self.push_field(FieldMeta::Ident(FieldMetaIdent { span: item.span.clone(), name, wrap_func: FieldWrapFunc::None })) + self.push_field(FieldMeta::Ident(FieldMetaIdent { + span: item.span.clone(), + name, + wrap_func: FieldWrapFunc::None, + })) } - Expr::SetFuncSpecExpr(e) => { - match e.as_ref() { - Expr::AggExpr(e) => { - match e.name { - AggFuncName::Avg => { - self.state = ScanState::Avg(item.span, item.alias_name.clone(), e.distinct); - }, - - AggFuncName::Max => { - self.state = ScanState::FieldWrapFunc(item.span, FieldWrapFunc::Max, item.alias_name.clone()); - } - - AggFuncName::Min => { - self.state = ScanState::FieldWrapFunc(item.span, FieldWrapFunc::Min, item.alias_name.clone()); - } - - _ => {} + Expr::SetFuncSpecExpr(e) => match e.as_ref() { + Expr::AggExpr(e) => { + match e.name { + AggFuncName::Avg => { + self.state = ScanState::Avg( + item.span, + item.alias_name.clone(), + e.distinct, + ); } - return false; + + AggFuncName::Max => { + self.state = ScanState::FieldWrapFunc( + item.span, + FieldWrapFunc::Max, + item.alias_name.clone(), + ); + } + + AggFuncName::Min => { + self.state = ScanState::FieldWrapFunc( + item.span, + FieldWrapFunc::Min, + item.alias_name.clone(), + ); + } + + AggFuncName::Count => { + self.state = ScanState::FieldWrapFunc( + item.span, + FieldWrapFunc::Count, + item.alias_name.clone(), + ); + } + + _ => {} } - _ => {} - + return false; } - } + _ => {} + }, _ => {} } - + return true; } @@ -354,48 +384,71 @@ impl Transformer for RewriteMetaData { if let Value::Ident { span, value, quoted } = val { let name = if *quoted { self.ac.replace_all(value, &[""]) } else { value.clone() }; - self.push_field(FieldMeta::Ident(FieldMetaIdent { span: *span, name, wrap_func: FieldWrapFunc::None })) + self.push_field(FieldMeta::Ident(FieldMetaIdent { + span: *span, + name, + wrap_func: FieldWrapFunc::None, + })) } } InsertIdent::TableWild(val) => self.push_field(FieldMeta::TableWild(val.clone())), }, Node::Expr(e) => match e { + Expr::Ori(_val) => match &self.state { + ScanState::FieldWrapFunc(span, wrap_func, alias_name) => { + let wrap_func = wrap_func.clone(); + let span = span.clone(); + + let name = Self::get_field_name(&self.input, &span, alias_name); + self.push_field(FieldMeta::Ident(FieldMetaIdent { span, name, wrap_func })); + } + _ => {} + }, + Expr::SimpleIdentExpr(Value::Ident { span, value, .. }) => match &mut self.state { ScanState::Field(alias_name) => { let name = match alias_name { - Some(name) => name.clone(), + Some(name) => name.clone(), None => value.to_string(), }; - self.push_field(FieldMeta::Ident(FieldMetaIdent { span: *span, name, wrap_func: FieldWrapFunc::None })) + self.push_field(FieldMeta::Ident(FieldMetaIdent { + span: *span, + name, + wrap_func: FieldWrapFunc::None, + })) } ScanState::FieldWrapFunc(span, wrap_func, alias_name) => { - let name = match alias_name { - Some(name) => name.clone(), - None => self.input[span.start()..span.end()].to_string(), - }; let wrap_func = wrap_func.clone(); let span = span.clone(); - self.push_field(FieldMeta::Ident( FieldMetaIdent { span, name, wrap_func })) + + let name = Self::get_field_name(&self.input, &span, alias_name); + self.push_field(FieldMeta::Ident(FieldMetaIdent { span, name, wrap_func })); } ScanState::Order(direction) => { let direction = direction.clone(); - self.push_order(OrderMeta { span: *span, name: value.to_string(), direction }) + self.push_order(OrderMeta { + span: *span, + name: value.to_string(), + direction, + }) } ScanState::Group => { self.push_group(GroupMeta { span: *span, name: value.to_string() }) } - ScanState::Where(args) => { - args.push(value.to_string()); - } ScanState::Avg(arg, alias_name, distinct) => { let avg_field_name = match alias_name { Some(name) => name.clone(), - None => self.input.as_str()[arg.start()..arg.end()].to_string(), + None => self.input.as_str()[arg.start()..arg.end()].to_string(), }; let arg = arg.clone(); let distinct = *distinct; - let meta = AvgMeta { span: arg, avg_field_name, field_name: value.to_string(), distinct }; + let meta = AvgMeta { + span: arg, + avg_field_name, + field_name: value.to_string(), + distinct, + }; self.push_avg(meta); self.state = ScanState::Empty; } @@ -426,23 +479,38 @@ impl Transformer for RewriteMetaData { } } - Expr::BinaryOperationExpr { span: _, operator, left: _, right } => { - match &mut self.state { - ScanState::Where(args) => { - if let Expr::LiteralExpr(_) = **right { - let op = operator.format(); - if op.as_str() == "=" { - args.push(op); - } + Expr::BinaryOperationExpr { span: _, operator, left, right } => { + if operator.format() != "=" { + return false; + } + + let where_left = + if let Expr::SimpleIdentExpr(Value::Ident { value, .. }) = &**left { + Some(value.to_string()) + } else { + None + }; + + let where_right = + if let Expr::LiteralExpr(Value::Num { value, signed, .. }) = &**right { + if *signed { + Some(WhereMetaRightDataType::SignedNum(value.to_string())) + } else { + Some(WhereMetaRightDataType::Num(value.to_string())) } - } + } else { + None + }; - _ => {} + if where_left.is_some() && where_right.is_some() { + let where_meta = WhereMeta::BinaryExpr { + left: where_left.unwrap(), + right: where_right.unwrap(), + }; + self.push_where(where_meta); } } - - Expr::InExpr { .. } => { self.prev_expr_type = Some("In".to_string()); } @@ -457,31 +525,12 @@ impl Transformer for RewriteMetaData { self.prev_expr_type = None } - Expr::LiteralExpr(Value::Num { span, value, signed }) => { - match &mut self.state { - ScanState::Where(args) => { - if args.len() > 1 && args[0] == "=" { - let right_value = if *signed { - WhereMetaRightDataType::SignedNum(value.to_string()) - } else { - WhereMetaRightDataType::Num(value.to_string()) - }; - - let where_meta = WhereMeta::BinaryExpr { - left: args[1].clone(), - right: right_value, - }; - self.push_where(where_meta); - } - self.state = ScanState::Empty; - } - - ScanState::InsertRowValue(args) => { - args.push(InsertValue { span: *span, value: value.clone() }); - } - _ => {} + Expr::LiteralExpr(Value::Num { span, value, .. }) => match &mut self.state { + ScanState::InsertRowValue(args) => { + args.push(InsertValue { span: *span, value: value.clone() }); } - } + _ => {} + }, _ => {} }, @@ -509,7 +558,23 @@ mod test { use mysql_parser::{ast::Visitor, parser::Parser}; use super::RewriteMetaData; - + use crate::sharding_rewrite::meta::FieldMeta; + + #[test] + fn test_count() { + let parser = Parser::new(); + let input = "select count(*) from t"; + let mut ast = parser.parse(input).unwrap(); + let mut meta = RewriteMetaData::new(input.to_string()); + let _ = ast[0].visit(&mut meta); + let meta = meta.get_fields(); + if let FieldMeta::Ident(meta) = &meta[0][0] { + assert_eq!(meta.wrap_func.as_ref(), "count") + } else { + assert!(false) + } + } + #[test] fn test_get_alias_name() { let parser = Parser::new(); @@ -520,14 +585,12 @@ mod test { let avg_meta = meta.get_avgs(); assert_eq!(avg_meta[0][0].avg_field_name, "tt"); - let input = "select avg(ss) from t"; let mut ast = parser.parse(input).unwrap(); let mut meta = RewriteMetaData::new(input.to_string()); let _ = ast[0].visit(&mut meta); let avg_meta = meta.get_avgs(); assert_eq!(avg_meta[0][0].avg_field_name, "avg(ss)"); - } #[test] @@ -625,4 +688,4 @@ mod test { assert_eq!(meta.tables.len(), input.5); } } -} \ No newline at end of file +} diff --git a/pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs b/pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs index 0db96026..afcf694b 100644 --- a/pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs +++ b/pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs @@ -13,30 +13,33 @@ // limitations under the License. mod generic_meta; +mod macros; pub mod meta; pub mod rewrite_const; use std::vec; +use crc32fast::Hasher; use endpoint::endpoint::Endpoint; use indexmap::IndexMap; -use crc32fast::Hasher; -use mysql_parser::ast::{SqlStmt, Visitor, TableIdent }; -use crate::config::NodeGroup; -use crate::sharding_rewrite::meta::AvgMeta; -use crate::sharding_rewrite::meta::FieldWrapFunc; -use crate::sharding_rewrite::meta::GroupMeta; -use crate::sharding_rewrite::meta::OrderMeta; -use crate::sharding_rewrite::rewrite_const::*; -use crate::sharding_rewrite::meta::OrderDirection; - -use self::meta::FieldMetaIdent; -use self::{meta::{ - FieldMeta, InsertValsMeta, RewriteMetaData, WhereMeta, WhereMetaRightDataType, -}, generic_meta::ShardingMeta}; +use mysql_parser::ast::{SqlStmt, TableIdent, Visitor}; +use paste::paste; + +use self::{ + generic_meta::{ShardingMeta, ShardingMetaBaseInfo}, + meta::{ + FieldMeta, FieldMetaIdent, InsertValsMeta, RewriteMetaData, WhereMeta, + WhereMetaRightDataType, + }, +}; use crate::{ - config::{Sharding, ShardingAlgorithmName, StrategyType}, + config::{NodeGroup, Sharding, ShardingAlgorithmName}, + get_meta_detail, rewrite::{ShardingRewriteInput, ShardingRewriter}, + sharding_rewrite::{ + meta::{AvgMeta, FieldWrapFunc, GroupMeta, OrderDirection, OrderMeta}, + rewrite_const::*, + }, }; pub trait CalcShardingIdx { @@ -46,9 +49,7 @@ pub trait CalcShardingIdx { impl CalcShardingIdx for u64 { fn calc(self, algo: &ShardingAlgorithmName, sharding_count: u64) -> Option { match algo { - ShardingAlgorithmName::Mod => { - Some(self.wrapping_rem(sharding_count)) - }, + ShardingAlgorithmName::Mod => Some(self.wrapping_rem(sharding_count)), ShardingAlgorithmName::CRC32Mod => { let mut hasher = Hasher::new(); hasher.update(&self.to_be_bytes()); @@ -62,9 +63,7 @@ impl CalcShardingIdx for u64 { impl CalcShardingIdx for i64 { fn calc(self, algo: &ShardingAlgorithmName, sharding_count: i64) -> Option { match algo { - ShardingAlgorithmName::Mod => { - Some(self.wrapping_rem(sharding_count) as u64) - }, + ShardingAlgorithmName::Mod => Some(self.wrapping_rem(sharding_count) as u64), ShardingAlgorithmName::CRC32Mod => { let mut hasher = Hasher::new(); hasher.update(&self.to_be_bytes()); @@ -78,9 +77,7 @@ impl CalcShardingIdx for i64 { impl CalcShardingIdx for f64 { fn calc(self, algo: &ShardingAlgorithmName, sharding_count: f64) -> Option { match algo { - ShardingAlgorithmName::Mod => { - Some((self % sharding_count).round() as u64) - }, + ShardingAlgorithmName::Mod => Some((self % sharding_count).round() as u64), ShardingAlgorithmName::CRC32Mod => { let mut hasher = Hasher::new(); hasher.update(&self.to_be_bytes()); @@ -93,52 +90,53 @@ impl CalcShardingIdx for f64 { #[derive(Debug)] pub enum RewriteChange { - DatabaseChange(DatabaseChange), + TableChange(TableChange), AvgChange(AvgChange), OrderChange(OrderChange), GroupChange(GroupChange), } #[derive(Debug, Clone)] -pub struct DatabaseChange { +pub struct TableChange { pub span: mysql_parser::Span, + pub table: Option, + pub database: Option, + pub rule: Sharding, +} + +#[derive(Debug, Clone)] +pub struct TableChangeDetail { pub target: String, pub shard_idx: u64, - pub rule: Sharding, } #[derive(Debug)] pub struct AvgChange { // avg sql rewrite target // example: AVG(pbl): avg_count: PBL_AVG_DERIVED_COUNT_00000, avg_sum: PBL_AVG_DERIVED_SUM_00000 - pub target: IndexMap:: + pub target: IndexMap, } #[derive(Debug)] pub struct OrderChange { - pub target: IndexMap::, - pub direction: OrderDirection + pub target: IndexMap, + pub direction: OrderDirection, } impl Default for OrderChange { fn default() -> OrderChange { - OrderChange { - target: IndexMap::new(), - direction: OrderDirection::Asc, - } + OrderChange { target: IndexMap::new(), direction: OrderDirection::Asc } } } #[derive(Debug)] pub struct GroupChange { - pub target: IndexMap:: + pub target: IndexMap, } impl Default for GroupChange { fn default() -> GroupChange { - GroupChange { - target: IndexMap::new(), - } + GroupChange { target: IndexMap::new() } } } @@ -149,6 +147,7 @@ pub struct ShardingRewriteOutput { pub data_source: DataSource, pub sharding_column: Option, pub min_max_fields: Vec, + pub count_field: Option, } #[derive(Debug, Clone, PartialEq)] @@ -178,8 +177,8 @@ pub struct ShardingRewrite { #[derive(Debug, thiserror::Error)] pub enum ShardingRewriteError { - #[error("sharding column not found {0:?}")] - ShardingColumnNotFound(String), + #[error("sharding column not found")] + ShardingColumnNotFound, #[error("parse str to u64 error {0:?}")] ParseIntError(#[from] std::num::ParseIntError), @@ -187,8 +186,8 @@ pub enum ShardingRewriteError { #[error("parse str to u64 error {0:?}")] ParseFloatError(#[from] std::num::ParseFloatError), - #[error("calc mod error")] - CalcModError, + #[error("calc sharding idx error")] + CalcShardingIdxError, #[error("enpoint not found when using actual_datanodes")] EndpointNotFound, @@ -197,27 +196,59 @@ pub enum ShardingRewriteError { FieldsIsEmpty, #[error("database is not found")] - DatabaseNotFound + DatabaseNotFound, } struct ChangeInsertMeta { - row_sharding_value: String, - row_value_span: mysql_parser::Span, + sharding_value: InsertShardingValue, + value_span: mysql_parser::Span, +} + +struct InsertShardingValue { + database: Option, + table: Option, } -enum StrategyTyp { +#[derive(Debug)] +struct ShardingIdx { + database: Option, + table: Option, + span: mysql_parser::Span, +} + +#[derive(Debug)] +pub enum StrategyTyp { + DatabaseTable, Database, Table, } +enum DatabaseTableStrategyPart { + Database(usize), + Table(usize), + All, +} + pub enum OrderGroupChange { Order(OrderMeta), Group(GroupMeta), } impl ShardingRewrite { - pub fn new(rules: Vec, endpoints: Vec, node_group_config: Option, has_rw: bool) -> Self { - ShardingRewrite { rules, raw_sql: "".to_string(), endpoints, node_group_config, has_rw , default_db: None, } + pub fn new( + rules: Vec, + endpoints: Vec, + node_group_config: Option, + has_rw: bool, + ) -> Self { + ShardingRewrite { + rules, + raw_sql: "".to_string(), + endpoints, + node_group_config, + has_rw, + default_db: None, + } } pub fn get_endpoints(&self) -> &Vec { @@ -232,99 +263,191 @@ impl ShardingRewrite { self.default_db = db; } - fn database_strategy( + fn database_table_strategy( &self, meta: RewriteMetaData, try_tables: Vec<(u8, Sharding, &TableIdent)>, ) -> Result, ShardingRewriteError> { - let wheres = meta.get_wheres(); - let inserts = meta.get_inserts(); - let fields = meta.get_fields(); - let avgs = meta.get_avgs(); - let orders = meta.get_orders(); - let groups = meta.get_groups(); - + get_meta_detail!(meta, wheres, inserts, fields, avgs, orders, groups); if !inserts.is_empty() { if fields.is_empty() { - return Err(ShardingRewriteError::FieldsIsEmpty) + return Err(ShardingRewriteError::FieldsIsEmpty); } return self.change_insert_sql(try_tables, fields, inserts); } if wheres.is_empty() { - return Ok(self.database_strategy_iproduct(try_tables, avgs, fields, orders, groups)); + return Ok(self.database_table_strategy_iproduct( + DatabaseTableStrategyPart::All, + try_tables, + avgs, + fields, + orders, + groups, + )); } - let wheres = Self::find_try_where(StrategyTyp::Database, &try_tables, wheres)?.into_iter().filter_map(|x| - match x { - Some((idx, num, _)) => Some((idx, num)), - None => None, - } - ).collect::>(); + let wheres = Self::find_try_where(StrategyTyp::DatabaseTable, &try_tables, wheres)? + .into_iter() + .filter_map(|x| (!x.1.is_empty()).then(|| x)) + .collect::>(); - let expect_sum = wheres[0].1 as usize * wheres.len(); - let sum: usize = wheres.iter().map(|x| x.1).sum::() as usize; + if wheres.is_empty() { + return Ok(self.database_table_strategy_iproduct( + DatabaseTableStrategyPart::All, + try_tables, + avgs, + fields, + orders, + groups, + )); + } - if expect_sum != sum { - return Ok(self.database_strategy_iproduct(try_tables, avgs, fields, orders, groups)); + // We don't consider subquery for now. + // The table sharding exists only. + if wheres[0].1[0].is_none() { + let table_idx = wheres[0].1[1].unwrap() as usize; + return Ok(self.database_table_strategy_iproduct( + DatabaseTableStrategyPart::Table(table_idx), + try_tables, + avgs, + fields, + orders, + groups, + )); + } + + // The database sharding exists only. + if wheres[0].1[1].is_none() { + let db_idx = wheres[0].1[0].unwrap() as usize; + return Ok(self.database_table_strategy_iproduct( + DatabaseTableStrategyPart::Database(db_idx), + try_tables, + avgs, + fields, + orders, + groups, + )); } - let node_fn = |rule: &Sharding, shard_idx: u64| { + let db_fn = |rule: &Sharding, shard_idx: u64| { let node = &rule.actual_datanodes[shard_idx as usize]; let ep = self.endpoints.iter().find(|x| x.name.eq(node)).unwrap(); Some(ep.db.as_str()) }; - let would_changes = self.get_database_change_plan(&try_tables, &wheres, node_fn); + + let would_changes = + self.get_table_change_plan(StrategyTyp::DatabaseTable, &try_tables, &wheres, db_fn); let mut target_sql = self.raw_sql.to_string(); - let shard_idx = would_changes[0].1.shard_idx; - let sharding_rule = &would_changes[0].1.rule.clone(); - let changes = Self::database_change_apply(&mut target_sql, would_changes); + let changes = + Self::table_change_apply(&mut target_sql, StrategyTyp::DatabaseTable, &would_changes); - let ep = self - .endpoints - .iter() - .find(|e| e.name == sharding_rule.actual_datanodes[shard_idx as usize]).ok_or_else(|| ShardingRewriteError::EndpointNotFound)?; + let shard_idx = would_changes[0].1.database.as_ref().unwrap().shard_idx; + let sharding_rule = &would_changes[0].1.rule; - let data_source = if self.has_rw { - DataSource::NodeGroup(sharding_rule.actual_datanodes[0].clone()) - } else { - DataSource::Endpoint(ep.clone()) - }; + let data_source = self.gen_data_source(sharding_rule, shard_idx as usize)?; - let sharding_column = sharding_rule.database_strategy.as_ref().unwrap().get_sharding_column().0.unwrap().to_string(); - let min_max_fields = changes.iter().map(|x| self.find_min_max_fields(&x.0, fields)).flatten().collect(); + let sharding_column = sharding_rule.get_sharding_column().0.unwrap().to_string(); + let min_max_fields = + changes.iter().map(|x| self.find_min_max_fields(&x.0, fields)).flatten().collect(); let mut changes = changes.into_iter().map(|x| x.1).collect::>(); - let (order_change, group_change) = Self::change_order_group(&mut target_sql, orders, groups, fields, shard_idx); + let _ = self.change_order_group( + &mut target_sql, + &mut changes, + orders, + groups, + fields, + shard_idx, + ); + self.change_avg(&mut target_sql, &mut changes, avgs, shard_idx, 0); + + Ok(vec![ShardingRewriteOutput { + changes, + target_sql, + data_source, + sharding_column: Some(sharding_column), + min_max_fields, + count_field: Self::get_count_field(fields), + }]) + } + + fn database_strategy( + &self, + meta: RewriteMetaData, + try_tables: Vec<(u8, Sharding, &TableIdent)>, + ) -> Result, ShardingRewriteError> { + get_meta_detail!(meta, wheres, inserts, fields, avgs, orders, groups); + if !inserts.is_empty() { + if fields.is_empty() { + return Err(ShardingRewriteError::FieldsIsEmpty); + } - if !order_change.target.is_empty() { - changes.push(RewriteChange::OrderChange(order_change)) + return self.change_insert_sql(try_tables, fields, inserts); } - - if !group_change.target.is_empty() { - changes.push(RewriteChange::GroupChange(group_change)) + + if wheres.is_empty() { + return Ok(self.database_strategy_iproduct(try_tables, avgs, fields, orders, groups)); } - if !avgs.is_empty() { - let target = Self::change_avg(&mut target_sql, avgs, shard_idx, 0); - changes.push(RewriteChange::AvgChange(AvgChange{target})); + let wheres = Self::find_try_where(StrategyTyp::Database, &try_tables, wheres)?; + if wheres.is_empty() { + return Ok(self.database_strategy_iproduct(try_tables, avgs, fields, orders, groups)); } - Ok( - vec![ - ShardingRewriteOutput { - changes, - target_sql, - data_source, - sharding_column: Some(sharding_column), - min_max_fields, - } - ] - ) + let expect_sum = wheres[0].1[0].unwrap() as usize * wheres.len(); + let sum: usize = wheres.iter().map(|x| x.1[0].unwrap()).sum::() as usize; + + if expect_sum != sum { + return Ok(self.database_strategy_iproduct(try_tables, avgs, fields, orders, groups)); + } + + let db_fn = |rule: &Sharding, shard_idx: u64| { + let node = &rule.actual_datanodes[shard_idx as usize]; + let ep = self.endpoints.iter().find(|x| x.name.eq(node)).unwrap(); + Some(ep.db.as_str()) + }; + + let mut target_sql = self.raw_sql.to_string(); + let would_changes = + self.get_table_change_plan(StrategyTyp::Database, &try_tables, &wheres, db_fn); + let changes = + Self::table_change_apply(&mut target_sql, StrategyTyp::Database, &would_changes); + + // Currently, we do not consider the endpoint corresponding to the subquery when the subquery exists. + let shard_idx = would_changes[0].1.database.as_ref().unwrap().shard_idx; + let sharding_rule = &would_changes[0].1.rule; + + let data_source = self.gen_data_source(sharding_rule, shard_idx as usize)?; + + let sharding_column = sharding_rule.get_sharding_column().0.unwrap().to_string(); + let min_max_fields = + changes.iter().map(|x| self.find_min_max_fields(&x.0, fields)).flatten().collect(); + + let mut changes = changes.into_iter().map(|x| x.1).collect::>(); + + let _ = self.change_order_group( + &mut target_sql, + &mut changes, + orders, + groups, + fields, + shard_idx, + ); + self.change_avg(&mut target_sql, &mut changes, avgs, shard_idx, 0); + + Ok(vec![ShardingRewriteOutput { + changes, + target_sql, + data_source, + sharding_column: Some(sharding_column), + min_max_fields, + count_field: Self::get_count_field(fields), + }]) } fn table_strategy( @@ -332,94 +455,77 @@ impl ShardingRewrite { meta: RewriteMetaData, try_tables: Vec<(u8, Sharding, &TableIdent)>, ) -> Result, ShardingRewriteError> { - let wheres = meta.get_wheres(); - let inserts = meta.get_inserts(); - let fields = meta.get_fields(); - let avgs = meta.get_avgs(); - let orders = meta.get_orders(); - let groups = meta.get_groups(); + get_meta_detail!(meta, wheres, inserts, fields, avgs, orders, groups); if !inserts.is_empty() { if fields.is_empty() { - return Err(ShardingRewriteError::FieldsIsEmpty) + return Err(ShardingRewriteError::FieldsIsEmpty); } return self.change_insert_sql(try_tables, fields, inserts); } if wheres.is_empty() { - return Ok(self.table_strategy_iproduct(try_tables.clone(), avgs, fields, orders, groups)); + return Ok(self.table_strategy_iproduct(try_tables, avgs, fields, orders, groups)); } - let wheres = Self::find_try_where(StrategyTyp::Table, &try_tables, wheres)?.into_iter().filter_map(|x| { - match x { - Some((idx, num, _)) => Some((idx, num)), - None => None, - } - }).collect::>(); - + let wheres = Self::find_try_where(StrategyTyp::Table, &try_tables, wheres)?; if wheres.is_empty() { - return Ok(self.table_strategy_iproduct(try_tables.clone(), avgs, fields, orders, groups)); + return Ok(self.table_strategy_iproduct(try_tables, avgs, fields, orders, groups)); } - let expect_sum = wheres[0].1 as usize * wheres.len(); - let sum: usize = wheres.iter().map(|x| x.1).sum::() as usize; + let expect_sum = wheres[0].1[0].unwrap() as usize * wheres.len(); + let sum: usize = wheres.iter().map(|x| x.1[0].unwrap()).sum::() as usize; if expect_sum != sum { return Ok(self.table_strategy_iproduct(try_tables, avgs, fields, orders, groups)); } - let node_fn = |_rule: &Sharding, _shard_idx: u64| { None }; - let would_changes = self.get_database_change_plan(&try_tables, &wheres, node_fn); + let db_fn = |_rule: &Sharding, _shard_idx: u64| None; + let would_changes = + self.get_table_change_plan(StrategyTyp::Table, &try_tables, &wheres, db_fn); let mut target_sql = self.raw_sql.clone(); - let sharding_rule = &would_changes[0].1.rule.clone(); - let shard_idx: u64 = would_changes[0].1.shard_idx; - - let changes = Self::database_change_apply(&mut target_sql, would_changes); - - let data_source = if self.has_rw { - DataSource::NodeGroup(sharding_rule.actual_datanodes[0].clone()) - } else { - let ep = self.endpoints.iter().find(|e| e.name == sharding_rule.actual_datanodes[0]).ok_or_else(|| ShardingRewriteError::EndpointNotFound)?; - DataSource::Endpoint(ep.clone()) - }; - - let sharding_column = sharding_rule.get_sharding_column().1.unwrap(); + let sharding_rule = &would_changes[0].1.rule; + let shard_idx: u64 = would_changes[0].1.table.as_ref().unwrap().shard_idx; - let min_max_fields = changes.iter().map(|x| self.find_min_max_fields(&x.0, fields)).flatten().collect(); - let mut changes = changes.into_iter().map(|x| x.1).collect::>(); + let data_source = self.gen_data_source(sharding_rule, 0)?; - let (order_change, group_change) = Self::change_order_group(&mut target_sql, orders, groups, fields, shard_idx); + let changes = Self::table_change_apply(&mut target_sql, StrategyTyp::Table, &would_changes); - if !order_change.target.is_empty() { - changes.push(RewriteChange::OrderChange(order_change)); - } - - if !group_change.target.is_empty() { - changes.push(RewriteChange::GroupChange(group_change)); - } + let sharding_column = sharding_rule.get_sharding_column().1.unwrap(); - if !avgs.is_empty() { - let target = Self::change_avg(&mut target_sql, avgs, shard_idx, 0); - changes.push(RewriteChange::AvgChange(AvgChange { target })); - } + let min_max_fields = + changes.iter().map(|x| self.find_min_max_fields(&x.0, fields)).flatten().collect(); + let mut changes = changes.into_iter().map(|x| x.1).collect::>(); - Ok( - vec![ - ShardingRewriteOutput { - changes, - target_sql: target_sql.to_string(), - data_source, - sharding_column: Some(sharding_column.to_string()), - min_max_fields, - } - ] - ) + let _ = self.change_order_group( + &mut target_sql, + &mut changes, + orders, + groups, + fields, + shard_idx, + ); + self.change_avg(&mut target_sql, &mut changes, avgs, shard_idx, 0); + + Ok(vec![ShardingRewriteOutput { + changes, + target_sql: target_sql.to_string(), + data_source, + sharding_column: Some(sharding_column.to_string()), + min_max_fields, + count_field: Self::get_count_field(fields), + }]) } - - fn get_database_change_plan<'a, F>(&self, try_tables: &[(u8, Sharding, &TableIdent)], wheres: &[(u8, u64)], node_fn: F) -> Vec<(u8, DatabaseChange)> + fn get_table_change_plan<'a, F>( + &self, + strategy_typ: StrategyTyp, + try_tables: &[(u8, Sharding, &TableIdent)], + wheres: &[(u8, Vec>)], + db_fn: F, + ) -> Vec<(u8, TableChange)> where F: Fn(&Sharding, u64) -> Option<&'a str>, { @@ -428,14 +534,67 @@ impl ShardingRewrite { .filter_map(|x| { let w = wheres.iter().find(|w| w.0 == x.0); if let Some(w) = w { - let node = node_fn(&x.1, w.1).unwrap_or_else(|| ""); - let target = self.change_table(x.2, node, w.1); - Some((x.0, DatabaseChange { - span: x.2.span, - shard_idx: w.1, - target, - rule: x.1.clone(), - })) + //let actual_db = db_fn(&x.1, w.1).unwrap_or_else(|| ""); + match strategy_typ { + StrategyTyp::Database => { + let actual_db = db_fn(&x.1, w.1[0].unwrap()).unwrap_or_else(|| ""); + let target = self.change_table(x.2, actual_db, w.1[0].unwrap()); + Some(( + x.0, + TableChange { + span: x.2.span, + rule: x.1.clone(), + database: Some(TableChangeDetail { + target, + shard_idx: w.1[0].unwrap(), + }), + table: None, + }, + )) + } + StrategyTyp::Table => { + let actual_db = db_fn(&x.1, w.1[0].unwrap()).unwrap_or_else(|| ""); + let target = self.change_table(x.2, actual_db, w.1[0].unwrap()); + Some(( + x.0, + TableChange { + span: x.2.span, + rule: x.1.clone(), + table: Some(TableChangeDetail { + target, + shard_idx: w.1[0].unwrap(), + }), + database: None, + }, + )) + } + StrategyTyp::DatabaseTable => { + let db_idx = w.1[0].unwrap(); + let actual_db = db_fn(&x.1, db_idx).unwrap_or_else(|| ""); + let _ = self.change_table(x.2, actual_db, w.1[0].unwrap()); + + let table = TableIdent { + span: x.2.span.clone(), + schema: Some(actual_db.to_string()), + name: x.2.name.clone(), + }; + + let table_idx = w.1[1].unwrap(); + let target = self.change_table(&table, "", table_idx); + Some(( + x.0, + TableChange { + span: x.2.span, + rule: x.1.clone(), + database: Some(TableChangeDetail { + target: target.clone(), + shard_idx: db_idx, + }), + table: Some(TableChangeDetail { target, shard_idx: table_idx }), + }, + )) + } + } } else { None } @@ -443,15 +602,47 @@ impl ShardingRewrite { .collect::>() } - fn database_change_apply(target_sql: &mut String, would_changes: Vec<(u8, DatabaseChange)>) -> Vec<(u8, RewriteChange)> { + fn table_change_apply( + target_sql: &mut String, + strategy_typ: StrategyTyp, + would_changes: &[(u8, TableChange)], + ) -> Vec<(u8, RewriteChange)> { let mut offset = 0; would_changes .into_iter() .map(|x| { - Self::change_sql(target_sql, x.1.span, &x.1.target, offset); - offset = x.1.target.len() - x.1.span.len(); - (x.0, RewriteChange::DatabaseChange(x.1)) + match strategy_typ { + StrategyTyp::Database => { + Self::change_sql( + target_sql, + x.1.span, + &x.1.database.as_ref().unwrap().target, + offset, + ); + offset = x.1.database.as_ref().unwrap().target.len() - x.1.span.len(); + } + StrategyTyp::Table => { + Self::change_sql( + target_sql, + x.1.span, + &x.1.table.as_ref().unwrap().target, + offset, + ); + offset = x.1.table.as_ref().unwrap().target.len() - x.1.span.len(); + } + StrategyTyp::DatabaseTable => { + Self::change_sql( + target_sql, + x.1.span, + &x.1.database.as_ref().unwrap().target, + offset, + ); + offset = x.1.database.as_ref().unwrap().target.len() - x.1.span.len(); + } + }; + + (x.0, RewriteChange::TableChange(x.1.clone())) }) .collect::>() } @@ -476,18 +667,24 @@ impl ShardingRewrite { let has_default_db = if let Some(default_db) = default_db { if self.has_rw { let node_group = self.node_group_config.as_ref().unwrap(); - rule.actual_datanodes.iter().find(|x| { - node_group.members.iter().find(|g| &g.name == *x).is_some() - }).is_some() + rule.actual_datanodes + .iter() + .find(|x| node_group.members.iter().find(|g| &g.name == *x).is_some()) + .is_some() } else { - rule.actual_datanodes.iter().find(|x| { - endpoints.find(|ep| &ep.name == *x && default_db == &ep.db).is_some() - }).is_some() + rule.actual_datanodes + .iter() + .find(|x| { + endpoints + .find(|ep| &ep.name == *x && default_db == &ep.db) + .is_some() + }) + .is_some() } } else { false }; - + if meta.schema.is_some() || has_default_db { (idx, Some(rule.clone()), true) } else { @@ -531,406 +728,745 @@ impl ShardingRewrite { .collect::>() } - fn find_try_where<'a>(strategy_typ: StrategyTyp, try_tables: &[(u8, Sharding, &TableIdent)], wheres: &'a IndexMap>) -> Result>, ShardingRewriteError> { - Self::find_where(wheres, |query_id, meta| { - let rule = try_tables.iter().find(|x| x.0 == query_id); - if let Some(rule) = rule { - let (sharding_column, algo, sharding_count) = match strategy_typ { - StrategyTyp::Database => { - (rule.1.get_sharding_column().0.unwrap(), rule.1.get_algo().0.unwrap(), rule.1.get_sharding_count().0.unwrap()) - } + fn find_try_where<'a>( + strategy_typ: StrategyTyp, + try_tables: &[(u8, Sharding, &TableIdent)], + wheres: &'a IndexMap>, + ) -> Result>)>, ShardingRewriteError> { + let mut res = vec![]; - StrategyTyp::Table => { - (rule.1.get_sharding_column().1.unwrap(), rule.1.get_algo().1.unwrap(), rule.1.get_sharding_count().1.unwrap()) - } - }; + match strategy_typ { + StrategyTyp::Database => { + for (query_id, metas) in wheres.iter() { + let rule = try_tables.iter().find(|x| x.0 == *query_id).unwrap(); - return Self::parse_where( - meta, - algo, - sharding_count, - query_id, - sharding_column, - ); + let sharding_column = rule.1.get_sharding_column().0; + let sharding_count = rule.1.get_sharding_count().0; + let algo = rule.1.get_algo().0; + + let db_right = Self::get_where_match_right(metas, sharding_column.unwrap()); + + if let Some(right) = db_right { + let idx = Self::parse_where(right, sharding_count.unwrap(), algo.unwrap())?; + res.push((*query_id, vec![idx])); + } + } } - Ok(None) - }) - } + StrategyTyp::Table => { + for (query_id, metas) in wheres.iter() { + let rule = try_tables.iter().find(|x| x.0 == *query_id).unwrap(); - fn find_where( - wheres: &IndexMap>, - calc_fn: F, - ) -> Result>, ShardingRewriteError> - where - F: Fn(u8, &WhereMeta) -> Result, ShardingRewriteError>, - { - wheres - .iter() - .filter_map(|(k, v)| { - Some( - v.iter() - .filter_map(|x| { - let res = calc_fn(*k, x); - match res { - Ok(None) => None, - _ => Some(res), - } - }) - .collect::>(), - ) - }) - .flatten() - .collect::, _>>() - } + let sharding_column = rule.1.get_sharding_column().1; + let sharding_count = rule.1.get_sharding_count().1; + let algo = rule.1.get_algo().1; - fn parse_where<'b>( - meta: &'b WhereMeta, - algo: &ShardingAlgorithmName, - sharding_count: u64, - query_id: u8, - sharding_column: &str, - ) -> Result, ShardingRewriteError> { - match meta { - WhereMeta::BinaryExpr { left, right } => { - let left = left.replace("`", ""); - if left != sharding_column { - return Ok(None); + let table_right = Self::get_where_match_right(metas, sharding_column.unwrap()); + + if let Some(right) = table_right { + let idx = Self::parse_where(right, sharding_count.unwrap(), algo.unwrap())?; + res.push((*query_id, vec![idx])); + } } + } - let num = match right { - WhereMetaRightDataType::Num(val) => { - let val = val.parse::()?; - val.calc(algo, sharding_count) - }, + StrategyTyp::DatabaseTable => { + for (query_id, metas) in wheres.iter() { + let mut shards = vec![None, None]; + let rule = try_tables.iter().find(|x| x.0 == *query_id).unwrap(); - WhereMetaRightDataType::SignedNum(val) => { - let val = val.parse::()?; - val.calc(algo, sharding_count as i64) - }, + let sharding_column = rule.1.get_sharding_column(); + let sharding_count = rule.1.get_sharding_count(); + let algo = rule.1.get_algo(); + + let db_right = Self::get_where_match_right(metas, sharding_column.0.unwrap()); + let table_right = + Self::get_where_match_right(metas, sharding_column.1.unwrap()); - WhereMetaRightDataType::FloatNum(val) => { - let val = val.parse::()?; - val.calc(algo, sharding_count as f64) + if let Some(right) = db_right { + let idx = + Self::parse_where(right, sharding_count.0.unwrap(), algo.0.unwrap())?; + shards[0] = idx; + } + + if let Some(right) = table_right { + let idx = + Self::parse_where(right, sharding_count.1.unwrap(), algo.1.unwrap())?; + shards[1] = idx; } - _ => return Ok(None), - }; - if let Some(num) = num { - return Ok(Some((query_id, num, meta))); + res.push((*query_id, shards)); } + } + }; + + Ok(res) + } + + fn get_where_match_right<'b>( + metas: &'b [WhereMeta], + sharding_column: &'b str, + ) -> Option<&'b WhereMetaRightDataType> { + metas.iter().find_map(|x| { + let WhereMeta::BinaryExpr { left, right } = x; + let left = left.replace("`", ""); + left.eq(sharding_column).then(|| right) + }) + } + + fn parse_where( + right: &WhereMetaRightDataType, + sharding_count: u32, + sharding_algo: &ShardingAlgorithmName, + ) -> Result, ShardingRewriteError> { + let num = match right { + WhereMetaRightDataType::Num(val) => { + let val = val.parse::()?; + val.calc(sharding_algo, sharding_count as u64) + } + + WhereMetaRightDataType::SignedNum(val) => { + let val = val.parse::()?; + val.calc(sharding_algo, sharding_count as i64) + } - Ok(None) + WhereMetaRightDataType::FloatNum(val) => { + let val = val.parse::()?; + val.calc(sharding_algo, sharding_count as f64) } + _ => return Ok(None), + }; + + if let Some(num) = num { + return Ok(Some(num)); } - } - fn change_order_group( - target_sql: &mut String, - orders: &IndexMap>, - groups: &IndexMap>, - fields: &IndexMap>, - idx: u64 - ) -> (OrderChange, GroupChange) { - let mut order_changes = OrderChange::default(); - let mut group_changes = GroupChange::default(); - - for (query_id, field) in fields.iter() { - if *query_id == 1 { - let last_span = match &field[&field.len() - 1] { - FieldMeta::Ident(field) => { - field.span - } - _ => unreachable!(), - }; - let first_span = match &field[0] { - FieldMeta::Ident(field) => { - field.span - }, - _ => unreachable!(), - }; + Ok(None) + } - let len = last_span.start() + last_span.len() - first_span.start(); + fn change_order_group(&self, target_sql: &mut String, changes: &mut Vec, orders: &IndexMap>, groups: &IndexMap>, fields: &IndexMap>, + idx: u64, + ) -> Result { + // Currently, we don't consider the case that query_id not equal 1. + let fields = fields.get(&1).ok_or_else(|| ())?; - let mut ori_field = String::with_capacity(len); - let fields_list = field.iter().map(|x| { - match x { - FieldMeta::Ident(field) => { - ori_field += &format!("{}, ", &field.name).to_string(); - field.name.clone() - }, - _ => unreachable!() - } - }).collect::>(); - - for field in fields_list.into_iter() { - if !orders.is_empty() { - for order in orders[query_id].iter() { - if field.replace("`", "") == order.name.replace("`", "") { - order_changes.target.insert(ORDER_FIELD.to_string(), order.name.clone()); - order_changes.target.insert(ORDER_TARGET.to_string(), "".to_string()); - order_changes.direction = order.direction.clone(); - return (order_changes, group_changes); - } - } - } - - if !groups.is_empty() { - for group in groups[query_id].iter() { - if field.replace("`", "") == group.name.replace("`", "") { - group_changes.target.insert(GROUP_FIELD.to_string(), group.name.clone()); - group_changes.target.insert(GROUP_TARGET.to_string(), "".to_string()); - return (order_changes, group_changes); - } - } - } - } - - if !orders.is_empty() || !groups.is_empty() { - for _ in 0..len { - target_sql.remove(first_span.start()); + let last_span = match fields.last() { + Some(FieldMeta::Ident(field)) => field.span, + _ => return Ok(0), + }; + + let default_orders = Vec::::new(); + let orders = orders.get(&1).map_or_else(|| &default_orders, |v| v); + // Get the order field that does not exist in the fields. + let not_exist_order_fields = orders.iter().filter_map(|x| { + let exist_field = fields.iter().find(|field| { + if let FieldMeta::Ident(field_ident) = field { + if field_ident.name.replace("`", "") == x.name.replace("`", "") { + return true; } } + false + }); - if !orders.is_empty() { - for (field_query_id, field_meta) in fields.into_iter() { - if let Some(order_metas) = orders.get(field_query_id) { - for order in order_metas.into_iter() { - if let None = field_meta.into_iter().find(|&x| { - match x { - FieldMeta::Ident(field) => { - if *field.name == order.name { - true - } else { - false - } - } - _ => unreachable!(), - } - }) { - let target_field = format!("{}{} {} ", ori_field, order.name, AS); - let target_as = if order.name.contains("`") { - let new_order_name = order.name.replace("`", ""); - format!("{}_{}_{:05}", new_order_name.to_ascii_uppercase(), ORDER_BY_DERIVED, idx) - } else { - format!("{}_{}_{:05}", order.name.to_ascii_uppercase(), ORDER_BY_DERIVED, idx) - }; - let target = format!("{}{}", target_field, target_as); - target_sql.insert_str(first_span.start(), &target.clone()); - order_changes.target.insert(ORDER_TARGET.to_string(), target_as); - order_changes.target.insert(ORDER_FIELD.to_string(), order.name.clone()); - order_changes.direction = order.direction.clone(); - } - } - } - } - } + exist_field.is_none().then(|| x) + }).collect::>(); - if !groups.is_empty() { - for (field_query_id, field_meta) in fields.iter() { - if let Some(group_metas) = groups.get(field_query_id) { - for group in group_metas.into_iter() { - if let None = field_meta.into_iter().find(|&x| { - match x { - FieldMeta::Ident(field) => { - if *field.name == group.name { - true - } else { - false - } - } - _ => unreachable!(), - } - }) { - let target_field = format!("{}{} {} ", ori_field, group.name, AS); - let target_as = if group.name.contains("`") { - let new_group_name = group.name.replace("`", ""); - format!("{}_{}_{:05}", new_group_name.to_ascii_uppercase(), GROUP_BY_DERIVED, idx) - } else { - format!("{}_{}_{:05}", group.name.to_ascii_uppercase(), GROUP_BY_DERIVED, idx) - }; - let target = format!("{}{}", target_field, target_as); - target_sql.insert_str(first_span.start(), &target.clone()); - group_changes.target.insert(GROUP_TARGET.to_string(), target_as); - group_changes.target.insert(GROUP_FIELD.to_string(), group.name.clone()); - } - } - } + let default_groups = Vec::::new(); + let groups = groups.get(&1).map_or_else(|| &default_groups, |v| v); + + // Get the group field that does not exist in the fields. + let not_exist_group_fields = groups.iter().filter_map(|x| { + let exist_field = fields.iter().find(|field| { + if let FieldMeta::Ident(field_ident) = field { + if field_ident.name.replace("`", "") == x.name.replace("`", "") { + return true; } } - } + false + }); + + exist_field.is_none().then(|| x) + }).collect::>(); + + let mut target_fields = vec![]; + + for field in not_exist_order_fields { + let order_as_name = field.name.replace("`", ""); + let target_field = format!( + "{} {} {}_{}_{:05}", + field.name, + AS, + order_as_name.to_ascii_uppercase(), + ORDER_BY_DERIVED, + idx + ); + target_fields.push(target_field.clone()); + + let mut change_target = IndexMap::new(); + change_target.insert(ORDER_TARGET.to_string(), target_field); + change_target.insert(ORDER_FIELD.to_string(), field.name.clone()); + let order_change = OrderChange { + target: change_target, + direction: field.direction.clone(), + }; + changes.push(RewriteChange::OrderChange(order_change)); + } + + for field in not_exist_group_fields { + let group_as_name = field.name.replace("`", ""); + let target_field = format!( + "{} {} {}_{}_{:05}", + field.name, + AS, + group_as_name.to_ascii_uppercase(), + GROUP_BY_DERIVED, + idx + ); + target_fields.push(target_field.clone()); + + let mut change_target = IndexMap::new(); + change_target.insert(GROUP_TARGET.to_string(), target_field); + change_target.insert(GROUP_FIELD.to_string(), field.name.clone()); + let group_change = GroupChange { + target: change_target, + }; + changes.push(RewriteChange::GroupChange(group_change)); + } + + if target_fields.is_empty() { + return Ok(0); } - return (order_changes, group_changes) + let mut target_fields_string = target_fields.join(", "); + target_fields_string.insert_str(0, ", "); + target_sql.insert_str(last_span.end(), &target_fields_string); + + return Ok(target_fields_string.len()); } - fn change_avg(target_sql: &mut String, avgs: &IndexMap>, idx: u64, offset: usize) -> IndexMap { - let mut target = String::with_capacity(AVG_DERIVED_COUNT.len() + AVG_DERIVED_SUM.len()); + fn change_avg( + &self, + target_sql: &mut String, + changes: &mut Vec, + avgs: &IndexMap>, + idx: u64, + offset: usize, + ) { + if avgs.is_empty() { + return; + } + let mut res = IndexMap::new(); + let mut target = String::with_capacity(AVG_DERIVED_COUNT.len() + AVG_DERIVED_SUM.len()); + for (_, avg) in avgs.iter() { - let last_span = avg[avg.len() - 1].span; - let first_span = avg[0].span; + let last_span = avg.last().unwrap().span; + let first_span = avg.first().unwrap().span; let len = last_span.start() + last_span.len() - first_span.start(); - + for _ in 0..len { target_sql.remove(first_span.start() + offset); } - + for avg_meta in avg { res.insert(AVG_FIELD.to_string(), avg_meta.avg_field_name.clone()); let target_count = &format!("{}({}) {} ", COUNT, avg_meta.field_name, AS); - let target_count_as = &format!("{}_{}_{:05}", avg_meta.field_name.to_ascii_uppercase(), AVG_DERIVED_COUNT, idx); + let target_count_as = &format!( + "{}_{}_{:05}", + avg_meta.field_name.to_ascii_uppercase(), + AVG_DERIVED_COUNT, + idx + ); res.insert(AVG_COUNT.to_string(), target_count_as.to_string()); let target_sum = &format!("{}({}) {} ", SUM, avg_meta.field_name, AS); - let target_sum_as = &format!("{}_{}_{:05}", avg_meta.field_name.to_ascii_uppercase(), AVG_DERIVED_SUM, idx); + let target_sum_as = &format!( + "{}_{}_{:05}", + avg_meta.field_name.to_ascii_uppercase(), + AVG_DERIVED_SUM, + idx + ); res.insert(AVG_SUM.to_string(), target_sum_as.to_string()); - target += &format!("{}{}, {}{}", target_count, target_count_as, target_sum, target_sum_as); + target += &format!( + "{}{}, {}{}", + target_count, target_count_as, target_sum, target_sum_as + ); if avg_meta.span != last_span { target.push(','); target.push(' '); } } - + target_sql.insert_str(first_span.start() + offset, &target); } - res + changes.push(RewriteChange::AvgChange(AvgChange { target: res })); } - fn change_insert_sql(&self, try_tables: Vec<(u8, Sharding, &TableIdent)>, fields: &IndexMap>,inserts: &IndexMap>) -> Result, ShardingRewriteError> { - let outputs = try_tables.into_iter().map(|(query_id, rule, table)| { - let (sharding_count, sharding_column, algo ) = if rule.table_strategy.is_some() { - (rule.get_sharding_count().1.unwrap(), rule.get_sharding_column().1.unwrap(), rule.get_algo().1.unwrap()) - - } else if rule.database_strategy.is_some() { - (rule.get_sharding_count().0.unwrap(), rule.get_sharding_column().0.unwrap(), rule.get_algo().0.unwrap()) - } else { - unreachable!() - }; + fn change_insert_sql( + &self, + try_tables: Vec<(u8, Sharding, &TableIdent)>, + fields: &IndexMap>, + inserts: &IndexMap>, + ) -> Result, ShardingRewriteError> { + let outputs = try_tables + .into_iter() + .map(|(query_id, rule, table)| { + let meta_base_info = if rule.table_strategy.is_some() { + ShardingMetaBaseInfo { + column: (rule.get_sharding_column().1, None), + count: (rule.get_sharding_count().1, None), + algo: (rule.get_algo().1, None), + } + } else if rule.database_strategy.is_some() { + ShardingMetaBaseInfo { + column: (rule.get_sharding_column().0, None), + count: (rule.get_sharding_count().0, None), + algo: (rule.get_algo().0, None), + } + } else { + ShardingMetaBaseInfo { + column: rule.get_sharding_column(), + count: rule.get_sharding_count(), + algo: rule.get_algo(), + } + }; - self.change_insert_sql_inner( - &rule, - &table, - &inserts.get(&query_id).unwrap(), - &fields.get(&query_id).unwrap(), - sharding_column, - algo, - sharding_count, - ) - }).collect::, _>>()?.into_iter().flatten().collect::>(); + self.change_insert_sql_inner( + &rule, + &table, + &inserts.get(&query_id).unwrap(), + &fields.get(&query_id).unwrap(), + meta_base_info, + Self::get_count_field(fields), + ) + }) + .collect::, _>>()? + .into_iter() + .flatten() + .collect::>(); Ok(outputs) } - fn change_insert_sql_inner( + fn change_insert_sql_inner<'b>( &self, rule: &Sharding, table: &TableIdent, inserts: &[InsertValsMeta], fields: &[FieldMeta], - sharding_column: &str, - algo: &ShardingAlgorithmName, - sharding_count: u64, + meta_base_info: ShardingMetaBaseInfo<'b>, + count_field: Option, ) -> Result, ShardingRewriteError> { - let changes = Self::change_insert(inserts, fields, sharding_column, algo, sharding_count)?; - let row_start_idx = changes[0].1.start(); - let row_prefix_text = &self.raw_sql[0..row_start_idx]; - let mut change_rows = IndexMap::::new(); - let mut idx: usize = 0; - - for change in changes.iter() { - let target_table = self.change_table(table, "", change.0); - let mut target_row_prefix_text = row_prefix_text.to_string(); - if rule.table_strategy.is_some() { - Self::change_sql(&mut target_row_prefix_text, table.span, &target_table, 0); - } else if rule.database_strategy.is_some() { - idx = change.0 as usize; + let strategy_typ = rule.get_strategy_typ(); + let changes = Self::change_insert(&strategy_typ, inserts, fields, meta_base_info)?; + let row_start_idx = changes[0].span.start(); + let sql_prefix_text = &self.raw_sql[0..row_start_idx]; + + let mut sqls = IndexMap::<(Option, Option), String>::new(); + + match strategy_typ { + StrategyTyp::Database => { + for change in changes.iter() { + let db_idx = change.database.unwrap(); + let actual_db = rule + .get_actual_schema(&self.endpoints, Some(db_idx as usize)) + .unwrap_or_else(|| ""); + let target = self.change_table(table, actual_db, db_idx); + let mut target_sql_prefix_text = sql_prefix_text.to_string(); + let row_value_text = self.get_change_insert_row( + &mut target_sql_prefix_text, + &target, + table.span, + change, + ); + sqls.entry((Some(db_idx), None)) + .or_insert(target_sql_prefix_text) + .push_str(&row_value_text); + } } - let mut row_text = self.raw_sql[change.1.start()..change.1.end()].to_string(); - row_text.push_str(", "); - change_rows.entry(change.0).or_insert(target_row_prefix_text).push_str(&row_text); - } - let data_source = if self.has_rw { - DataSource::NodeGroup(rule.actual_datanodes[0].clone()) - } else { - if rule.table_strategy.is_some() { - let ep = self.endpoints.iter().find(|e| e.name == rule.actual_datanodes[0]).ok_or_else(|| ShardingRewriteError::EndpointNotFound)?; - DataSource::Endpoint(ep.clone()) - } else if rule.database_strategy.is_some() { - let ep = &self.endpoints[idx]; - DataSource::Endpoint(ep.clone()) - } else { - unreachable!() + StrategyTyp::Table => { + for change in changes.iter() { + let db_idx = change.table.unwrap(); + let target = self.change_table(table, "", db_idx); + let mut target_sql_prefix_text = sql_prefix_text.to_string(); + let row_value_text = self.get_change_insert_row( + &mut target_sql_prefix_text, + &target, + table.span, + change, + ); + sqls.entry((Some(db_idx), None)) + .or_insert(target_sql_prefix_text) + .push_str(&row_value_text); + } } - }; - - let outputs = change_rows.into_iter().map(|(_, v)| { - ShardingRewriteOutput { - changes: vec![], - target_sql: v.trim_end_matches(", ").to_string(), - data_source: data_source.clone(), - sharding_column: Some(sharding_column.to_string()), - min_max_fields: vec![], + + StrategyTyp::DatabaseTable => { + for change in changes.iter() { + let db_idx = change.database.unwrap(); + let actual_db = rule + .get_actual_schema(&self.endpoints, Some(db_idx as usize)) + .unwrap_or_else(|| ""); + let _ = self.change_table(table, actual_db, db_idx); + let table = TableIdent { + span: table.span.clone(), + schema: Some(actual_db.to_string()), + name: table.name.clone(), + }; + + let table_idx = change.table.unwrap(); + let target = self.change_table(&table, "", table_idx); + + let mut target_sql_prefix_text = sql_prefix_text.to_string(); + let row_value_text = self.get_change_insert_row( + &mut target_sql_prefix_text, + &target, + table.span, + change, + ); + sqls.entry((Some(db_idx), Some(table_idx))) + .or_insert(target_sql_prefix_text) + .push_str(&row_value_text); + } } + } - }).collect::>(); + let outputs = sqls + .into_iter() + .map(|((idx, _), v)| -> Result { + let data_source = self.gen_data_source(rule, idx.unwrap() as usize)?; + let sharding_column = rule.get_sharding_column().0.map(|x| x.to_string()); + + Ok(ShardingRewriteOutput { + changes: vec![], + target_sql: v.trim_end_matches(", ").to_string(), + data_source: data_source.clone(), + sharding_column, + min_max_fields: vec![], + count_field: count_field.clone(), + }) + }) + .collect::, _>>()?; Ok(outputs) } - fn change_insert( + fn get_change_insert_row( + &self, + prefix_text: &mut String, + target: &str, + span: mysql_parser::Span, + sharding_idx: &ShardingIdx, + ) -> String { + Self::change_sql(prefix_text, span, target, 0); + let mut row_value_text = + self.raw_sql[sharding_idx.span.start()..sharding_idx.span.end()].to_string(); + row_value_text.push_str(", "); + row_value_text + } + + fn change_insert<'b>( + strategy_typ: &StrategyTyp, inserts: &[InsertValsMeta], fields: &[FieldMeta], - sharding_column: &str, - algo: &ShardingAlgorithmName, - sharding_count: u64, - ) -> Result, ShardingRewriteError> { - let insert_values = Self::find_inserts(inserts, fields, sharding_column)?; - + meta_base_info: ShardingMetaBaseInfo<'b>, + ) -> Result, ShardingRewriteError> { + let insert_values = Self::find_inserts(&strategy_typ, inserts, fields, &meta_base_info)?; + let mut changes = vec![]; - for value in insert_values.iter() { - let shard_value = - value.row_sharding_value.parse::().map_err(ShardingRewriteError::from)?; - let shard_value = shard_value - .calc(algo, sharding_count) - .ok_or_else(|| ShardingRewriteError::CalcModError)?; - - changes.push((shard_value, value.row_value_span)) - } + + match strategy_typ { + StrategyTyp::Database => { + for value in insert_values.iter() { + let idx = Self::calc_database_or_table_sharding_idx( + value.sharding_value.database.as_ref(), + &meta_base_info, + )?; + changes.push(ShardingIdx { + database: Some(idx), + table: None, + span: value.value_span, + }) + } + } + + StrategyTyp::Table => { + for value in insert_values.iter() { + let idx = Self::calc_database_or_table_sharding_idx( + value.sharding_value.table.as_ref(), + &meta_base_info, + )?; + + changes.push(ShardingIdx { + database: None, + table: Some(idx), + span: value.value_span, + }) + } + } + + StrategyTyp::DatabaseTable => { + for value in insert_values.iter() { + let (db_idx, table_idx) = Self::calc_database_and_table_sharding_idx( + value.sharding_value.database.as_ref(), + value.sharding_value.table.as_ref(), + &meta_base_info, + )?; + + changes.push(ShardingIdx { + database: Some(db_idx), + table: Some(table_idx), + span: value.value_span, + }) + } + } + }; Ok(changes) } - fn find_inserts( + fn calc_database_or_table_sharding_idx<'b>( + value: Option<&String>, + meta_base_info: &'b ShardingMetaBaseInfo<'b>, + ) -> Result { + let algo = meta_base_info.algo.0.unwrap(); + let sharding_count = meta_base_info.count.0.unwrap() as u64; + let value = value + .ok_or_else(|| ShardingRewriteError::CalcShardingIdxError)? + .parse::() + .map_err(ShardingRewriteError::from)?; + value.calc(algo, sharding_count).ok_or_else(|| ShardingRewriteError::CalcShardingIdxError) + } + + fn calc_database_and_table_sharding_idx<'b>( + db_value: Option<&String>, + table_value: Option<&String>, + meta_base_info: &ShardingMetaBaseInfo<'b>, + ) -> Result<(u64, u64), ShardingRewriteError> { + let algo = meta_base_info.algo.0.unwrap(); + let sharding_count = meta_base_info.count.0.unwrap() as u64; + + let db_value = db_value + .ok_or_else(|| ShardingRewriteError::CalcShardingIdxError)? + .parse::() + .map_err(ShardingRewriteError::from)?; + let db_idx = db_value + .calc(algo, sharding_count) + .ok_or_else(|| ShardingRewriteError::CalcShardingIdxError)?; + + let table_value = table_value + .ok_or_else(|| ShardingRewriteError::CalcShardingIdxError)? + .parse::() + .map_err(ShardingRewriteError::from)?; + let table_idx = table_value + .calc(algo, sharding_count) + .ok_or_else(|| ShardingRewriteError::CalcShardingIdxError)?; + + Ok((db_idx, table_idx)) + } + + fn find_inserts<'b>( + strategy_typ: &StrategyTyp, inserts: &[InsertValsMeta], fields: &[FieldMeta], - sharding_column: &str, + meta_base_info: &ShardingMetaBaseInfo<'b>, ) -> Result, ShardingRewriteError> { - let field = fields + let (db_idx, table_idx) = match strategy_typ { + StrategyTyp::Database => { + let idx = Self::find_insert_field_idx(fields, |idx, field| { + (field.name == meta_base_info.column.0.unwrap()).then(|| idx) + })?; + (Some(idx), None) + } + StrategyTyp::Table => { + let idx = Self::find_insert_field_idx(fields, |idx, field| { + (field.name == meta_base_info.column.0.unwrap()).then(|| idx) + })?; + (None, Some(idx)) + } + StrategyTyp::DatabaseTable => { + let db_idx = Self::find_insert_field_idx(fields, |idx, field| { + (field.name == meta_base_info.column.0.unwrap()).then(|| idx) + })?; + + let table_idx = Self::find_insert_field_idx(fields, |idx, field| { + (field.name == meta_base_info.column.1.unwrap()).then(|| idx) + })?; + (Some(db_idx), Some(table_idx)) + } + }; + + if db_idx.is_some() && table_idx.is_some() { + let changes = inserts + .iter() + .map(|x| ChangeInsertMeta { + sharding_value: InsertShardingValue { + database: Some(x.values[db_idx.unwrap()].value.clone()), + table: Some(x.values[table_idx.unwrap()].value.clone()), + }, + value_span: x.span, + }) + .collect::>(); + + return Ok(changes); + } + + if let Some(db_idx) = db_idx { + let db_changes = inserts + .iter() + .map(|x| ChangeInsertMeta { + sharding_value: InsertShardingValue { + database: Some(x.values[db_idx].value.clone()), + table: None, + }, + value_span: x.span, + }) + .collect::>(); + + return Ok(db_changes); + } + + if let Some(table_idx) = table_idx { + let table_changes = inserts + .iter() + .map(|x| ChangeInsertMeta { + sharding_value: InsertShardingValue { + database: None, + table: Some(x.values[table_idx].value.clone()), + }, + value_span: x.span, + }) + .collect::>(); + + return Ok(table_changes); + } + + Ok(vec![]) + } + + fn find_insert_field_idx( + fields: &[FieldMeta], + find_f: F, + ) -> Result + where + F: Fn(usize, &FieldMetaIdent) -> Option, + { + fields .iter() .enumerate() - .find_map(|x| { - if let FieldMeta::Ident(field) = x.1 { - if field.name == sharding_column { - return Some((x.0, field.span, field.name.clone())); + .find_map(|x| if let FieldMeta::Ident(field) = x.1 { find_f(x.0, field) } else { None }) + .ok_or_else(|| ShardingRewriteError::ShardingColumnNotFound) + } + + fn database_table_strategy_iproduct( + &self, + part: DatabaseTableStrategyPart, + tables: Vec<(u8, Sharding, &TableIdent)>, + avgs: &IndexMap>, + fields: &IndexMap>, + orders: &IndexMap>, + groups: &IndexMap>, + ) -> Vec { + let mut outputs = vec![]; + + for t in tables.iter() { + let db_sharding_column = t.1.get_sharding_column().0.map(|x| x.to_string()); + let sharding_count = t.1.get_sharding_count().1.unwrap() as u64; + + let actual_nodes = if let DatabaseTableStrategyPart::Database(idx) = part { + vec![t.1.actual_datanodes[idx].clone()] + } else { + t.1.actual_datanodes.clone() + }; + + for (idx, node) in actual_nodes.iter().enumerate() { + let ep = self.endpoints.iter().find(|e| e.name.eq(node)).unwrap().clone(); + let _ = self.change_table(t.2, &ep.db, 0); + let table = + TableIdent { span: t.2.span, schema: Some(ep.db), name: t.2.name.clone() }; + let mut group_changes = IndexMap::)>::new(); + let data_source = self.gen_data_source(&t.1, idx).unwrap(); + + if let DatabaseTableStrategyPart::Table(idx) = part { + let target = self.change_table(&table, "", idx as u64); + let change = TableChange { + span: t.2.span, + table: Some(TableChangeDetail { target, shard_idx: idx as u64 }), + database: None, + rule: t.1.clone(), + }; + group_changes.entry(idx as usize).or_insert((t.0, vec![])).1.push(change); + } else { + for table_idx in 0..sharding_count { + let target = self.change_table(&table, "", table_idx); + let change = TableChange { + span: t.2.span, + table: Some(TableChangeDetail { target, shard_idx: table_idx as u64 }), + database: None, + rule: t.1.clone(), + }; + group_changes + .entry(table_idx as usize) + .or_insert((t.0, vec![])) + .1 + .push(change); } } - None - }) - .ok_or_else(|| { - ShardingRewriteError::ShardingColumnNotFound(sharding_column.to_string()) - })?; + for (_group, changes) in group_changes.into_iter() { + let mut offset = 0; + let mut target_sql = self.raw_sql.to_string(); + + for change in changes.1.iter() { + Self::change_sql( + &mut target_sql, + change.span, + &change.table.as_ref().unwrap().target, + offset, + ); + offset = change.table.as_ref().unwrap().target.len() - change.span.len(); + } - Ok(inserts - .iter() - .enumerate() - .map(|(_, insert)| ChangeInsertMeta { - row_sharding_value: insert.values[field.0].value.clone(), - row_value_span: insert.span.clone(), - }) - .collect::>()) + let min_max_fields = self.find_min_max_fields(&changes.0, fields); + let mut rewrite_changes: Vec = + changes.1.iter().map(|x| RewriteChange::TableChange(x.clone())).collect(); + + let _ = self.change_order_group( + &mut target_sql, + &mut rewrite_changes, + orders, + groups, + fields, + changes.1[0].table.as_ref().unwrap().shard_idx, + ); + + if !avgs.is_empty() { + if changes.1[0].span.start() > avgs[0][0].span.start() { + offset = 0; + } + self.change_avg( + &mut target_sql, + &mut rewrite_changes, + avgs, + changes.1[0].table.as_ref().unwrap().shard_idx, + offset, + ); + } + + outputs.push(ShardingRewriteOutput { + changes: rewrite_changes, + target_sql, + data_source: data_source.clone(), + sharding_column: db_sharding_column.clone(), + min_max_fields, + count_field: Self::get_count_field(fields), + }) + } + } + } + + outputs } fn database_strategy_iproduct( @@ -942,23 +1478,21 @@ impl ShardingRewrite { groups: &IndexMap>, ) -> Vec { let mut output = vec![]; - + let mut sharding_column = None; - let mut group_changes = IndexMap::)>::new(); + let mut group_changes = IndexMap::)>::new(); for t in tables.iter() { - if let StrategyType::DatabaseStrategyConfig(config) = &t.1.database_strategy.as_ref().unwrap() { - sharding_column = Some(config.database_sharding_column.clone()) - } + sharding_column = Some(t.1.get_sharding_column().0.unwrap().to_string()); for (idx, node) in t.1.actual_datanodes.iter().enumerate() { let ep = self.endpoints.iter().find(|e| e.name.eq(node)).unwrap().clone(); let target = self.change_table(t.2, &ep.db, 0); - let change = DatabaseChange { + let change = TableChange { span: t.2.span, - target, - shard_idx: idx as u64, + database: Some(TableChangeDetail { target, shard_idx: idx as u64 }), + table: None, rule: t.1.clone(), }; @@ -972,29 +1506,39 @@ impl ShardingRewrite { let ep = self.endpoints[group].clone(); for change in changes.1.iter() { - Self::change_sql(&mut target_sql, change.span, &change.target, offset); - offset = change.target.len() - change.span.len(); + Self::change_sql( + &mut target_sql, + change.span, + &change.database.as_ref().unwrap().target, + offset, + ); + offset = change.database.as_ref().unwrap().target.len() - change.span.len(); } let min_max_fields = self.find_min_max_fields(&changes.0, fields); - let (order_change, group_change) = Self::change_order_group(&mut target_sql, orders, groups, fields, changes.1[0].shard_idx); - - let mut rewrite_changes: Vec = changes.1.iter().map(|x| RewriteChange::DatabaseChange(x.clone())).collect(); + let mut rewrite_changes: Vec = + changes.1.iter().map(|x| RewriteChange::TableChange(x.clone())).collect(); + + let _ = self.change_order_group( + &mut target_sql, + &mut rewrite_changes, + orders, + groups, + fields, + changes.1[0].database.as_ref().unwrap().shard_idx, + ); - if !order_change.target.is_empty() { - rewrite_changes.push(RewriteChange::OrderChange(order_change)) - } - - if !group_change.target.is_empty() { - rewrite_changes.push(RewriteChange::GroupChange(group_change)) - } - if !avgs.is_empty() { if changes.1[0].span.start() > avgs[0][0].span.start() { - offset = 0; + offset = 0; } - let target = Self::change_avg(&mut target_sql, avgs, changes.1[0].shard_idx, offset); - rewrite_changes.push(RewriteChange::AvgChange(AvgChange { target })) + self.change_avg( + &mut target_sql, + &mut rewrite_changes, + avgs, + changes.1[0].database.as_ref().unwrap().shard_idx, + offset, + ); } output.push(ShardingRewriteOutput { @@ -1003,6 +1547,7 @@ impl ShardingRewrite { data_source: DataSource::Endpoint(ep), sharding_column: sharding_column.clone(), min_max_fields, + count_field: Self::get_count_field(fields), }) } output @@ -1017,31 +1562,23 @@ impl ShardingRewrite { groups: &IndexMap>, ) -> Vec { let mut output = vec![]; - let mut group_changes = IndexMap::)>::new(); - let mut sharding_column = None; - let mut data_source = None; + let mut group_changes = IndexMap::)>::new(); + let mut sharding_column = None; + let data_source = self.gen_data_source(&tables[0].1, 0).unwrap(); for t in tables.iter() { - // Fixme: it should execute once only. - data_source = if self.has_rw { - Some(DataSource::NodeGroup(t.1.actual_datanodes[0].clone())) - } else { - Some(DataSource::Endpoint(self.endpoints[0].clone())) - }; - sharding_column = Some(t.1.get_sharding_column().1.unwrap().to_string()); let sharding_count = t.1.get_sharding_count().1.unwrap(); for idx in 0..sharding_count as u64 { let target = self.change_table(t.2, "", idx); - let change = DatabaseChange { + let change = TableChange { span: t.2.span, - target, - shard_idx: idx as u64, + table: Some(TableChangeDetail { target, shard_idx: idx as u64 }), + database: None, rule: t.1.clone(), }; - group_changes.entry(idx as usize).or_insert((t.0, vec![])).1.push(change); } } @@ -1051,77 +1588,81 @@ impl ShardingRewrite { let mut target_sql = self.raw_sql.clone(); for change in changes.1.iter() { - Self::change_sql(&mut target_sql, change.span, &change.target, offset); - offset = change.target.len() - change.span.len(); - + Self::change_sql( + &mut target_sql, + change.span, + &change.table.as_ref().unwrap().target, + offset, + ); + offset = change.table.as_ref().unwrap().target.len() - change.span.len(); } let min_max_fields = self.find_min_max_fields(&changes.0, fields); - let (order_change, group_change) = Self::change_order_group(&mut target_sql, orders, groups, fields, changes.1[0].shard_idx); - - let mut rewrite_changes: Vec = changes.1.iter().map(|x| RewriteChange::DatabaseChange(x.clone())).collect(); - - if !order_change.target.is_empty() { - rewrite_changes.push(RewriteChange::OrderChange(order_change)) - } - - if !group_change.target.is_empty() { - rewrite_changes.push(RewriteChange::GroupChange(group_change)); - } + let mut rewrite_changes: Vec = + changes.1.iter().map(|x| RewriteChange::TableChange(x.clone())).collect(); + let _ = self.change_order_group( + &mut target_sql, + &mut rewrite_changes, + orders, + groups, + fields, + changes.1[0].table.as_ref().unwrap().shard_idx, + ); if !avgs.is_empty() { if changes.1[0].span.start() > avgs[0][0].span.start() { - offset = 0; + offset = 0; } - let target = Self::change_avg(&mut target_sql, avgs, changes.1[0].shard_idx, offset); - rewrite_changes.push(RewriteChange::AvgChange(AvgChange{target})); + self.change_avg( + &mut target_sql, + &mut rewrite_changes, + avgs, + changes.1[0].table.as_ref().unwrap().shard_idx, + offset, + ); } output.push(ShardingRewriteOutput { - changes: rewrite_changes, + changes: rewrite_changes, target_sql: target_sql.to_string(), - data_source: data_source.as_ref().unwrap().clone(), + data_source: data_source.clone(), sharding_column: sharding_column.clone(), min_max_fields, + count_field: Self::get_count_field(fields), }) } output } - fn change_table(&self, table: &TableIdent, actual_node: &str, table_idx: u64) -> String { - let schema = if let Some(schema) = table.schema.as_ref() { - schema - } else { - self.default_db.as_ref().unwrap() - }; + fn change_table(&self, table: &TableIdent, actual_db: &str, table_idx: u64) -> String { + let db = table.schema.as_ref().unwrap_or_else(|| self.default_db.as_ref().unwrap()); - let mut target = String::with_capacity(schema.len()); + let mut target = String::with_capacity(db.len()); - if actual_node.is_empty() { + if actual_db.is_empty() { target.push('`'); - target.push_str(schema); + target.push_str(db); target.push('`'); target.push('.'); - if table.name.contains("`") { - target.push('`'); + if table.name.contains("`") { + target.push('`'); let name = table.name.replace("`", ""); target.push_str(&format!("{}_{:05}", &name, table_idx)); target.push('`'); } else { target.push_str(&format!("{}_{:05}", &table.name, table_idx)); } - } else { - if schema.contains("`") { + if db.contains("`") { target.push('`'); - target.push_str(actual_node); + target.push_str(actual_db); target.push('`'); } else { - target.push_str(actual_node); + target.push_str(actual_db); } - + target.push_str("."); target.push_str(&table.name); } @@ -1141,27 +1682,61 @@ impl ShardingRewrite { meta } - fn find_min_max_fields(&self, idx: &u8, fields: &IndexMap>) -> Vec { + fn find_min_max_fields( + &self, + idx: &u8, + fields: &IndexMap>, + ) -> Vec { match fields.get(idx) { Some(fields) => { let mut min_max_fields = vec![]; for f in fields.into_iter() { match f { - FieldMeta::Ident(field) => { - match field.wrap_func { - FieldWrapFunc::Min | FieldWrapFunc::Max => { - min_max_fields.push(FieldMetaIdent { span: field.span, name: field.name.clone(), wrap_func: field.wrap_func.clone() }) - }, - _ => {} + FieldMeta::Ident(field) => match field.wrap_func { + FieldWrapFunc::Min | FieldWrapFunc::Max => { + min_max_fields.push(FieldMetaIdent { + span: field.span, + name: field.name.clone(), + wrap_func: field.wrap_func.clone(), + }) } + _ => {} }, _ => {} } } min_max_fields - }, + } - None => vec![] + None => vec![], + } + } + + fn get_count_field(fields: &IndexMap>) -> Option { + fields.values().find_map(|f| { + f.iter().find_map(|x| { + if let FieldMeta::Ident(meta) = x { + if meta.wrap_func == FieldWrapFunc::Count { + return Some(meta.clone()); + } + } + None + }) + }) + } + + fn gen_data_source( + &self, + rule: &Sharding, + shard_idx: usize, + ) -> Result { + if self.has_rw { + Ok(DataSource::NodeGroup(rule.actual_datanodes[0].clone())) + } else { + let ep = rule + .get_endpoint(&self.endpoints, Some(shard_idx)) + .ok_or_else(|| ShardingRewriteError::EndpointNotFound)?; + Ok(DataSource::Endpoint(ep.clone())) } } } @@ -1171,7 +1746,7 @@ impl ShardingRewriter for ShardingRewrite { fn rewrite(&mut self, mut input: ShardingRewriteInput) -> Self::Output { self.set_raw_sql(input.raw_sql); self.set_default_db(input.default_db); - + let meta = self.get_meta(&mut input.ast); let tables = meta.get_tables().clone(); let try_tables = self.find_table_rule(&tables); @@ -1180,19 +1755,18 @@ impl ShardingRewriter for ShardingRewrite { return Ok(vec![]); } - // Strategy according to first element of `try_tables`. let rule = &try_tables[0].1; if rule.database_strategy.is_some() { return self.database_strategy(meta, try_tables); - } - - if rule.table_strategy.is_some() { + } else if rule.table_strategy.is_some() { return self.table_strategy(meta, try_tables); + } else if rule.database_table_strategy.is_some() { + return self.database_table_strategy(meta, try_tables); } - return Ok(vec![]) + return Ok(vec![]); } } @@ -1202,7 +1776,54 @@ mod test { use mysql_parser::parser::Parser; use super::ShardingRewrite; - use crate::{config::{DatabaseStrategyConfig, Sharding, ShardingAlgorithmName, StrategyType}, rewrite::{ShardingRewriteInput, ShardingRewriter}, sharding_rewrite::DataSource}; + use crate::{ + config::{ + DatabaseStrategyConfig, DatabaseTableStrategyConfig, Sharding, ShardingAlgorithmName, + StrategyType, + }, + rewrite::{ShardingRewriteInput, ShardingRewriter}, + sharding_rewrite::DataSource, + }; + + fn get_database_table_sharding_config() -> (Vec, Vec) { + ( + vec![Sharding { + table_name: "tshard".to_string(), + actual_datanodes: vec!["ds0".to_string(), "ds1".to_string()], + binding_tables: None, + broadcast_tables: None, + database_strategy: None, + table_strategy: None, + database_table_strategy: Some(StrategyType::DatabaseTableStrategyConfig( + DatabaseTableStrategyConfig { + database_sharding_algorithm_name: ShardingAlgorithmName::Mod, + database_sharding_column: "didx".to_string(), + table_sharding_algorithm_name: ShardingAlgorithmName::Mod, + table_sharding_column: "idx".to_string(), + shading_count: 4, + }, + )), + }], + vec![ + Endpoint { + weight: 1, + name: String::from("ds0"), + db: String::from("db0"), + user: String::from("user"), + password: String::from("password"), + addr: String::from("127.0.0.1"), + }, + Endpoint { + weight: 1, + name: String::from("ds1"), + db: String::from("db1"), + user: String::from("user"), + password: String::from("password"), + addr: String::from("127.0.0.2"), + }, + ], + ) + } fn get_database_sharding_config() -> (Vec, Vec) { ( @@ -1269,6 +1890,95 @@ mod test { ) } + #[test] + fn test_database_table_sharding_strategy() { + let config = get_database_table_sharding_config(); + let raw_sql = "SELECT * FROM db.tshard where idx > 3"; + let parser = Parser::new(); + let ast = parser.parse(raw_sql).unwrap(); + let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), None, false); + let input = ShardingRewriteInput { + raw_sql: raw_sql.to_string(), + ast: ast[0].clone(), + default_db: Some("db".to_string()), + }; + let res = sr.rewrite(input).unwrap(); + let sqls = res.iter().map(|x| x.target_sql.clone()).collect::>(); + assert_eq!(sqls[0], "SELECT * FROM `db0`.tshard_00000 where idx > 3"); + assert_eq!(sqls[1], "SELECT * FROM `db0`.tshard_00001 where idx > 3"); + assert_eq!(sqls[2], "SELECT * FROM `db0`.tshard_00002 where idx > 3"); + assert_eq!(sqls[3], "SELECT * FROM `db0`.tshard_00003 where idx > 3"); + assert_eq!(sqls[4], "SELECT * FROM `db1`.tshard_00000 where idx > 3"); + assert_eq!(sqls[5], "SELECT * FROM `db1`.tshard_00001 where idx > 3"); + assert_eq!(sqls[6], "SELECT * FROM `db1`.tshard_00002 where idx > 3"); + assert_eq!(sqls[7], "SELECT * FROM `db1`.tshard_00003 where idx > 3"); + } + + #[test] + fn test_database_table_sharding_strategy_where() { + let config = get_database_table_sharding_config(); + let raw_sql = "SELECT * FROM db.tshard where idx = 3"; + let parser = Parser::new(); + let ast = parser.parse(raw_sql).unwrap(); + let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), None, false); + let input = ShardingRewriteInput { + raw_sql: raw_sql.to_string(), + ast: ast[0].clone(), + default_db: None, + }; + let res = sr.rewrite(input).unwrap(); + let sqls = res.iter().map(|x| x.target_sql.clone()).collect::>(); + + assert_eq!(sqls[0], "SELECT * FROM `db0`.tshard_00003 where idx = 3"); + assert_eq!(sqls[1], "SELECT * FROM `db1`.tshard_00003 where idx = 3"); + + let raw_sql = "SELECT * FROM db.tshard where didx = 5"; + let ast = parser.parse(raw_sql).unwrap(); + let input = ShardingRewriteInput { + raw_sql: raw_sql.to_string(), + ast: ast[0].clone(), + default_db: None, + }; + + let res = sr.rewrite(input).unwrap(); + let sqls = res.iter().map(|x| x.target_sql.clone()).collect::>(); + assert_eq!(sqls[0], "SELECT * FROM `db1`.tshard_00000 where didx = 5"); + assert_eq!(sqls[1], "SELECT * FROM `db1`.tshard_00001 where didx = 5"); + assert_eq!(sqls[2], "SELECT * FROM `db1`.tshard_00002 where didx = 5"); + assert_eq!(sqls[3], "SELECT * FROM `db1`.tshard_00003 where didx = 5"); + + let raw_sql = "SELECT * FROM db.tshard where didx = 6 and idx = 3"; + let ast = parser.parse(raw_sql).unwrap(); + let input = ShardingRewriteInput { + raw_sql: raw_sql.to_string(), + ast: ast[0].clone(), + default_db: None, + }; + + let res = sr.rewrite(input).unwrap(); + let sqls = res.iter().map(|x| x.target_sql.clone()).collect::>(); + + assert_eq!(sqls[0], "SELECT * FROM `db0`.tshard_00003 where didx = 6 and idx = 3"); + } + + #[test] + fn test_database_table_sharding_strategy_insert() { + let config = get_database_table_sharding_config(); + let raw_sql = "INSERT INTO db.tshard(didx, idx) VALUES (12, 22), (13, 57), (19, 37)"; + let parser = Parser::new(); + let ast = parser.parse(raw_sql).unwrap(); + let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), None, false); + let input = ShardingRewriteInput { + raw_sql: raw_sql.to_string(), + ast: ast[0].clone(), + default_db: None, + }; + let res = sr.rewrite(input).unwrap(); + let sqls = res.iter().map(|x| x.target_sql.clone()).collect::>(); + assert_eq!(sqls[0], "INSERT INTO `db0`.tshard_00000(didx, idx) VALUES (12, 22)"); + assert_eq!(sqls[1], "INSERT INTO `db1`.tshard_00001(didx, idx) VALUES (13, 57), (19, 37)"); + } + #[test] fn test_database_sharding_strategy() { let config = get_database_sharding_config(); @@ -1305,10 +2015,26 @@ mod test { assert_eq!( res.into_iter().map(|x| x.target_sql).collect::>(), vec![ - "SELECT idx from db0.tshard where idx = 3 and idx = (SELECT idx from db0.tshard where idx = 4)", - "SELECT idx from db1.tshard where idx = 3 and idx = (SELECT idx from db1.tshard where idx = 4)", + "SELECT idx from db0.tshard where idx = 3 and idx = (SELECT idx from db0.tshard where idx = 4)", + "SELECT idx from db1.tshard where idx = 3 and idx = (SELECT idx from db1.tshard where idx = 4)", ], ); + + let raw_sql = "INSERT INTO db.tshard(idx, tt) VALUES (12, 22), (13, 55), (16, 77)"; + let ast = parser.parse(raw_sql).unwrap(); + let input = ShardingRewriteInput { + raw_sql: raw_sql.to_string(), + ast: ast[0].clone(), + default_db: None, + }; + let res = sr.rewrite(input).unwrap(); + assert_eq!( + res.into_iter().map(|x| x.target_sql).collect::>(), + vec![ + "INSERT INTO db0.tshard(idx, tt) VALUES (12, 22), (16, 77)", + "INSERT INTO db1.tshard(idx, tt) VALUES (13, 55)" + ] + ); } #[test] @@ -1343,7 +2069,10 @@ mod test { }; let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), None, false); let res = sr.rewrite(input).unwrap(); - assert_eq!(res[0].target_sql, "SELECT idx from `db`.tshard_00000 where idx = 4".to_string()); + assert_eq!( + res[0].target_sql, + "SELECT idx from `db`.tshard_00000 where idx = 4".to_string() + ); let raw_sql = "SELECT idx from db.`tshard` where idx = 3 and idx = (SELECT idx from db.tshard where idx = 3)".to_string(); let ast = parser.parse(&raw_sql).unwrap(); @@ -1372,7 +2101,6 @@ mod test { "SELECT idx from `db`.tshard_00003 where idx = 3 and idx = (SELECT idx from `db`.tshard_00003 where idx = 4)", ], ); - } #[test] @@ -1413,9 +2141,10 @@ mod test { let res = sr.rewrite(input).unwrap(); assert_eq!(res[0].target_sql, "SELECT COUNT(price) AS PRICE_AVG_DERIVED_COUNT_00000, SUM(price) AS PRICE_AVG_DERIVED_SUM_00000 FROM `db`.tshard_00000 WHERE idx > 3"); - let raw_sql = "SELECT AVG(pbl), AVG(znl), AVG(ngl) FROM db.tshard WHERE idx > 3".to_string(); + let raw_sql = + "SELECT AVG(pbl), AVG(znl), AVG(ngl) FROM db.tshard WHERE idx > 3".to_string(); let ast = parser.parse(&raw_sql).unwrap(); - let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(),None, false); + let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), None, false); let input = ShardingRewriteInput { raw_sql: raw_sql.clone(), ast: ast[0].clone(), @@ -1437,7 +2166,7 @@ mod test { let raw_sql = "SELECT AVG(znl) FROM db.tshard"; let ast = parser.parse(raw_sql).unwrap(); - let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(),None, false); + let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), None, false); let input = ShardingRewriteInput { raw_sql: raw_sql.to_string(), ast: ast[0].clone(), @@ -1445,8 +2174,9 @@ mod test { }; let res = sr.rewrite(input).unwrap(); assert_eq!(res[0].target_sql, "SELECT COUNT(znl) AS ZNL_AVG_DERIVED_COUNT_00000, SUM(znl) AS ZNL_AVG_DERIVED_SUM_00000 FROM `db`.tshard_00000"); - - let raw_sql = "SELECT * from db.tshard where znl > (SELECT AVG(znl) from db.tshard)".to_string(); + + let raw_sql = + "SELECT * from db.tshard where znl > (SELECT AVG(znl) from db.tshard)".to_string(); let ast = parser.parse(&raw_sql).unwrap(); let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), None, false); let input = ShardingRewriteInput { @@ -1474,9 +2204,7 @@ mod test { let res = sr.rewrite(input).unwrap(); assert_eq!( res.into_iter().map(|x| x.target_sql).collect::>(), - vec![ - "UPDATE `db`.tshard_00002 set a=1 where idx = 2" - ], + vec!["UPDATE `db`.tshard_00002 set a=1 where idx = 2"], ); let raw_sql = "DELETE FROM db.tshard where idx = 1"; @@ -1489,9 +2217,7 @@ mod test { let res = sr.rewrite(input).unwrap(); assert_eq!( res.into_iter().map(|x| x.target_sql).collect::>(), - vec![ - "DELETE FROM `db`.tshard_00001 where idx = 1" - ], + vec!["DELETE FROM `db`.tshard_00001 where idx = 1"], ); } @@ -1510,9 +2236,7 @@ mod test { let res = sr.rewrite(input).unwrap(); assert_eq!( res.into_iter().map(|x| x.target_sql).collect::>(), - vec![ - "UPDATE `db`.tshard_00002 set a=1 where idx = 2" - ], + vec!["UPDATE `db`.tshard_00002 set a=1 where idx = 2"], ); } @@ -1548,7 +2272,8 @@ mod test { "SELECT order_id, order_item_id, user_id AS USER_ID_GROUP_BY_DERIVED_00000 FROM `db`.tshard_00000 GROUP BY user_id" ); - let raw_sql = "SELECT order_id, order_item_id from db.tshard where user_id > 3 ORDER BY user_id"; + let raw_sql = + "SELECT order_id, order_item_id from db.tshard where user_id > 3 ORDER BY user_id"; let ast = parser.parse(raw_sql).unwrap(); let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), None, false); let input = ShardingRewriteInput { @@ -1576,7 +2301,8 @@ mod test { "SELECT order_id, order_item_id, user_id AS USER_ID_ORDER_BY_DERIVED_00000 FROM `db`.tshard_00000 WHERE id in (SELECT s_id, ngl, znl from `db`.tshard_00000) ORDER BY user_id" ); - let raw_sql = "SELECT order_id, order_item_id from db.tshard where idx = 3 ORDER BY `user_id`"; + let raw_sql = + "SELECT order_id, order_item_id from db.tshard where idx = 3 ORDER BY `user_id`"; let ast = parser.parse(raw_sql).unwrap(); let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), None, false); let input = ShardingRewriteInput { @@ -1610,7 +2336,10 @@ mod test { ast: ast[0].clone(), }; let res = sr.rewrite(input).unwrap(); - assert_eq!(res[0].target_sql, "SELECT id, idx AS IDX_ORDER_BY_DERIVED_00000 FROM `db`.tshard_00000 ORDER BY idx DESC"); + assert_eq!( + res[0].target_sql, + "SELECT id, idx AS IDX_ORDER_BY_DERIVED_00000 FROM `db`.tshard_00000 ORDER BY idx DESC" + ); } #[test] @@ -1626,9 +2355,6 @@ mod test { ast: ast[0].clone(), }; let res = sr.rewrite(input).unwrap(); - assert_eq!( - res[0].data_source, - DataSource::NodeGroup("ds001".to_string()) - ); + assert_eq!(res[0].data_source, DataSource::NodeGroup("ds001".to_string())); } } diff --git a/pisa-proxy/proxy/strategy/src/sharding_rewrite/rewrite_const.rs b/pisa-proxy/proxy/strategy/src/sharding_rewrite/rewrite_const.rs index a91324fe..19e135ed 100644 --- a/pisa-proxy/proxy/strategy/src/sharding_rewrite/rewrite_const.rs +++ b/pisa-proxy/proxy/strategy/src/sharding_rewrite/rewrite_const.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - pub const ORDER_BY_DERIVED: &str = "ORDER_BY_DERIVED"; pub const GROUP_BY_DERIVED: &str = "GROUP_BY_DERIVED"; pub const AVG_DERIVED_COUNT: &str = "AVG_DERIVED_COUNT"; @@ -26,4 +25,4 @@ pub const GROUP_FIELD: &str = "group_field"; pub const GROUP_TARGET: &str = "group_target"; pub const AVG_FIELD: &str = "avg_field"; pub const AVG_COUNT: &str = "avg_count"; -pub const AVG_SUM: &str = "avg_sum"; \ No newline at end of file +pub const AVG_SUM: &str = "avg_sum"; diff --git a/pisa-proxy/runtime/mysql/src/server/executor.rs b/pisa-proxy/runtime/mysql/src/server/executor.rs index afeea6cc..01ab07de 100644 --- a/pisa-proxy/runtime/mysql/src/server/executor.rs +++ b/pisa-proxy/runtime/mysql/src/server/executor.rs @@ -15,7 +15,7 @@ use std::{ marker::PhantomData, sync::{atomic::Ordering, Arc}, - vec, ops::Div, + vec, }; use bytes::{BytesMut, BufMut, Buf}; @@ -27,16 +27,16 @@ use mysql_protocol::{ conn::{ClientConn, SessionAttr}, stmt::Stmt, }, - column::{Column, ColumnInfo, decode_column}, + column::{ColumnInfo, decode_column}, err::ProtocolError, mysql_const::*, - row::{RowData, RowDataBinary, RowDataText, RowDataTyp}, + row::{RowData, RowDataBinary, RowDataText, RowDataTyp, decode_with_name, RowPartData}, server::codec::{make_eof_packet, CommonPacket, PacketSend}, - util::{is_eof, length_encode_int, BufMutExt}, + util::{length_encode_int, BufMutExt}, }; use pisa_error::error::{Error, ErrorKind}; use rayon::prelude::*; -use strategy::sharding_rewrite::{DataSource, ShardingRewriteOutput, RewriteChange, meta::FieldWrapFunc, rewrite_const::{AVG_COUNT, AVG_SUM, AVG_FIELD}}; +use strategy::sharding_rewrite::{DataSource, ShardingRewriteOutput, RewriteChange, meta::FieldWrapFunc, rewrite_const::{AVG_COUNT, AVG_SUM}}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::codec::{Decoder, Encoder}; @@ -47,6 +47,8 @@ use crate::{ use byteorder::{ByteOrder, LittleEndian}; +use super::util::filter_avg_column; + #[derive(Debug, thiserror::Error)] pub enum ExecuteError { #[error("execute sql: {0:?} error")] @@ -57,6 +59,11 @@ pub struct Executor { _phant: PhantomData<(T, C)>, } +struct ColumnUpdate { + ori_columns: Arc<[ColumnInfo]>, + new_columns: Arc<[ColumnInfo]>, +} + impl Executor where T: AsyncRead + AsyncWrite + Unpin + Send, @@ -157,7 +164,7 @@ where .codec_mut() .encode(PacketSend::EncodeOffset(header[4..].into(), 0), &mut buf); - let col_info = Self::get_columns(req, merge_stream, cols, &mut buf).await?; + let column_update = Self::get_columns(req, merge_stream, cols, &mut buf).await?; // read eof let _ = merge_stream.next().await; @@ -169,7 +176,7 @@ where merge_stream.set_state(MergeResultsetState::Row); // get rows - Self::get_rows(req, merge_stream, &mut buf, sharding_column, col_info, is_binary).await?; + Self::get_rows(req, merge_stream, &mut buf, sharding_column, column_update, is_binary).await?; let _ = req .framed @@ -185,33 +192,34 @@ where stream: &mut MergeStream>, buf: &mut BytesMut, sharding_column: Option, - col_info: Arc<[ColumnInfo]>, + column_update: ColumnUpdate, is_binary: bool, ) -> Result<(), Error> { - let row_data = match is_binary { + let mut row_data = match is_binary { false => { - let row_data_text = RowDataText::new(col_info.clone(), &[][..]); + let row_data_text = RowDataText::new(column_update.ori_columns.clone(), &[][..]); RowDataTyp::Text(row_data_text) } true => { - let row_data_binary = RowDataBinary::new(col_info.clone(), &[][..]); + let row_data_binary = RowDataBinary::new(column_update.ori_columns.clone(), &[][..]); RowDataTyp::Binary(row_data_binary) } }; - // Save min or max row data - let mut agg_buf = Vec::with_capacity(1024); - while let Some(chunk) = stream.next().await { let mut chunk = chunk .into_par_iter().map(|x| x.unwrap()).collect::, _>>().map_err(ErrorKind::from)?; + for i in chunk.iter() { + println!("or min_max {:?}", &i[..]); + } let ro = &req.rewrite_outputs[0]; - Self::handle_min_max(ro, &mut chunk, row_data.clone(), is_binary, &mut agg_buf)?; + Self::handle_min_max(ro, &mut chunk, row_data.clone(), is_binary)?; + for i in chunk.iter() { + println!("min_max {:?}", &i[..]); + } - let mut avg_chunk = vec![]; - let avg_change = ro.changes.iter().find_map(|x| { if let RewriteChange::AvgChange(change) = x { Some(change) @@ -223,74 +231,100 @@ where if let Some(avg) = avg_change { let count_field = avg.target.get(AVG_COUNT).unwrap(); let sum_field = avg.target.get(AVG_SUM).unwrap(); - avg_chunk = chunk.par_iter().map(|x| { + + let (count_data, sum_data): (Vec<_>, Vec<_>) = chunk.par_iter().map(|x| -> Result<(u64, u64), Error> { + println!("xxx {:?}", &x[..]); let mut row_data = row_data.clone(); - row_data.with_buf(&x[4..]); - if is_binary { - let count = row_data.decode_with_name::(count_field).unwrap().unwrap(); - let sum = row_data.decode_with_name::(sum_field).unwrap().unwrap(); - (count, sum) + let count = decode_with_name::<&[u8], u64>(&mut row_data, &count_field, is_binary).map_err(|e| ErrorKind::Runtime(e))?.unwrap_or_else(|| 0); + // Sum type is `MYSQL_TYPE_NEWDECIMAL` in binary, so need to convet to `String` type. + let sum = if is_binary { + let sum = decode_with_name::<&[u8], String>(&mut row_data, &sum_field, is_binary).map_err(|e| ErrorKind::from(e))?; + if let Some(sum) = sum { + sum.parse::().map_err(|e| ErrorKind::Runtime(e.into()))? + } else { + 0 + } } else { - let count = row_data.decode_with_name::(count_field).unwrap().unwrap(); - let count = count.parse::().unwrap(); - - let sum = row_data.decode_with_name::(sum_field).unwrap().unwrap(); - let sum = sum.parse::().unwrap(); - (count, sum) - } - }).collect::>(); + decode_with_name::<&[u8], u64>(&mut row_data, &sum_field, is_binary).map_err(|e| ErrorKind::Runtime(e))?.unwrap_or_else(|| 0) + }; + + Ok((count, sum)) + }).collect::, _>>()?.par_iter().cloned().unzip(); - let count: u64 = avg_chunk.par_iter().map(|x| x.0).sum(); - let sum: u64 = avg_chunk.par_iter().map(|x| x.1).sum(); + let count: u64 = count_data.par_iter().sum(); + let sum: u64 = sum_data.par_iter().sum(); + println!("count {:?}, sum {:?}", count, sum); - let _ = chunk.par_iter_mut().map(|x| { + chunk.par_iter_mut().for_each(|x| { let mut row_data = row_data.clone(); - let mut data = x.split_off(4); - row_data.with_buf(&data[..]); - + row_data.with_buf(&x[4..]); let count_data = row_data.get_row_data_with_name(count_field).unwrap().unwrap(); let sum_data = row_data.get_row_data_with_name(sum_field).unwrap().unwrap(); + let part_data = RowPartData { + data: vec![].into(), + start_idx: count_data.start_idx, + part_encode_length: count_data.part_encode_length + sum_data.part_encode_length, + part_data_length: count_data.part_data_length + sum_data.part_data_length, + }; - for _ in count_data.start_idx .. sum_data.end_part_idx + 2 { - data.get_u8(); - } - - x.extend_from_slice(&data[..]); - - let mut buf = Vec::with_capacity(32); let avg = format!("{:.4}", (sum as f64 / count as f64)); - buf.put_lenc_int(avg.len() as u64, true); - buf.put_slice(avg.as_bytes()); - x.put_slice(&buf); - }).collect::>(); + + row_data_cut_merge(x, &part_data, |data: &mut BytesMut| { + data.put_lenc_int(avg.len() as u64, false); + data.extend_from_slice(avg.as_bytes()); + }); + }); // When columns has count and sum field only, return directly. - if col_info.len() == 2 { + if column_update.ori_columns.len() == 2 { let _ = req .framed .codec_mut() .encode(PacketSend::EncodeOffset(chunk[0][4..].into(), buf.len()), buf); return Ok(()); } - } - - if col_info.len() == ro.min_max_fields.len() { - if is_binary { - agg_buf.insert(0, 0); - for _ in 0..(col_info.len() + 7 + 2) >> 3 { - agg_buf.insert(1, 0); + + row_data = match is_binary { + false => { + let row_data_text = RowDataText::new(column_update.new_columns.clone(), &[][..]); + RowDataTyp::Text(row_data_text) } - } - let _ = req - .framed - .codec_mut() - .encode(PacketSend::EncodeOffset(agg_buf[..].into(), buf.len()), buf); - return Ok(()); + true => { + let row_data_binary = RowDataBinary::new(column_update.new_columns.clone(), &[][..]); + RowDataTyp::Binary(row_data_binary) + } + }; + } + + if let Some(count_field) = &ro.count_field { + let count_sum = chunk.par_iter().map(|x| { + let mut row_data = row_data.clone(); + row_data.with_buf(&x[4..]); + println!("count x {:?}", &x[4..]); + decode_with_name::<&[u8], u64>(&mut row_data, &count_field.name, is_binary).unwrap().unwrap() + }).sum::(); + println!("count_sun {:?}", count_sum); + + let chunk_data = &chunk[0]; + let mut row_data = row_data.clone(); + row_data.with_buf(&chunk_data[4..]); + let row_part_data = row_data.get_row_data_with_name(&count_field.name).map_err(|e| ErrorKind::Runtime(e))?.unwrap(); + chunk.par_iter_mut().for_each(|x| { + row_data_cut_merge(x, &row_part_data, |data: &mut BytesMut| { + if is_binary { + data.extend_from_slice(&count_sum.to_le_bytes()[..]) + } else { + let count_sum_str = count_sum.to_string(); + data.put_lenc_int(count_sum_str.len() as u64, false); + data.extend_from_slice(count_sum_str.as_bytes()); + } + }); + }); } if let Some(name) = &sharding_column { - if let Some(_) = col_info.iter().find(|col_info| col_info.column_name.eq(name)) { + if let Some(_) = column_update.ori_columns.iter().find(|col_info| col_info.column_name.eq(name)) { chunk.par_sort_by_cached_key(|x| { let mut row_data = row_data.clone(); row_data.with_buf(&x[4..]); @@ -304,7 +338,16 @@ where } } + if chunk.par_iter().min() == chunk.par_iter().max() { + let _ = req + .framed + .codec_mut() + .encode(PacketSend::EncodeOffset(chunk[0][4..].into(), buf.len()), buf); + return Ok(()) + } + for row in chunk.iter() { + println!("end row {:?}", &row[..]); let _ = req .framed .codec_mut() @@ -315,65 +358,42 @@ where Ok(()) } - fn handle_min_max(ro: &ShardingRewriteOutput, chunk: &mut [BytesMut], row_data: RowDataTyp<&[u8]>, is_binary: bool, agg_buf: &mut B) -> Result<(), Error> { - for (idx, mmf) in ro.min_max_fields.iter().enumerate() { + fn handle_min_max(ro: &ShardingRewriteOutput, chunk: &mut [BytesMut], row_data: RowDataTyp<&[u8]>, is_binary: bool) -> Result<(), Error> { + for (_, mmf) in ro.min_max_fields.iter().enumerate() { match mmf.wrap_func { FieldWrapFunc::Max => { chunk.par_sort_unstable_by(|a, b| { let mut row_data = row_data.clone(); - - row_data.with_buf(&a[4..]); - let a = if is_binary { - row_data.decode_with_name::(&mmf.name).unwrap().unwrap() - } else { - let value = row_data.decode_with_name::(&mmf.name).unwrap().unwrap(); - value.parse::().unwrap() - }; - - row_data.with_buf(&b[4..]); - let b = if is_binary { - row_data.decode_with_name::(&mmf.name).unwrap().unwrap() - } else { - let value = row_data.decode_with_name::(&mmf.name).unwrap().unwrap(); - value.parse::().unwrap() - }; - + let (a, b) = get_min_max_value(&mut row_data, is_binary, &mmf.name, a, b); b.cmp(&a) }); + } FieldWrapFunc::Min => { chunk.par_sort_unstable_by(|a, b| { let mut row_data = row_data.clone(); - - row_data.with_buf(&a[4..]); - let a = if is_binary { - row_data.decode_with_name::(&mmf.name).unwrap().unwrap() - } else { - let value = row_data.decode_with_name::(&mmf.name).unwrap().unwrap(); - value.parse::().unwrap() - }; - - row_data.with_buf(&b[4..]); - let b = if is_binary { - row_data.decode_with_name::(&mmf.name).unwrap().unwrap() - } else { - let value = row_data.decode_with_name::(&mmf.name).unwrap().unwrap(); - value.parse::().unwrap() - }; - + let (a, b) = get_min_max_value(&mut row_data, is_binary, &mmf.name, a, b); a.cmp(&b) }); } - FieldWrapFunc::None => break + _ => break } + let chunk_data = &chunk[0]; let mut row_data = row_data.clone(); - row_data.with_buf(&chunk[0][4..]); - let row = row_data.get_row_data_with_name(&mmf.name).map_err(|e| ErrorKind::Runtime(e))?; - if let Some(row) = row { - agg_buf.put_slice(&row.data); - } + row_data.with_buf(&chunk_data[4..]); + let row_part_data = row_data.get_row_data_with_name(&mmf.name).map_err(|e| ErrorKind::Runtime(e))?.unwrap(); + chunk.par_iter_mut().for_each(|x| { + row_data_cut_merge(x, &row_part_data, |data: &mut BytesMut| { + if is_binary { + data.extend_from_slice(&row_part_data.data); + } else { + data.put_lenc_int(row_part_data.part_data_length as u64, false); + data.extend_from_slice(&row_part_data.data); + } + }) + }); } Ok(()) @@ -384,8 +404,9 @@ where stream: &mut MergeStream>, column_length: u64, buf: &mut BytesMut, - ) -> Result, Error> { - let mut col_buf = Vec::with_capacity(100); + ) -> Result { + let mut ori_columns = Vec::with_capacity(32); + let mut new_columns = Vec::with_capacity(32); let mut idx: Option = None; let ro = &req.rewrite_outputs[0]; @@ -397,25 +418,8 @@ where } }); - let mut avg_column_buf = Vec::with_capacity(128); - if let Some(change) = avg_change { - let avg_field = change.target.get(AVG_FIELD).unwrap(); - let avg_column = ColumnInfo { - schema: None, - table_name: None, - column_name: avg_field.to_string(), - charset: 0x3f, - column_length: 0x46, - column_type: ColumnType::MYSQL_TYPE_NEWDECIMAL, - column_flag: 0x0080, - decimals: 4, - }; - - avg_column.encode(&mut avg_column_buf); - } - - //let column_infos = Vec::with_capacity(column_length as usize); - let mut is_add_avg_column = false; + //let mut avg_column_buf = Vec::with_capacity(128); + let mut is_added_avg_column = false; for _ in 0..column_length { let data = stream.next().await; @@ -441,34 +445,37 @@ where } }; - col_buf.extend_from_slice(&data[..]); let column_info = decode_column(&data[4..]); + ori_columns.push(column_info.clone()); + if let Some(change) = avg_change { - let avg_count = change.target.get(AVG_COUNT).unwrap(); - let avg_sum = change.target.get(AVG_SUM).unwrap(); - if &column_info.column_name == avg_count || &column_info.column_name == avg_sum { - if !is_add_avg_column { + let filter_res = filter_avg_column(change, &column_info, is_added_avg_column); + if let Some(avg_column) = filter_res.1 { + if !filter_res.0.is_empty() { let _ = req .framed .codec_mut() - .encode(PacketSend::EncodeOffset(avg_column_buf[..].into(), buf.len()), buf); + .encode(PacketSend::EncodeOffset(filter_res.0.into(), buf.len()), buf); + new_columns.push(avg_column) } - - is_add_avg_column = true; + + is_added_avg_column = true; continue; } } + new_columns.push(column_info); + let _ = req .framed .codec_mut() .encode(PacketSend::EncodeOffset(data[4..].into(), buf.len()), buf); } - let col_info = col_buf.as_slice().decode_columns(); - let arc_col_info: Arc<[ColumnInfo]> = col_info.into_boxed_slice().into(); - - Ok(arc_col_info) + Ok(ColumnUpdate { + ori_columns: ori_columns.into(), + new_columns: new_columns.into(), + }) } fn get_shard_one_data( @@ -638,10 +645,24 @@ where } } -#[cfg(test)] -mod test { - #[test] - fn test() { - assert_eq!(1, 1); - } +fn row_data_cut_merge(ori_data: &mut BytesMut, row_part_data: &RowPartData, f: F) +where F: FnOnce(&mut BytesMut) +{ + let mut data = ori_data.split_off(4); + let mut data_remain = data.split_off(row_part_data.start_idx); + + f(&mut data); + + let _ = data_remain.split_to(row_part_data.part_encode_length + row_part_data.part_data_length); + data.extend_from_slice(&data_remain); + ori_data.extend_from_slice(&data); +} + +fn get_min_max_value<'a>(row_data: &mut RowDataTyp<&'a [u8]>, is_binary: bool, name: &'a str, a: &'a BytesMut, b: &'a BytesMut) -> (u64, u64) { + row_data.with_buf(&a[4..]); + let a = decode_with_name::<&[u8],u64>(row_data, name, is_binary).unwrap().unwrap(); + + row_data.with_buf(&b[4..]); + let b = decode_with_name::<&[u8],u64>(row_data, name, is_binary).unwrap().unwrap(); + (a, b) } diff --git a/pisa-proxy/runtime/mysql/src/server/mod.rs b/pisa-proxy/runtime/mysql/src/server/mod.rs index 8005eba2..448cb0b0 100644 --- a/pisa-proxy/runtime/mysql/src/server/mod.rs +++ b/pisa-proxy/runtime/mysql/src/server/mod.rs @@ -18,4 +18,5 @@ pub mod server; pub use server::*; mod executor; -pub mod stmt_cache; \ No newline at end of file +pub mod stmt_cache; +mod util; \ No newline at end of file diff --git a/pisa-proxy/runtime/mysql/src/server/server.rs b/pisa-proxy/runtime/mysql/src/server/server.rs index b835a675..9be199cd 100644 --- a/pisa-proxy/runtime/mysql/src/server/server.rs +++ b/pisa-proxy/runtime/mysql/src/server/server.rs @@ -29,10 +29,10 @@ use mysql_protocol::{ err::MySQLError, }, session::{SessionMut, Session}, - util::{is_eof, length_encode_int}, + util::{is_eof, length_encode_int}, column::{Column, ColumnInfo}, }; use pisa_error::error::{Error, ErrorKind}; -use strategy::{route::RouteInputTyp, sharding_rewrite::ShardingRewriteOutput}; +use strategy::{route::RouteInputTyp, sharding_rewrite::{ShardingRewriteOutput, RewriteChange}}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::codec::{Decoder, Encoder}; use tracing::{debug, error}; @@ -47,7 +47,7 @@ use crate::{ use std::sync::atomic::Ordering; -use super::executor::Executor; +use super::{executor::Executor, util::filter_avg_column}; pub struct PisaMySQLService { _phat: PhantomData<(T, C)>, @@ -176,7 +176,21 @@ where let mut buf = BytesMut::with_capacity(128); let mut data = vec![0]; data.extend_from_slice(&u32::to_le_bytes(stmt.stmt_id)); - data.extend_from_slice(&u16::to_le_bytes(stmt.cols_count)); + let avg_change = req.rewrite_outputs[0].changes.iter().find_map(|x| { + if let RewriteChange::AvgChange(change) = x { + Some(change) + } else { + None + } + }); + + if avg_change.is_some() { + data.extend_from_slice(&u16::to_le_bytes(stmt.cols_count - 1)); + } else { + data.extend_from_slice(&u16::to_le_bytes(stmt.cols_count)); + } + + data.extend_from_slice(&u16::to_le_bytes(stmt.params_count)); data.extend_from_slice(&[0, 0, 0]); @@ -199,7 +213,24 @@ where } if !stmt.cols_data.is_empty() { + let mut is_added_avg_column = false; for col_data in stmt.cols_data { + let column_info = (&col_data[4..]).decode_column(); + if let Some(change) = avg_change { + let filter_res = filter_avg_column(change, &column_info, is_added_avg_column); + if filter_res.1.is_some() { + if !filter_res.0.is_empty() { + let _ = req + .framed + .codec_mut() + .encode(PacketSend::EncodeOffset(filter_res.0.into(), buf.len()), &mut buf); + } + + is_added_avg_column = true; + continue; + } + } + let _ = req .framed .codec_mut() @@ -557,7 +588,7 @@ where Ok(()) } - async fn sharding_command_not_support( + async fn _sharding_command_not_support( cx: &mut ReqContext, command: &str, ) -> Result<(), Error> { diff --git a/pisa-proxy/runtime/mysql/src/server/util.rs b/pisa-proxy/runtime/mysql/src/server/util.rs new file mode 100644 index 00000000..6ce3b53f --- /dev/null +++ b/pisa-proxy/runtime/mysql/src/server/util.rs @@ -0,0 +1,44 @@ +// Copyright 2022 SphereEx Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use mysql_protocol::{column::ColumnInfo, mysql_const::ColumnType}; +use strategy::sharding_rewrite::{AvgChange, rewrite_const::{AVG_COUNT, AVG_SUM, AVG_FIELD}}; + +pub fn filter_avg_column(change: &AvgChange, column_info: &ColumnInfo, is_added: bool) -> (Vec, Option) { + let avg_count = change.target.get(AVG_COUNT).unwrap(); + let avg_sum = change.target.get(AVG_SUM).unwrap(); + let avg_field = change.target.get(AVG_FIELD).unwrap(); + let avg_column = ColumnInfo { + schema: None, + table_name: None, + column_name: avg_field.to_string(), + charset: 0x3f, + column_length: 0x46, + column_type: ColumnType::MYSQL_TYPE_NEWDECIMAL, + column_flag: 0x0080, + decimals: 4, + }; + + if &column_info.column_name == avg_count || &column_info.column_name == avg_sum { + if !is_added { + let mut avg_column_buf = Vec::with_capacity(128); + avg_column.encode(&mut avg_column_buf); + return (avg_column_buf, Some(avg_column)) + } else { + return (vec![], Some(avg_column)) + } + } + + (vec![], None) +} \ No newline at end of file diff --git a/pisa-proxy/runtime/mysql/src/transaction_fsm.rs b/pisa-proxy/runtime/mysql/src/transaction_fsm.rs index 77fdcb2e..11f62d39 100644 --- a/pisa-proxy/runtime/mysql/src/transaction_fsm.rs +++ b/pisa-proxy/runtime/mysql/src/transaction_fsm.rs @@ -113,6 +113,7 @@ pub fn query_rewrite( data_source: strategy::sharding_rewrite::DataSource::Endpoint(x.clone()), sharding_column: None, min_max_fields: vec![], + count_field: None, }) .collect::>();