Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add cANI estimation to branchwater #188

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ zip = "0.6"
tempfile = "3.9"
needletail = "0.5.1"
csv = "1.3.0"
statrs = "0.16.0"
ndarray = "0.15.6"
roots = "0.0.8"

[dev-dependencies]
assert_cmd = "2.0.13"
Expand Down
241 changes: 241 additions & 0 deletions src/c_ani.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
use roots::{find_root_brent, SimpleConvergency};
use statrs::distribution::{ContinuousCDF, Normal};
use std::error::Error;

#[derive(Debug)]
#[allow(dead_code)]
pub struct CiAniResult {
pub point_estimate: f64,
prob_nothing_in_common: f64,
dist_low: Option<f64>,
dist_high: Option<f64>,
p_threshold: f64,
}

fn exp_n_mutated(l: u64, k: u32, r1: f64) -> f64 {
let q = r1_to_q(k, r1);
l as f64 * q
}

fn var_n_mutated(l: u64, k: u32, r1: f64, q: Option<f64>) -> Result<f64, Box<dyn Error>> {
if r1 == 0.0 {
return Ok(0.0);
}

let q = q.unwrap_or_else(|| r1_to_q(k, r1));

let var_n = l as f64 * (1.0 - q) * (q * (2.0 * k as f64 + (2.0 / r1) - 1.0) - 2.0 * k as f64)
+ k as f64 * (k as f64 - 1.0) * (1.0 - q).powi(2)
+ (2.0 * (1.0 - q) / (r1.powi(2))) * ((1.0 + (k as f64 - 1.0) * (1.0 - q)) * r1 - q);

if var_n < 0.0 {
Err("Error: varN < 0.0!".into())
} else {
Ok(var_n)
}
}

fn exp_n_mutated_squared(l: u64, k: u32, p: f64) -> Result<f64, Box<dyn Error>> {
let var_n = var_n_mutated(l, k, p, None)?;
Ok(var_n + exp_n_mutated(l, k, p).powi(2))
}

fn probit(p: f64) -> f64 {
Normal::new(0.0, 1.0).unwrap().inverse_cdf(p)
}

fn r1_to_q(k: u32, r1: f64) -> f64 {
1.0 - (1.0 - r1).powi(k as i32)
}

fn get_exp_probability_nothing_common(
mutation_rate: f64,
kmer_size: u32,
scaled: u64,
n_unique_kmers: u64,
) -> f64 {
// Inverse of the scale factor, used in probability calculation.
let inverse_scaled = 1.0 / scaled as f64;

// Handle special cases for mutation rate.
if mutation_rate == 1.0 {
1.0
} else if mutation_rate == 0.0 {
0.0
} else {
// Calculate the expected log probability.
let expected_log_probability =
get_expected_log_probability(n_unique_kmers, kmer_size, mutation_rate, inverse_scaled);
// Return the exponential of the expected log probability.
expected_log_probability.exp()
}
}

fn get_expected_log_probability(
n_unique_kmers: u64,
ksize: u32,
mutation_rate: f64,
scaled_fraction: f64,
) -> f64 {
let exp_nmut = exp_n_mutated(n_unique_kmers, ksize, mutation_rate);
let result = (n_unique_kmers as f64 - exp_nmut) * (1.0 - scaled_fraction).ln();

if result.is_infinite() {
f64::NEG_INFINITY
} else {
result
}
}

fn handle_seqlen_nkmers(
ksize: u32,
sequence_len_bp: Option<u64>,
n_unique_kmers: Option<u64>,
) -> Result<u64, Box<dyn Error>> {
match n_unique_kmers {
Some(n) => Ok(n),
None => match sequence_len_bp {
Some(len) => Ok(len.saturating_sub(ksize as u64 - 1)),
None => Err(
"Error: distance estimation requires 'sequence_len_bp' or 'n_unique_kmers'".into(),
),
},
}
}

