Skip to content

Commit

Permalink
Merge pull request #247 from michaelfeil/quantization-process
Browse files Browse the repository at this point in the history
add embedding quantization interface
  • Loading branch information
michaelfeil authored Jun 7, 2024
2 parents 2f87f23 + 3852e39 commit 2da1f32
Show file tree
Hide file tree
Showing 13 changed files with 696 additions and 53 deletions.
524 changes: 524 additions & 0 deletions docs/assets/multilingual_calibration.utf8

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions libs/infinity_emb/infinity_emb/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import os
from functools import cached_property
from pathlib import Path
from typing import TypeVar

from infinity_emb.primitives import (
Expand Down Expand Up @@ -143,6 +144,24 @@ def preload_only(self):
self._optional_infinity_var("preload_only", default="false")
)

@cached_property
def infinity_cache_dir(self) -> Path:
"""gets the cache directory for infinity_emb."""
cache_dir = None
hf_home = os.environ.get("HF_HOME")
inf_home = os.environ.get("INFINITY_HOME")
if inf_home:
cache_dir = Path(inf_home) / ".infinity_cache"
elif hf_home:
cache_dir = Path(hf_home) / ".infinity_cache"
else:
cache_dir = Path(".").resolve() / ".infinity_cache"

if not cache_dir.exists():
cache_dir.mkdir(parents=True, exist_ok=True)

return cache_dir

