diff --git a/tokenizers/src/models/bpe/serialization.rs b/tokenizers/src/models/bpe/serialization.rs index 8f60d6ddb..ca2590f1d 100644 --- a/tokenizers/src/models/bpe/serialization.rs +++ b/tokenizers/src/models/bpe/serialization.rs @@ -21,6 +21,7 @@ impl Serialize for BPE { model.serialize_field("end_of_word_suffix", &self.end_of_word_suffix)?; model.serialize_field("fuse_unk", &self.fuse_unk)?; model.serialize_field("byte_fallback", &self.byte_fallback)?; + model.serialize_field("ignore_merges", &self.ignore_merges)?; // Then the large ones let mut merges: Vec<(&Pair, &u32)> = self @@ -57,6 +58,7 @@ impl<'de> Deserialize<'de> for BPE { "end_of_word_suffix", "fuse_unk", "byte_fallback", + "ignore_merges", "vocab", "merges", ], @@ -112,6 +114,11 @@ impl<'de> Visitor<'de> for BPEVisitor { builder = builder.byte_fallback(suffix); } } + "ignore_merges" => { + if let Some(suffix) = map.next_value()? { + builder = builder.ignore_merges(suffix); + } + } "vocab" => vocab = Some(map.next_value()?), "merges" => merges = Some(map.next_value()?), "type" => match map.next_value()? { @@ -136,3 +143,49 @@ impl<'de> Visitor<'de> for BPEVisitor { } } } + +#[cfg(test)] +mod test { + use super::*; + use crate::models::bpe::Vocab; + + #[test] + fn test_serialization() { + let vocab: Vocab = [("".into(), 0), ("a".into(), 1), ("b".into(), 2)] + .iter() + .cloned() + .collect(); + let bpe = BpeBuilder::default() + .vocab_and_merges(vocab, vec![]) + .unk_token("".to_string()) + .ignore_merges(true) + .build() + .unwrap(); + + let data = serde_json::to_string(&bpe).unwrap(); + let reconstructed = serde_json::from_str(&data).unwrap(); + + assert_eq!(bpe, reconstructed); + } + + #[test] + fn test_serialization_ignore_merges() { + let vocab: Vocab = [("".into(), 0), ("a".into(), 1), ("b".into(), 2)] + .iter() + .cloned() + .collect(); + let mut bpe = BpeBuilder::default() + .vocab_and_merges(vocab, vec![]) + .unk_token("".to_string()) + .ignore_merges(true) + .build() + .unwrap(); + + let bpe_string = r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2},"merges":[]}"#; + assert_eq!(serde_json::from_str::(bpe_string).unwrap(), bpe); + + bpe.ignore_merges = false; + let bpe_string = r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"vocab":{"":0,"a":1,"b":2},"merges":[]}"#; + assert_eq!(serde_json::from_str::(bpe_string).unwrap(), bpe); + } +}