Skip to content

Commit

Permalink
Refine row sams constrainer trait
Browse files Browse the repository at this point in the history
  • Loading branch information
Baxter Eaves committed May 3, 2024
1 parent 5fe9be0 commit 92a85c7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
10 changes: 7 additions & 3 deletions lace/lace_cc/src/constrain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub struct RowSamsInfo {
pub j: usize,
pub z_i: usize,
pub z_j: usize,
pub z_split: usize,
}

pub trait RowSamsConstrainer: Sync {
Expand All @@ -58,20 +59,23 @@ pub trait RowSamsConstrainer: Sync {
/// split.
fn ln_sis_contstraints(&mut self, ix: usize) -> (f64, f64);

/// Assign ix to z during SIS
fn sis_assign(&mut self, ix: usize, to_proposed_cluster: bool);

/// Should return the log hastings ratio constraint, which is
/// ln p(x|z_proposed) - ln p (x|z_current)
fn ln_mh_constraint(&self, asgn_proposed: &Assignment) -> f64;
}

impl RowSamsConstrainer for () {
fn initialize(&mut self, _info: RowSamsInfo) {
()
}
fn initialize(&mut self, _info: RowSamsInfo) {}

fn ln_sis_contstraints(&mut self, _ix: usize) -> (f64, f64) {
(0.0, 0.0)
}

fn sis_assign(&mut self, _ix: usize, _to_proposed_cluster: bool) {}

fn ln_mh_constraint(&self, _asgn_proposed: &Assignment) -> f64 {
0.0
}
Expand Down
18 changes: 16 additions & 2 deletions lace/lace_cc/src/view/sams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,23 @@ impl View {
}
};

constrainer.initialize(RowSamsInfo { i, j, z_i, z_j });

if z_i == z_j {
constrainer.initialize(RowSamsInfo {
i,
j,
z_i,
z_j,
z_split: self.asgn_mut().n_cats,
});
self.sams_split(i, j, constrainer, rng);
} else {
constrainer.initialize(RowSamsInfo {
i,
j,
z_i,
z_j,
z_split: z_j,
});
assert!(z_i < z_j);
self.sams_merge(i, j, constrainer, rng);
}
Expand Down Expand Up @@ -219,11 +231,13 @@ impl View {
};

if assign_to_zi {
constrainer.sis_assign(ix, false);
logq += logp_z_i - lognorm;
self.force_observe_row(ix, z_i_tmp);
nk_i += 1.0;
tmp_z[ix] = z_i_tmp;
} else {
constrainer.sis_assign(ix, true);
logq += logp_z_j - lognorm;
self.force_observe_row(ix, z_j_tmp);
nk_j += 1.0;
Expand Down

0 comments on commit 92a85c7

Please sign in to comment.