Skip to content

Commit

Permalink
Use just new Engine
Browse files Browse the repository at this point in the history
  • Loading branch information
quackzar committed Oct 1, 2024
1 parent ddaff6c commit 928152e
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 102 deletions.
8 changes: 8 additions & 0 deletions pycare/caring.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ class Engine:
"""


def sum(self, a: float | list[float]) -> float:
"""
Performs a summation with the connected parties.
Returns the sum of all the numbers.
:param a: number to summate with
"""


"""
Preprocess mult. triples and preshares
Expand Down
2 changes: 1 addition & 1 deletion pycare/examples/addm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def lasso_ADMM(engine: caring.Engine, A, b, max_iter=100, lam=1.):
n = len(u)
u_and_theta = u.tolist()
u_and_theta.extend(theta)
u_and_theta : list[float] = engine.sum_many(u_and_theta)
u_and_theta : list[float] = engine.sum(u_and_theta)
time_mpc_sum += (time.time() - t0)
u = np.array(u_and_theta[:n]) / 2.
theta = np.array(u_and_theta[n:]) / 2.
Expand Down
10 changes: 8 additions & 2 deletions pycare/examples/lasso1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@

# Processing (only for spdz)
print("Preprocessing...")
caring.preproc(100000, 0, "./ctx1.bin", "./ctx2.bin")
caring.preproc(100000, 10, "./context1.bin", "./context2.bin")

engine = caring.Engine(
scheme="spdz-25519",
address="localhost:1234",
peers=["localhost:1235"],
threshold=1,
preprocessed_path="./context1.bin"
)

print("Running...")
engine = caring.spdz("./ctx1.bin", "127.0.0.1:1234", "127.0.0.1:1235")
theta_1, func_vals = lasso_ADMM(engine, A_1, b_1)
# lets plot the objective values of the function
# to make sure it has converged
Expand Down
9 changes: 8 additions & 1 deletion pycare/examples/lasso2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@
b_1, b_2 = np.array_split(b, 2)


engine = caring.spdz("./ctx2.bin", "127.0.0.1:1235", "127.0.0.1:1234")
engine = caring.Engine(
scheme="spdz-25519",
address="localhost:1235",
peers=["localhost:1234"],
threshold=1,
preprocessed_path="./context2.bin"
)

theta_1, func_vals = lasso_ADMM(engine, A_2, b_2)
# lets plot the objective values of the function
# to make sure it has converged
Expand Down
18 changes: 8 additions & 10 deletions pycare/examples/test1.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import caring
# TODO: We need to make som preprocessing. Non of the participating parties should be allowed to do this,
# as knowing the other parties preprocedsed values breaks privacy.
# However for starters, and testing purposses ONLY we will allow party 1 to do it
# and save it where both party one and party two can find it.
from caring import Engine

caring.preproc(12, 0, "./context1.bin", "./context2.bin")
# engine = caring.spdz("./context1.bin", "127.0.0.1:1234", "127.0.0.1:1235")
engine = caring.shamir(2, "127.0.0.1:1234", "127.0.0.1:1235")
engine = Engine(
scheme="shamir-25519",
address="localhost:1234",
peers=["localhost:1235"],
threshold=2,
)

res = engine.sum(2.5)
print(f"2.5 - 5 = {res}")

res = engine.sum_many([2.5, 3.5])
res = engine.sum([2.5, 3.5])
print(f"[2.5, 3.5] + [3.2, 0.5] = {res}")

res = engine.sum(3.14159265359)
Expand Down Expand Up @@ -41,4 +40,3 @@
res = engine.sum(8.0)
print(f"8.0 + 2.02 = {res}")

engine.takedown()
14 changes: 9 additions & 5 deletions pycare/examples/test2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import caring
# engine = caring.spdz("./context2.bin", "127.0.0.1:1235", "127.0.0.1:1234")
engine = caring.shamir(2, "127.0.0.1:1235", "127.0.0.1:1234")
from caring import Engine

engine = Engine(
scheme="shamir-25519",
address="localhost:1235",
peers=["localhost:1234"],
threshold=2
)

res = engine.sum(-5)
print(f"2.5 - 5 = {res}")

res = engine.sum_many([3.2, 0.5])
res = engine.sum([3.2, 0.5])
print(f"[2.5, 3.5] + [3.2, 0.5] = {res}")

res = engine.sum(3.14159265359)
Expand Down Expand Up @@ -35,4 +40,3 @@
res = engine.sum(2.02)
print(f"8.0 + 2.02 = {res}")

engine.takedown()
83 changes: 1 addition & 82 deletions pycare/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,69 +1,11 @@
use pyo3::{exceptions::PyIOError, prelude::*, types::PyTuple};
use pyo3::{prelude::*, types::PyTuple};

pub mod expr;
pub mod vm;

use std::fs::File;
use wecare::*;

#[pyclass]
struct OldEngine(Option<wecare::Engine>);

