Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bfloat16 (bf16) support in faiss #3862

Open
mlomeli1 opened this issue Sep 16, 2024 · 4 comments
Open

bfloat16 (bf16) support in faiss #3862

mlomeli1 opened this issue Sep 16, 2024 · 4 comments

Comments

@mlomeli1
Copy link
Contributor

Many LLMs are trained with bf16, if we want to use the hidden states of LLMs for retrieval, those vectors will be in bf16 dtype. It would be helpful to support bf16 in Faiss so that we can use LLMs as retriever or embedding model.

@asadoughi asadoughi changed the title bf16 support in faiss bfloat16 (bf16) support in faiss Sep 16, 2024
@mdouze
Copy link
Contributor

mdouze commented Sep 16, 2024

Note that we cannot pass through the numpy wrapper because numpy does not support bf16.
Adapting gpu_knn code for pytorch should be easy

def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2, device=-1, use_raft=False):

@mdouze
Copy link
Contributor

mdouze commented Sep 17, 2024

@mlomeli1
Copy link
Contributor Author

Example of how we currently would use a PQ codec to encode/decode pytorch bf16 tensors:

torch.from_numpy( codec.sa_decode(codec.sa_encode(x.to(device='cpu', dtype=torch.float32).numpy()) )

this piece of code showcases all the cpu moves + up casting + converting to numpy array. Successively, at decoding, we need to convert back to a tensor. Ideally, we could avoid some of these if this was supported for pytorch tensors.

@alexanderguzhva
Copy link
Contributor

just in case, there is a ScalarQuantizer implementation for bf16, maybe portions of it can be reused

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

No branches or pull requests

5 participants