Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Added batch_size as parameter to SentenceTransformerEmbeddingFu… #2759

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
model_name: str = "all-MiniLM-L6-v2",
device: str = "cpu",
normalize_embeddings: bool = False,
batch_size: int = 32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this Optional[int] = 32

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, technically, this is already supported by the kwargs (not that we pass them to the encode method), but perhaps making it explicit is arguably a better DX.

**kwargs: Any,
):
"""Initialize SentenceTransformerEmbeddingFunction.
Expand All @@ -25,6 +26,7 @@ def __init__(
model_name (str, optional): Identifier of the SentenceTransformer model, defaults to "all-MiniLM-L6-v2"
device (str, optional): Device used for computation, defaults to "cpu"
normalize_embeddings (bool, optional): Whether to normalize returned vectors, defaults to False
batch_size (int, optional): Batch size for encoding, defaults to 32
**kwargs: Additional arguments to pass to the SentenceTransformer model.
"""
if model_name not in self.models:
Expand All @@ -39,6 +41,7 @@ def __init__(
)
self._model = self.models[model_name]
self._normalize_embeddings = normalize_embeddings
self._batch_size = batch_size

def __call__(self, input: Documents) -> Embeddings:
return cast(
Expand All @@ -47,5 +50,6 @@ def __call__(self, input: Documents) -> Embeddings:
list(input),
convert_to_numpy=True,
normalize_embeddings=self._normalize_embeddings,
batch_size=self._batch_size,
).tolist(),
)
Loading