Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add serialization for ignore_merges #1504

Merged
merged 3 commits into from
Apr 17, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions tokenizers/src/models/bpe/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ impl Serialize for BPE {
model.serialize_field("end_of_word_suffix", &self.end_of_word_suffix)?;
model.serialize_field("fuse_unk", &self.fuse_unk)?;
model.serialize_field("byte_fallback", &self.byte_fallback)?;
model.serialize_field("ignore_merges", &self.ignore_merges)?;

// Then the large ones
let mut merges: Vec<(&Pair, &u32)> = self
Expand Down Expand Up @@ -57,6 +58,7 @@ impl<'de> Deserialize<'de> for BPE {
"end_of_word_suffix",
"fuse_unk",
"byte_fallback",
"ignore_merges",
"vocab",
"merges",
],
Expand Down Expand Up @@ -112,6 +114,11 @@ impl<'de> Visitor<'de> for BPEVisitor {
builder = builder.byte_fallback(suffix);
}
}
"ignore_merges" => {
if let Some(suffix) = map.next_value()? {
builder = builder.ignore_merges(suffix);
}
}
"vocab" => vocab = Some(map.next_value()?),
"merges" => merges = Some(map.next_value()?),
"type" => match map.next_value()? {
Expand All @@ -136,3 +143,45 @@ impl<'de> Visitor<'de> for BPEVisitor {
}
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::models::bpe::Vocab;

#[test]
fn test_serialization() {
let vocab: Vocab = [("<unk>".into(), 0), ("a".into(), 1), ("b".into(), 2)]
.iter()
.cloned()
.collect();
let bpe = BpeBuilder::default()
.vocab_and_merges(vocab, vec![])
.unk_token("<unk>".to_string())
.ignore_merges(true)
.build()
.unwrap();

let data = serde_json::to_string(&bpe).unwrap();
let reconstructed = serde_json::from_str(&data).unwrap();

assert_eq!(bpe, reconstructed);
}

#[test]
fn test_serialization_ignore_merges() {
let vocab: Vocab = [("<unk>".into(), 0), ("a".into(), 1), ("b".into(), 2)]
.iter()
.cloned()
.collect();
let bpe = BpeBuilder::default()
.vocab_and_merges(vocab, vec![])
.unk_token("<unk>".to_string())
.ignore_merges(true)
.build()
.unwrap();

let bpe_string = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2},"merges":[]}"#;
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
assert_eq!(serde_json::from_str::<BPE>(bpe_string).unwrap(), bpe);
}
}
Loading