Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent OOM when IVF centroids are provided #1653

Merged
merged 1 commit into from
Nov 22, 2023

Conversation

wjones127
Copy link
Contributor

These are a small collection of fixes I needed in order to run a benchmark for partitioning while using CUDA acceleration.

  1. Regardless of whether IVF is going to be trained, we load in enough vectors to do the training. This causes OOM on large datasets, so I've changed this to skip loading those vectors if the IVF centroids have already been passed in. GPU training can handle larger than memory datasets thanks to it's data loader, so this makes large scale training with GPUs possible.
  2. Added some cast to ints, as it's easy to accidentally get floats in some cases.
  3. Handle older pyarrow versions in GPU training API

for col in batch.column_names:
for col in batch.schema.names:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to support older versions of pyarrow. pyarrow 12.0.0 doesn't have RecordBatch.column_names. (For some reason that's the version conda solved for when I needed to install pytorch.)

@wjones127 wjones127 marked this pull request as ready for review November 22, 2023 03:09
Comment on lines +634 to +645
let mut training_data = if ivf_params.centroids.is_none() {
let start = std::time::Instant::now();
log::info!(
"Loading training data for IVF. Sample size: {}",
sample_size_hint
);
let data = Some(maybe_sample_training_data(dataset, column, sample_size_hint).await?);
log::info!(
"Finished loading training data in {:02} seconds",
start.elapsed().as_secs_f32()
);
data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe just add the timer in maybe_sample_training_data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the narrative logs, I like that they are all in 1 function. 🤷

@wjones127 wjones127 merged commit 8fc78d7 into main Nov 22, 2023
17 checks passed
@wjones127 wjones127 deleted the wjones127/training-fixes branch November 22, 2023 03:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants