From f007d9311e0ec727a33e4a465ab48a0bee94cb1d Mon Sep 17 00:00:00 2001 From: Patrick Casey Date: Thu, 19 Sep 2024 16:43:16 -0400 Subject: [PATCH] feat: added preliminary SDK support to "rand_data" plugin --- Cargo.lock | 29 ++ Cargo.toml | 9 +- hipcheck/src/cli.rs | 2 + hipcheck/src/engine.rs | 23 +- hipcheck/src/main.rs | 10 +- hipcheck/src/plugin/mod.rs | 9 +- hipcheck/src/plugin/types.rs | 3 + plugins/dummy_rand_data/src/transport.rs | 13 +- plugins/dummy_rand_data_sdk/Cargo.toml | 13 + .../schema/query_schema_get_rand.json | 3 + plugins/dummy_rand_data_sdk/src/main.rs | 118 +++++ plugins/dummy_sha256/src/main.rs | 4 +- plugins/dummy_sha256/src/transport.rs | 27 +- .../default_policy_expr_request.proto | 12 + .../default_policy_expr_response.proto | 15 + proto/hipcheck/v1/messages/empty.proto | 7 + .../explain_default_query_request.proto | 8 + .../explain_default_query_response.proto | 12 + proto/hipcheck/v1/messages/query.proto | 46 ++ .../hipcheck/v1/messages/query_request.proto | 8 + .../hipcheck/v1/messages/query_response.proto | 8 + .../v1/messages/query_schemas_request.proto | 8 + .../v1/messages/query_schemas_response.proto | 18 + proto/hipcheck/v1/messages/query_state.proto | 19 + .../v1/messages/set_config_request.proto | 8 + .../v1/messages/set_config_response.proto | 9 + proto/hipcheck/v1/plugin_service.proto | 67 +++ sdk/rust/Cargo.toml | 21 + sdk/rust/build.rs | 6 + sdk/rust/src/error.rs | 140 ++++++ sdk/rust/src/lib.rs | 146 ++++++ sdk/rust/src/plugin_engine.rs | 475 ++++++++++++++++++ sdk/rust/src/plugin_server.rs | 184 +++++++ 33 files changed, 1448 insertions(+), 32 deletions(-) create mode 100644 plugins/dummy_rand_data_sdk/Cargo.toml create mode 100644 plugins/dummy_rand_data_sdk/schema/query_schema_get_rand.json create mode 100644 plugins/dummy_rand_data_sdk/src/main.rs create mode 100644 proto/hipcheck/v1/messages/default_policy_expr_request.proto create mode 100644 proto/hipcheck/v1/messages/default_policy_expr_response.proto create mode 100644 proto/hipcheck/v1/messages/empty.proto create mode 100644 proto/hipcheck/v1/messages/explain_default_query_request.proto create mode 100644 proto/hipcheck/v1/messages/explain_default_query_response.proto create mode 100644 proto/hipcheck/v1/messages/query.proto create mode 100644 proto/hipcheck/v1/messages/query_request.proto create mode 100644 proto/hipcheck/v1/messages/query_response.proto create mode 100644 proto/hipcheck/v1/messages/query_schemas_request.proto create mode 100644 proto/hipcheck/v1/messages/query_schemas_response.proto create mode 100644 proto/hipcheck/v1/messages/query_state.proto create mode 100644 proto/hipcheck/v1/messages/set_config_request.proto create mode 100644 proto/hipcheck/v1/messages/set_config_response.proto create mode 100644 proto/hipcheck/v1/plugin_service.proto create mode 100644 sdk/rust/Cargo.toml create mode 100644 sdk/rust/build.rs create mode 100644 sdk/rust/src/error.rs create mode 100644 sdk/rust/src/lib.rs create mode 100644 sdk/rust/src/plugin_engine.rs create mode 100644 sdk/rust/src/plugin_server.rs diff --git a/Cargo.lock b/Cargo.lock index 718d620c..037dd011 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -743,6 +743,17 @@ dependencies = [ "tonic-build", ] +[[package]] +name = "dummy_rand_data_sdk" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "hipcheck-sdk", + "rand", + "tokio", +] + [[package]] name = "dummy_sha256" version = "0.1.0" @@ -1239,6 +1250,24 @@ dependencies = [ "syn 2.0.75", ] +[[package]] +name = "hipcheck-sdk" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures", + "indexmap 2.5.0", + "prost", + "rand", + "schemars", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tonic", + "tonic-build", +] + [[package]] name = "hmac" version = "0.12.1" diff --git a/Cargo.toml b/Cargo.toml index 9cb2fe19..f9d0dc2a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,8 +15,10 @@ members = [ "hipcheck", "hipcheck-macros", "xtask", + "plugins/dummy_rand_data_sdk", "plugins/dummy_rand_data", "plugins/dummy_sha256", + "sdk/rust", ] # Make sure Hipcheck is run with `cargo run`. @@ -34,7 +36,12 @@ ci = "github" # The installers to generate for each app installers = ["shell", "powershell"] # Target platforms to build apps for (Rust target-triple syntax) -targets = ["aarch64-apple-darwin", "x86_64-apple-darwin", "x86_64-unknown-linux-gnu", "x86_64-pc-windows-msvc"] +targets = [ + "aarch64-apple-darwin", + "x86_64-apple-darwin", + "x86_64-unknown-linux-gnu", + "x86_64-pc-windows-msvc", +] # Which actions to run on pull requests pr-run-mode = "plan" # Whether to install an updater program diff --git a/hipcheck/src/cli.rs b/hipcheck/src/cli.rs index 14e23a3e..1db7cad6 100644 --- a/hipcheck/src/cli.rs +++ b/hipcheck/src/cli.rs @@ -922,6 +922,8 @@ impl TryFrom> for RepoCacheDeleteScope { pub struct PluginArgs { #[arg(long = "async")] pub asynch: bool, + #[arg(long = "sdk")] + pub sdk: bool, } /// The format to report results in. diff --git a/hipcheck/src/engine.rs b/hipcheck/src/engine.rs index 0a4ab53b..02b2d2fd 100644 --- a/hipcheck/src/engine.rs +++ b/hipcheck/src/engine.rs @@ -90,19 +90,22 @@ fn query( }; // Initiate the query. If remote closed or we got our response immediately, // return - println!("Querying {plugin}::{query} with key {key:?}"); + eprintln!("Querying {plugin}::{query} with key {key:?}"); let mut ar = match runtime.block_on(p_handle.query(query, key))? { PluginResponse::RemoteClosed => { return Err(hc_error!("Plugin channel closed unexpected")); } PluginResponse::Completed(v) => return Ok(v), - PluginResponse::AwaitingResult(a) => a, + PluginResponse::AwaitingResult(a) => { + eprintln!("awaiting result: {:?}", a); + a + } }; // Otherwise, the plugin needs more data to continue. Recursively query // (with salsa memo-ization) to get the needed data, and resume our // current query by providing the plugin the answer. loop { - println!("Query needs more info, recursing..."); + eprintln!("Query needs more info, recursing..."); let answer = db .query( ar.publisher.clone(), @@ -111,7 +114,7 @@ fn query( ar.key.clone(), )? .value; - println!("Got answer {answer:?}, resuming"); + eprintln!("Got answer {answer:?}, resuming"); ar = match runtime.block_on(p_handle.resume_query(ar, answer))? { PluginResponse::RemoteClosed => { return Err(hc_error!("Plugin channel closed unexpected")); @@ -138,7 +141,7 @@ pub fn async_query( }; // Initiate the query. If remote closed or we got our response immediately, // return - println!("Querying: {query}, key: {key:?}"); + eprintln!("Querying: {query}, key: {key:?}"); let mut ar = match p_handle.query(query, key).await? { PluginResponse::RemoteClosed => { return Err(hc_error!("Plugin channel closed unexpected")); @@ -152,7 +155,7 @@ pub fn async_query( // (with salsa memo-ization) to get the needed data, and resume our // current query by providing the plugin the answer. loop { - println!("Awaiting result, now recursing"); + eprintln!("Awaiting result, now recursing"); let answer = async_query( Arc::clone(&core), ar.publisher.clone(), @@ -162,12 +165,14 @@ pub fn async_query( ) .await? .value; - println!("Resuming query with answer {answer:?}"); + eprintln!("Resuming query with answer {answer:?}"); ar = match p_handle.resume_query(ar, answer).await? { PluginResponse::RemoteClosed => { return Err(hc_error!("Plugin channel closed unexpected")); } - PluginResponse::Completed(v) => return Ok(v), + PluginResponse::Completed(v) => { + return Ok(v); + } PluginResponse::AwaitingResult(a) => a, }; } @@ -198,7 +203,7 @@ impl HcEngineImpl { // independent of Salsa. pub fn new(executor: PluginExecutor, plugins: Vec) -> Result { let runtime = RUNTIME.handle(); - println!("Starting HcPluginCore"); + eprintln!("Starting HcPluginCore"); let core = runtime.block_on(HcPluginCore::new(executor, plugins))?; let mut engine = HcEngineImpl { storage: Default::default(), diff --git a/hipcheck/src/main.rs b/hipcheck/src/main.rs index 6c45f02e..64b0b4f7 100644 --- a/hipcheck/src/main.rs +++ b/hipcheck/src/main.rs @@ -540,7 +540,15 @@ fn cmd_plugin(args: PluginArgs) { use tokio::task::JoinSet; let tgt_dir = "./target/debug"; - let entrypoint1 = pathbuf![tgt_dir, "dummy_rand_data"]; + + let entrypoint1 = match args.sdk { + true => { + pathbuf![tgt_dir, "dummy_rand_data_sdk"] + } + false => { + pathbuf![tgt_dir, "dummy_rand_data"] + } + }; let entrypoint2 = pathbuf![tgt_dir, "dummy_sha256"]; let plugin1 = Plugin { name: "dummy/rand_data".to_owned(), diff --git a/hipcheck/src/plugin/mod.rs b/hipcheck/src/plugin/mod.rs index 338971cd..a8c3327c 100644 --- a/hipcheck/src/plugin/mod.rs +++ b/hipcheck/src/plugin/mod.rs @@ -71,17 +71,22 @@ impl ActivePlugin { pub async fn query(&self, name: String, key: Value) -> Result { let id = self.get_unique_id().await; + + // TODO: remove this unwrap + let (publisher, plugin) = self.channel.name().split_once('/').unwrap(); + // @Todo - check name+key valid for schema let query = Query { id, request: true, - publisher: "".to_owned(), - plugin: self.channel.name().to_owned(), + publisher: publisher.to_owned(), + plugin: plugin.to_owned(), query: name, key, output: serde_json::json!(null), concerns: vec![], }; + Ok(self.channel.query(query).await?.into()) } diff --git a/hipcheck/src/plugin/types.rs b/hipcheck/src/plugin/types.rs index 2687a440..94b7dde4 100644 --- a/hipcheck/src/plugin/types.rs +++ b/hipcheck/src/plugin/types.rs @@ -478,6 +478,7 @@ impl PluginTransport { // Send the query let query: PluginQuery = query.try_into()?; + let id = query.id; self.tx .send(query) @@ -499,6 +500,8 @@ impl PluginTransport { while matches!(state, ReplyInProgress) { // We expect another message. Pull it off the existing queue, // or get a new one if we have run out + eprintln!("In progress"); + let next = match msg_chunks.pop_front() { Some(msg) => msg, None => { diff --git a/plugins/dummy_rand_data/src/transport.rs b/plugins/dummy_rand_data/src/transport.rs index cf6f0434..18d10f1c 100644 --- a/plugins/dummy_rand_data/src/transport.rs +++ b/plugins/dummy_rand_data/src/transport.rs @@ -13,7 +13,7 @@ use std::{ pin::Pin, }; use tokio::sync::mpsc::{self, error::TrySendError}; -use tonic::{Status, Streaming}; +use tonic::Status; #[derive(Debug)] pub struct Query { @@ -236,9 +236,12 @@ impl Drop for QuerySession { } } +type PluginQueryStream = + Box> + Send + Unpin + 'static>; + pub struct HcSessionSocket { tx: mpsc::Sender>, - rx: Streaming, + rx: PluginQueryStream, drop_tx: mpsc::Sender, drop_rx: mpsc::Receiver, sessions: SessionTracker, @@ -261,7 +264,7 @@ impl std::fmt::Debug for HcSessionSocket { impl HcSessionSocket { pub fn new( tx: mpsc::Sender>, - rx: Streaming, + rx: impl Stream> + Send + Unpin + 'static, ) -> Self { // channel for QuerySession objects to notify us they dropped // @Todo - make this configurable @@ -269,7 +272,7 @@ impl HcSessionSocket { Self { tx, - rx, + rx: Box::new(rx), drop_tx, drop_rx, sessions: HashMap::new(), @@ -289,7 +292,7 @@ impl HcSessionSocket { } async fn message(&mut self) -> Result, Status> { - let fut = poll_fn(|cx| Pin::new(&mut self.rx).poll_next(cx)); + let fut = poll_fn(|cx| Pin::new(&mut *self.rx).poll_next(cx)); match fut.await { Some(Ok(m)) => Ok(m.query), diff --git a/plugins/dummy_rand_data_sdk/Cargo.toml b/plugins/dummy_rand_data_sdk/Cargo.toml new file mode 100644 index 00000000..d0b3027c --- /dev/null +++ b/plugins/dummy_rand_data_sdk/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "dummy_rand_data_sdk" +version = "0.1.0" +license = "Apache-2.0" +edition = "2021" +publish = false + +[dependencies] +anyhow = "1.0.87" +clap = { version = "4.5.16", features = ["derive"] } +hipcheck-sdk = { path = "../../sdk/rust" } +rand = "0.8.5" +tokio = { version = "1.40.0", features = ["rt"] } diff --git a/plugins/dummy_rand_data_sdk/schema/query_schema_get_rand.json b/plugins/dummy_rand_data_sdk/schema/query_schema_get_rand.json new file mode 100644 index 00000000..8b50ea30 --- /dev/null +++ b/plugins/dummy_rand_data_sdk/schema/query_schema_get_rand.json @@ -0,0 +1,3 @@ +{ + "type": "integer" +} diff --git a/plugins/dummy_rand_data_sdk/src/main.rs b/plugins/dummy_rand_data_sdk/src/main.rs new file mode 100644 index 00000000..dc35ca57 --- /dev/null +++ b/plugins/dummy_rand_data_sdk/src/main.rs @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: Apache-2.0 + +use anyhow::Result; +use clap::Parser; +use hipcheck_sdk::{ + deps::{async_trait, from_str, JsonSchema, Value}, + error::Error, + plugin_engine::PluginEngine, + plugin_server::PluginServer, + NamedQuery, Plugin, Query, +}; + +static GET_RAND_KEY_SCHEMA: &str = include_str!("../schema/query_schema_get_rand.json"); +static GET_RAND_OUTPUT_SCHEMA: &str = include_str!("../schema/query_schema_get_rand.json"); + +fn reduce(input: u64) -> u64 { + input % 7 +} + +/// Plugin that queries hipcheck takes a `Value::Number` as input and performs following steps: +/// - ensures input is u64 +/// - % 7 of input +/// - queries `hipcheck` for sha256 of (% 7 of input) +/// - returns `Value::Number`, where Number is the first `u8` in the sha256 +/// +/// Goals of this plugin +/// - Verify `salsa` memoization is working (there should only ever be 7 queries made to `hipcheck`) +/// - Verify plugins are able to query `hipcheck` for additional information +#[derive(Clone, Debug)] +struct RandDataPlugin; + +#[async_trait] +impl Query for RandDataPlugin { + fn input_schema(&self) -> JsonSchema { + from_str(GET_RAND_KEY_SCHEMA).unwrap() + } + + fn output_schema(&self) -> JsonSchema { + from_str(GET_RAND_OUTPUT_SCHEMA).unwrap() + } + + async fn run( + &self, + engine: &mut PluginEngine, + input: Value, + ) -> hipcheck_sdk::error::Result { + let Value::Number(num_size) = input else { + return Err(Error::UnexpectedPluginQueryDataFormat); + }; + + let Some(size) = num_size.as_u64() else { + return Err(Error::UnexpectedPluginQueryDataFormat); + }; + + let reduced_num = reduce(size); + + let value = engine + .query("dummy/sha256/sha256", vec![reduced_num]) + .await?; + + let Value::Array(mut sha256) = value else { + return Err(Error::UnexpectedPluginQueryDataFormat); + }; + + let Value::Number(num) = sha256.pop().unwrap() else { + return Err(Error::UnexpectedPluginQueryDataFormat); + }; + + match num.as_u64() { + Some(val) => return Ok(Value::Number(val.into())), + None => { + return Err(Error::UnexpectedPluginQueryDataFormat); + } + } + } +} + +impl Plugin for RandDataPlugin { + const PUBLISHER: &'static str = "dummy"; + const NAME: &'static str = "rand_data"; + + fn set_config( + &self, + _config: Value, + ) -> std::result::Result<(), hipcheck_sdk::error::ConfigError> { + Ok(()) + } + + fn default_policy_expr(&self) -> hipcheck_sdk::error::Result { + Ok("".to_owned()) + } + + fn explain_default_query(&self) -> hipcheck_sdk::error::Result> { + Ok(Some("generate random data".to_owned())) + } + + fn queries(&self) -> impl Iterator { + vec![NamedQuery { + name: "rand_data", + inner: Box::new(RandDataPlugin), + }] + .into_iter() + } +} + +#[derive(Parser, Debug)] +struct Args { + #[arg(long)] + port: u16, +} + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), hipcheck_sdk::error::Error> { + let args = Args::try_parse().unwrap(); + PluginServer::register(RandDataPlugin) + .listen(args.port) + .await +} diff --git a/plugins/dummy_sha256/src/main.rs b/plugins/dummy_sha256/src/main.rs index ebeee0e1..df9e9717 100644 --- a/plugins/dummy_sha256/src/main.rs +++ b/plugins/dummy_sha256/src/main.rs @@ -36,10 +36,10 @@ fn sha256(content: &[u8]) -> Vec { } async fn handle_sha256(session: QuerySession, key: &[u8]) -> Result<()> { - println!("Key: {key:02x?}"); + eprintln!("Key: {key:02x?}"); let res = sha256(key); - println!("Hash: {res:02x?}"); + eprintln!("Hash: {res:02x?}"); let output = serde_json::to_value(res)?; let resp = Query { diff --git a/plugins/dummy_sha256/src/transport.rs b/plugins/dummy_sha256/src/transport.rs index cf6f0434..28e3e9de 100644 --- a/plugins/dummy_sha256/src/transport.rs +++ b/plugins/dummy_sha256/src/transport.rs @@ -13,7 +13,7 @@ use std::{ pin::Pin, }; use tokio::sync::mpsc::{self, error::TrySendError}; -use tonic::{Status, Streaming}; +use tonic::Status; #[derive(Debug)] pub struct Query { @@ -109,7 +109,7 @@ impl QuerySession { async fn recv_raw(&mut self) -> Result>> { let mut out = VecDeque::new(); - eprintln!("RAND-session: awaiting raw rx recv"); + eprintln!("SHA256-session: awaiting raw rx recv"); let opt_first = self .rx @@ -121,12 +121,12 @@ impl QuerySession { // Underlying gRPC channel closed return Ok(None); }; - eprintln!("RAND-session: got first msg"); + eprintln!("SHA256-session: got first msg"); out.push_back(first); // If more messages in the queue, opportunistically read more loop { - eprintln!("RAND-session: trying to get additional msg"); + eprintln!("SHA256-session: trying to get additional msg"); match self.rx.try_recv() { Ok(Some(msg)) => { @@ -143,12 +143,12 @@ impl QuerySession { } } - eprintln!("RAND-session: got {} msgs", out.len()); + eprintln!("SHA256-session: got {} msgs", out.len()); Ok(Some(out)) } pub async fn send(&self, query: Query) -> Result<()> { - eprintln!("RAND-session: sending query"); + eprintln!("SHA256-session: sending query"); let query = InitiateQueryProtocolResponse { query: Some(self.convert(query)?), @@ -162,13 +162,13 @@ impl QuerySession { pub async fn recv(&mut self) -> Result> { use QueryState::*; - eprintln!("RAND-session: calling recv_raw"); + eprintln!("SHA256-session: calling recv_raw"); let Some(mut msg_chunks) = self.recv_raw().await? else { return Ok(None); }; let mut raw = msg_chunks.pop_front().unwrap(); - eprintln!("RAND-session: recv got raw {raw:?}"); + eprintln!("SHA256-session: recv got raw {raw:?}"); let mut state: QueryState = raw.state.try_into()?; @@ -236,9 +236,12 @@ impl Drop for QuerySession { } } +type PluginQueryStream = + Box> + Send + Unpin + 'static>; + pub struct HcSessionSocket { tx: mpsc::Sender>, - rx: Streaming, + rx: PluginQueryStream, drop_tx: mpsc::Sender, drop_rx: mpsc::Receiver, sessions: SessionTracker, @@ -261,7 +264,7 @@ impl std::fmt::Debug for HcSessionSocket { impl HcSessionSocket { pub fn new( tx: mpsc::Sender>, - rx: Streaming, + rx: impl Stream> + Send + Unpin + 'static, ) -> Self { // channel for QuerySession objects to notify us they dropped // @Todo - make this configurable @@ -269,7 +272,7 @@ impl HcSessionSocket { Self { tx, - rx, + rx: Box::new(rx), drop_tx, drop_rx, sessions: HashMap::new(), @@ -289,7 +292,7 @@ impl HcSessionSocket { } async fn message(&mut self) -> Result, Status> { - let fut = poll_fn(|cx| Pin::new(&mut self.rx).poll_next(cx)); + let fut = poll_fn(|cx| Pin::new(&mut *self.rx).poll_next(cx)); match fut.await { Some(Ok(m)) => Ok(m.query), diff --git a/proto/hipcheck/v1/messages/default_policy_expr_request.proto b/proto/hipcheck/v1/messages/default_policy_expr_request.proto new file mode 100644 index 00000000..6acc1a83 --- /dev/null +++ b/proto/hipcheck/v1/messages/default_policy_expr_request.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; +package hipcheck.v1; + +import "empty.proto"; + +/** + * Getting the default policy expression has no params, so we just wrap + * the empty message for maximal forward compatibility. + */ +message DefaultPolicyExprRequest { + Empty empty = 1; +} diff --git a/proto/hipcheck/v1/messages/default_policy_expr_response.proto b/proto/hipcheck/v1/messages/default_policy_expr_response.proto new file mode 100644 index 00000000..ef00ab36 --- /dev/null +++ b/proto/hipcheck/v1/messages/default_policy_expr_response.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; +package hipcheck.v1; + +/** + * The response from the DefaultPolicyExpr RPC call. + */ +message DefaultPolicyExprResponse { + /** + * A policy expression, if the plugin has a default policy. + * This MUST be filled in with any default values pulled from the plugin's + * configuration. Hipcheck will only request the default policy _after_ + * configuring the plugin. + */ + string policy_expression = 1; +} diff --git a/proto/hipcheck/v1/messages/empty.proto b/proto/hipcheck/v1/messages/empty.proto new file mode 100644 index 00000000..cc468416 --- /dev/null +++ b/proto/hipcheck/v1/messages/empty.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; +package hipcheck.v1; + +/** + * An empty message. + */ +message Empty {} diff --git a/proto/hipcheck/v1/messages/explain_default_query_request.proto b/proto/hipcheck/v1/messages/explain_default_query_request.proto new file mode 100644 index 00000000..0a02c862 --- /dev/null +++ b/proto/hipcheck/v1/messages/explain_default_query_request.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; +package hipcheck.v1; + +import "empty.proto"; + +message ExplainDefaultQueryRequest { + Empty empty = 1; +} diff --git a/proto/hipcheck/v1/messages/explain_default_query_response.proto b/proto/hipcheck/v1/messages/explain_default_query_response.proto new file mode 100644 index 00000000..03ad17bf --- /dev/null +++ b/proto/hipcheck/v1/messages/explain_default_query_response.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; +package hipcheck.v1; + +/** + * The response from the ExplainDefaultQuery RPC call. + */ +message ExplainDefaultQueryResponse { + /** + * An unstructured description of the default query. + */ + string explanation = 1; +} diff --git a/proto/hipcheck/v1/messages/query.proto b/proto/hipcheck/v1/messages/query.proto new file mode 100644 index 00000000..41161b0e --- /dev/null +++ b/proto/hipcheck/v1/messages/query.proto @@ -0,0 +1,46 @@ +syntax = "proto3"; +package hipcheck.v1; + +import "query_state.proto"; + +message Query { + // The ID of the request, used to associate requests and replies. + // Odd numbers = initiated by `hc`. + // Even numbers = initiated by a plugin. + int32 id = 1; + + // The state of the query, indicating if this is a request or a reply, + // and if it's a reply whether it's the end of the reply. + QueryState state = 2; + + // Publisher name and plugin name, when sent from Hipcheck to a plugin + // to initiate a fresh query, are used by the receiving plugin to validate + // that the query was intended for them. + // + // When a plugin is making a query to another plugin through Hipcheck, it's + // used to indicate the destination plugin, and to indicate the plugin that + // is replying when Hipcheck sends back the reply. + string publisher_name = 3; + string plugin_name = 4; + + // The name of the query being made, so the responding plugin knows what + // to do with the provided data. + string query_name = 5; + + // The key for the query, as a JSON object. This is the data that Hipcheck's + // incremental computation system will use to cache the response. + string key = 6; + + // The response for the query, as a JSON object. This will be cached by + // Hipcheck for future queries matching the publisher name, plugin name, + // query name, and key. + string output = 7; + + // Any "concerns" reported by a query. Concerns are *not* provided to + // other plugins calling a query, and are _only_ used by Hipcheck itself + // to provide the end-user with additional information about issues found + // during analysis. + // + // Concern chunking is the same as other fields. + repeated string concern = 8; +} diff --git a/proto/hipcheck/v1/messages/query_request.proto b/proto/hipcheck/v1/messages/query_request.proto new file mode 100644 index 00000000..7c8af882 --- /dev/null +++ b/proto/hipcheck/v1/messages/query_request.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; +package hipcheck.v1; + +import "query.proto"; + +message QueryRequest { + Query query = 1; +} diff --git a/proto/hipcheck/v1/messages/query_response.proto b/proto/hipcheck/v1/messages/query_response.proto new file mode 100644 index 00000000..1fa86bac --- /dev/null +++ b/proto/hipcheck/v1/messages/query_response.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; +package hipcheck.v1; + +import "query.proto"; + +message QueryResponse { + Query query = 1; +} diff --git a/proto/hipcheck/v1/messages/query_schemas_request.proto b/proto/hipcheck/v1/messages/query_schemas_request.proto new file mode 100644 index 00000000..d04c2cef --- /dev/null +++ b/proto/hipcheck/v1/messages/query_schemas_request.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; +package hipcheck.v1; + +import "empty.proto"; + +message QuerySchemasRequest { + Empty empty = 1; +} diff --git a/proto/hipcheck/v1/messages/query_schemas_response.proto b/proto/hipcheck/v1/messages/query_schemas_response.proto new file mode 100644 index 00000000..58e70217 --- /dev/null +++ b/proto/hipcheck/v1/messages/query_schemas_response.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; +package hipcheck.v1; + +message QuerySchemasResponse { + // The name of the query being described by the schemas provided. + // + // If either the key and/or output schemas result in a message which is + // too big, they may be chunked across multiple replies in the stream. + // Replies with matching query names should have their fields concatenated + // in the order received to reconstruct the chunks. + string query_name = 1; + + // The key schema, in JSON Schema format. + string key_schema = 2; + + // The output schema, in JSON Schema format. + string output_schema = 3; +} diff --git a/proto/hipcheck/v1/messages/query_state.proto b/proto/hipcheck/v1/messages/query_state.proto new file mode 100644 index 00000000..1a093605 --- /dev/null +++ b/proto/hipcheck/v1/messages/query_state.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; +package hipcheck.v1; + +enum QueryState { + // Something has gone wrong. + QUERY_STATE_UNSPECIFIED = 0; + + // We are submitting a new query. + QUERY_STATE_SUBMIT = 1; + + // We are replying to a query and expect more chunks. + QUERY_STATE_REPLY_IN_PROGRESS = 2; + + // We are closing a reply to a query. If a query response is in one chunk, + // just send this. If a query is in more than one chunk, send this with + // the last message in the reply. This tells the receiver that all chunks + // have been received. + QUERY_STATE_REPLY_COMPLETE = 3; +} diff --git a/proto/hipcheck/v1/messages/set_config_request.proto b/proto/hipcheck/v1/messages/set_config_request.proto new file mode 100644 index 00000000..430ab075 --- /dev/null +++ b/proto/hipcheck/v1/messages/set_config_request.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; +package hipcheck.v1; + +message SetConfigRequest { + // JSON string containing configuration data expected by the plugin, + // pulled from the user's policy file. + string configuration = 1; +} diff --git a/proto/hipcheck/v1/messages/set_config_response.proto b/proto/hipcheck/v1/messages/set_config_response.proto new file mode 100644 index 00000000..f4c0951a --- /dev/null +++ b/proto/hipcheck/v1/messages/set_config_response.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; +package hipcheck.v1; + +import "empty.proto"; + +message SetConfigResponse { + // No actual data returned. Errors handled with normal gRPC error system. + Empty empty = 1; +} diff --git a/proto/hipcheck/v1/plugin_service.proto b/proto/hipcheck/v1/plugin_service.proto new file mode 100644 index 00000000..201189c3 --- /dev/null +++ b/proto/hipcheck/v1/plugin_service.proto @@ -0,0 +1,67 @@ +syntax = "proto3"; +package hipcheck.v1; + +import "messages/query_schemas_request.proto"; +import "messages/query_schemas_response.proto"; +import "messages/set_config_request.proto"; +import "messages/set_config_response.proto"; +import "messages/default_policy_expr_request.proto"; +import "messages/default_policy_expr_response.proto"; +import "messages/explain_default_query_request.proto"; +import "messages/explain_default_query_response.proto"; +import "messages/query_request.proto"; +import "messages/query_response.proto"; + +/** + * Defines a Hipcheck plugin, able to interact with Hipcheck to provide + * support for additional analyses and sources of data. + */ +service PluginService { + /** + * Get schemas for all supported queries by the plugin. + * + * This is used by Hipcheck to validate that: + * + * - The plugin supports a default query taking a `target` type if used + * as a top-level plugin in the user's policy file. + * - That requests sent to the plugin and data returned by the plugin + * match the schema during execution. + */ + rpc QuerySchemas (QuerySchemasRequest) returns (stream QuerySchemasResponse); + + /** + * Hipcheck sends all child nodes for the plugin from the user's policy + * file to configure the plugin. + */ + rpc SetConfig (SetConfigRequest) returns (SetConfigResponse); + + /** + * Get the default policy for a plugin, which may additionally depend on + * the plugin's configuration. + */ + rpc DefaultPolicyExpr (DefaultPolicyExprRequest) returns (DefaultPolicyExprResponse); + + /** + * Get an explanation of what the default query returns, to use when + * reporting analysis results to users. + * + * Note that, because users can specify their own policy expression, this + * explanation *should not* assume the user has used the default policy + * expression, if one is provided by the plugin. + */ + rpc ExplainDefaultQuery (ExplainDefaultQueryRequest) + returns (ExplainDefaultQueryResponse); + + /** + * Open a bidirectional streaming RPC to enable a request/response + * protocol between Hipcheck and a plugin, where Hipcheck can issue + * queries to the plugin, and the plugin may issue queries to _other_ + * plugins through Hipcheck. + * + * Queries are cached by the publisher name, plugin name, query name, + * and key, and if a match is found for those four values, then + * Hipcheck will respond with the cached result of that prior matching + * query rather than running the query again. + */ + rpc Query (stream QueryRequest) returns (stream QueryResponse); +} diff --git a/sdk/rust/Cargo.toml b/sdk/rust/Cargo.toml new file mode 100644 index 00000000..0571480d --- /dev/null +++ b/sdk/rust/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "hipcheck-sdk" +license = "Apache-2.0" +version = "0.1.0" +edition = "2021" + +[dependencies] +thiserror = "1.0.63" +futures = "0.3.30" +indexmap = "2.4.0" +prost = "0.13.1" +rand = "0.8.5" +serde_json = "1.0.125" +tokio = { version = "1.39.2", features = ["rt"] } +tokio-stream = "0.1.15" +tonic = "0.12.1" +schemars = "0.8.21" + +[build-dependencies] +anyhow = "1.0.86" +tonic-build = "0.12.1" diff --git a/sdk/rust/build.rs b/sdk/rust/build.rs new file mode 100644 index 00000000..1c0e3adb --- /dev/null +++ b/sdk/rust/build.rs @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 + +fn main() -> anyhow::Result<()> { + tonic_build::compile_protos("../../hipcheck/proto/hipcheck/v1/hipcheck.proto")?; + Ok(()) +} diff --git a/sdk/rust/src/error.rs b/sdk/rust/src/error.rs new file mode 100644 index 00000000..73139318 --- /dev/null +++ b/sdk/rust/src/error.rs @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: Apache-2.0 + +use crate::proto::{ConfigurationStatus, InitiateQueryProtocolResponse, SetConfigurationResponse}; +use std::{convert::Infallible, ops::Not, result::Result as StdResult}; +use tokio::sync::mpsc::error::SendError as TokioMpscSendError; +use tonic::Status as TonicStatus; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("unknown error; query is in an unspecified state")] + UnspecifiedQueryState, + + #[error("unexpected ReplyInProgress state for query")] + UnexpectedReplyInProgress, + + #[error("invalid JSON in query key")] + InvalidJsonInQueryKey(#[source] serde_json::Error), + + #[error("invalid JSON in query output")] + InvalidJsonInQueryOutput(#[source] serde_json::Error), + + #[error("session channel closed unexpectedly")] + SessionChannelClosed, + + #[error("failed to send query from session to server")] + FailedToSendQueryFromSessionToServer( + #[source] TokioMpscSendError>, + ), + + #[error("plugin sent QueryReply when server was expecting a request")] + ReceivedReplyWhenExpectingRequest, + + #[error("plugin sent QuerySubmit when server was expecting a reply chunk")] + ReceivedSubmitWhenExpectingReplyChunk, + + #[error("received additional message for ID '{id}' after query completion")] + MoreAfterQueryComplete { id: usize }, + + #[error("failed to start server")] + FailedToStartServer(#[source] tonic::transport::Error), + + #[error("unexpected JSON value from plugin")] + UnexpectedPluginQueryDataFormat, + + #[error("could not determine which plugin query to run")] + UnknownPluginQuery, + + #[error("invalid format for QueryTarget")] + InvalidQueryTarget, +} + +// this will never happen, but is needed to enable passing QueryTarget to PluginEngine::query +impl From for Error { + fn from(_value: Infallible) -> Self { + Error::UnspecifiedQueryState + } +} + +pub type Result = StdResult; + +#[derive(Debug)] +pub enum ConfigError { + InvalidConfigValue { + field_name: String, + value: String, + reason: String, + }, + + MissingRequiredConfig { + field_name: String, + field_type: String, + possible_values: Vec, + }, + + UnrecognizedConfig { + field_name: String, + field_value: String, + possible_confusables: Vec, + }, + + Unspecified { + message: String, + }, +} + +impl From for SetConfigurationResponse { + fn from(value: ConfigError) -> Self { + match value { + ConfigError::InvalidConfigValue { + field_name, + value, + reason, + } => SetConfigurationResponse { + status: ConfigurationStatus::InvalidConfigurationValue as i32, + message: format!("invalid value '{value}' for '{field_name}', reason: '{reason}'"), + }, + ConfigError::MissingRequiredConfig { + field_name, + field_type, + possible_values, + } => SetConfigurationResponse { + status: ConfigurationStatus::MissingRequiredConfiguration as i32, + message: { + let mut message = format!( + "missing required config item '{field_name}' of type '{field_type}'" + ); + + if possible_values.is_empty().not() { + message.push_str("; possible values: "); + message.push_str(&possible_values.join(", ")); + } + + message + }, + }, + ConfigError::UnrecognizedConfig { + field_name, + field_value, + possible_confusables, + } => SetConfigurationResponse { + status: ConfigurationStatus::UnrecognizedConfiguration as i32, + message: { + let mut message = + format!("unrecognized field '{field_name}' with value '{field_value}'"); + + if possible_confusables.is_empty().not() { + message.push_str("; possible field names: "); + message.push_str(&possible_confusables.join(", ")); + } + + message + }, + }, + ConfigError::Unspecified { message } => SetConfigurationResponse { + status: ConfigurationStatus::Unspecified as i32, + message: format!("unknown error; {message}"), + }, + } + } +} diff --git a/sdk/rust/src/lib.rs b/sdk/rust/src/lib.rs new file mode 100644 index 00000000..1a637127 --- /dev/null +++ b/sdk/rust/src/lib.rs @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: Apache-2.0 + +mod proto { + include!(concat!(env!("OUT_DIR"), "/hipcheck.v1.rs")); +} + +pub mod error; +pub mod plugin_engine; +pub mod plugin_server; + +use crate::error::Error; +use crate::error::Result; +use error::ConfigError; +use plugin_engine::PluginEngine; +use schemars::schema::SchemaObject as JsonSchema; +use serde_json::Value as JsonValue; +use std::result::Result as StdResult; +use std::str::FromStr; + +// TODO: Do we want a prelude for the SDK? +// TODO: How do we want deal with third-party types exposed +// - schemars::schema::SchemaObject +// - serde_json::Value +// - tonic::async_trait +// ... + +// re-export of user facing third party dependencies +pub mod deps { + pub use schemars::schema::SchemaObject as JsonSchema; + pub use serde_json::{from_str, Value}; + pub use tonic::async_trait; +} + +#[derive(Debug, Clone)] +pub struct QueryTarget { + pub publisher: String, + pub plugin: String, + pub query: Option, +} + +impl FromStr for QueryTarget { + type Err = Error; + + fn from_str(s: &str) -> StdResult { + let parts: Vec<&str> = s.split('/').collect(); + match parts.as_slice() { + [publisher, plugin, query] => Ok(Self { + publisher: publisher.to_string(), + plugin: plugin.to_string(), + query: Some(query.to_string()), + }), + [publisher, plugin] => Ok(Self { + publisher: publisher.to_string(), + plugin: plugin.to_string(), + query: None, + }), + _ => Err(Error::InvalidQueryTarget), + } + } +} + +impl TryInto for &str { + type Error = Error; + fn try_into(self) -> StdResult { + QueryTarget::from_str(self) + } +} + +pub struct QuerySchema { + /// The name of the query being described. + query_name: &'static str, + + /// The query's input schema. + input_schema: JsonSchema, + + /// The query's output schema. + output_schema: JsonSchema, +} + +/// Query trait object. +pub type DynQuery = Box; + +pub struct NamedQuery { + /// The name of the query. + pub name: &'static str, + + /// The query object. + pub inner: DynQuery, +} + +impl NamedQuery { + /// Is the current query the default query? + fn is_default(&self) -> bool { + self.name.is_empty() + } +} + +/// Defines a single query for the plugin. +/// +/// TODO: Is `Send` needed here +#[tonic::async_trait] +pub trait Query: Send { + /// Get the input schema for the query. + fn input_schema(&self) -> JsonSchema; + + /// Get the output schema for the query. + fn output_schema(&self) -> JsonSchema; + + /// Run the plugin, optionally making queries to other plugins. + async fn run(&self, engine: &mut PluginEngine, input: JsonValue) -> Result; +} + +pub trait Plugin: Send + Sync + 'static { + /// The name of the publisher of the pl∂ugin. + const PUBLISHER: &'static str; + + /// The name of the plugin. + const NAME: &'static str; + + /// Handles setting configuration. + fn set_config(&self, config: JsonValue) -> StdResult<(), ConfigError>; + + /// Gets the plugin's default policy expression. + fn default_policy_expr(&self) -> Result; + + /// Gets a description of what is returned by the plugin's default query. + fn explain_default_query(&self) -> Result>; + + /// Get all the queries supported by the plugin. + fn queries(&self) -> impl Iterator; + + /// Get the plugin's default query, if it has one. + fn default_query(&self) -> Option { + self.queries() + .find_map(|named| named.is_default().then_some(named.inner)) + } + + /// Get all schemas for queries provided by the plugin. + fn schemas(&self) -> impl Iterator { + self.queries().map(|query| QuerySchema { + query_name: query.name, + input_schema: query.inner.input_schema(), + output_schema: query.inner.output_schema(), + }) + } +} diff --git a/sdk/rust/src/plugin_engine.rs b/sdk/rust/src/plugin_engine.rs new file mode 100644 index 00000000..62dee88a --- /dev/null +++ b/sdk/rust/src/plugin_engine.rs @@ -0,0 +1,475 @@ +// SPDX-License-Identifier: Apache-2.0 + +use crate::proto::QueryState; +use crate::{ + error::{Error, Result}, + proto::{ + self, InitiateQueryProtocolRequest, InitiateQueryProtocolResponse, Query as PluginQuery, + }, + QueryTarget, +}; +use crate::{JsonValue, Plugin}; +use futures::Stream; +use serde_json::{json, Value}; +use std::sync::Arc; +use std::{ + collections::{HashMap, VecDeque}, + future::poll_fn, + ops::Not, + pin::Pin, + result::Result as StdResult, +}; +use tokio::sync::mpsc::{self, error::TrySendError}; +use tonic::Status; + +impl From for Error { + fn from(_value: Status) -> Error { + // TODO: higher-fidelity handling? + Error::SessionChannelClosed + } +} + +#[derive(Debug)] +struct Query { + direction: QueryDirection, + publisher: String, + plugin: String, + query: String, + key: Value, + output: Value, + concerns: Vec, +} + +#[derive(Debug, PartialEq, Eq)] +enum QueryDirection { + Request, + Response, +} + +impl TryFrom for QueryDirection { + type Error = Error; + + fn try_from(value: QueryState) -> std::result::Result { + match value { + QueryState::Unspecified => Err(Error::UnspecifiedQueryState), + QueryState::Submit => Ok(QueryDirection::Request), + QueryState::ReplyInProgress => Err(Error::UnexpectedReplyInProgress), + QueryState::ReplyComplete => Ok(QueryDirection::Response), + } + } +} + +impl From for QueryState { + fn from(value: QueryDirection) -> Self { + match value { + QueryDirection::Request => QueryState::Submit, + QueryDirection::Response => QueryState::ReplyComplete, + } + } +} + +impl TryFrom for Query { + type Error = Error; + + fn try_from(value: PluginQuery) -> Result { + Ok(Query { + direction: QueryDirection::try_from(value.state())?, + publisher: value.publisher_name, + plugin: value.plugin_name, + query: value.query_name, + key: serde_json::from_str(value.key.as_str()).map_err(Error::InvalidJsonInQueryKey)?, + output: serde_json::from_str(value.output.as_str()) + .map_err(Error::InvalidJsonInQueryOutput)?, + concerns: value.concern, + }) + } +} + +type SessionTracker = HashMap>>; + +pub struct PluginEngine { + id: usize, + tx: mpsc::Sender>, + rx: mpsc::Receiver>, + // So that we can remove ourselves when we get dropped + drop_tx: mpsc::Sender, +} + +impl PluginEngine { + pub async fn query(&mut self, target: T, input: V) -> Result + where + T: TryInto>, + V: Into, + { + let input: JsonValue = input.into(); + let query_target: QueryTarget = target.try_into().map_err(|e| e.into())?; + + async fn query_inner( + engine: &mut PluginEngine, + target: QueryTarget, + input: JsonValue, + ) -> Result { + let query = Query { + direction: QueryDirection::Request, + publisher: target.publisher, + plugin: target.plugin, + query: target.query.unwrap_or_else(|| "".to_owned()), + key: input, + output: json!(Value::Null), + concerns: vec![], + }; + engine.send(query).await?; + let response = engine.recv().await?; + match response { + Some(response) => Ok(response.output), + None => Err(Error::SessionChannelClosed), + } + } + query_inner(self, query_target, input).await + } + + fn id(&self) -> usize { + self.id + } + + // Roughly equivalent to TryFrom, but the `id` field value + // comes from the QuerySession + fn convert(&self, value: Query) -> Result { + let state: QueryState = value.direction.into(); + let key = serde_json::to_string(&value.key).map_err(Error::InvalidJsonInQueryKey)?; + let output = + serde_json::to_string(&value.output).map_err(Error::InvalidJsonInQueryOutput)?; + + Ok(PluginQuery { + id: self.id() as i32, + state: state as i32, + publisher_name: value.publisher, + plugin_name: value.plugin, + query_name: value.query, + key, + output, + concern: value.concerns, + }) + } + + async fn recv_raw(&mut self) -> Result>> { + let mut out = VecDeque::new(); + + eprintln!("SDK: awaiting raw rx recv"); + + let opt_first = self.rx.recv().await.ok_or(Error::SessionChannelClosed)?; + + let Some(first) = opt_first else { + // Underlying gRPC channel closed + return Ok(None); + }; + out.push_back(first); + + // If more messages in the queue, opportunistically read more + loop { + match self.rx.try_recv() { + Ok(Some(msg)) => { + out.push_back(msg); + } + Ok(None) => { + eprintln!("warning: None received, gRPC channel closed. we may not close properly if None is not returned again"); + break; + } + // Whether empty or disconnected, we return what we have + Err(_) => { + break; + } + } + } + Ok(Some(out)) + } + + /// Send a gRPC query from plugin to the hipcheck server + async fn send(&self, query: Query) -> Result<()> { + let query = InitiateQueryProtocolResponse { + query: Some(self.convert(query)?), + }; + self.tx + .send(Ok(query)) + .await + .map_err(Error::FailedToSendQueryFromSessionToServer)?; + Ok(()) + } + + async fn recv(&mut self) -> Result> { + let Some(mut msg_chunks) = self.recv_raw().await? else { + return Ok(None); + }; + + let mut raw: PluginQuery = msg_chunks.pop_front().unwrap(); + // eprintln!("SDK: recv got raw {raw:?}"); + + let mut state: QueryState = raw + .state + .try_into() + .map_err(|_| Error::UnspecifiedQueryState)?; + + // If response is the first of a set of chunks, handle + if matches!(state, QueryState::ReplyInProgress) { + while matches!(state, QueryState::ReplyInProgress) { + // We expect another message. Pull it off the existing queue, + // or get a new one if we have run out + let next = match msg_chunks.pop_front() { + Some(msg) => msg, + None => { + // We ran out of messages, get a new batch + match self.recv_raw().await? { + Some(x) => { + msg_chunks = x; + } + None => { + return Ok(None); + } + }; + msg_chunks.pop_front().unwrap() + } + }; + + // By now we have our "next" message + state = next + .state + .try_into() + .map_err(|_| Error::UnspecifiedQueryState)?; + match state { + QueryState::Unspecified => return Err(Error::UnspecifiedQueryState), + QueryState::Submit => return Err(Error::ReceivedSubmitWhenExpectingReplyChunk), + QueryState::ReplyInProgress | QueryState::ReplyComplete => { + raw.output.push_str(next.output.as_str()); + raw.concern.extend_from_slice(next.concern.as_slice()); + } + }; + } + + // Sanity check - after we've left this loop, there should be no left over message + if msg_chunks.is_empty().not() { + return Err(Error::MoreAfterQueryComplete { id: self.id }); + } + } + + raw.try_into().map(Some) + } + + async fn handle_session

