From a3beba3321822a4b41f1767fdf9fc6781ea5f04d Mon Sep 17 00:00:00 2001 From: Carter Himmel Date: Tue, 10 Oct 2023 16:50:08 -0600 Subject: [PATCH] chore: save progress --- Cargo.lock | 15 + example/Cargo.toml | 6 +- example/benches/bench.rs | 33 +- example/expanded.rs | 831 +++++++++++++++++++ example/src/entities/person/mod.rs | 5 + example/src/entities/person/queries.rs | 5 + example/src/entities/person_login/queries.rs | 1 + scyllax-cli/src/model.rs | 1 + scyllax-macros-core/src/queries/read.rs | 84 +- scyllax-macros/src/lib.rs | 6 +- scyllax/src/executor.rs | 98 ++- scyllax/src/prelude.rs | 1 + 12 files changed, 1015 insertions(+), 71 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bbdf180..543bd73 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -591,12 +591,16 @@ version = "0.1.9-alpha" dependencies = [ "anyhow", "criterion", + "futures", + "futures-util", "pretty_assertions", + "rayon", "scylla", "scyllax", "serde", "serde_json", "tokio", + "tokio-stream", "tracing", "tracing-subscriber", "uuid", @@ -1911,6 +1915,17 @@ dependencies = [ "syn 2.0.29", ] +[[package]] +name = "tokio-stream" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml_datetime" version = "0.6.3" diff --git a/example/Cargo.toml b/example/Cargo.toml index 9a7f5d6..d2dced7 100644 --- a/example/Cargo.toml +++ b/example/Cargo.toml @@ -15,10 +15,14 @@ serde.workspace = true serde_json.workspace = true scylla.workspace = true scyllax = { path = "../scyllax" } -tokio.workspace = true +tokio = { version = "1", features = ["full"] } tracing.workspace = true tracing-subscriber.workspace = true uuid.workspace = true +tokio-stream = "0.1.14" +futures-util = "0.3.28" +futures = "0.3.28" +rayon = "1.8.0" [features] default = ["integration"] diff --git a/example/benches/bench.rs b/example/benches/bench.rs index 2f310d4..2bf69f8 100644 --- a/example/benches/bench.rs +++ b/example/benches/bench.rs @@ -1,23 +1,27 @@ //! benches -use example::entities::person::{self, queries::PersonQueries}; -use scyllax::prelude::{create_session, Executor}; use std::sync::Arc; + +use example::entities::{ + person::{self, queries::PersonQueries}, + PersonEntity, +}; +use scyllax::prelude::{create_session, Executor}; use tracing_subscriber::prelude::*; -async fn test_select(executor: Arc>) { +async fn test_select(executor: Arc>) -> Option { let query = person::queries::GetPersonByEmail { email: "foo1@scyllax.local".to_string(), }; - let _ = executor + executor .execute_read(query) .await - .expect("person not found"); + .expect("person not found") } -const RUNS: usize = 100_000; +const RUNS: usize = 1000; -#[tokio::main] +#[tokio::main(flavor = "multi_thread", worker_threads = 10)] async fn main() -> Result<(), anyhow::Error> { tracing_subscriber::registry() .with(tracing_subscriber::EnvFilter::from_default_env()) @@ -30,13 +34,20 @@ async fn main() -> Result<(), anyhow::Error> { let session = create_session(known_nodes, default_keyspace).await?; let executor = Arc::new(Executor::::new(Arc::new(session)).await?); - let start = std::time::Instant::now(); - for _ in 0..RUNS { - test_select(executor.clone()).await; + + let futures: Vec<_> = (0..RUNS) + .map(|_| { + let executor = executor.clone(); + tokio::spawn(test_select(executor)) + }) + .collect(); + let mut res = Vec::with_capacity(futures.len()); + for f in futures.into_iter() { + res.push(f.await.unwrap()); } - let end = std::time::Instant::now(); + let end = std::time::Instant::now(); println!("elapsed: {:#?}", end - start); println!("per run: {:?}", (end - start) / RUNS as u32); diff --git a/example/expanded.rs b/example/expanded.rs index e69de29..e23dcb6 100644 --- a/example/expanded.rs +++ b/example/expanded.rs @@ -0,0 +1,831 @@ +/// All select queries +pub mod queries { + use super::model::{UpsertPerson, UpsertPersonWithTTL}; + use scyllax::prelude::*; + use uuid::Uuid; + ///A collection of prepared statements. + #[allow(non_snake_case)] + pub struct PersonQueries { + #[allow(non_snake_case)] + ///The prepared statement for `GetPersonById`. + pub get_person_by_id: scylla_reexports::PreparedStatement, + #[allow(non_snake_case)] + ///The prepared statement for `GetPeopleByIds`. + pub get_people_by_ids: scylla_reexports::PreparedStatement, + #[allow(non_snake_case)] + ///The prepared statement for `GetPersonByEmail`. + pub get_person_by_email: scylla_reexports::PreparedStatement, + #[allow(non_snake_case)] + ///The prepared statement for `DeletePersonById`. + pub delete_person_by_id: scylla_reexports::PreparedStatement, + #[allow(non_snake_case)] + ///The prepared statement for `UpsertPerson`. + pub upsert_person: scylla_reexports::PreparedStatement, + #[allow(non_snake_case)] + ///The prepared statement for `UpsertPersonWithTTL`. + pub upsert_person_with_ttl: scylla_reexports::PreparedStatement, + #[allow(non_snake_case)] + ///The task for `GetPersonById`. + pub get_person_by_id_task: Option< + tokio::sync::mpsc::Sender>, + >, + #[allow(non_snake_case)] + ///The task for `GetPeopleByIds`. + pub get_people_by_ids_task: Option< + tokio::sync::mpsc::Sender>, + >, + #[allow(non_snake_case)] + ///The task for `GetPersonByEmail`. + pub get_person_by_email_task: Option< + tokio::sync::mpsc::Sender>, + >, + } + #[automatically_derived] + #[allow(non_snake_case)] + impl ::core::fmt::Debug for PersonQueries { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + let names: &'static _ = &[ + "get_person_by_id", + "get_people_by_ids", + "get_person_by_email", + "delete_person_by_id", + "upsert_person", + "upsert_person_with_ttl", + "get_person_by_id_task", + "get_people_by_ids_task", + "get_person_by_email_task", + ]; + let values: &[&dyn ::core::fmt::Debug] = &[ + &self.get_person_by_id, + &self.get_people_by_ids, + &self.get_person_by_email, + &self.delete_person_by_id, + &self.upsert_person, + &self.upsert_person_with_ttl, + &self.get_person_by_id_task, + &self.get_people_by_ids_task, + &&self.get_person_by_email_task, + ]; + ::core::fmt::Formatter::debug_struct_fields_finish( + f, + "PersonQueries", + names, + values, + ) + } + } + #[automatically_derived] + #[allow(non_snake_case)] + impl ::core::clone::Clone for PersonQueries { + #[inline] + fn clone(&self) -> PersonQueries { + PersonQueries { + get_person_by_id: ::core::clone::Clone::clone(&self.get_person_by_id), + get_people_by_ids: ::core::clone::Clone::clone(&self.get_people_by_ids), + get_person_by_email: ::core::clone::Clone::clone( + &self.get_person_by_email, + ), + delete_person_by_id: ::core::clone::Clone::clone( + &self.delete_person_by_id, + ), + upsert_person: ::core::clone::Clone::clone(&self.upsert_person), + upsert_person_with_ttl: ::core::clone::Clone::clone( + &self.upsert_person_with_ttl, + ), + get_person_by_id_task: ::core::clone::Clone::clone( + &self.get_person_by_id_task, + ), + get_people_by_ids_task: ::core::clone::Clone::clone( + &self.get_people_by_ids_task, + ), + get_person_by_email_task: ::core::clone::Clone::clone( + &self.get_person_by_email_task, + ), + } + } + } + ///A collection of prepared statements. + impl scyllax::prelude::QueryCollection for PersonQueries { + #[allow( + clippy::async_yields_async, + clippy::diverging_sub_expression, + clippy::let_unit_value, + clippy::no_effect_underscore_binding, + clippy::shadow_same, + clippy::type_complexity, + clippy::type_repetition_in_bounds, + clippy::used_underscore_binding + )] + fn new<'life0, 'async_trait>( + session: &'life0 scylla::Session, + ) -> ::core::pin::Pin< + Box< + dyn ::core::future::Future< + Output = Result, + > + ::core::marker::Send + 'async_trait, + >, + > + where + 'life0: 'async_trait, + Self: 'async_trait, + { + Box::pin(async move { + if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::< + Result, + > { + return __ret; + } + let __ret: Result = { + Ok(Self { + get_person_by_id: scyllax::prelude::prepare_query( + &session, + GetPersonById::query(), + ) + .await?, + get_people_by_ids: scyllax::prelude::prepare_query( + &session, + GetPeopleByIds::query(), + ) + .await?, + get_person_by_email: scyllax::prelude::prepare_query( + &session, + GetPersonByEmail::query(), + ) + .await?, + delete_person_by_id: scyllax::prelude::prepare_query( + &session, + DeletePersonById::query(), + ) + .await?, + upsert_person: scyllax::prelude::prepare_query( + &session, + UpsertPerson::query(), + ) + .await?, + upsert_person_with_ttl: scyllax::prelude::prepare_query( + &session, + UpsertPersonWithTTL::query(), + ) + .await?, + get_person_by_id_task: None, + get_people_by_ids_task: None, + get_person_by_email_task: None, + }) + }; + #[allow(unreachable_code)] __ret + }) + } + fn register_tasks( + mut self, + executor: std::sync::Arc>, + ) -> Self { + self + .get_person_by_id_task = { + let (tx, rx) = tokio::sync::mpsc::channel(100); + let ex = executor.clone(); + tokio::spawn(async move { + ex.read_task::(rx).await; + }); + Some(tx) + }; + self + .get_people_by_ids_task = { + let (tx, rx) = tokio::sync::mpsc::channel(100); + let ex = executor.clone(); + tokio::spawn(async move { + ex.read_task::(rx).await; + }); + Some(tx) + }; + self + .get_person_by_email_task = { + let (tx, rx) = tokio::sync::mpsc::channel(100); + let ex = executor.clone(); + tokio::spawn(async move { + ex.read_task::(rx).await; + }); + Some(tx) + }; + self + } + } + impl scyllax::prelude::GetPreparedStatement for PersonQueries { + ///Get a prepared statement. + fn get(&self) -> &scyllax::prelude::scylla_reexports::PreparedStatement { + &self.get_person_by_id + } + } + impl scyllax::prelude::GetPreparedStatement for PersonQueries { + ///Get a prepared statement. + fn get(&self) -> &scyllax::prelude::scylla_reexports::PreparedStatement { + &self.get_people_by_ids + } + } + impl scyllax::prelude::GetPreparedStatement for PersonQueries { + ///Get a prepared statement. + fn get(&self) -> &scyllax::prelude::scylla_reexports::PreparedStatement { + &self.get_person_by_email + } + } + impl scyllax::prelude::GetPreparedStatement for PersonQueries { + ///Get a prepared statement. + fn get(&self) -> &scyllax::prelude::scylla_reexports::PreparedStatement { + &self.delete_person_by_id + } + } + impl scyllax::prelude::GetPreparedStatement for PersonQueries { + ///Get a prepared statement. + fn get(&self) -> &scyllax::prelude::scylla_reexports::PreparedStatement { + &self.upsert_person + } + } + impl scyllax::prelude::GetPreparedStatement for PersonQueries { + ///Get a prepared statement. + fn get(&self) -> &scyllax::prelude::scylla_reexports::PreparedStatement { + &self.upsert_person_with_ttl + } + } + impl scyllax::prelude::GetCoalescingSender for PersonQueries { + ///Get a task. + fn get( + &self, + ) -> &tokio::sync::mpsc::Sender> { + &self.get_person_by_id_task.as_ref().unwrap() + } + } + impl scyllax::prelude::GetCoalescingSender for PersonQueries { + ///Get a task. + fn get( + &self, + ) -> &tokio::sync::mpsc::Sender< + scyllax::executor::ShardMessage, + > { + &self.get_people_by_ids_task.as_ref().unwrap() + } + } + impl scyllax::prelude::GetCoalescingSender for PersonQueries { + ///Get a task. + fn get( + &self, + ) -> &tokio::sync::mpsc::Sender< + scyllax::executor::ShardMessage, + > { + &self.get_person_by_email_task.as_ref().unwrap() + } + } + /// Get a [`super::model::PersonEntity`] by its [`uuid::Uuid`] + #[read_query( + query = "select * from person where id = :id limit 1", + return_type = "super::model::PersonEntity" + )] + pub struct GetPersonById { + /// The [`uuid::Uuid`] of the [`super::model::PersonEntity`] to get + #[read_query(coalesce_shard_key)] + pub id: Uuid, + } + #[automatically_derived] + impl ::core::fmt::Debug for GetPersonById { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + ::core::fmt::Formatter::debug_struct_field1_finish( + f, + "GetPersonById", + "id", + &&self.id, + ) + } + } + #[automatically_derived] + impl ::core::clone::Clone for GetPersonById { + #[inline] + fn clone(&self) -> GetPersonById { + GetPersonById { + id: ::core::clone::Clone::clone(&self.id), + } + } + } + #[automatically_derived] + impl ::core::marker::StructuralPartialEq for GetPersonById {} + #[automatically_derived] + impl ::core::cmp::PartialEq for GetPersonById { + #[inline] + fn eq(&self, other: &GetPersonById) -> bool { + self.id == other.id + } + } + impl scylla::_macro_internal::ValueList for GetPersonById { + fn serialized(&self) -> scylla::_macro_internal::SerializedResult { + let mut result = scylla::_macro_internal::SerializedValues::with_capacity( + 1usize, + ); + result.add_value(&self.id)?; + ::std::result::Result::Ok(::std::borrow::Cow::Owned(result)) + } + } + impl scyllax::prelude::Query for GetPersonById { + fn query() -> String { + "select * from person where id = :id limit 1" + .replace("*", &super::model::PersonEntity::keys().join(", ")) + } + fn bind(&self) -> scyllax::prelude::SerializedValuesResult { + let mut values = scylla_reexports::value::SerializedValues::new(); + values.add_named_value("id", &self.id)?; + Ok(values) + } + } + impl scyllax::prelude::ReadQuery for GetPersonById { + type Output = Option; + #[allow( + clippy::async_yields_async, + clippy::diverging_sub_expression, + clippy::let_unit_value, + clippy::no_effect_underscore_binding, + clippy::shadow_same, + clippy::type_complexity, + clippy::type_repetition_in_bounds, + clippy::used_underscore_binding + )] + fn parse_response<'async_trait>( + res: scylla::QueryResult, + ) -> ::core::pin::Pin< + Box< + dyn ::core::future::Future< + Output = Result, + > + ::core::marker::Send + 'async_trait, + >, + > + where + Self: 'async_trait, + { + Box::pin(async move { + if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::< + Result, + > { + return __ret; + } + let res = res; + let __ret: Result = { + match res.single_row_typed::() { + Ok(data) => Ok(Some(data)), + Err(err) => { + use scylla::transport::query_result::SingleRowTypedError; + match err { + SingleRowTypedError::BadNumberOfRows(_) => Ok(None), + _ => { + { + use ::tracing::__macro_support::Callsite as _; + static CALLSITE: ::tracing::callsite::DefaultCallsite = { + static META: ::tracing::Metadata<'static> = { + ::tracing_core::metadata::Metadata::new( + "event example/src/entities/person/queries.rs:12", + "example::entities::person::queries", + ::tracing::Level::ERROR, + Some("example/src/entities/person/queries.rs"), + Some(12u32), + Some("example::entities::person::queries"), + ::tracing_core::field::FieldSet::new( + &["message"], + ::tracing_core::callsite::Identifier(&CALLSITE), + ), + ::tracing::metadata::Kind::EVENT, + ) + }; + ::tracing::callsite::DefaultCallsite::new(&META) + }; + let enabled = ::tracing::Level::ERROR + <= ::tracing::level_filters::STATIC_MAX_LEVEL + && ::tracing::Level::ERROR + <= ::tracing::level_filters::LevelFilter::current() + && { + let interest = CALLSITE.interest(); + !interest.is_never() + && ::tracing::__macro_support::__is_enabled( + CALLSITE.metadata(), + interest, + ) + }; + if enabled { + (|value_set: ::tracing::field::ValueSet| { + let meta = CALLSITE.metadata(); + ::tracing::Event::dispatch(meta, &value_set); + })({ + #[allow(unused_imports)] + use ::tracing::field::{debug, display, Value}; + let mut iter = CALLSITE.metadata().fields().iter(); + CALLSITE + .metadata() + .fields() + .value_set( + &[ + ( + &iter.next().expect("FieldSet corrupted (this is a bug)"), + Some(&format_args!("err: {0:?}", err) as &dyn Value), + ), + ], + ) + }); + } else { + } + }; + Err(scyllax::error::ScyllaxError::SingleRowTyped(err)) + } + } + } + } + }; + #[allow(unreachable_code)] __ret + }) + } + fn shard_key(&self) -> String { + [self.id.to_string()].join(":") + } + } + /// Get many [`super::model::PersonEntity`] by many [`uuid::Uuid`] + #[read_query( + query = "select * from person where id in :ids limit :rowlimit", + return_type = "Vec" + )] + pub struct GetPeopleByIds { + /// The [`uuid::Uuid`]s of the [`super::model::PersonEntity`]s to get + pub ids: Vec, + /// The maximum number of [`super::model::PersonEntity`]s to get + pub rowlimit: i32, + } + #[automatically_derived] + impl ::core::fmt::Debug for GetPeopleByIds { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + ::core::fmt::Formatter::debug_struct_field2_finish( + f, + "GetPeopleByIds", + "ids", + &self.ids, + "rowlimit", + &&self.rowlimit, + ) + } + } + #[automatically_derived] + impl ::core::clone::Clone for GetPeopleByIds { + #[inline] + fn clone(&self) -> GetPeopleByIds { + GetPeopleByIds { + ids: ::core::clone::Clone::clone(&self.ids), + rowlimit: ::core::clone::Clone::clone(&self.rowlimit), + } + } + } + #[automatically_derived] + impl ::core::marker::StructuralPartialEq for GetPeopleByIds {} + #[automatically_derived] + impl ::core::cmp::PartialEq for GetPeopleByIds { + #[inline] + fn eq(&self, other: &GetPeopleByIds) -> bool { + self.ids == other.ids && self.rowlimit == other.rowlimit + } + } + impl scylla::_macro_internal::ValueList for GetPeopleByIds { + fn serialized(&self) -> scylla::_macro_internal::SerializedResult { + let mut result = scylla::_macro_internal::SerializedValues::with_capacity( + 2usize, + ); + result.add_value(&self.ids)?; + result.add_value(&self.rowlimit)?; + ::std::result::Result::Ok(::std::borrow::Cow::Owned(result)) + } + } + impl scyllax::prelude::Query for GetPeopleByIds { + fn query() -> String { + "select * from person where id in :ids limit :rowlimit" + .replace("*", &super::model::PersonEntity::keys().join(", ")) + } + fn bind(&self) -> scyllax::prelude::SerializedValuesResult { + let mut values = scylla_reexports::value::SerializedValues::new(); + values.add_named_value("ids", &self.ids)?; + values.add_named_value("rowlimit", &self.rowlimit)?; + Ok(values) + } + } + impl scyllax::prelude::ReadQuery for GetPeopleByIds { + type Output = Vec; + #[allow( + clippy::async_yields_async, + clippy::diverging_sub_expression, + clippy::let_unit_value, + clippy::no_effect_underscore_binding, + clippy::shadow_same, + clippy::type_complexity, + clippy::type_repetition_in_bounds, + clippy::used_underscore_binding + )] + fn parse_response<'async_trait>( + res: scylla::QueryResult, + ) -> ::core::pin::Pin< + Box< + dyn ::core::future::Future< + Output = Result, + > + ::core::marker::Send + 'async_trait, + >, + > + where + Self: 'async_trait, + { + Box::pin(async move { + if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::< + Result, + > { + return __ret; + } + let res = res; + let __ret: Result = { + match res.rows_typed::() { + Ok(xs) => { + Ok( + xs + .filter_map(|x| x.ok()) + .collect::>(), + ) + } + Err(e) => { + { + use ::tracing::__macro_support::Callsite as _; + static CALLSITE: ::tracing::callsite::DefaultCallsite = { + static META: ::tracing::Metadata<'static> = { + ::tracing_core::metadata::Metadata::new( + "event example/src/entities/person/queries.rs:24", + "example::entities::person::queries", + ::tracing::Level::ERROR, + Some("example/src/entities/person/queries.rs"), + Some(24u32), + Some("example::entities::person::queries"), + ::tracing_core::field::FieldSet::new( + &["message"], + ::tracing_core::callsite::Identifier(&CALLSITE), + ), + ::tracing::metadata::Kind::EVENT, + ) + }; + ::tracing::callsite::DefaultCallsite::new(&META) + }; + let enabled = ::tracing::Level::ERROR + <= ::tracing::level_filters::STATIC_MAX_LEVEL + && ::tracing::Level::ERROR + <= ::tracing::level_filters::LevelFilter::current() + && { + let interest = CALLSITE.interest(); + !interest.is_never() + && ::tracing::__macro_support::__is_enabled( + CALLSITE.metadata(), + interest, + ) + }; + if enabled { + (|value_set: ::tracing::field::ValueSet| { + let meta = CALLSITE.metadata(); + ::tracing::Event::dispatch(meta, &value_set); + })({ + #[allow(unused_imports)] + use ::tracing::field::{debug, display, Value}; + let mut iter = CALLSITE.metadata().fields().iter(); + CALLSITE + .metadata() + .fields() + .value_set( + &[ + ( + &iter.next().expect("FieldSet corrupted (this is a bug)"), + Some(&format_args!("err: {0:?}", e) as &dyn Value), + ), + ], + ) + }); + } else { + } + }; + Ok(::alloc::vec::Vec::new()) + } + } + }; + #[allow(unreachable_code)] __ret + }) + } + fn shard_key(&self) -> String { + String::new() + } + } + /// Get a [`super::model::PersonEntity`] by its email address + #[read_query( + query = "select * from person_by_email where email = :email limit 1", + return_type = "super::model::PersonEntity" + )] + pub struct GetPersonByEmail { + /// The email address of the [`super::model::PersonEntity`] to get + pub email: String, + } + #[automatically_derived] + impl ::core::fmt::Debug for GetPersonByEmail { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + ::core::fmt::Formatter::debug_struct_field1_finish( + f, + "GetPersonByEmail", + "email", + &&self.email, + ) + } + } + #[automatically_derived] + impl ::core::clone::Clone for GetPersonByEmail { + #[inline] + fn clone(&self) -> GetPersonByEmail { + GetPersonByEmail { + email: ::core::clone::Clone::clone(&self.email), + } + } + } + #[automatically_derived] + impl ::core::marker::StructuralPartialEq for GetPersonByEmail {} + #[automatically_derived] + impl ::core::cmp::PartialEq for GetPersonByEmail { + #[inline] + fn eq(&self, other: &GetPersonByEmail) -> bool { + self.email == other.email + } + } + impl scylla::_macro_internal::ValueList for GetPersonByEmail { + fn serialized(&self) -> scylla::_macro_internal::SerializedResult { + let mut result = scylla::_macro_internal::SerializedValues::with_capacity( + 1usize, + ); + result.add_value(&self.email)?; + ::std::result::Result::Ok(::std::borrow::Cow::Owned(result)) + } + } + impl scyllax::prelude::Query for GetPersonByEmail { + fn query() -> String { + "select * from person_by_email where email = :email limit 1" + .replace("*", &super::model::PersonEntity::keys().join(", ")) + } + fn bind(&self) -> scyllax::prelude::SerializedValuesResult { + let mut values = scylla_reexports::value::SerializedValues::new(); + values.add_named_value("email", &self.email)?; + Ok(values) + } + } + impl scyllax::prelude::ReadQuery for GetPersonByEmail { + type Output = Option; + #[allow( + clippy::async_yields_async, + clippy::diverging_sub_expression, + clippy::let_unit_value, + clippy::no_effect_underscore_binding, + clippy::shadow_same, + clippy::type_complexity, + clippy::type_repetition_in_bounds, + clippy::used_underscore_binding + )] + fn parse_response<'async_trait>( + res: scylla::QueryResult, + ) -> ::core::pin::Pin< + Box< + dyn ::core::future::Future< + Output = Result, + > + ::core::marker::Send + 'async_trait, + >, + > + where + Self: 'async_trait, + { + Box::pin(async move { + if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::< + Result, + > { + return __ret; + } + let res = res; + let __ret: Result = { + match res.single_row_typed::() { + Ok(data) => Ok(Some(data)), + Err(err) => { + use scylla::transport::query_result::SingleRowTypedError; + match err { + SingleRowTypedError::BadNumberOfRows(_) => Ok(None), + _ => { + { + use ::tracing::__macro_support::Callsite as _; + static CALLSITE: ::tracing::callsite::DefaultCallsite = { + static META: ::tracing::Metadata<'static> = { + ::tracing_core::metadata::Metadata::new( + "event example/src/entities/person/queries.rs:37", + "example::entities::person::queries", + ::tracing::Level::ERROR, + Some("example/src/entities/person/queries.rs"), + Some(37u32), + Some("example::entities::person::queries"), + ::tracing_core::field::FieldSet::new( + &["message"], + ::tracing_core::callsite::Identifier(&CALLSITE), + ), + ::tracing::metadata::Kind::EVENT, + ) + }; + ::tracing::callsite::DefaultCallsite::new(&META) + }; + let enabled = ::tracing::Level::ERROR + <= ::tracing::level_filters::STATIC_MAX_LEVEL + && ::tracing::Level::ERROR + <= ::tracing::level_filters::LevelFilter::current() + && { + let interest = CALLSITE.interest(); + !interest.is_never() + && ::tracing::__macro_support::__is_enabled( + CALLSITE.metadata(), + interest, + ) + }; + if enabled { + (|value_set: ::tracing::field::ValueSet| { + let meta = CALLSITE.metadata(); + ::tracing::Event::dispatch(meta, &value_set); + })({ + #[allow(unused_imports)] + use ::tracing::field::{debug, display, Value}; + let mut iter = CALLSITE.metadata().fields().iter(); + CALLSITE + .metadata() + .fields() + .value_set( + &[ + ( + &iter.next().expect("FieldSet corrupted (this is a bug)"), + Some(&format_args!("err: {0:?}", err) as &dyn Value), + ), + ], + ) + }); + } else { + } + }; + Err(scyllax::error::ScyllaxError::SingleRowTyped(err)) + } + } + } + } + }; + #[allow(unreachable_code)] __ret + }) + } + fn shard_key(&self) -> String { + String::new() + } + } + /// Get a [`super::model::PersonEntity`] by its [`uuid::Uuid`] + pub struct DeletePersonById { + /// The [`uuid::Uuid`] of the [`super::model::PersonEntity`] to get + pub id: Uuid, + } + #[automatically_derived] + impl ::core::fmt::Debug for DeletePersonById { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + ::core::fmt::Formatter::debug_struct_field1_finish( + f, + "DeletePersonById", + "id", + &&self.id, + ) + } + } + #[automatically_derived] + impl ::core::clone::Clone for DeletePersonById { + #[inline] + fn clone(&self) -> DeletePersonById { + DeletePersonById { + id: ::core::clone::Clone::clone(&self.id), + } + } + } + #[automatically_derived] + impl ::core::marker::StructuralPartialEq for DeletePersonById {} + #[automatically_derived] + impl ::core::cmp::PartialEq for DeletePersonById { + #[inline] + fn eq(&self, other: &DeletePersonById) -> bool { + self.id == other.id + } + } + #[automatically_derived] + impl ::core::hash::Hash for DeletePersonById { + #[inline] + fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) -> () { + ::core::hash::Hash::hash(&self.id, state) + } + } + impl scyllax::prelude::Query for DeletePersonById { + fn query() -> String { + "delete from person where id = :id".to_string() + } + fn bind(&self) -> scyllax::prelude::SerializedValuesResult { + let mut values = scylla_reexports::value::SerializedValues::new(); + values.add_named_value("id", &self.id)?; + Ok(values) + } + } + impl scyllax::prelude::WriteQuery for DeletePersonById {} +} diff --git a/example/src/entities/person/mod.rs b/example/src/entities/person/mod.rs index 56f9e15..e8a7a32 100644 --- a/example/src/entities/person/mod.rs +++ b/example/src/entities/person/mod.rs @@ -1,4 +1,9 @@ //! The Person entity +//! Viewing the expanded code: +//! +//! ```console +//! cargo expand -p example --lib entities::person::model > example/expanded.rs +//! ``` /// The model itself pub mod model; diff --git a/example/src/entities/person/queries.rs b/example/src/entities/person/queries.rs index b069f63..ac743d9 100644 --- a/example/src/entities/person/queries.rs +++ b/example/src/entities/person/queries.rs @@ -9,16 +9,19 @@ create_query_collection!( ); /// Get a [`super::model::PersonEntity`] by its [`uuid::Uuid`] +#[derive(Debug, Clone, PartialEq, ValueList, ReadQuery)] #[read_query( query = "select * from person where id = :id limit 1", return_type = "super::model::PersonEntity" )] pub struct GetPersonById { /// The [`uuid::Uuid`] of the [`super::model::PersonEntity`] to get + #[read_query(coalesce_shard_key)] pub id: Uuid, } /// Get many [`super::model::PersonEntity`] by many [`uuid::Uuid`] +#[derive(Debug, Clone, PartialEq, ValueList, ReadQuery)] #[read_query( query = "select * from person where id in :ids limit :rowlimit", return_type = "Vec" @@ -31,12 +34,14 @@ pub struct GetPeopleByIds { } /// Get a [`super::model::PersonEntity`] by its email address +#[derive(Debug, Clone, PartialEq, ValueList, ReadQuery)] #[read_query( query = "select * from person_by_email where email = :email limit 1", return_type = "super::model::PersonEntity" )] pub struct GetPersonByEmail { /// The email address of the [`super::model::PersonEntity`] to get + #[read_query(coalesce_shard_key)] pub email: String, } diff --git a/example/src/entities/person_login/queries.rs b/example/src/entities/person_login/queries.rs index bbddc86..74ebdb1 100644 --- a/example/src/entities/person_login/queries.rs +++ b/example/src/entities/person_login/queries.rs @@ -9,6 +9,7 @@ create_query_collection!( ); /// Get a [`super::model::PersonLoginEntity`] by its [`uuid::Uuid`] +#[derive(Debug, Clone, PartialEq, ValueList, ReadQuery)] #[read_query( query = "select * from person_login where id = :id limit 1", return_type = "super::model::PersonLoginEntity" diff --git a/scyllax-cli/src/model.rs b/scyllax-cli/src/model.rs index 6003a69..dd4d547 100644 --- a/scyllax-cli/src/model.rs +++ b/scyllax-cli/src/model.rs @@ -22,6 +22,7 @@ pub struct MigrationEntity { } // get the latest version from the database +#[derive(Debug, Clone, PartialEq, ValueList, ReadQuery)] #[read_query( query_nocheck = "select * from migration where bucket = 0 order by version desc limit 1", return_type = "MigrationEntity" diff --git a/scyllax-macros-core/src/queries/read.rs b/scyllax-macros-core/src/queries/read.rs index 827f531..8b24622 100644 --- a/scyllax-macros-core/src/queries/read.rs +++ b/scyllax-macros-core/src/queries/read.rs @@ -1,41 +1,56 @@ -use darling::{export::NestedMeta, FromMeta}; +use darling::{ast, FromDeriveInput, FromField}; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use scyllax_parser::{select::parse_select, SelectQuery, Value, Variable}; -use syn::ItemStruct; +use syn::{DeriveInput, Ident, ItemStruct, Type}; use crate::queries::impl_generic_query; -#[derive(FromMeta)] -pub(crate) struct SelectQueryOptions { - query: Option, - query_nocheck: Option, - return_type: syn::Type, +#[derive(Debug, PartialEq, FromField)] +#[darling(attributes(read_query))] +pub struct ReadQueryDeriveVariable { + pub ident: Option, + pub ty: Type, + #[darling(default)] + pub coalesce_shard_key: bool, } -pub fn expand(args: TokenStream, item: TokenStream) -> TokenStream { - let attr_args = match NestedMeta::parse_meta_list(args.clone()) { - Ok(args) => args, - Err(e) => return darling::Error::from(e).write_errors(), +#[derive(Debug, PartialEq, FromDeriveInput)] +#[darling(attributes(read_query), supports(struct_named))] +pub struct ReadQueryDerive { + pub ident: syn::Ident, + pub data: ast::Data<(), ReadQueryDeriveVariable>, + + #[darling(default)] + pub query: Option, + #[darling(default)] + pub query_nocheck: Option, + pub return_type: syn::Type, +} + +pub fn expand(input: TokenStream) -> TokenStream { + let parsed_input: DeriveInput = match syn::parse2(input.clone()) { + Ok(it) => it, + Err(e) => return e.to_compile_error(), }; - let args = match SelectQueryOptions::from_list(&attr_args) { - Ok(o) => o, + let args = match ReadQueryDerive::from_derive_input(&parsed_input) { + Ok(i) => i, Err(e) => return e.write_errors(), }; + let fields = args + .data + .take_struct() + .expect("Should never be enum") + .fields; if args.query.is_none() && args.query_nocheck.is_none() { - return syn::Error::new_spanned(item, "Either query or query_nocheck must be specified") + return syn::Error::new_spanned(input, "Either query or query_nocheck must be specified") .to_compile_error(); } - let return_type = args.return_type; - - let input: ItemStruct = match syn::parse2(item.clone()) { - Ok(it) => it, - Err(e) => return e.to_compile_error(), - }; - let struct_ident = &input.ident; + let struct_ident = args.ident; + let r#struct = syn::parse2::(input.clone()).unwrap(); // trimmed return_type // eg: Vec -> OrgEntity @@ -82,7 +97,7 @@ pub fn expand(args: TokenStream, item: TokenStream) -> TokenStream { // query parsing let query = if let Some(query) = args.query { - match parse_query(&input, &query) { + match parse_query(&r#struct, &query) { Ok(_) => (), Err(e) => return e.to_compile_error(), }; @@ -136,12 +151,25 @@ pub fn expand(args: TokenStream, item: TokenStream) -> TokenStream { .to_compile_error(); }; - let impl_query = impl_generic_query(&input, query, Some(&inner_entity_type)); + let impl_query = impl_generic_query(&r#struct, query, Some(&inner_entity_type)); - quote! { - #[derive(scylla::ValueList, std::fmt::Debug, std::clone::Clone, PartialEq, Hash)] - #input + let shard_keys = fields + .iter() + .filter(|v| v.coalesce_shard_key) + .map(|v| v.ident.as_ref().unwrap()) + .collect::>(); + let shard_key = if !shard_keys.is_empty() { + // create a redis-like shard key, joining all shard keys with a colon + quote! { + [#(self.#shard_keys.to_string()),*].join(":") + } + } else { + quote! { + String::new() + } + }; + quote! { #impl_query #[scyllax::prelude::async_trait] @@ -153,6 +181,10 @@ pub fn expand(args: TokenStream, item: TokenStream) -> TokenStream { { #parser } + + fn shard_key(&self) -> String { + #shard_key + } } } } diff --git a/scyllax-macros/src/lib.rs b/scyllax-macros/src/lib.rs index 5c4f6cc..92428a3 100644 --- a/scyllax-macros/src/lib.rs +++ b/scyllax-macros/src/lib.rs @@ -30,9 +30,9 @@ use scyllax_macros_core::{entity, json, prepare, queries, r#enum}; /// executor.execute_select(GetPeopleByIds { ids, limit }).await?; /// // -> Vec /// ``` -#[proc_macro_attribute] -pub fn read_query(args: TokenStream, input: TokenStream) -> TokenStream { - queries::read::expand(args.into(), input.into()).into() +#[proc_macro_derive(ReadQuery, attributes(read_query))] +pub fn read_query(input: TokenStream) -> TokenStream { + queries::read::expand(input.into()).into() } /// Apply this attribute to a struct to generate a write query. diff --git a/scyllax/src/executor.rs b/scyllax/src/executor.rs index e370e25..4ae5a69 100644 --- a/scyllax/src/executor.rs +++ b/scyllax/src/executor.rs @@ -7,7 +7,7 @@ use crate::{ }; use scylla::{prepared_statement::PreparedStatement, QueryResult, Session, SessionBuilder}; use std::{collections::HashMap, sync::Arc}; -use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::{sync::{mpsc::{Receiver, Sender, self}, MutexGuard, Mutex}, task::JoinSet}; use tokio::sync::oneshot; /// Creates a new [`CachingSession`] and returns it @@ -35,7 +35,7 @@ pub trait GetCoalescingSender { fn get(&self) -> &Sender>; } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Executor { pub session: Arc, queries: T, @@ -43,12 +43,20 @@ pub struct Executor { pub type ShardMessage = ( Q, - oneshot::Sender::Output, ScyllaxError>>, + oneshot::Sender>, ); type TaskRequestMap = - HashMap::Output, ScyllaxError>>>>; + HashMap>>>; -impl Executor { +type ReadQueryResult = Arc::Output, ScyllaxError>>; + +pub struct QueryRunnerMessage { + key: String, + query: Q, + response_rx: oneshot::Sender>, +} + +impl Executor { pub async fn new(session: Arc) -> Result { let queries = T::new(&session).await?; let executor = Arc::new(Self { @@ -57,6 +65,7 @@ impl Executor { }); let queries = executor.queries.clone().register_tasks(executor); + // let queries = Arc::new(queries); let executor = Self { session, queries }; Ok(executor) @@ -73,45 +82,74 @@ impl Executor { task.send((query, response_tx)).await.unwrap(); match response_rx.await { - Ok(result) => result, + Ok(result) => { + let result = Arc::try_unwrap(result).unwrap(); + result + }, Err(e) => Err(ScyllaxError::ReceiverError(e)), } } - pub async fn read_task(&self, mut rx: Receiver>) + /// the read task is responsible for coalescing requests + pub async fn read_task(&self, mut request_receiver: Receiver>) where - Q: Query + ReadQuery, + Q: Query + ReadQuery + Send + Sync + 'static, T: GetPreparedStatement + GetCoalescingSender, { let mut requests: TaskRequestMap = HashMap::new(); - - while let Some((query, tx)) = rx.recv().await { - let key = query.shard_key(); - - if let Some(senders) = requests.get_mut(&key) { - senders.push(tx); - } else { - let senders = vec![tx]; - requests.insert(key.clone(), senders); - - // Execute the query here and send the result back - // let result = self.execute_read(&query).await; - let statement = self.queries.get_prepared::(); - // FIXME: better error handling - let variables = query.bind().unwrap(); - // FIXME: better error handling - let result = self.session.execute(statement, variables).await.unwrap(); - let parsed = Q::parse_response(result).await; - - if let Some(senders) = requests.remove(&key) { - for tx in senders { - let _ = tx.send(parsed.clone()); + let mut join_set: JoinSet<_> = JoinSet::new(); + + loop { + tokio::select! { + Some((query, tx)) = request_receiver.recv() => { + let key = query.shard_key(); + if let Some(senders) = requests.get_mut(&key) { + senders.push(tx); + } else { + requests.insert(key.clone(), vec![tx]); + // let (response_rx, response_tx) = oneshot::channel(); + // let _ = runner.send(QueryRunnerMessage { key: key.clone(), query, response_rx }).await; + + let session = self.session.clone(); + let statement = self.queries.get_prepared::(); + let handle = Executor::::perform_read_query( + session, + statement, + query + ); + join_set.spawn(async move { + (key, handle.await) + }); + } + }, + Some(join_handle) = join_set.join_next() => { + if let Ok((key, result)) = join_handle { + if let Some(senders) = requests.remove(&key) { + let result = Arc::new(result); + for sender in senders { + let _ = sender.send(result.clone()); + } + } } } } } } + /// this function does the requests themselves + async fn perform_read_query(session: Arc, statement: &PreparedStatement, query: Q) -> Result<::Output, ScyllaxError> + where + Q: Query + ReadQuery + Send + Sync, + T: GetPreparedStatement + GetCoalescingSender, + { + // FIXME: better error handling + let variables = query.bind().unwrap(); + // FIXME: better error handling + let result = session.execute(statement, variables).await.unwrap(); + + Q::parse_response(result).await + } + pub async fn execute_write(&self, query: &Q) -> Result where Q: Query + WriteQuery, diff --git a/scyllax/src/prelude.rs b/scyllax/src/prelude.rs index 00a31d6..19a04f7 100644 --- a/scyllax/src/prelude.rs +++ b/scyllax/src/prelude.rs @@ -9,6 +9,7 @@ pub use crate::{ util::v1_uuid, }; pub use async_trait::async_trait; +pub use scylla_reexports::*; pub use scyllax_macros::*; pub mod scylla_reexports {