@cached_property
def permissive_cors(self):
return self._to_bool(
Expand Down
5 changes: 2 additions & 3 deletions libs/infinity_emb/infinity_emb/inference/caching_layer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
import os
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Union

from infinity_emb._optional_imports import CHECK_DISKCACHE
from infinity_emb.env import MANAGER
from infinity_emb.inference.threading_asyncio import to_thread
from infinity_emb.log_handler import logger
from infinity_emb.primitives import EmbeddingReturnType, QueueItemInner
Expand All @@ -17,11 +17,10 @@
class Cache:
def __init__(self, cache_name: str, shutdown: threading.Event) -> None:
CHECK_DISKCACHE.mark_required()
from infinity_emb.transformer.utils import infinity_cache_dir

self._shutdown = shutdown
self._add_q: queue.Queue = queue.Queue()
dir = os.path.join(infinity_cache_dir(), "cache_vectors", f"cache_{cache_name}")
dir = MANAGER.infinity_cache_dir / "cache_vectors" f"cache_{cache_name}"
logger.info(f"caching vectors under: {dir}")
self._cache = dc.Cache(dir, size_limit=2**28)
self.is_running = False
Expand Down
4 changes: 4 additions & 0 deletions libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from infinity_emb.primitives import (
Device,
Dtype,
EmbeddingDtype,
InferenceEngine,
ModelNotDeployedError,
PoolingMethod,
Expand Down Expand Up @@ -455,6 +456,7 @@ def v2(
device: list[Device] = MANAGER.device, # type: ignore
lengths_via_tokenize: list[bool] = MANAGER.lengths_via_tokenize,
dtype: list[Dtype] = MANAGER.dtype, # type: ignore
embedding_dtype: list[EmbeddingDtype] = MANAGER.embedding_dtype, # type: ignore
pooling_method: list[PoolingMethod] = MANAGER.pooling_method, # type: ignore
compile: list[bool] = MANAGER.compile,
bettertransformer: list[bool] = MANAGER.bettertransformer,
Expand Down Expand Up @@ -492,6 +494,7 @@ def v2(
device, Device: device to use for inference. Defaults to Device.auto or "auto"
lengths_via_tokenize: bool: schedule by token usage. Defaults to False.
dtype, Dtype: data type to use for inference. Defaults to Dtype.auto or "auto"
embedding_dtype, EmbeddingDtype: data type to use for embeddings. Defaults to EmbeddingDtype.float32 or "float32"
pooling_method, PoolingMethod: pooling method to use. Defaults to PoolingMethod.auto or "auto"
compile, bool: compile model for faster inference. Defaults to False.
use_bettertransformer, bool: use bettertransformer. Defaults to True.
Expand All @@ -512,6 +515,7 @@ def v2(
device=device,
lengths_via_tokenize=lengths_via_tokenize,
dtype=dtype,
embedding_dtype=embedding_dtype,
pooling_method=pooling_method,
compile=compile,
bettertransformer=bettertransformer,
Expand Down
6 changes: 4 additions & 2 deletions libs/infinity_emb/infinity_emb/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,10 @@ def default_value():

class EmbeddingDtype(EnumType):
float32: str = "float32"
# int8: str = "int8"
# binary: str = "binary"
int8: str = "int8"
uint8: str = "uint8"
binary: str = "binary"
ubinary: str = "ubinary"

@staticmethod
def default_value():
Expand Down
12 changes: 11 additions & 1 deletion libs/infinity_emb/infinity_emb/transformer/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Set

from infinity_emb.primitives import (
EmbeddingDtype,
EmbeddingInner,
EmbeddingReturnType,
EmbeddingSingle,
Expand All @@ -12,6 +13,7 @@
ReRankInner,
ReRankSingle,
)
from infinity_emb.transformer.quantization.interface import quant_embedding_decorator

INPUT_FEATURE = Any
OUT_FEATURES = Any
Expand Down Expand Up @@ -46,12 +48,19 @@ def warmup(self, *, batch_size: int = 64, n_tokens=1) -> tuple[float, float, str
class BaseEmbedder(BaseTransformer): # Inherit from ABC(Abstract base class)
capabilities = {"embed"}

@property
def embedding_dtype(self) -> EmbeddingDtype:
"""returns the dtype of the embeddings"""
return self.engine_args.embedding_dtype # type: ignore

@abstractmethod # Decorator to define an abstract method
def encode_pre(self, sentences: list[str]) -> INPUT_FEATURE:
"""takes care of the tokenization and feature preparation"""

@abstractmethod
def encode_post(self, embedding: OUT_FEATURES) -> EmbeddingReturnType:
def encode_post(
self, embedding: OUT_FEATURES, skip_quanitzation=True
) -> EmbeddingReturnType:
"""runs post encoding such as normalization"""

def warmup(self, *, batch_size: int = 64, n_tokens=1) -> tuple[float, float, str]:
Expand Down Expand Up @@ -91,6 +100,7 @@ def encode_pre(self, queries_docs: list[tuple[str, str]]) -> INPUT_FEATURE:
"""takes care of the tokenization and feature preparation"""

@abstractmethod
@quant_embedding_decorator()
def encode_post(self, embedding: OUT_FEATURES) -> list[float]:
"""runs post encoding such as normalization"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from infinity_emb.args import EngineArgs
from infinity_emb.primitives import EmbeddingReturnType
from infinity_emb.transformer.abstract import BaseEmbedder
from infinity_emb.transformer.quantization.interface import quant_embedding_decorator


class DummyTransformer(BaseEmbedder):
"""fix-13 dimension embedding, filled with length of sentence"""

def __init__(self, *, engine_args: EngineArgs) -> None:
print(f"running DummyTransformer.__init__ with engine_args={engine_args}")
self.engine_args = engine_args

def encode_pre(self, sentences: list[str]) -> np.ndarray:
return np.asarray(sentences)
Expand All @@ -19,6 +21,7 @@ def encode_core(self, features: np.ndarray) -> EmbeddingReturnType:
# embedding of size 13
return np.ones([len(features), 13]) * lengths.T

@quant_embedding_decorator()
def encode_post(self, embedding: EmbeddingReturnType):
return embedding

Expand Down
3 changes: 3 additions & 0 deletions libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from infinity_emb.args import EngineArgs
from infinity_emb.primitives import EmbeddingReturnType, PoolingMethod
from infinity_emb.transformer.abstract import BaseEmbedder
from infinity_emb.transformer.quantization.interface import quant_embedding_decorator
from infinity_emb.transformer.utils_optimum import (
cls_token_pooling,
mean_pooling,
Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(self, *, engine_args: EngineArgs):
else 512
),
}
self.engine_args = engine_args
self.model = NeuronModelForFeatureExtraction.from_pretrained(
model_id=engine_args.model_name_or_path,
revision=engine_args.revision,
Expand Down Expand Up @@ -136,6 +138,7 @@ def encode_core(self, input_dict: dict[str, np.ndarray]) -> dict:
"attention_mask": input_dict["attention_mask"][:actual_bsize],
}

@quant_embedding_decorator()
def encode_post(self, embedding: dict) -> EmbeddingReturnType:
embedding = self.pooling( # type: ignore
embedding["token_embeddings"].numpy(), embedding["attention_mask"].numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from infinity_emb.args import EngineArgs
from infinity_emb.primitives import EmbeddingReturnType, PoolingMethod
from infinity_emb.transformer.abstract import BaseEmbedder
from infinity_emb.transformer.quantization.interface import quant_embedding_decorator
from infinity_emb.transformer.utils_optimum import (
cls_token_pooling,
device_to_onnx,
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(self, *, engine_args: EngineArgs):
trust_remote_code=engine_args.trust_remote_code,
)
self._infinity_tokenizer = copy.deepcopy(self.tokenizer)
self.engine_args = engine_args

def encode_pre(self, sentences: list[str]) -> dict[str, np.ndarray]:
encoded = self.tokenizer(
Expand All @@ -90,6 +92,7 @@ def encode_core(self, onnx_input: dict[str, np.ndarray]) -> dict:
"attention_mask": onnx_input["attention_mask"],
}

@quant_embedding_decorator()
def encode_post(self, embedding: dict) -> EmbeddingReturnType:
embedding = self.pooling( # type: ignore
embedding["token_embeddings"], embedding["attention_mask"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from infinity_emb.primitives import Device, Dtype, EmbeddingReturnType
from infinity_emb.transformer.abstract import BaseEmbedder
from infinity_emb.transformer.acceleration import to_bettertransformer
from infinity_emb.transformer.quantization.interface import quant_interface
from infinity_emb.transformer.quantization.interface import (
quant_embedding_decorator,
quant_interface,
)

if TYPE_CHECKING:
from torch import Tensor
Expand Down Expand Up @@ -57,8 +60,7 @@ def __init__(self, *, engine_args=EngineArgs):
fm = self._first_module()
self._infinity_tokenizer = copy.deepcopy(fm.tokenizer)
self.eval()

self.embedding_dtype = engine_args.embedding_dtype
self.engine_args = engine_args

if not (self.device.type == "mps" or not engine_args.bettertransformer):
fm.auto_model = to_bettertransformer(
Expand Down Expand Up @@ -98,6 +100,7 @@ def encode_core(self, features: Mapping[str, "Tensor"]) -> "Tensor":

return out_features.detach().cpu()

@quant_embedding_decorator()
def encode_post(
self, out_features: "Tensor", normalize_embeddings: bool = True
) -> EmbeddingReturnType:
Expand All @@ -108,11 +111,6 @@ def encode_post(

embeddings_np: np.ndarray = embeddings.numpy()

if self.embedding_dtype.value != "float32":
raise NotImplementedError(
f"EmbeddingDtype for {self.embedding_dtype} not implemented"
)

return embeddings_np

def tokenize_lengths(self, sentences: list[str]) -> list[int]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from typing import Any
from functools import cache, wraps
from typing import TYPE_CHECKING, Any

import numpy as np
import requests # type: ignore

from infinity_emb._optional_imports import CHECK_TORCH
from infinity_emb.log_handler import logger
from infinity_emb.primitives import Device, Dtype
from infinity_emb.primitives import Device, Dtype, EmbeddingDtype
from infinity_emb.transformer.quantization.quant import quantize

if TYPE_CHECKING:
from infinity_emb.transformer.abstract import BaseEmbedder

if CHECK_TORCH.is_available:
import torch
from sentence_transformers.quantization import quantize_embeddings # type: ignore


def quant_interface(model: Any, dtype: Dtype = Dtype.int8, device: Device = Device.cpu):
Expand Down Expand Up @@ -54,3 +62,61 @@ def quant_interface(model: Any, dtype: Dtype = Dtype.int8, device: Device = Devi
f"Quantization is not supported on {device} with dtype {dtype}."
)
return model


@cache
def _get_text_calibration_dataset() -> list[str]:
url = "https://raw.githubusercontent.com/turboderp/exllamav2/master/conversion/standard_cal_data/multilingual.utf8"
response = requests.get(url) # TODO: add local file caching
response.raise_for_status() # This will raise an exception if the request failed
return [line.strip() for line in response.text.splitlines()]


@cache
def _create_statistics_embedding(model: "BaseEmbedder", percentile=100) -> np.ndarray:
"""returns `ranges`, the min and max values of the embeddings for quantization."""

def _encode(model, dataset, batch_size=8):
"""batched encoding of the dataset"""
for i in range(0, len(dataset), batch_size):
yield model.encode_post(
model.encode_core(model.encode_pre(dataset[i : i + batch_size])),
# _internal_skip_quanitzation is a hack to skip quantization
# and avoid infinite recursion
_internal_skip_quanitzation=True,
)

dataset = _get_text_calibration_dataset()

calibration_embeddings = np.concatenate(list(_encode(model, dataset)))
assert (
percentile > 50 and percentile <= 100
), "percentile should be between 50 and 100"
return np.percentile(calibration_embeddings, [100 - percentile, percentile], axis=0)


def quant_embedding_decorator():
def decorator(func):
@wraps(func)
def wrapper(self: "BaseEmbedder", *args, **kwargs):
# Assume the first argument is the instance of BaseEmbedder or similar
skip_quanitzation = kwargs.pop("_internal_skip_quanitzation", False)
embeddings = func(self, *args, **kwargs)
if self.embedding_dtype == EmbeddingDtype.float32 or skip_quanitzation:
return embeddings
elif (
self.embedding_dtype == EmbeddingDtype.int8
or self.embedding_dtype == EmbeddingDtype.uint8
):
calibration_ranges = _create_statistics_embedding(self)
else:
calibration_ranges = None
return quantize_embeddings(
embeddings,
precision=self.embedding_dtype.value,
ranges=calibration_ranges,
)

return wrapper

return decorator
36 changes: 0 additions & 36 deletions libs/infinity_emb/infinity_emb/transformer/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
from enum import Enum
from pathlib import Path
from typing import Callable

from infinity_emb.primitives import InferenceEngine
Expand All @@ -20,7 +18,6 @@
__all__ = [
"length_tokenizer",
"get_lengths_with_tokenize",
"infinity_cache_dir",
]


Expand Down Expand Up @@ -83,36 +80,3 @@ def get_lengths_with_tokenize(
) -> tuple[list[int], int]:
_lengths = tokenize(_sentences)
return _lengths, sum(_lengths)


def infinity_cache_dir(overwrite=False):
"""gets the cache dir. If
Args:
overwrite (bool, optional): _description_. Defaults to True.
Returns:
_type_: _description_
"""
cache_dir = None
inf_home = os.environ.get("INFINITY_HOME")
st_home = os.environ.get("SENTENCE_TRANSFORMERS_HOME")
hf_home = os.environ.get("HF_HOME")
if inf_home:
cache_dir = inf_home
elif st_home:
cache_dir = st_home
elif hf_home:
cache_dir = hf_home
else:
cache_dir = str(Path(".").resolve() / ".infinity_cache")

if overwrite:
os.environ.setdefault("INFINITY_HOME", cache_dir)
os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", cache_dir)
os.environ.setdefault("HF_HOME", cache_dir)

if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)

return cache_dir
Loading

0 comments on commit 2da1f32

Please sign in to comment.