Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow users to specify prior process for row and column partitions #191

Merged
merged 44 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
f7087ae
Incremental (non-compiling) progress
Feb 26, 2024
dab8c62
More things compiling
Feb 28, 2024
6049aab
Move geweke examples to lace_cc; get view passing
Feb 28, 2024
4d5f27f
Get state geweke example to compile
Feb 28, 2024
9dcc954
incremental updates
Feb 29, 2024
edc4c4c
PriorProcess builder
Mar 4, 2024
99dbdbe
Trying to get Geweke and enumeration tests to pass
Mar 5, 2024
a712c68
state Gibbs passing enum tests but not Geweke
Mar 6, 2024
9fe0f3d
lace, but not pylace, compiling; untested
Mar 7, 2024
2db31e7
pylace compiling
Mar 7, 2024
4f693e8
Got pitman-yor P(parition) correct
Mar 11, 2024
5d0a047
All lace_cc tests pass with pitman-yor
Mar 12, 2024
bae9f3b
Clean up prior process methods
Mar 13, 2024
272d56a
making progress on tests
Mar 18, 2024
1e2e27a
lace_stats tests passing
Mar 18, 2024
c2d6f8b
lace_codebook tests passing
Mar 18, 2024
8f2bfc6
lace tests compile; don't pass
Mar 18, 2024
80d38c0
lace tests passing
Mar 18, 2024
b649e72
lace tests passing for --all-features
Mar 19, 2024
4142231
Lace CLI tests pass
Mar 19, 2024
ffc92c3
pylace tests passing
Mar 19, 2024
2ba83e8
Get the book working
Mar 19, 2024
a8dd136
Merge branch 'master' into feature/pitman-yor
BaxterEaves Mar 19, 2024
8fee02a
fix field in benchmarks
Mar 20, 2024
a92ca9d
Attempt to get prev version seed ctrl working
Mar 20, 2024
45f1cc4
Fix pylace dataset codebooks
Mar 20, 2024
04f1273
Nevermind
Mar 21, 2024
59f3de3
Fix engine.py doctests
Mar 29, 2024
82d1d6a
Get doctests to pass
Mar 29, 2024
612ee9e
remove lace_cc misc.rs
Mar 29, 2024
4bf68cb
Fixing doctests
Mar 30, 2024
6c7fc01
Fixed test in mdbook
Mar 31, 2024
2f8a4a5
Add metadata read/convert tests
Apr 1, 2024
9216856
Increase number of states for flaky test
Apr 1, 2024
39a51e2
Fix book and add CRP figures
Apr 2, 2024
0c15294
Update pitman-yor image in book
Apr 2, 2024
982c4d9
updating versions
Apr 9, 2024
c4bed75
feat(python): Added `remove_rows` to pylace CoreEngine which removes …
schmidmt Apr 25, 2024
d1f8915
feat(python): Added Codebook.with_index to create a new codebook from…
schmidmt Apr 25, 2024
1fb140f
feat(python): Added tests for Codebook.with_index and Engine.remove_rows
schmidmt May 2, 2024
9238cc5
chore: Updated CI's maturin, ARM building on MacOS 14, and book versions
schmidmt Apr 30, 2024
a063d09
feat: Updated pyo3
schmidmt May 1, 2024
4092fdf
chore: Updated changelog and citation
schmidmt May 2, 2024
0d405de
Added note about refactor to contributing
May 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions book/lace_preprocess_mdbook_yaml/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 21 additions & 44 deletions book/lace_preprocess_mdbook_yaml/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::collections::HashMap;

use anyhow::anyhow;
use log::debug;
use mdbook::{
Expand All @@ -12,7 +10,15 @@ use regex::Regex;

use serde::Deserialize;

type GammaMap = HashMap<String, lace_stats::rv::dist::Gamma>;
#[derive(Deserialize)]
struct ViewPriorProcess {
pub view_prior_process: lace_codebook::PriorProcess,
}

#[derive(Deserialize)]
struct StatePriorProcess {
pub state_prior_process: lace_codebook::PriorProcess,
}

fn check_deserialize_yaml<T>(input: &str) -> anyhow::Result<()>
where
Expand Down Expand Up @@ -54,17 +60,14 @@ macro_rules! check_deserialize_arm {
}
}

