diff --git a/Cargo.toml b/Cargo.toml index c8c1f85..bfd17f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" authors = ["Ryan Walker "] [dependencies] -docopt = "0.7.0" -rustc-serialize = "0.3" +docopt = "1" +serde = { version = "1", features = ["derive"] } random_choice = "0.3.2" fnv = "1.0.3" \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 830f7e1..2dfd1ec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,48 +1,55 @@ #![allow(non_snake_case)] -extern crate random_choice; extern crate fnv; +extern crate random_choice; -use std::collections::{HashSet, HashMap}; -use std::cmp::max; -use fnv::FnvHashMap; use self::random_choice::random_choice; +use fnv::FnvHashMap; +use std::cmp::max; +use std::collections::{HashMap, HashSet}; pub struct GSDMM { alpha: f64, beta: f64, - K:usize, - V:f64, - D:usize, - maxit:isize, + K: usize, + V: f64, + D: usize, + maxit: isize, clusters: Vec, - pub doc_vectors:Vec>, + pub doc_vectors: Vec>, pub labels: Vec, pub cluster_counts: Vec, - pub cluster_word_counts:Vec, - pub word_index_map:HashMap, - pub index_word_map:HashMap, - pub cluster_word_distributions: Vec> + pub cluster_word_counts: Vec, + pub word_index_map: HashMap, + pub index_word_map: HashMap, + pub cluster_word_distributions: Vec>, } impl GSDMM { - pub fn new(alpha:f64, beta:f64, K: usize, maxit:isize, vocab:HashSet, docs:Vec>) -> GSDMM { + pub fn new( + alpha: f64, + beta: f64, + K: usize, + maxit: isize, + vocab: &HashSet, + docs: &Vec>, + ) -> GSDMM { let D = docs.len(); // compute utilized vocabulary size. - let mut word_index_map = HashMap::::with_capacity(vocab.len()/2); - let mut index_word_map = HashMap::::with_capacity(vocab.len()/2); + let mut word_index_map = HashMap::::with_capacity(vocab.len() / 2); + let mut index_word_map = HashMap::::with_capacity(vocab.len() / 2); let mut index = 0_usize; let mut doc_vectors = Vec::>::with_capacity(D); - for doc in &docs { + for doc in docs { let mut doc_vector = Vec::::with_capacity(doc.len()); for word in doc { if !word_index_map.contains_key(word) { word_index_map.insert(word.clone(), index); index_word_map.insert(index, word.clone()); - index+=1; + index += 1; } - doc_vector.push(word_index_map.get(word).unwrap().clone()); + doc_vector.push(*word_index_map.get(word).unwrap()); } // dedupe vector and compact @@ -54,52 +61,58 @@ impl GSDMM { doc_vectors.push(doc_vector); } let V = index as f64; - println!("Fitting with alpha={}, beta={}, K={}, maxit={}, vocab size={}", alpha, beta, K, maxit, V as u32); + println!( + "Fitting with alpha={}, beta={}, K={}, maxit={}, vocab size={}", + alpha, beta, K, maxit, V as u32 + ); let clusters = (0_usize..K).collect::>(); let mut d_z: Vec = (0_usize..D).map(|_| 0_usize).collect::>(); // doc labels - let mut m_z: Vec = GSDMM::zero_vector(K); // cluster sizes - let mut n_z: Vec = GSDMM::zero_vector(K); // cluster word counts - let mut n_z_w = Vec::>::with_capacity(K); // container for cluster word distributions + let mut m_z: Vec = GSDMM::zero_vector(K); // cluster sizes + let mut n_z: Vec = GSDMM::zero_vector(K); // cluster word counts + let mut n_z_w = Vec::>::with_capacity(K); // container for cluster word distributions for _ in 0_usize..K { - let m = FnvHashMap::::with_capacity_and_hasher(max(vocab.len() / 10, 100), Default::default()); - &n_z_w.push(m); + let m = FnvHashMap::::with_capacity_and_hasher( + max(vocab.len() / 10, 100), + Default::default(), + ); + let _ = &n_z_w.push(m); } // randomly initialize cluster assignment let p = (0..K).map(|_| 1_f64 / (K as f64)).collect::>(); - let choices = random_choice().random_choice_f64(&clusters, &p, D) ; + let choices = random_choice().random_choice_f64(&clusters, &p, D); for i in 0..D { - let z = choices[i].clone(); - let ref doc = doc_vectors[i]; + let z = *choices[i]; + let doc = &doc_vectors[i]; d_z[i] = z; m_z[z] += 1; n_z[z] += doc.len() as u32; - let ref mut clust_words: FnvHashMap = n_z_w[z]; + let clust_words: &mut FnvHashMap = &mut n_z_w[z]; for word in doc { if !clust_words.contains_key(word) { - clust_words.insert(word.clone(), 0_u32); + clust_words.insert(*word, 0_u32); } - * clust_words.get_mut(word).unwrap() += 1_u32; + *clust_words.get_mut(word).unwrap() += 1_u32; } } GSDMM { - alpha: alpha, - beta: beta, - K: K, - V: V, - D: D, - maxit:maxit, - doc_vectors:doc_vectors, - clusters: clusters.clone(), // Don't totally get why we need the clone here! + alpha, + beta, + K, + V, + D, + maxit, + doc_vectors, + clusters: clusters.clone(), // Don't totally get why we need the clone here! labels: d_z, cluster_counts: m_z, cluster_word_counts: n_z, - word_index_map: word_index_map, - index_word_map: index_word_map, - cluster_word_distributions: n_z_w + word_index_map, + index_word_map, + cluster_word_distributions: n_z_w, } } @@ -108,7 +121,7 @@ impl GSDMM { for it in 0..self.maxit { let mut total_transfers = 0; for i in 0..self.D { - let ref doc = self.doc_vectors[i]; + let doc = &self.doc_vectors[i]; let doc_size = doc.len() as u32; // remove the doc from its current cluster @@ -118,7 +131,8 @@ impl GSDMM { // modify the map: enclose it in a block so we can borrow views again. { - let ref mut old_clust_words: FnvHashMap = self.cluster_word_distributions[z_old]; + let old_clust_words: &mut FnvHashMap = + &mut self.cluster_word_distributions[z_old]; for word in doc { *old_clust_words.get_mut(word).unwrap() -= 1_u32; @@ -130,10 +144,10 @@ impl GSDMM { } // update the probability vector - let p = self.score(&doc); + let p = self.score(doc); // choose the next cluster randomly according to the computed probability - let z_new: usize = random_choice().random_choice_f64(&self.clusters, &p, 1)[0].clone(); + let z_new: usize = *random_choice().random_choice_f64(&self.clusters, &p, 1)[0]; // transfer document to the new cluster if z_new != z_old { @@ -144,40 +158,51 @@ impl GSDMM { self.cluster_word_counts[z_new] += doc_size; { - let ref mut new_clust_words: FnvHashMap = self.cluster_word_distributions[z_new]; + let new_clust_words: &mut FnvHashMap = + &mut self.cluster_word_distributions[z_new]; for word in doc { if !new_clust_words.contains_key(word) { - new_clust_words.insert(word.clone(), 0_u32); + new_clust_words.insert(*word, 0_u32); } - *new_clust_words.get_mut(word).unwrap() += 1_u32; + *new_clust_words.get_mut(word).unwrap() += 1_u32; } } } - let new_number_clusters = self.cluster_word_distributions.iter().map(|c| if c.len()>0 {1} else {0} ).sum(); - println!("Iteration {}: {} docs transferred with {} clusters populated.", it, total_transfers, new_number_clusters); + let new_number_clusters = self + .cluster_word_distributions + .iter() + .map(|c| if !c.is_empty() { 1 } else { 0 }) + .sum(); + println!( + "Iteration {}: {} docs transferred with {} clusters populated.", + it, total_transfers, new_number_clusters + ); // apply ad-hoc convergence test - if total_transfers==0 && new_number_clusters==number_clusters { - println!("Converged after {} iterations. Solution has {} clusters.", it, new_number_clusters); - break + if total_transfers == 0 && new_number_clusters == number_clusters { + println!( + "Converged after {} iterations. Solution has {} clusters.", + it, new_number_clusters + ); + break; } number_clusters = new_number_clusters; } } - pub fn score(&self, doc:&Vec) -> Vec { - /// Score an input document using the formula of Yin and Wang 2014 (equation 3) - /// http://dbgroup.cs.tsinghua.edu.cn/wangjy/papers/KDD14-GSDMM.pdf - /// - /// # Arguments - /// - /// * `doc` - A vector of unique index tokens characterizing the document - /// - /// # Value - /// - /// Vec - A length K probability vector where each component represents the probability - /// of the doc belonging to a particular cluster. - /// + pub fn score(&self, doc: &Vec) -> Vec { + // Score an input document using the formula of Yin and Wang 2014 (equation 3) + // http://dbgroup.cs.tsinghua.edu.cn/wangjy/papers/KDD14-GSDMM.pdf + // + // # Arguments + // + // * `doc` - A vector of unique index tokens characterizing the document + // + // # Value + // + // Vec - A length K probability vector where each component represents the probability + // of the doc belonging to a particular cluster. + // // We break the formula into the following pieces // p = N1*N2/(D1*D2) = exp(lN1 - lD1 + lN2 - lD2) @@ -188,41 +213,39 @@ impl GSDMM { let mut p = (0..self.K).map(|_| 0_f64).collect::>(); let lD1 = ((self.D - 1) as f64 + (self.K as f64) * self.alpha).ln(); let doc_size = doc.len() as u32; - for label in 0_usize..self.K { + for (label, item) in p.iter_mut().enumerate().take(self.K) { let lN1 = (self.cluster_counts[label] as f64 + self.alpha).ln(); let mut lN2 = 0_f64; let mut lD2 = 0_f64; - - let ref cluster: FnvHashMap = self.cluster_word_distributions[label]; + let cluster: &FnvHashMap = &self.cluster_word_distributions[label]; for word in doc { lN2 += (*cluster.get(word).unwrap_or(&0_u32) as f64 + self.beta).ln(); } - for j in 1_u32..(doc_size+1) { - lD2 += ((self.cluster_word_counts[label] + j) as f64 - 1_f64 + self.V * self.beta).ln(); + for j in 1_u32..(doc_size + 1) { + lD2 += ((self.cluster_word_counts[label] + j) as f64 - 1_f64 + self.V * self.beta) + .ln(); } - p[label] = (lN1 - lD1 + lN2 - lD2).exp(); + *item = (lN1 - lD1 + lN2 - lD2).exp(); } // normalize the probability let pnorm: f64 = p.iter().sum(); - if pnorm>0_f64 { - for label in 0_usize..self.K { - p[label] = p[label] / pnorm; + if pnorm > 0_f64 { + for item in p.iter_mut().take(self.K) { + *item /= pnorm; } } p } - fn zero_vector(size:usize) -> Vec - { + fn zero_vector(size: usize) -> Vec { let mut v = Vec::::with_capacity(size); for _ in 0_usize..size { v.push(0_u32) } v } - } #[test] @@ -232,44 +255,50 @@ fn simple_run() { vocab.insert("B".to_string()); vocab.insert("C".to_string()); - let mut docs = Vec::>::new(); - docs.push(vec!("A".to_string())); - docs.push(vec!("A".to_string())); - docs.push(vec!("B".to_string())); - docs.push(vec!("B".to_string())); - docs.push(vec!("B".to_string())); - docs.push(vec!("B".to_string())); - docs.push(vec!("B".to_string())); - docs.push(vec!("B".to_string())); - docs.push(vec!("B".to_string())); - docs.push(vec!("B".to_string())); - docs.push(vec!("C".to_string())); - docs.push(vec!("C".to_string())); - docs.push(vec!("C".to_string())); - docs.push(vec!("C".to_string())); - docs.push(vec!("C".to_string())); - docs.push(vec!("C".to_string())); - docs.push(vec!("C".to_string())); - docs.push(vec!("C".to_string())); - - let mut model = GSDMM::new(0.1, 0.00001, 10, 30, vocab, docs); + let docs = vec![ + vec!["A".to_string()], + vec!["A".to_string()], + vec!["B".to_string()], + vec!["B".to_string()], + vec!["B".to_string()], + vec!["B".to_string()], + vec!["B".to_string()], + vec!["B".to_string()], + vec!["B".to_string()], + vec!["C".to_string()], + vec!["C".to_string()], + ]; + + let mut model = GSDMM::new(0.1, 0.00001, 10, 30, &vocab, &docs); model.fit(); // check the total number across all partitions is equal to the number of docs assert_eq!(18, model.cluster_counts.iter().sum::()); // check that we get three clusters - assert_eq!(3, model.cluster_counts.into_iter().filter(|x| x>&0_u32 ).collect::>().len()); + assert_eq!( + 3, + model + .cluster_counts + .into_iter() + .filter(|x| x > &0_u32) + .collect::>() + .len() + ); // check that the clusters are pure - let mut check_map = HashMap::::new(); - for (i,label) in vec!("A","A","B","B","B","B","B","B","B","B","C","C","C","C","C","C","C","C").into_iter().enumerate() { - if check_map.contains_key(&model.labels[i]) { - assert_eq!(check_map[&model.labels[i]], label); + let mut check_map = HashMap::::new(); + for (i, label) in vec![ + "A", "A", "B", "B", "B", "B", "B", "B", "B", "B", "C", "C", "C", "C", "C", "C", "C", "C", + ] + .into_iter() + .enumerate() + { + if let std::collections::hash_map::Entry::Vacant(e) = check_map.entry(model.labels[i]) { + e.insert(label.to_string()); } else { - check_map.insert(model.labels[i], label.to_string()); + assert_eq!(check_map[&model.labels[i]], label); } - } } @@ -281,12 +310,13 @@ fn indexing() { vocab.insert("C".to_string()); vocab.insert("D".to_string()); - let mut docs = Vec::>::new(); - docs.push(vec!("A".to_string(),"B".to_string())); - docs.push(vec!("D".to_string())); - docs.push(vec!("C".to_string())); + let docs = vec![ + vec!["A".to_string(), "B ".to_string()], + vec!["D".to_string()], + vec!["C".to_string()], + ]; - let mut model = GSDMM::new(0.1, 0.00001, 10, 30, vocab, docs); + let model = GSDMM::new(0.1, 0.00001, 10, 30, &vocab, &docs); // test the index mapping assert_eq!("A", model.index_word_map.get(&0_usize).unwrap()); @@ -297,5 +327,4 @@ fn indexing() { assert_eq!(1_usize, *model.word_index_map.get("B").unwrap()); assert_eq!(2_usize, *model.word_index_map.get("D").unwrap()); assert_eq!(3_usize, *model.word_index_map.get("C").unwrap()); - } diff --git a/src/main.rs b/src/main.rs index be63e95..eee6695 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,15 +1,16 @@ extern crate docopt; -extern crate rustc_serialize; extern crate gsdmm; +extern crate serde; -use gsdmm::GSDMM; use docopt::Docopt; -use std::io::{BufRead,BufReader}; -use std::fs::File; +use gsdmm::GSDMM; +use serde::Deserialize; use std::collections::HashSet; +use std::fs::File; use std::io::Write; +use std::io::{BufRead, BufReader}; -const USAGE: &'static str =" +const USAGE: &str =" Gibbs sampling algorithm for a Dirichlet Mixture Model of Yin and Wang 2014. Usage: @@ -30,7 +31,7 @@ Options: "; -#[derive(Debug, RustcDecodable)] +#[derive(Debug, Deserialize)] struct Args { // flag_mode: isize, arg_datafile: String, @@ -39,43 +40,56 @@ struct Args { flag_k: usize, flag_alpha: f64, flag_beta: f64, - flag_maxit: isize + flag_maxit: isize, } fn main() { - let args: Args = Docopt::new(USAGE) - .and_then(|d| d.decode()) + .and_then(|d| d.deserialize()) .unwrap_or_else(|e| e.exit()); // get the data and vocabulary - let vocab:HashSet = lines_from_file(&args.arg_vocabfile).into_iter().collect(); - let docs:Vec> = lines_from_file(&args.arg_datafile).into_iter().map(|line| { - let mut term_vector = line.to_owned() - .split_whitespace() - .map(|s| s.to_owned()) - .filter(|s| (&vocab).contains(s)) - .collect::>(); - - // sort and dedupe: this implementation requires binary term counts - term_vector.sort(); - term_vector.dedup(); - term_vector - }).collect::>>(); - - let mut model = GSDMM::new(args.flag_alpha, args.flag_beta, args.flag_k, args.flag_maxit, vocab, docs); + let vocab: HashSet = lines_from_file(&args.arg_vocabfile).into_iter().collect(); + let docs: Vec> = lines_from_file(&args.arg_datafile) + .into_iter() + .map(|line| { + let mut term_vector = line + .split_whitespace() + .map(|s| s.to_owned()) + .filter(|s| vocab.contains(s)) + .collect::>(); + + // sort and dedupe: this implementation requires binary term counts + term_vector.sort(); + term_vector.dedup(); + term_vector + }) + .collect::>>(); + + let mut model = GSDMM::new( + args.flag_alpha, + args.flag_beta, + args.flag_k, + args.flag_maxit, + &vocab, + &docs, + ); model.fit(); // write the labels { - let fname = (&args.arg_outprefix).clone() + "labels.csv"; - let error_msg = format ! ("Could not write file {}!", fname); - let mut f = File::create( fname ).expect( & error_msg); - let mut scored = Vec::<(String,String)>::new(); + let fname = args.arg_outprefix.clone() + "labels.csv"; + let error_msg = format!("Could not write file {}!", fname); + let mut f = File::create(fname).expect(&error_msg); + let mut scored = Vec::<(String, String)>::new(); // zip with the input data so we get clustered, raw input documents in the output set - for (doc,txt) in (&model.doc_vectors).iter().zip(lines_from_file(&args.arg_datafile).iter()) { - let p = model.score( & doc); + for (doc, txt) in model + .doc_vectors + .iter() + .zip(lines_from_file(&args.arg_datafile).iter()) + { + let p = model.score(doc); let mut row = p.iter().enumerate().collect::>(); if row_has_nan(&row, txt) { scored.push(("-1".to_string(), "0".to_string())); @@ -93,34 +107,40 @@ fn main() { // write the cluster descriptions { - let fname = (&args.arg_outprefix).clone() + "cluster_descriptions.txt"; + let fname = args.arg_outprefix.clone() + "cluster_descriptions.txt"; let error_msg = format!("Could not write file {}!", fname); let mut f = File::create(fname).expect(&error_msg); for k in 0..args.flag_k { - let ref word_dist = model.cluster_word_distributions[k]; + let word_dist = &model.cluster_word_distributions[k]; let mut line = k.to_string() + " "; - let mut dist_counts:Vec = word_dist.iter().map(|(a,b)| model.index_word_map.get(a).unwrap().to_string() + ":" + &b.clone().to_string() ).collect(); + let mut dist_counts: Vec = word_dist + .iter() + .map(|(a, b)| { + model.index_word_map.get(a).unwrap().to_string() + ":" + &b.clone().to_string() + }) + .collect(); dist_counts.sort(); line += &dist_counts.join(" "); - let _ = f.write((line+"\n").as_bytes()); + let _ = f.write((line + "\n").as_bytes()); } } - fn lines_from_file(filename: &str) -> Vec - { + fn lines_from_file(filename: &str) -> Vec { let error_msg = format!("Could not read file {}!", filename); let file = File::open(filename).expect(&error_msg); let buf = BufReader::new(file); - buf.lines().map(|l| l.expect("Could not parse line!")).collect() + buf.lines() + .map(|l| l.expect("Could not parse line!")) + .collect() } - fn row_has_nan(row:&Vec<(usize, &f64)>, doc:&String) -> bool { + fn row_has_nan(row: &Vec<(usize, &f64)>, doc: &String) -> bool { for entry in row { if entry.1.is_nan() { println!("Cluster: {:?} has NaN score for document {:?}", entry, doc); - return true + return true; } } - return false; + false } }