Skip to content

Commit

Permalink
chore: use torch.index_add to compute new centroids, to improve train…
Browse files Browse the repository at this point in the history
…ing performance on MPS (#1371)
  • Loading branch information
eddyxu authored Oct 7, 2023
1 parent 9fee111 commit 9a56cff
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
5 changes: 3 additions & 2 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ def create_index(
The number of sub-vectors for PQ (Product Quantization).
accelerator : str, optional
If set, use an accelerator to speed up the training process.
Accepted accelerator: "cuda".
Accepted accelerator: "cuda" (Nvidia GPU) and "mps" (Apple Silicon GPU).
If not set, use the CPU.
kwargs :
Parameters passed to the index building process.
Expand Down Expand Up @@ -714,7 +714,8 @@ def create_index(
Experimental Accelerator (GPU) support:
- *accelerate*: use GPU to train IVF partitions.
Only supports CUDA (Nvidia) currently. Requires PyTorch being installed.
Only supports CUDA (Nvidia) or MPS (Apple) currently.
Requires PyTorch being installed.
.. code-block:: python
Expand Down
11 changes: 6 additions & 5 deletions python/python/lance/torch/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@ def _new_centroids_mps(
# MPS does not have Torch.index_reduce_()
# See https://github.com/pytorch/pytorch/issues/77764

# Use CPU makes for loop faster
new_centroids = torch.zeros((k, data[0].shape[1]), device="cpu")
new_centroids = torch.zeros((k, data[0].shape[1]), device=data[0].device)
for ids, chunk in zip(part_ids, data):
for part_id, vector in zip(ids.cpu(), chunk.cpu()):
new_centroids[part_id, :] = new_centroids[part_id, :].add(vector)
new_centroids.index_add_(0, ids, chunk)
for idx, cnt in enumerate(cnts.cpu()):
if cnt > 0:
if cnt == 0:
new_centroids[idx, :] = torch.nan
else:
new_centroids[idx, :] = new_centroids[idx, :].div(cnt)

return new_centroids.to(data[0].device)


Expand Down

0 comments on commit 9a56cff

Please sign in to comment.