fn check_deserialize_dyn(
input: &str,
type_name: &str,
format: &str,
) -> anyhow::Result<()> {
fn check_deserialize_dyn(input: &str, type_name: &str, format: &str) -> anyhow::Result<()> {
check_deserialize_arm!(
input,
type_name,
format,
[
GammaMap,
ViewPriorProcess,
StatePriorProcess,
lace_codebook::ColMetadata,
lace_codebook::ColMetadataList
]
Expand All @@ -80,35 +83,21 @@ impl YamlTester {
YamlTester
}

fn examine_chapter_content(
&self,
content: &str,
re: &Regex,
) -> anyhow::Result<()> {
fn examine_chapter_content(&self, content: &str, re: &Regex) -> anyhow::Result<()> {
let parser = Parser::new(content);
let mut code_block = Some(String::new());

for event in parser {
match event {
Event::Start(Tag::CodeBlock(CodeBlockKind::Fenced(
ref code_block_string,
))) => {
Event::Start(Tag::CodeBlock(CodeBlockKind::Fenced(ref code_block_string))) => {
if re.is_match(code_block_string) {
debug!(
"YAML Block Start, identifier string={}",
code_block_string
);
debug!("YAML Block Start, identifier string={}", code_block_string);
code_block = Some(String::new());
}
}
Event::End(Tag::CodeBlock(CodeBlockKind::Fenced(
ref code_block_string,
))) => {
Event::End(Tag::CodeBlock(CodeBlockKind::Fenced(ref code_block_string))) => {
if let Some(captures) = re.captures(code_block_string) {
debug!(
"Code Block End, identifier string={}",
code_block_string
);
debug!("Code Block End, identifier string={}", code_block_string);

let serialization_format = captures
.get(1)
Expand All @@ -119,21 +108,13 @@ impl YamlTester {
.get(2)
.ok_or(anyhow!("No deserialize type found"))?
.as_str();
debug!(
"Target deserialization type is {}",
target_type
);
debug!("Target deserialization type is {}", target_type);

let final_block = code_block.take();
let final_block =
final_block.ok_or(anyhow!("No YAML text found"))?;
let final_block = final_block.ok_or(anyhow!("No YAML text found"))?;
debug!("Code block ended up as\n{}", final_block);

check_deserialize_dyn(
&final_block,
target_type,
serialization_format,
)?;
check_deserialize_dyn(&final_block, target_type, serialization_format)?;
}
}
Event::Text(ref text) => {
Expand All @@ -154,11 +135,7 @@ impl Preprocessor for YamlTester {
"lace-yaml-tester"
}

fn run(
&self,
_ctx: &PreprocessorContext,
book: Book,
) -> anyhow::Result<Book> {
fn run(&self, _ctx: &PreprocessorContext, book: Book) -> anyhow::Result<Book> {
debug!("Starting the run");
let re = Regex::new(r"^(yaml|json).*,deserializeTo=([^,]+)").unwrap();
for book_item in book.iter() {
Expand Down
1 change: 1 addition & 0 deletions book/src/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- [Data Simulation](./pcc/simulate.md)
- [In- and out-of-table operations](./pcc/inouttable.md)
- [Adding data to a model](./pcc/add-data.md)
- [Prior processes](./pcc/prior-processes.md)
- [Preparing your data](./data/basics.md)
- [Codebook reference](./codebook-ref.md)
- [Appendix](./appendix/appendix.md)
Expand Down
52 changes: 36 additions & 16 deletions book/src/codebook-ref.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,50 @@ information about

String name of the table. For your reference.

### `state_alpha_prior`
### `state_prior_process`

A gamma prior on the Chinese Restaurant Process (CRP) alpha parameter assigning
columns to views.
The prior process used for assigning columns to views. Can either be a Dirichlet process with a Gamma prior on alpha:

Example with a gamma prior
```yaml,deserializeTo=StatePriorProcess
state_prior_process: !dirichlet
alpha_prior:
shape: 1.0
rate: 1.0
```

or a Pitman-Yor process with a Gamma prior on alpha and a Beta prior on d.

```yaml,deserializeTo=GammaMap
state_alpha_prior:
shape: 1.0
rate: 1.0
```yaml,deserializeTo=StatePriorProcess
state_prior_process: !pitman_yor
alpha_prior:
shape: 1.0
rate: 1.0
d_prior:
alpha: 0.5
beta: 0.5
```

### `view_alpha_prior`
### `view_prior_process`

A gamma prior on the Chinese Restaurant Process (CRP) alpha parameter assigning
rows within views to categories.
The prior process used for assigning rows to categories. Can either be a Dirichlet process with a Gamma prior on alpha:

```yaml,deserializeTo=ViewPriorProcess
view_prior_process: !dirichlet
alpha_prior:
shape: 1.0
rate: 1.0
```

Example with a gamma prior
or a Pitman-Yor process with a Gamma prior on alpha and a Beta prior on d.

```yaml,deserializeTo=GammaMap
view_alpha_prior:
shape: 1.0
rate: 1.0
```yaml,deserializeTo=StatePriorProcess
view_prior_process: !pitman_yor
alpha_prior:
shape: 1.0
rate: 1.0
d_prior:
alpha: 0.5
beta: 0.5
```

### `col_metadata`
Expand Down
12 changes: 12 additions & 0 deletions book/src/pcc/prior-processes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Prior Processes

In Lace (and in Bayesian nonparametrics) we put a prior on the number of parameters. This *prior process* formalizes how instances are distributed to an unknown number of categories. Lace gives you two options

- The one-parameter Dirichlet process, `DP(α)`
- The two-parameter Pitman-Yor process, `PYP(α, d)`

The Dirichlet process more heavily penalizes new categories with an exponential fall off while the Pitman-Yor process has a power law fall off in the number for categories. When d = 0, Pitman-Yor is equivalent to the Dirichlet process.

While Pitman-Yor may fit the data better it will create more parameters, which will cause model training to take longer.

For those looking for a good introduction to prior process, [this slide deck](https://www.gatsby.ucl.ac.uk/~ywteh/teaching/probmodels/lecture5bnp.pdf) from Yee Whye Teh is a good resource.
2 changes: 1 addition & 1 deletion book/src/workflow/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ engine.update(
save_path="mydata.lace",
transitions=[
StateTransition.row_assignment(RowKernel.slice()),
StateTransition.view_alphas(),
StateTransition.view_prior_process_params(),
],
)
```
Expand Down
2 changes: 1 addition & 1 deletion cli/README.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
# Lace CLI
# Lace CLI
16 changes: 9 additions & 7 deletions cli/resources/datasets/animals/codebook.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
table_name: my_table
state_alpha_prior:
shape: 1.0
rate: 1.0
view_alpha_prior:
shape: 1.0
rate: 1.0
table_name: animals
state_prior_process: !dirichlet
alpha_prior:
shape: 1.0
rate: 1.0
view_prior_process: !dirichlet
alpha_prior:
shape: 1.0
rate: 1.0
col_metadata:
- name: black
coltype: !Categorical
Expand Down
16 changes: 9 additions & 7 deletions cli/resources/datasets/satellites/codebook.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
table_name: my_table
state_alpha_prior:
shape: 1.0
rate: 1.0
view_alpha_prior:
shape: 1.0
rate: 1.0
table_name: satellites
state_prior_process: !dirichlet
alpha_prior:
shape: 1.0
rate: 1.0
view_prior_process: !dirichlet
alpha_prior:
shape: 1.0
rate: 1.0
col_metadata:
- name: Country_of_Operator
coltype: !Categorical
Expand Down
27 changes: 12 additions & 15 deletions cli/src/opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ pub enum Transition {
ColumnAssignment,
ComponentParams,
RowAssignment,
StateAlpha,
ViewAlphas,
StatePriorProcessParams,
ViewPriorProcessParams,
FeaturePriors,
}

Expand All @@ -33,8 +33,8 @@ impl std::str::FromStr for Transition {
match s {
"column_assignment" => Ok(Self::ColumnAssignment),
"row_assignment" => Ok(Self::RowAssignment),
"state_alpha" => Ok(Self::StateAlpha),
"view_alphas" => Ok(Self::ViewAlphas),
"state_prior_process_params" => Ok(Self::StatePriorProcessParams),
"view_prior_process_params" => Ok(Self::ViewPriorProcessParams),
"feature_priors" => Ok(Self::FeaturePriors),
"component_params" => Ok(Self::ComponentParams),
_ => Err(format!("cannot parse '{s}'")),
Expand Down Expand Up @@ -142,17 +142,17 @@ impl RunArgs {
let transitions = match self.transitions {
None => vec![
StateTransition::ColumnAssignment(col_alg),
StateTransition::StateAlpha,
StateTransition::StatePriorProcessParams,
StateTransition::RowAssignment(row_alg),
StateTransition::ViewAlphas,
StateTransition::ViewPriorProcessParams,
StateTransition::FeaturePriors,
],
Some(ref ts) => ts
.iter()
.map(|t| match t {
Transition::FeaturePriors => StateTransition::FeaturePriors,
Transition::StateAlpha => StateTransition::StateAlpha,
Transition::ViewAlphas => StateTransition::ViewAlphas,
Transition::StatePriorProcessParams => StateTransition::StatePriorProcessParams,
Transition::ViewPriorProcessParams => StateTransition::ViewPriorProcessParams,
Transition::ComponentParams => StateTransition::ComponentParams,
Transition::RowAssignment => StateTransition::RowAssignment(row_alg),
Transition::ColumnAssignment => StateTransition::ColumnAssignment(col_alg),
Expand Down Expand Up @@ -264,9 +264,6 @@ pub struct CodebookArgs {
/// Parquet input filename
#[clap(long = "parquet", group = "src")]
pub parquet_src: Option<PathBuf>,
/// CRP alpha prior on columns and rows
#[clap(long, default_value = "1.0, 1.0")]
pub alpha_prior: GammaParameters,
/// Codebook out. May be either json or yaml
#[clap(name = "CODEBOOK_OUT")]
pub output: PathBuf,
Expand Down Expand Up @@ -343,16 +340,16 @@ mod tests {
#[test]
fn view_alphas_from_str() {
assert_eq!(
Transition::from_str("view_alphas").unwrap(),
Transition::ViewAlphas
Transition::from_str("view_prior_process_params").unwrap(),
Transition::ViewPriorProcessParams
);
}

#[test]
fn state_alpha_from_str() {
assert_eq!(
Transition::from_str("state_alpha").unwrap(),
Transition::StateAlpha
Transition::from_str("state_prior_process_params").unwrap(),
Transition::StatePriorProcessParams
);
}

Expand Down
Loading
Loading