Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added preliminary SDK support to "rand_data" plugin #427

Merged
merged 2 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ members = [
"hipcheck",
"hipcheck-macros",
"xtask",
"plugins/dummy_rand_data_sdk",
"plugins/dummy_rand_data",
"plugins/dummy_sha256",
"sdk/rust",
]

# Make sure Hipcheck is run with `cargo run`.
Expand All @@ -34,7 +36,12 @@ ci = "github"
# The installers to generate for each app
installers = ["shell", "powershell"]
# Target platforms to build apps for (Rust target-triple syntax)
targets = ["aarch64-apple-darwin", "x86_64-apple-darwin", "x86_64-unknown-linux-gnu", "x86_64-pc-windows-msvc"]
targets = [
"aarch64-apple-darwin",
"x86_64-apple-darwin",
"x86_64-unknown-linux-gnu",
"x86_64-pc-windows-msvc",
]
# Which actions to run on pull requests
pr-run-mode = "plan"
# Whether to install an updater program
Expand Down
2 changes: 2 additions & 0 deletions hipcheck/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,8 @@ impl TryFrom<Vec<String>> for RepoCacheDeleteScope {
pub struct PluginArgs {
#[arg(long = "async")]
pub asynch: bool,
#[arg(long = "sdk")]
pub sdk: bool,
}

/// The format to report results in.
Expand Down
23 changes: 14 additions & 9 deletions hipcheck/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,22 @@ fn query(
};
// Initiate the query. If remote closed or we got our response immediately,
// return
println!("Querying {plugin}::{query} with key {key:?}");
eprintln!("Querying {plugin}::{query} with key {key:?}");
let mut ar = match runtime.block_on(p_handle.query(query, key))? {
PluginResponse::RemoteClosed => {
return Err(hc_error!("Plugin channel closed unexpected"));
}
PluginResponse::Completed(v) => return Ok(v),
PluginResponse::AwaitingResult(a) => a,
PluginResponse::AwaitingResult(a) => {
eprintln!("awaiting result: {:?}", a);
a
}
};
// Otherwise, the plugin needs more data to continue. Recursively query
// (with salsa memo-ization) to get the needed data, and resume our
// current query by providing the plugin the answer.
loop {
println!("Query needs more info, recursing...");
eprintln!("Query needs more info, recursing...");
let answer = db
.query(
ar.publisher.clone(),
Expand All @@ -111,7 +114,7 @@ fn query(
ar.key.clone(),
)?
.value;
println!("Got answer {answer:?}, resuming");
eprintln!("Got answer {answer:?}, resuming");
ar = match runtime.block_on(p_handle.resume_query(ar, answer))? {
PluginResponse::RemoteClosed => {
return Err(hc_error!("Plugin channel closed unexpected"));
Expand All @@ -138,7 +141,7 @@ pub fn async_query(
};
// Initiate the query. If remote closed or we got our response immediately,
// return
println!("Querying: {query}, key: {key:?}");
eprintln!("Querying: {query}, key: {key:?}");
let mut ar = match p_handle.query(query, key).await? {
PluginResponse::RemoteClosed => {
return Err(hc_error!("Plugin channel closed unexpected"));
Expand All @@ -152,7 +155,7 @@ pub fn async_query(
// (with salsa memo-ization) to get the needed data, and resume our
// current query by providing the plugin the answer.
loop {
println!("Awaiting result, now recursing");
eprintln!("Awaiting result, now recursing");
let answer = async_query(
Arc::clone(&core),
ar.publisher.clone(),
Expand All @@ -162,12 +165,14 @@ pub fn async_query(
)
.await?
.value;
println!("Resuming query with answer {answer:?}");
eprintln!("Resuming query with answer {answer:?}");
ar = match p_handle.resume_query(ar, answer).await? {
PluginResponse::RemoteClosed => {
return Err(hc_error!("Plugin channel closed unexpected"));
}
PluginResponse::Completed(v) => return Ok(v),
PluginResponse::Completed(v) => {
return Ok(v);
}
PluginResponse::AwaitingResult(a) => a,
};
}
Expand Down Expand Up @@ -198,7 +203,7 @@ impl HcEngineImpl {
// independent of Salsa.
pub fn new(executor: PluginExecutor, plugins: Vec<PluginWithConfig>) -> Result<Self> {
let runtime = RUNTIME.handle();
println!("Starting HcPluginCore");
eprintln!("Starting HcPluginCore");
let core = runtime.block_on(HcPluginCore::new(executor, plugins))?;
let mut engine = HcEngineImpl {
storage: Default::default(),
Expand Down
10 changes: 9 additions & 1 deletion hipcheck/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,15 @@ fn cmd_plugin(args: PluginArgs) {
use tokio::task::JoinSet;

let tgt_dir = "./target/debug";
let entrypoint1 = pathbuf![tgt_dir, "dummy_rand_data"];

let entrypoint1 = match args.sdk {
true => {
pathbuf![tgt_dir, "dummy_rand_data_sdk"]
}
false => {
pathbuf![tgt_dir, "dummy_rand_data"]
}
};
let entrypoint2 = pathbuf![tgt_dir, "dummy_sha256"];
let plugin1 = Plugin {
name: "dummy/rand_data".to_owned(),
Expand Down
9 changes: 7 additions & 2 deletions hipcheck/src/plugin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,22 @@ impl ActivePlugin {

pub async fn query(&self, name: String, key: Value) -> Result<PluginResponse> {
let id = self.get_unique_id().await;

// TODO: remove this unwrap
let (publisher, plugin) = self.channel.name().split_once('/').unwrap();

// @Todo - check name+key valid for schema
let query = Query {
id,
request: true,
publisher: "".to_owned(),
plugin: self.channel.name().to_owned(),
publisher: publisher.to_owned(),
plugin: plugin.to_owned(),
query: name,
key,
output: serde_json::json!(null),
concerns: vec![],
};

Ok(self.channel.query(query).await?.into())
}

Expand Down
3 changes: 3 additions & 0 deletions hipcheck/src/plugin/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ impl PluginTransport {

// Send the query
let query: PluginQuery = query.try_into()?;

let id = query.id;
self.tx
.send(query)
Expand All @@ -499,6 +500,8 @@ impl PluginTransport {
while matches!(state, ReplyInProgress) {
// We expect another message. Pull it off the existing queue,
// or get a new one if we have run out
eprintln!("In progress");

let next = match msg_chunks.pop_front() {
Some(msg) => msg,
None => {
Expand Down
13 changes: 8 additions & 5 deletions plugins/dummy_rand_data/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::{
pin::Pin,
};
use tokio::sync::mpsc::{self, error::TrySendError};
use tonic::{Status, Streaming};
use tonic::Status;

#[derive(Debug)]
pub struct Query {
Expand Down Expand Up @@ -236,9 +236,12 @@ impl Drop for QuerySession {
}
}

type PluginQueryStream =
Box<dyn Stream<Item = Result<InitiateQueryProtocolRequest, Status>> + Send + Unpin + 'static>;

pub struct HcSessionSocket {
tx: mpsc::Sender<Result<InitiateQueryProtocolResponse, Status>>,
rx: Streaming<InitiateQueryProtocolRequest>,
rx: PluginQueryStream,
drop_tx: mpsc::Sender<i32>,
drop_rx: mpsc::Receiver<i32>,
sessions: SessionTracker,
Expand All @@ -261,15 +264,15 @@ impl std::fmt::Debug for HcSessionSocket {
impl HcSessionSocket {
pub fn new(
tx: mpsc::Sender<Result<InitiateQueryProtocolResponse, Status>>,
rx: Streaming<InitiateQueryProtocolRequest>,
rx: impl Stream<Item = Result<InitiateQueryProtocolRequest, Status>> + Send + Unpin + 'static,
) -> Self {
// channel for QuerySession objects to notify us they dropped
// @Todo - make this configurable
let (drop_tx, drop_rx) = mpsc::channel(10);

Self {
tx,
rx,
rx: Box::new(rx),
drop_tx,
drop_rx,
sessions: HashMap::new(),
Expand All @@ -289,7 +292,7 @@ impl HcSessionSocket {
}

async fn message(&mut self) -> Result<Option<PluginQuery>, Status> {
let fut = poll_fn(|cx| Pin::new(&mut self.rx).poll_next(cx));
let fut = poll_fn(|cx| Pin::new(&mut *self.rx).poll_next(cx));

match fut.await {
Some(Ok(m)) => Ok(m.query),
Expand Down
13 changes: 13 additions & 0 deletions plugins/dummy_rand_data_sdk/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "dummy_rand_data_sdk"
version = "0.1.0"
license = "Apache-2.0"
edition = "2021"
publish = false

[dependencies]
anyhow = "1.0.87"
clap = { version = "4.5.16", features = ["derive"] }
hipcheck-sdk = { path = "../../sdk/rust" }
rand = "0.8.5"
tokio = { version = "1.40.0", features = ["rt"] }
3 changes: 3 additions & 0 deletions plugins/dummy_rand_data_sdk/schema/query_schema_get_rand.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"type": "integer"
}
112 changes: 112 additions & 0 deletions plugins/dummy_rand_data_sdk/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// SPDX-License-Identifier: Apache-2.0

use anyhow::Result;
use clap::Parser;
use hipcheck_sdk::prelude::*;

static GET_RAND_KEY_SCHEMA: &str = include_str!("../schema/query_schema_get_rand.json");
static GET_RAND_OUTPUT_SCHEMA: &str = include_str!("../schema/query_schema_get_rand.json");

fn reduce(input: u64) -> u64 {
input % 7
}

/// Plugin that queries hipcheck takes a `Value::Number` as input and performs following steps:
/// - ensures input is u64
/// - % 7 of input
/// - queries `hipcheck` for sha256 of (% 7 of input)
/// - returns `Value::Number`, where Number is the first `u8` in the sha256
///
/// Goals of this plugin
/// - Verify `salsa` memoization is working (there should only ever be 7 queries made to `hipcheck`)
/// - Verify plugins are able to query `hipcheck` for additional information
#[derive(Clone, Debug)]
struct RandDataPlugin;

#[async_trait]
impl Query for RandDataPlugin {
fn input_schema(&self) -> JsonSchema {
from_str(GET_RAND_KEY_SCHEMA).unwrap()
}

fn output_schema(&self) -> JsonSchema {
from_str(GET_RAND_OUTPUT_SCHEMA).unwrap()
}

async fn run(
&self,
engine: &mut PluginEngine,
input: Value,
) -> hipcheck_sdk::error::Result<Value> {
let Value::Number(num_size) = input else {
return Err(Error::UnexpectedPluginQueryDataFormat);
};

let Some(size) = num_size.as_u64() else {
return Err(Error::UnexpectedPluginQueryDataFormat);
};

let reduced_num = reduce(size);

let value = engine
.query("dummy/sha256/sha256", vec![reduced_num])
.await?;

let Value::Array(mut sha256) = value else {
return Err(Error::UnexpectedPluginQueryDataFormat);
};

let Value::Number(num) = sha256.pop().unwrap() else {
return Err(Error::UnexpectedPluginQueryDataFormat);
};

match num.as_u64() {
Some(val) => return Ok(Value::Number(val.into())),
None => {
return Err(Error::UnexpectedPluginQueryDataFormat);
}
}
}
}

impl Plugin for RandDataPlugin {
const PUBLISHER: &'static str = "dummy";
const NAME: &'static str = "rand_data";

fn set_config(
&self,
_config: Value,
) -> std::result::Result<(), hipcheck_sdk::error::ConfigError> {
Ok(())
}

fn default_policy_expr(&self) -> hipcheck_sdk::error::Result<String> {
Ok("".to_owned())
}

fn explain_default_query(&self) -> hipcheck_sdk::error::Result<Option<String>> {
Ok(Some("generate random data".to_owned()))
}

fn queries(&self) -> impl Iterator<Item = hipcheck_sdk::NamedQuery> {
vec![NamedQuery {
name: "rand_data",
inner: Box::new(RandDataPlugin),
}]
.into_iter()
}
}

#[derive(Parser, Debug)]
struct Args {
#[arg(long)]
port: u16,
}

#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), hipcheck_sdk::error::Error> {
let args = Args::try_parse().unwrap();
PluginServer::register(RandDataPlugin)
.listen(args.port)
.await
}
Loading