From 0a60b16e4b3791429052e9d9f385cd85478db4cf Mon Sep 17 00:00:00 2001 From: ZanSara Date: Mon, 28 Aug 2023 18:35:40 +0200 Subject: [PATCH] serialization methods for LocalWhisperTranscriber --- .../preview/components/audio/whisper_local.py | 8 ++-- .../components/audio/test_whisper_local.py | 41 +++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/haystack/preview/components/audio/whisper_local.py b/haystack/preview/components/audio/whisper_local.py index ae04362ddb..35fdef229a 100644 --- a/haystack/preview/components/audio/whisper_local.py +++ b/haystack/preview/components/audio/whisper_local.py @@ -6,7 +6,7 @@ import torch import whisper -from haystack.preview import component, Document +from haystack.preview import component, Document, default_to_dict, default_from_dict logger = logging.getLogger(__name__) @@ -59,14 +59,16 @@ def to_dict(self) -> Dict[str, Any]: """ Serialize this component to a dictionary. """ - # return default_to_dict(self, model_name_or_path=self.model_name, device=self.device, whisper_params=self.whisper_params) + return default_to_dict( + self, model_name_or_path=self.model_name, device=str(self.device), whisper_params=self.whisper_params + ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "LocalWhisperTranscriber": """ Deserialize this component from a dictionary. """ - # return default_from_dict(cls, data) + return default_from_dict(cls, data) @component.output_types(documents=List[Document]) def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None): diff --git a/test/preview/components/audio/test_whisper_local.py b/test/preview/components/audio/test_whisper_local.py index 6dfcef8c1e..6132c2cd32 100644 --- a/test/preview/components/audio/test_whisper_local.py +++ b/test/preview/components/audio/test_whisper_local.py @@ -26,6 +26,47 @@ def test_init_wrong_model(self): with pytest.raises(ValueError, match="Model name 'whisper-1' not recognized"): LocalWhisperTranscriber(model_name_or_path="whisper-1") + @pytest.mark.unit + def test_to_dict(self): + transcriber = LocalWhisperTranscriber() + data = transcriber.to_dict() + assert data == { + "type": "LocalWhisperTranscriber", + "init_parameters": {"model_name_or_path": "large", "device": "cpu", "whisper_params": {}}, + } + + @pytest.mark.unit + def test_to_dict_with_custom_init_parameters(self): + transcriber = LocalWhisperTranscriber( + model_name_or_path="tiny", + device="cuda", + whisper_params={"return_segments": True, "temperature": [0.1, 0.6, 0.8]}, + ) + data = transcriber.to_dict() + assert data == { + "type": "LocalWhisperTranscriber", + "init_parameters": { + "model_name_or_path": "tiny", + "device": "cuda", + "whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}, + }, + } + + @pytest.mark.unit + def test_from_dict(self): + data = { + "type": "LocalWhisperTranscriber", + "init_parameters": { + "model_name_or_path": "tiny", + "device": "cuda", + "whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}, + }, + } + transcriber = LocalWhisperTranscriber.from_dict(data) + assert transcriber.model_name == "tiny" + assert transcriber.device == torch.device("cuda") + assert transcriber.whisper_params == {"return_segments": True, "temperature": [0.1, 0.6, 0.8]} + @pytest.mark.unit def test_warmup(self): with patch("haystack.preview.components.audio.whisper_local.whisper") as mocked_whisper: