Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENV-Manager update #238

Merged
merged 12 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ engine = AsyncEmbeddingEngine.from_args(engine_args)
async def main():
async with engine:
predictions, usage = await engine.classify(sentences=sentences)
return predictions, usage
# or handle the async start / stop yourself.
await engine.astart()
predictions, usage = await engine.classify(sentences=sentences)
Expand Down
1 change: 0 additions & 1 deletion docs/docs/python_engine.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ engine = AsyncEmbeddingEngine.from_args(engine_args)
async def main():
async with engine:
predictions, usage = await engine.classify(sentences=sentences)
return predictions, usage
# or handle the async start / stop yourself.
await engine.astart()
predictions, usage = await engine.classify(sentences=sentences)
Expand Down
39 changes: 28 additions & 11 deletions libs/infinity_emb/infinity_emb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,31 @@
import importlib.metadata
import os

import huggingface_hub.constants # type: ignore

### Check if HF_HUB_ENABLE_HF_TRANSFER is set, if not try to enable it
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
try:
# enable hf hub transfer if available
import hf_transfer # type: ignore # noqa

# Needs to be at the top of the file / before other
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
except ImportError:
pass
huggingface_hub.constants.HF_HUB_DISABLE_PROGRESS_BARS = True


from infinity_emb import fastapi_schemas, inference, transformer # noqa: E402
from infinity_emb.args import EngineArgs # noqa: E402
from infinity_emb.engine import AsyncEmbeddingEngine # noqa: E402

# reexports
from infinity_emb.infinity_server import create_server # noqa: E402
from infinity_emb.log_handler import logger # noqa: E402

__version__ = importlib.metadata.version("infinity_emb")

__all__ = [
"transformer",
"inference",
Expand All @@ -8,14 +36,3 @@
"EngineArgs",
"__version__",
]
import importlib.metadata

from infinity_emb import fastapi_schemas, inference, transformer
from infinity_emb.args import EngineArgs
from infinity_emb.engine import AsyncEmbeddingEngine

# reexports
from infinity_emb.infinity_server import create_server
from infinity_emb.log_handler import logger

__version__ = importlib.metadata.version("infinity_emb")
1 change: 0 additions & 1 deletion libs/infinity_emb/infinity_emb/_optional_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def _raise_error(self) -> None:
CHECK_DISKCACHE = OptionalImports("diskcache", "cache")
CHECK_CTRANSLATE2 = OptionalImports("ctranslate2", "ctranslate2")
CHECK_FASTAPI = OptionalImports("fastapi", "server")
CHECK_HF_TRANSFER = OptionalImports("hf_transfer", "hf_transfer")
CHECK_ONNXRUNTIME = OptionalImports("optimum.onnxruntime", "optimum")
CHECK_OPTIMUM = OptionalImports("optimum", "optimum")
CHECK_OPTIMUM_NEURON = OptionalImports("optimum.neuron", "neuronx")
Expand Down
20 changes: 10 additions & 10 deletions libs/infinity_emb/infinity_emb/args.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import sys
from dataclasses import asdict, dataclass
from typing import Optional

from infinity_emb._optional_imports import CHECK_PYDANTIC
from infinity_emb.env import MANAGER
from infinity_emb.primitives import (
Device,
Dtype,
Expand Down Expand Up @@ -41,21 +41,21 @@ class EngineArgs:
served_model_name, str: Defaults to readable name of model_name_or_path.
"""

model_name_or_path: str = "michaelfeil/bge-small-en-v1.5"
batch_size: int = 32
revision: Optional[str] = None
trust_remote_code: bool = True
model_name_or_path: str = MANAGER.model_id[0]
batch_size: int = MANAGER.batch_size[0]
revision: Optional[str] = MANAGER.revision[0]
trust_remote_code: bool = MANAGER.trust_remote_code[0]
engine: InferenceEngine = InferenceEngine.torch
model_warmup: bool = False
model_warmup: bool = MANAGER.model_warmup[0]
vector_disk_cache_path: str = ""
device: Device = Device.auto
compile: bool = not os.environ.get("INFINITY_DISABLE_COMPILE", "Disable")
bettertransformer: bool = True
compile: bool = MANAGER.compile[0]
bettertransformer: bool = MANAGER.bettertransformer[0]
dtype: Dtype = Dtype.auto
pooling_method: PoolingMethod = PoolingMethod.auto
lengths_via_tokenize: bool = False
lengths_via_tokenize: bool = MANAGER.lengths_via_tokenize[0]
embedding_dtype: EmbeddingDtype = EmbeddingDtype.float32
served_model_name: str = ""
served_model_name: str = MANAGER.served_model_name[0]

def __post_init__(self):
# convert the following strings to enums
Expand Down
9 changes: 3 additions & 6 deletions libs/infinity_emb/infinity_emb/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,9 @@ def from_args(cls, engine_args_array: Iterable[EngineArgs]) -> "AsyncEngineArray
Args:
engine_args_array (list[EngineArgs]): EngineArgs object
"""
return cls(
engines=tuple(
AsyncEmbeddingEngine.from_args(engine_args)
for engine_args in engine_args_array
)
)
engines = map(AsyncEmbeddingEngine.from_args, engine_args_array)

return cls(engines=tuple(engines))

def __iter__(self) -> Iterator["AsyncEmbeddingEngine"]:
return iter(self.engines_dict.values())
Expand Down
218 changes: 218 additions & 0 deletions libs/infinity_emb/infinity_emb/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# cache
from __future__ import annotations

import os
from functools import cached_property

from infinity_emb.primitives import (
Device,
Dtype,
EmbeddingDtype,
InferenceEngine,
PoolingMethod,
)


class __Infinity_EnvManager:
def __init__(self):
self._debug(f"Loading Infinity ENV variables.\nCONFIG:\n{'-'*10}")
for f_name in dir(self):
if isinstance(getattr(type(self), f_name, None), cached_property):
getattr(self, f_name) # pre-cache
self._debug(f"{'-'*10}\nENV variables loaded.")

def _debug(self, message: str):
if "API_KEY" in message:
print("INFINITY_API_KEY=not_shown")
print(f"INFINITY_LOG_LEVEL={self.log_level}")
elif "LOG_LEVEL" in message:
return # recursion
elif self.log_level in {"debug", "trace"}:
print(message)

@staticmethod
def _to_name(name: str) -> str:
return "INFINITY_" + name.upper().replace("-", "_")

def _optional_infinity_var(self, name: str, default: str = ""):
name = self._to_name(name)
value = os.getenv(name)
if value is None:
self._debug(f"{name}=`{default}`(default)")
return default
self._debug(f"{name}=`{value}`")
return value

def _optional_infinity_var_multiple(
self, name: str, default: list[str]
) -> list[str]:
name = self._to_name(name)
value = os.getenv(name)
if value is None:
self._debug(f"{name}=`{';'.join(default)}`(default)")
return default
if value.endswith(";"):
value = value[:-1]
value_list = value.split(";")
self._debug(f"{name}=`{';'.join(value_list)}`")
return value_list

@staticmethod
def _to_bool(value: str) -> bool:
return value.lower() in {"true", "1"}

@staticmethod
def _to_bool_multiple(value: list[str]) -> list[bool]:
return [v.lower() in {"true", "1"} for v in value]

@staticmethod
def _to_int_multiple(value: list[str]) -> list[int]:
return [int(v) for v in value]

@cached_property
def api_key(self):
return self._optional_infinity_var("api_key", default="")

@cached_property
def model_id(self):
return self._optional_infinity_var_multiple(
"model_id", default=["michaelfeil/bge-small-en-v1.5"]
)

@cached_property
def served_model_name(self):
return self._optional_infinity_var_multiple("served_model_name", default=[""])

@cached_property
def batch_size(self):
return self._to_int_multiple(
self._optional_infinity_var_multiple("batch_size", default=["32"])
)

@cached_property
def revision(self):
return self._optional_infinity_var_multiple("revision", default=[""])

@cached_property
def trust_remote_code(self):
return self._to_bool_multiple(
self._optional_infinity_var_multiple("trust_remote_code", default=["true"])
)

@cached_property
def model_warmup(self):
return self._to_bool_multiple(
self._optional_infinity_var_multiple("model_warmup", default=["true"])
)

@cached_property
def vector_disk_cache(self):
return self._to_bool_multiple(
self._optional_infinity_var_multiple("vector_disk_cache", default=["false"])
)

@cached_property
def lengths_via_tokenize(self):
return self._to_bool_multiple(
self._optional_infinity_var_multiple(
"lengths_via_tokenize", default=["false"]
)
)

@cached_property
def compile(self):
return self._to_bool_multiple(
self._optional_infinity_var_multiple("compile", default=["false"])
)

@cached_property
def bettertransformer(self):
return self._to_bool_multiple(
self._optional_infinity_var_multiple("bettertransformer", default=["true"])
)

@cached_property
def preload_only(self):
return self._to_bool(
self._optional_infinity_var("preload_only", default="false")
)

@cached_property
def permissive_cors(self):
return self._to_bool(
self._optional_infinity_var("permissive_cors", default="false")
)

@cached_property
def url_prefix(self):
return self._optional_infinity_var("url_prefix", default="")

@cached_property
def port(self):
port = self._optional_infinity_var("port", default="7997")
assert port.isdigit(), "INFINITY_PORT must be a number"
return int(port)

@cached_property
def host(self):
return self._optional_infinity_var("host", default="0.0.0.0")

@cached_property
def redirect_slash(self):
route = self._optional_infinity_var("redirect_slash", default="/docs")
assert not route or route.startswith(
"/"
), "INFINITY_REDIRECT_SLASH must start with /"
return route

@cached_property
def log_level(self):
return self._optional_infinity_var("log_level", default="info")

@cached_property
def dtype(self) -> list[Dtype]:
return [
Dtype(v)
for v in self._optional_infinity_var_multiple(
"dtype", default=[Dtype.default_value()]
)
]

@cached_property
def engine(self) -> list[InferenceEngine]:
return [
InferenceEngine(v)
for v in self._optional_infinity_var_multiple(
"engine", default=[InferenceEngine.default_value()]
)
]

@cached_property
def pooling_method(self) -> list[PoolingMethod]:
return [
PoolingMethod(v)
for v in self._optional_infinity_var_multiple(
"pooling_method", default=[PoolingMethod.default_value()]
)
]

@cached_property
def device(self) -> list[Device]:
return [
Device(v)
for v in self._optional_infinity_var_multiple(
"device", default=[Device.default_value()]
)
]

@cached_property
def embedding_dtype(self) -> list[EmbeddingDtype]:
return [
EmbeddingDtype(v)
for v in self._optional_infinity_var_multiple(
"embedding_dtype", default=[EmbeddingDtype.default_value()]
)
]


MANAGER = __Infinity_EnvManager()
2 changes: 1 addition & 1 deletion libs/infinity_emb/infinity_emb/fastapi_schemas/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
FASTAPI_DESCRIPTION = ""


def startup_message(host: str, port: str, prefix: str) -> str:
def startup_message(host: str, port: int, prefix: str) -> str:
from infinity_emb import __version__

return f"""
Expand Down
5 changes: 0 additions & 5 deletions libs/infinity_emb/infinity_emb/inference/caching_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@
if CHECK_DISKCACHE.is_available:
import diskcache as dc # type: ignore[import-untyped]

INFINITY_CACHE_VECTORS = (
bool(os.environ.get("INFINITY_CACHE_VECTORS", False))
and CHECK_DISKCACHE.is_available
)


class Cache:
def __init__(self, cache_name: str, shutdown: threading.Event) -> None:
Expand Down
Loading