fn check_distance(dist: f64) -> Result<f64, Box<dyn Error>> {
if dist >= 0.0 && dist <= 1.0 {
Ok(dist)
} else {
Err(format!("Error: distance value {:.4} is not between 0 and 1!", dist).into())
}
}

impl CiAniResult {
fn new(
point_estimate: f64,
prob_nothing_in_common: f64,
dist_low: Option<f64>,
dist_high: Option<f64>,
p_threshold: f64,
) -> Result<Self, Box<dyn Error>> {
let dist_low_checked = dist_low.map_or(Ok(None), |d| check_distance(d).map(Some))?;
let dist_high_checked = dist_high.map_or(Ok(None), |d| check_distance(d).map(Some))?;

Ok(Self {
point_estimate,
prob_nothing_in_common,
dist_low: dist_low_checked,
dist_high: dist_high_checked,
p_threshold,
})
}
}

pub fn containment_to_distance(
containment: f64,
ksize: u32,
scaled: u64,
n_unique_kmers: Option<u64>,
sequence_len_bp: Option<u64>,
confidence: Option<f64>,
estimate_ci: Option<bool>,
prob_threshold: Option<f64>,
) -> Result<CiAniResult, Box<dyn Error>> {
let scaled_f64 = scaled as f64;
let n_unique_kmers = handle_seqlen_nkmers(ksize, sequence_len_bp, n_unique_kmers)?;

let confidence = confidence.unwrap_or(0.95);
let estimate_ci = estimate_ci.unwrap_or(false);
let prob_threshold = prob_threshold.unwrap_or(1e-3);

let point_estimate = if containment == 0.0 {
1.0
} else if containment == 1.0 {
0.0
} else {
1.0 - containment.powf(1.0 / ksize as f64)
};

let mut sol1 = None;
let mut sol2 = None;

if estimate_ci {
let alpha = 1.0 - confidence;
let z_alpha = probit(1.0 - alpha / 2.0);
let f_scaled = 1.0 / scaled_f64;
let bias_factor = 1.0 - (1.0 - f_scaled).powi(n_unique_kmers as i32);
let term_1 =
(1.0 - f_scaled) / (f_scaled * (n_unique_kmers as f64).powi(3) * bias_factor.powi(2));
let term_2 = |pest: f64| {
n_unique_kmers as f64 * exp_n_mutated(n_unique_kmers, ksize, pest)
- exp_n_mutated_squared(n_unique_kmers, ksize, pest).unwrap_or(0.0)
};
let term_3 = |pest: f64| {
var_n_mutated(n_unique_kmers, ksize, pest, None).unwrap_or(0.0)
/ (n_unique_kmers as f64).powi(2)
};

let var_direct = |pest: f64| term_1 * term_2(pest) + term_3(pest);

let f1 = |pest: f64| {
(1.0 - pest).powi(ksize as i32) + z_alpha * var_direct(pest).sqrt() - containment
};
let f2 = |pest: f64| {
(1.0 - pest).powi(ksize as i32) - z_alpha * var_direct(pest).sqrt() - containment
};

let mut convergency = SimpleConvergency {
eps: 1e-15,
max_iter: 1000,
};
sol1 = find_root_brent(0.0000001, 0.9999999, &f1, &mut convergency).ok();
sol2 = find_root_brent(0.0000001, 0.9999999, &f2, &mut convergency).ok();
}

let prob_nothing_in_common =
get_exp_probability_nothing_common(point_estimate, ksize, scaled, n_unique_kmers);

CiAniResult::new(
point_estimate,
prob_nothing_in_common,
sol1,
sol2,
prob_threshold,
)
}