(&mut self, plugin: Arc

) -> crate::error::Result<()> + where + P: Plugin, + { + let Some(query) = self.recv().await? else { + return Err(Error::SessionChannelClosed); + }; + + if query.direction == QueryDirection::Response { + return Err(Error::ReceivedSubmitWhenExpectingReplyChunk); + } + + let name = query.query; + let key = query.key; + + // if we find the plugin by name, run it + // if not, check if there is a default plugin and run that one + // otherwise error out + let query = plugin + .queries() + .map(|x| x.inner) + .next() + .or_else(|| plugin.default_query()) + .ok_or_else(|| Error::UnknownPluginQuery)?; + + let value = query.run(self, key).await?; + + let query = proto::Query { + id: self.id() as i32, + state: QueryState::ReplyComplete as i32, + publisher_name: P::PUBLISHER.to_owned(), + plugin_name: P::NAME.to_owned(), + query_name: name, + key: json!(Value::Null).to_string(), + output: value.to_string(), + concern: vec![], + }; + + self.tx + .send(Ok(InitiateQueryProtocolResponse { query: Some(query) })) + .await + .map_err(Error::FailedToSendQueryFromSessionToServer)?; + + Ok(()) + } +} + +impl Drop for PluginEngine { + // Notify to have self removed from session tracker + fn drop(&mut self) { + while let Err(e) = self.drop_tx.try_send(self.id as i32) { + match e { + TrySendError::Closed(_) => { + break; + } + TrySendError::Full(_) => (), + } + } + } +} + +type PluginQueryStream = Box< + dyn Stream> + Send + Unpin + 'static, +>; + +pub(crate) struct HcSessionSocket { + tx: mpsc::Sender>, + rx: PluginQueryStream, + drop_tx: mpsc::Sender, + drop_rx: mpsc::Receiver, + sessions: SessionTracker, +} + +// This is implemented manually since the stream trait object +// can't impl `Debug`. +impl std::fmt::Debug for HcSessionSocket { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HcSessionSocket") + .field("tx", &self.tx) + .field("rx", &"") + .field("drop_tx", &self.drop_tx) + .field("drop_rx", &self.drop_rx) + .field("sessions", &self.sessions) + .finish() + } +} + +impl HcSessionSocket { + pub(crate) fn new( + tx: mpsc::Sender>, + rx: impl Stream> + Send + Unpin + 'static, + ) -> Self { + // channel for QuerySession objects to notify us they dropped + // TODO: make this configurable + let (drop_tx, drop_rx) = mpsc::channel(10); + Self { + tx, + rx: Box::new(rx), + drop_tx, + drop_rx, + sessions: HashMap::new(), + } + } + + /// Clean up completed sessions by going through all drop messages. + fn cleanup_sessions(&mut self) { + while let Ok(id) = self.drop_rx.try_recv() { + match self.sessions.remove(&id) { + Some(_) => eprintln!("Cleaned up session {id}"), + None => eprintln!( + "WARNING: HcSessionSocket got request to drop a session that does not exist" + ), + } + } + } + + async fn message(&mut self) -> StdResult, Status> { + let fut = poll_fn(|cx| Pin::new(&mut *self.rx).poll_next(cx)); + + match fut.await { + Some(Ok(m)) => Ok(m.query), + Some(Err(e)) => Err(e), + None => Ok(None), + } + } + + pub(crate) async fn listen(&mut self) -> Result> { + loop { + let Some(raw) = self.message().await.map_err(Error::from)? else { + return Ok(None); + }; + let id = raw.id; + + // While we were waiting for a message, some session objects may have + // dropped, handle them before we look at the ID of this message. + // The downside of this strategy is that once we receive our last message, + // we won't clean up any sessions that close after + self.cleanup_sessions(); + + match self.decide_action(&raw) { + Ok(HandleAction::ForwardMsgToExistingSession(tx)) => { + eprintln!("SDK: forwarding message to session {id}"); + + if let Err(_e) = tx.send(Some(raw)).await { + eprintln!("Error forwarding msg to session {id}"); + self.sessions.remove(&id); + }; + } + Ok(HandleAction::CreateSession) => { + eprintln!("SDK: creating new session {id}"); + + let (in_tx, rx) = mpsc::channel::>(10); + let tx = self.tx.clone(); + + let session = PluginEngine { + id: id as usize, + tx, + rx, + drop_tx: self.drop_tx.clone(), + }; + + in_tx.send(Some(raw)).await.expect( + "Failed sending message to newly created Session, should never happen", + ); + + eprintln!("RAND-listen: adding new session {id} to tracker"); + self.sessions.insert(id, in_tx); + + return Ok(Some(session)); + } + Err(e) => eprintln!("error: {}", e), + } + } + } + + fn decide_action(&mut self, query: &PluginQuery) -> Result> { + if let Some(tx) = self.sessions.get_mut(&query.id) { + return Ok(HandleAction::ForwardMsgToExistingSession(tx)); + } + + if query.state() == QueryState::Submit { + return Ok(HandleAction::CreateSession); + } + + Err(Error::ReceivedReplyWhenExpectingRequest) + } + + pub(crate) async fn run

(&mut self, plugin: Arc

) -> Result<()> + where + P: Plugin, + { + loop { + eprintln!("SHA256: Looping"); + + let Some(mut engine) = self + .listen() + .await + .map_err(|_| Error::SessionChannelClosed)? + else { + eprintln!("Channel closed by remote"); + break; + }; + + let cloned_plugin = plugin.clone(); + tokio::spawn(async move { + if let Err(e) = engine.handle_session(cloned_plugin).await { + panic!("handle_session failed: {e}"); + }; + }); + } + + Ok(()) + } +} + +enum HandleAction<'s> { + ForwardMsgToExistingSession(&'s mut mpsc::Sender>), + CreateSession, +} diff --git a/sdk/rust/src/plugin_server.rs b/sdk/rust/src/plugin_server.rs new file mode 100644 index 00000000..90023c58 --- /dev/null +++ b/sdk/rust/src/plugin_server.rs @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + error::{Error, Result}, + plugin_engine::HcSessionSocket, + proto::{ + plugin_service_server::{PluginService, PluginServiceServer}, + ConfigurationStatus, ExplainDefaultQueryRequest as ExplainDefaultQueryReq, + ExplainDefaultQueryResponse as ExplainDefaultQueryResp, + GetDefaultPolicyExpressionRequest as GetDefaultPolicyExpressionReq, + GetDefaultPolicyExpressionResponse as GetDefaultPolicyExpressionResp, + GetQuerySchemasRequest as GetQuerySchemasReq, + GetQuerySchemasResponse as GetQuerySchemasResp, + InitiateQueryProtocolRequest as InitiateQueryProtocolReq, + InitiateQueryProtocolResponse as InitiateQueryProtocolResp, + SetConfigurationRequest as SetConfigurationReq, + SetConfigurationResponse as SetConfigurationResp, + }, + Plugin, QuerySchema, +}; +use std::{result::Result as StdResult, sync::Arc}; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream as RecvStream; +use tonic::{transport::Server, Code, Request as Req, Response as Resp, Status, Streaming}; + +/// Runs the Hipcheck plugin protocol based on the user's plugin definition. +/// +/// The key idea is that this implements the gRPC mechanics and handles all +/// the details of the query protocol, so that the user doesn't need to do +/// anything more than define queries as asynchronous functions with associated +/// input and output schemas. +pub struct PluginServer