/// Setup a MPC addition engine connected to the given sockets using SPDZ.
#[pyfunction]
#[pyo3(signature = (path_to_pre, my_addr, *others))]
fn spdz(path_to_pre: &str, my_addr: &str, others: &Bound<'_, PyTuple>) -> PyResult<OldEngine> {
let others: Vec<_> = others
.iter()
.map(|x| x.extract::<String>().unwrap().clone())
.collect();
let mut file = File::open(path_to_pre).unwrap();
match wecare::Engine::setup(my_addr)
.add_participants(&others)
.file_to_preprocessed(&mut file)
.build_spdz()
{
Ok(e) => Ok(OldEngine(Some(e))),
Err(e) => Err(PyIOError::new_err(e.0)),
}
}

/// Setup a MPC addition engine connected to the given sockets using shamir secret sharing.
#[pyfunction]
#[pyo3(signature = (threshold, my_addr, *others))]
fn shamir(threshold: u32, my_addr: &str, others: &Bound<'_, PyTuple>) -> PyResult<OldEngine> {
let others: Vec<_> = others
.iter()
.map(|x| x.extract::<String>().unwrap().clone())
.collect();
match wecare::Engine::setup(my_addr)
.add_participants(&others)
.threshold(threshold as u64)
.build_shamir()
{
Ok(e) => Ok(OldEngine(Some(e))),
Err(e) => Err(PyIOError::new_err(e.0)),
}
}

/// Setup a MPC addition engine connected to the given sockets using shamir secret sharing.
#[pyfunction]
#[pyo3(signature = (threshold, my_addr, *others))]
fn feldman(threshold: u32, my_addr: &str, others: &Bound<'_, PyTuple>) -> PyResult<OldEngine> {
let others: Vec<_> = others
.iter()
.map(|x| x.extract::<String>().unwrap().clone())
.collect();
match wecare::Engine::setup(my_addr)
.add_participants(&others)
.threshold(threshold as u64)
.build_feldman()
{
Ok(e) => Ok(OldEngine(Some(e))),
Err(e) => Err(PyIOError::new_err(e.0)),
}
}

/// Calculate and save the preprocessing
#[pyfunction]
#[pyo3(signature = (num_shares, num_triplets, *paths_to_pre, scheme="spdz-25519"))]
Expand Down Expand Up @@ -93,25 +35,6 @@ fn preproc(
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
}

#[pymethods]
impl OldEngine {
/// Run a sum procedure in which each party supplies a double floating point
fn sum(&mut self, a: f64) -> f64 {
self.0.as_mut().unwrap().mpc_sum(&[a]).unwrap()[0]
}

/// Run a sum procedure in which each party supplies a double floating point
fn sum_many(&mut self, a: Vec<f64>) -> Vec<f64> {
self.0.as_mut().unwrap().mpc_sum(&a).unwrap()
}

/// takedown engine
fn takedown(&mut self) {
let engine = self.0.take().unwrap();
engine.shutdown();
}
}

/// A Python module implemented in Rust.
#[pymodule]
fn caring(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
Expand All @@ -123,11 +46,7 @@ fn caring(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
let filter = EnvFilter::from_default_env();
tracing_subscriber::fmt().with_env_filter(filter).init();

m.add_function(wrap_pyfunction!(spdz, m)?)?;
m.add_function(wrap_pyfunction!(shamir, m)?)?;
m.add_function(wrap_pyfunction!(feldman, m)?)?;
m.add_function(wrap_pyfunction!(preproc, m)?)?;
m.add_class::<OldEngine>()?;
m.add_class::<vm::Engine>()?;
m.add_class::<vm::Computed>()?;
m.add_class::<expr::Expr>()?;
Expand Down
25 changes: 24 additions & 1 deletion pycare/src/vm.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{future::Future, ops::DerefMut, sync::Mutex, time::Duration};

use crate::expr::{Id, Opened};
use pyo3::{exceptions::PyValueError, prelude::*, types::PyList};
use pyo3::{exceptions::{PyTypeError, PyValueError}, prelude::*, types::PyList};
use wecare::vm;

#[pyclass(frozen)]
Expand Down Expand Up @@ -116,6 +116,29 @@ impl Engine {
Id(self.0.lock().unwrap().engine.id())
}


/// Sum
fn sum(&self, py: Python<'_>, num: &Bound<'_, PyAny>) -> PyResult<Vec<f64>> {
let mut this = self.0.lock().expect("Lock poisoned");
let EngineInner { engine, runtime } = this.deref_mut();
if let Ok(num) = num.extract::<f64>() {
runtime.block_on(check_signals(py, async {
engine
.sum(&[num])
.await
}))
} else if let Ok(nums) = num.extract::<Vec<f64>>() {
runtime.block_on(check_signals(py, async {
engine
.sum(&nums)
.await
}))
} else {
Err(PyTypeError::new_err("num is not a number"))
}
}


/// Your own Id
fn peers(&self) -> Vec<Id> {
self.0
Expand Down

0 comments on commit 928152e

Please sign in to comment.