Skip to content

Commit

Permalink
serde support
Browse files Browse the repository at this point in the history
  • Loading branch information
nmandery committed May 11, 2021
1 parent 141b6f1 commit ddc576b
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 41 deletions.
13 changes: 5 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
[package]
name = "extended-isolation-forest"
version = "0.1.0"
description = "rust port of the anomaly detection algorithm"
authors = ["Nico Mandery <[email protected]>"]
edition = "2018"

#[features]
#use-serde = ["serde", "num/serde"]

[dependencies]
num = "0.3"
num = "0.4"
rand = { version = "0.8", features = ["alloc"] }
rand_distr = "0.4"
#serde = { version = "1", optional = true, features = ["derive"] }

#[dev-dependencies]
#bincode = "1.3"
serde = { version = "1", features = ["derive"] }

[dev-dependencies]
serde_json = "1"
66 changes: 33 additions & 33 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,12 @@ use rand::{
},
};
use rand_distr::{Distribution, StandardNormal};
use serde::{Deserialize, Serialize};

pub use crate::error::Error;

//#[cfg(feature = "use-serde")]
//use serde::{Deserialize, Serialize};

mod error;
mod serde_array;

pub struct ForestOptions {
/// `n_trees` is the number of trees to be created.
Expand Down Expand Up @@ -100,20 +99,19 @@ impl Default for ForestOptions {
}
}

//#[cfg_attr(feature="use-serde", derive(Serialize, Deserialize))]
pub struct Forest<T: Float + SampleUniform + Default, const N: usize>
where
StandardNormal: Distribution<T>

#[derive(Serialize, Deserialize)]
pub struct Forest<T, const N: usize>
{
/// Multiplicative factor used in computing the anomaly scores.
avg_path_length_c: f64,

trees: Vec<Tree<T, N>>,
}

impl<T, const N: usize> Forest<T, N>
impl<'de, T, const N: usize> Forest<T, N>
where
T: Float + SampleUniform + Default,
T: Float + SampleUniform + Default + Serialize + Deserialize<'de>,
StandardNormal: Distribution<T>
{
/// Build a new forest from the given training data
Expand Down Expand Up @@ -163,14 +161,16 @@ impl<T, const N: usize> Forest<T, N>
}
}

//#[cfg_attr(feature="use-serde", derive(Serialize, Deserialize))]
enum Node<T: Float + Default, const N: usize> {
#[derive(Serialize, Deserialize)]
enum Node<T, const N: usize>
{
Ex(ExNode),
In(InNode<T, N>),
}

//#[cfg_attr(feature="use-serde", derive(Serialize, Deserialize))]
struct InNode<T: Float + Default, const N: usize> {
#[derive(Serialize, Deserialize)]
struct InNode<T, const N: usize>
{
/// Left child node.
left: Box<Node<T, N>>,

Expand All @@ -179,29 +179,29 @@ struct InNode<T: Float + Default, const N: usize> {

/// Normal vector at the root of this tree, which is used in
/// creating hyperplanes for splitting criteria
#[serde(with = "serde_array")]
n: [T; N],

/// Intercept point through which the hyperplane passes.
#[serde(with = "serde_array")]
p: [T; N],
}

//#[cfg_attr(feature="use-serde", derive(Serialize, Deserialize))]
#[derive(Serialize, Deserialize)]
struct ExNode {
/// Size of the dataset present at the node.
num_samples: usize,
}

//#[cfg_attr(feature="use-serde", derive(Serialize, Deserialize))]
struct Tree<T: Float + SampleUniform + Default, const N: usize>
where
StandardNormal: Distribution<T>
#[derive(Serialize, Deserialize)]
struct Tree<T, const N: usize>
{
root: Node<T, N>,
}

impl<T, const N: usize> Tree<T, N>
impl<'de, T, const N: usize> Tree<T, N>
where
T: Float + SampleUniform + Default,
T: Float + SampleUniform + Default + Serialize + Deserialize<'de>,
StandardNormal: Distribution<T>
{
pub fn new(samples: &[&[T; N]], rng: &mut ThreadRng, max_tree_depth: usize, extension_level: usize) -> Self {
Expand Down Expand Up @@ -357,9 +357,7 @@ mod tests {
Forest::from_slice(values.as_slice(), &options).unwrap()
}

#[test]
fn score_forest_3d_f64() {
let forest = make_f64_forest();
fn assert_anomalies_forest_3d_f64(forest: &Forest<f64, 3>) {
// no anomaly
assert!(forest.score(&[1.0, 3.0, 25.0]) < 0.5);
assert!(forest.score(&[1.0, 3.0, 35.0]) < 0.5);
Expand All @@ -369,16 +367,18 @@ mod tests {
assert!(forest.score(&[-1.0, 2.0, 60.0]) > 0.5);
assert!(forest.score(&[-1.0, 2.0, 0.0]) > 0.5);
}
/*
#[cfg(feature = "use-serde")]
#[test]
fn serialize_forest_2d_f64() {
let forest = make_f64_forest();

let forest_bytes = bincode::serialize(&forest).unwrap();
dbg!(&forest_bytes.len());
}
#[test]
fn score_forest_3d_f64() {
let forest = make_f64_forest();
assert_anomalies_forest_3d_f64(&forest);
}

*/
#[test]
fn serialize_forest_3d_f64() {
let forest = make_f64_forest();
let forest_json = serde_json::to_string(&forest).unwrap();
let forest2 = serde_json::from_str(forest_json.as_str()).unwrap();
assert_anomalies_forest_3d_f64(&forest2);
}
}
60 changes: 60 additions & 0 deletions src/serde_array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use std::{convert::TryInto, marker::PhantomData};

use serde::{
de::{SeqAccess, Visitor},
Deserialize,
Deserializer, ser::SerializeTuple, Serialize, Serializer,
};

// from https://github.com/serde-rs/serde/issues/1937#issuecomment-812137971

pub fn serialize<S: Serializer, T: Serialize, const N: usize>(
data: &[T; N],
ser: S,
) -> Result<S::Ok, S::Error> {
let mut s = ser.serialize_tuple(N)?;
for item in data {
s.serialize_element(item)?;
}
s.end()
}

struct ArrayVisitor<T, const N: usize>(PhantomData<T>);

impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
where
T: Deserialize<'de>,
{
type Value = [T; N];

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str(&format!("an array of length {}", N))
}

#[inline]
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
// can be optimized using MaybeUninit
let mut data = Vec::with_capacity(N);
for _ in 0..N {
match (seq.next_element())? {
Some(val) => data.push(val),
None => return Err(serde::de::Error::invalid_length(N, &self)),
}
}
match data.try_into() {
Ok(arr) => Ok(arr),
Err(_) => unreachable!(),
}
}
}

pub fn deserialize<'de, D, T, const N: usize>(deserializer: D) -> Result<[T; N], D::Error>
where
D: Deserializer<'de>,
T: Deserialize<'de>,
{
deserializer.deserialize_tuple(N, ArrayVisitor::<T, N>(PhantomData))
}

0 comments on commit ddc576b

Please sign in to comment.