{ + plugin: Arc

, +} + +impl PluginServer

{ + /// Create a new plugin server for the provided plugin. + pub fn register(plugin: P) -> PluginServer

{ + PluginServer { + plugin: Arc::new(plugin), + } + } + + /// Run the plugin server on the provided port. + pub async fn listen(self, port: u16) -> Result<()> { + let service = PluginServiceServer::new(self); + let host = format!("127.0.0.1:{}", port).parse().unwrap(); + + Server::builder() + .add_service(service) + .serve(host) + .await + .map_err(Error::FailedToStartServer)?; + + Ok(()) + } +} + +/// The result of running a query. +pub type QueryResult = StdResult; + +#[tonic::async_trait] +impl PluginService for PluginServer

{ + type GetQuerySchemasStream = RecvStream>; + type InitiateQueryProtocolStream = RecvStream>; + + async fn set_configuration( + &self, + req: Req, + ) -> QueryResult> { + let config = serde_json::from_str(&req.into_inner().configuration) + .map_err(|e| Status::from_error(Box::new(e)))?; + match self.plugin.set_config(config) { + Ok(_) => Ok(Resp::new(SetConfigurationResp { + status: ConfigurationStatus::None as i32, + message: "".to_owned(), + })), + Err(e) => Ok(Resp::new(e.into())), + } + } + + async fn get_default_policy_expression( + &self, + _req: Req, + ) -> QueryResult> { + // The request is empty, so we do nothing. + match self.plugin.default_policy_expr() { + Ok(policy_expression) => Ok(Resp::new(GetDefaultPolicyExpressionResp { + policy_expression, + })), + Err(e) => Err(Status::new( + tonic::Code::NotFound, + format!( + "Error determining default policy expr for {}/{}: {}", + P::PUBLISHER, + P::NAME, + e + ), + )), + } + } + + async fn explain_default_query( + &self, + _req: Req, + ) -> QueryResult> { + match self.plugin.default_policy_expr() { + Ok(explanation) => Ok(Resp::new(ExplainDefaultQueryResp { explanation })), + Err(e) => Err(Status::new( + tonic::Code::NotFound, + format!( + "Error explaining default query expr for {}/{}: {}", + P::PUBLISHER, + P::NAME, + e + ), + )), + } + } + + async fn get_query_schemas( + &self, + _req: Req, + ) -> QueryResult> { + // Ignore the input, it's empty. + let query_schemas = self.plugin.schemas().collect::>(); + // TODO: does this need to be configurable? + let (tx, rx) = mpsc::channel(10); + tokio::spawn(async move { + for x in query_schemas { + let input_schema = serde_json::to_string(&x.input_schema); + let output_schema = serde_json::to_string(&x.output_schema); + + let schema_resp = match (input_schema, output_schema) { + (Ok(input_schema), Ok(output_schema)) => Ok(GetQuerySchemasResp { + query_name: x.query_name.to_string(), + key_schema: input_schema, + output_schema, + }), + (Ok(_), Err(e)) => Err(Status::new( + Code::FailedPrecondition, + format!("Error converting output schema to String: {}", e), + )), + (Err(_), Ok(e)) => Err(Status::new( + Code::FailedPrecondition, + format!("Error converting input schema to String: {}", e), + )), + (Err(e1), Err(e2)) => Err(Status::new( + Code::FailedPrecondition, + format!( + "Error converting input and output schema to String: {} {}", + e1, e2 + ), + )), + }; + + if tx.send(schema_resp).await.is_err() { + // TODO: handle this? + panic!(); + } + } + }); + Ok(Resp::new(RecvStream::new(rx))) + } + + async fn initiate_query_protocol( + &self, + req: Req>, + ) -> QueryResult> { + let rx = req.into_inner(); + // TODO: - make channel size configurable + let (tx, out_rx) = mpsc::channel::>(10); + + let cloned_plugin = self.plugin.clone(); + + tokio::spawn(async move { + let mut channel = HcSessionSocket::new(tx, rx); + if let Err(e) = channel.run(cloned_plugin).await { + panic!("Error: {e}"); + } + }); + Ok(Resp::new(RecvStream::new(out_rx))) + } +}