/* I calculate the distance, then I do 1 - distance to get the ANI identity.

fn main() {
let containment = 0.1;
let ksize = 31;
let scaled = 1000; // u64
let n_unique_kmers = 1000;
let sequence_len_bp = n_unique_kmers * scaled;
let confidence = Some(0.95);
let estimate_ci = Some(true);
let prob_threshold = Some(1e-3);

let result = containment_to_distance(
containment,
ksize,
scaled,
Some(n_unique_kmers),
Some(sequence_len_bp),
confidence,
estimate_ci,
prob_threshold,
);


match result {
Ok(ci_result) => {
let ani_result = 1.0 - ci_result.point_estimate;
println!("ANI: {}", ani_result);
}

Err(e) => println!("Error occurred: {:?}", e),
}
}
*/
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ extern crate simple_error;
mod utils;
use crate::utils::build_template;
use crate::utils::is_revindex_database;
mod c_ani;
mod check;
mod fastgather;
mod fastmultigather;
Expand Down
49 changes: 45 additions & 4 deletions src/pairwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use std::sync::atomic::AtomicUsize;
use sourmash::signature::SigsTrait;
use sourmash::sketch::Sketch;

use crate::c_ani::containment_to_distance;

use crate::utils::{load_sketches_from_zip_or_pathlist, ReportType};

/// Perform pairwise comparisons of all signatures in a list.
Expand All @@ -37,12 +39,13 @@ pub fn pairwise<P: AsRef<Path>>(
};
let thrd = std::thread::spawn(move || {
let mut writer = BufWriter::new(out);
writeln!(&mut writer, "query_name,query_md5,match_name,match_md5,containment,max_containment,jaccard,intersect_hashes").unwrap();
for (query, query_md5, m, m_md5, cont, max_cont, jaccard, overlap) in recv.into_iter() {
writeln!(&mut writer, "query_name,query_md5,match_name,match_md5,containment,max_containment,jaccard,intersect_hashes,ani").unwrap();
for (query, query_md5, m, m_md5, cont, max_cont, jaccard, overlap, ani) in recv.into_iter()
{
writeln!(
&mut writer,
"\"{}\",{},\"{}\",{},{},{},{},{}",
query, query_md5, m, m_md5, cont, max_cont, jaccard, overlap
"\"{}\",{},\"{}\",{},{},{},{},{},{}",
query, query_md5, m, m_md5, cont, max_cont, jaccard, overlap, ani
)
.ok();
}
Expand All @@ -60,11 +63,48 @@ pub fn pairwise<P: AsRef<Path>>(
let query1_size = q1.minhash.size() as f64;
let query2_size = q2.minhash.size() as f64;

let k_size = q1.minhash.ksize() as u32;
let scaled = q1.minhash.scaled() as u64;

let containment_q1_in_q2 = overlap / query1_size;
let containment_q2_in_q1 = overlap / query2_size;
let max_containment = containment_q1_in_q2.max(containment_q2_in_q1);
let jaccard = overlap / (query1_size + query2_size - overlap);

// TODO: this is equivalent to query1_size and query2_size but un u64
let sig_1_size = q1.minhash.size() as u64;
let sig_2_size = q2.minhash.size() as u64;

let ani_1_in_2 = 1.00
- containment_to_distance(
containment_q1_in_q2,
k_size,
scaled,
Some(sig_2_size),
Some(sig_2_size * scaled),
Some(0.95),
Some(true),
Some(1e-3),
)
.unwrap()
.point_estimate as f64;

let ani_2_in_1 = 1.00
- containment_to_distance(
containment_q2_in_q1,
k_size,
scaled,
Some(sig_1_size),
Some(sig_1_size * scaled),
Some(0.95),
Some(true),
Some(1e-3),
)
.unwrap()
.point_estimate as f64;

let ani = (ani_1_in_2 + ani_2_in_1) / 2.0;

if containment_q1_in_q2 > threshold || containment_q2_in_q1 > threshold {
send.send((
q1.name.clone(),
Expand All @@ -75,6 +115,7 @@ pub fn pairwise<P: AsRef<Path>>(
max_containment,
jaccard,
overlap,
ani,
))
.unwrap();
}
Expand Down
Loading