Skip to content

Commit

Permalink
chore: migrate to canals==0.7.0 (#5647)
Browse files Browse the repository at this point in the history
* add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0

* canals==0.7.0

* whisper components

* add to_dict/from_dict stubs

* import serialization methods in init to hide canals imports

* reno

* export deserializationerror too

* Update haystack/preview/__init__.py

Co-authored-by: Silvano Cerza <[email protected]>

* serialization methods for LocalWhisperTranscriber (#5648)

* chore: serialization methods for `FileExtensionClassifier` (#5651)

* serialization methods for FileExtensionClassifier

* Update test_file_classifier.py

* chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652)

* serialization methods for SentenceTransformersDocumentEmbedder

* fix device management

* serialization methods for SentenceTransformersTextEmbedder (#5653)

* serialization methods for TextFileToDocument (#5654)

* chore: serialization methods for `RemoteWhisperTranscriber` (#5650)

* serialization methods for RemoteWhisperTranscriber

* remove patches

* Add default to_dict and from_dict in document stores built with factory (#5674)

* fix tests (#5671)

* chore: simplify serialization methods for `MemoryDocumentStore` (#5667)

* simplify serialization for MemoryDocumentStore

* remove redundant tests

* pylint

* chore: serialization methods for `MemoryRetriever` (#5663)

* serialization method for MemoryRetriever

* more tests

* remove hash from default_document_store_to_dict

* remove diff in factory.py

* chore: serialization methods for `DocumentWriter` (#5661)

* serialization methods for DocumentWriter

* more tests

* use factory

* black

---------

Co-authored-by: Silvano Cerza <[email protected]>
  • Loading branch information
ZanSara and silvanocerza authored Aug 29, 2023
1 parent a613b1b commit b1daa7c
Show file tree
Hide file tree
Showing 27 changed files with 699 additions and 171 deletions.
2 changes: 2 additions & 0 deletions haystack/preview/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from canals import component, Pipeline
from canals.serialization import default_from_dict, default_to_dict
from canals.errors import DeserializationError
from haystack.preview.dataclasses import *
17 changes: 16 additions & 1 deletion haystack/preview/components/audio/whisper_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import whisper

from haystack.preview import component, Document
from haystack.preview import component, Document, default_to_dict, default_from_dict


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,6 +55,21 @@ def warm_up(self) -> None:
if not self._model:
self._model = whisper.load_model(self.model_name, device=self.device)

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self, model_name_or_path=self.model_name, device=str(self.device), whisper_params=self.whisper_params
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "LocalWhisperTranscriber":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None):
"""
Expand Down
28 changes: 20 additions & 8 deletions haystack/preview/components/audio/whisper_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path

from haystack.preview.utils import request_with_retry
from haystack.preview import component, Document
from haystack.preview import component, Document, default_to_dict, default_from_dict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,17 +49,29 @@ def __init__(
if not api_key:
raise ValueError("API key is None.")

self.model_name = model_name
self.api_key = api_key
self.api_base = api_base
self.whisper_params = whisper_params or {}

self.model_name = model_name
self.init_parameters = {
"api_key": self.api_key,
"model_name": self.model_name,
"api_base": self.api_base,
"whisper_params": self.whisper_params,
}
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
model_name=self.model_name,
api_key=self.api_key,
api_base=self.api_base,
whisper_params=self.whisper_params,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "RemoteWhisperTranscriber":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None):
Expand Down
20 changes: 15 additions & 5 deletions haystack/preview/components/classifiers/file_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import mimetypes
from collections import defaultdict
from pathlib import Path
from typing import List, Union, Optional
from typing import List, Union, Optional, Dict, Any

from haystack.preview import component
from haystack.preview import component, default_from_dict, default_to_dict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,12 +38,22 @@ def __init__(self, mime_types: List[str]):
f"Unknown mime type: '{mime_type}'. Ensure you passed a list of strings in the 'mime_types' parameter"
)

# save the init parameters for serialization
self.init_parameters = {"mime_types": mime_types}

component.set_output_types(self, unclassified=List[Path], **{mime_type: List[Path] for mime_type in mime_types})
self.mime_types = mime_types

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, mime_types=self.mime_types)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "FileExtensionClassifier":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)

def run(self, paths: List[Union[str, Path]]):
"""
Run the FileExtensionClassifier.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List, Optional, Union
from typing import List, Optional, Union, Dict, Any

from haystack.preview import component
from haystack.preview import Document
from haystack.preview import component, Document, default_to_dict, default_from_dict
from haystack.preview.embedding_backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
Expand Down Expand Up @@ -42,14 +41,37 @@ def __init__(

self.model_name_or_path = model_name_or_path
# TODO: remove device parameter and use Haystack's device management once migrated
self.device = device
self.device = device or "cpu"
self.use_auth_token = use_auth_token
self.batch_size = batch_size
self.progress_bar = progress_bar
self.normalize_embeddings = normalize_embeddings
self.metadata_fields_to_embed = metadata_fields_to_embed or []
self.embedding_separator = embedding_separator

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
model_name_or_path=self.model_name_or_path,
device=self.device,
use_auth_token=self.use_auth_token,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
metadata_fields_to_embed=self.metadata_fields_to_embed,
embedding_separator=self.embedding_separator,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedder":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)

def warm_up(self):
"""
Load the embedding backend.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional, Union
from typing import List, Optional, Union, Dict, Any

from haystack.preview import component
from haystack.preview import component, default_to_dict, default_from_dict
from haystack.preview.embedding_backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
Expand Down Expand Up @@ -40,14 +40,37 @@ def __init__(

self.model_name_or_path = model_name_or_path
# TODO: remove device parameter and use Haystack's device management once migrated
self.device = device
self.device = device or "cpu"
self.use_auth_token = use_auth_token
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
self.progress_bar = progress_bar
self.normalize_embeddings = normalize_embeddings

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
model_name_or_path=self.model_name_or_path,
device=self.device,
use_auth_token=self.use_auth_token,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)

def warm_up(self):
"""
Load the embedding backend.
Expand Down
25 changes: 23 additions & 2 deletions haystack/preview/components/file_converters/txt.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
from pathlib import Path
from typing import Optional, List, Union, Dict
from typing import Optional, List, Union, Dict, Any

from canals.errors import PipelineRuntimeError
from tqdm import tqdm

from haystack.preview.lazy_imports import LazyImport
from haystack.preview import Document, component
from haystack.preview import Document, component, default_to_dict, default_from_dict

with LazyImport("Run 'pip install farm-haystack[preprocessing]'") as langdetect_import:
import langdetect
Expand Down Expand Up @@ -61,6 +61,27 @@ def __init__(
self.id_hash_keys = id_hash_keys or []
self.progress_bar = progress_bar

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
encoding=self.encoding,
remove_numeric_tables=self.remove_numeric_tables,
numeric_row_threshold=self.numeric_row_threshold,
valid_languages=self.valid_languages,
id_hash_keys=self.id_hash_keys,
progress_bar=self.progress_bar,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TextFileToDocument":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
self,
Expand Down
31 changes: 29 additions & 2 deletions haystack/preview/components/retrievers/memory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, List, Any, Optional

from haystack.preview import component, Document
from haystack.preview.document_stores import MemoryDocumentStore
from haystack.preview import component, Document, default_to_dict, default_from_dict, DeserializationError
from haystack.preview.document_stores import MemoryDocumentStore, document_store


@component
Expand Down Expand Up @@ -41,6 +41,33 @@ def __init__(
self.top_k = top_k
self.scale_score = scale_score

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
docstore = self.document_store.to_dict()
return default_to_dict(
self, document_store=docstore, filters=self.filters, top_k=self.top_k, scale_score=self.scale_score
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MemoryRetriever":
"""
Deserialize this component from a dictionary.
"""
init_params = data.get("init_parameters", {})
if "document_store" not in init_params:
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
if init_params["document_store"]["type"] not in document_store.registry:
raise DeserializationError(f"DocumentStore type '{init_params['document_store']['type']}' not found")

docstore_class = document_store.registry[init_params["document_store"]["type"]]
docstore = docstore_class.from_dict(init_params["document_store"])
data["init_parameters"]["document_store"] = docstore
return default_from_dict(cls, data)

@component.output_types(documents=List[List[Document]])
def run(
self,
Expand Down
31 changes: 28 additions & 3 deletions haystack/preview/components/writers/document_writer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional
from typing import List, Optional, Dict, Any

from haystack.preview import component, Document
from haystack.preview.document_stores import DocumentStore, DuplicatePolicy
from haystack.preview import component, Document, default_from_dict, default_to_dict, DeserializationError
from haystack.preview.document_stores import DocumentStore, DuplicatePolicy, document_store


@component
Expand All @@ -19,6 +19,31 @@ def __init__(self, document_store: DocumentStore, policy: DuplicatePolicy = Dupl
self.document_store = document_store
self.policy = policy

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, document_store=self.document_store.to_dict(), policy=self.policy.name)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DocumentWriter":
"""
Deserialize this component from a dictionary.
"""
init_params = data.get("init_parameters", {})
if "document_store" not in init_params:
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
if init_params["document_store"]["type"] not in document_store.registry:
raise DeserializationError(f"DocumentStore of type '{init_params['document_store']['type']}' not found.")
docstore_class = document_store.registry[init_params["document_store"]["type"]]
docstore = docstore_class.from_dict(init_params["document_store"])

data["init_parameters"]["document_store"] = docstore
data["init_parameters"]["policy"] = DuplicatePolicy[data["init_parameters"]["policy"]]
return default_from_dict(cls, data)

def run(self, documents: List[Document], policy: Optional[DuplicatePolicy] = None):
"""
Run DocumentWriter on the given input data.
Expand Down
31 changes: 0 additions & 31 deletions haystack/preview/document_stores/decorator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from typing import Dict, Any, Type
import logging

from haystack.preview.document_stores.protocols import DocumentStore
from haystack.preview.document_stores.errors import DocumentStoreDeserializationError

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -40,30 +36,3 @@ def __call__(self, cls=None):


document_store = _DocumentStore()


def default_document_store_to_dict(store_: DocumentStore) -> Dict[str, Any]:
"""
Default DocumentStore serializer.
Serializes a DocumentStore to a dictionary.
"""
return {
"hash": id(store_),
"type": store_.__class__.__name__,
"init_parameters": getattr(store_, "init_parameters", {}),
}


def default_document_store_from_dict(cls: Type[DocumentStore], data: Dict[str, Any]) -> DocumentStore:
"""
Default DocumentStore deserializer.
The "type" field in `data` must match the class that is being deserialized into.
"""
init_params = data.get("init_parameters", {})
if "type" not in data:
raise DocumentStoreDeserializationError("Missing 'type' in DocumentStore serialization data")
if data["type"] != cls.__name__:
raise DocumentStoreDeserializationError(
f"DocumentStore '{data['type']}' can't be deserialized as '{cls.__name__}'"
)
return cls(**init_params)
4 changes: 0 additions & 4 deletions haystack/preview/document_stores/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,3 @@ class DuplicateDocumentError(DocumentStoreError):

class MissingDocumentError(DocumentStoreError):
pass


class DocumentStoreDeserializationError(DocumentStoreError):
pass
Loading

0 comments on commit b1daa7c

Please sign in to comment.