Skip to content

Commit

Permalink
Merge pull request facebookresearch#12 from Enet4/multi-gpu
Browse files Browse the repository at this point in the history
Multi GPU support
  • Loading branch information
Enet4 authored Apr 21, 2021
2 parents cebd46e + 5ce5690 commit c872866
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub enum Error {
BadFilePath,
/// Invalid parameter name of index.
ParameterName,
/// The number of GPU resources and devices do not match.
GpuResourcesMatch,
}

impl fmt::Display for Error {
Expand All @@ -32,6 +34,7 @@ impl fmt::Display for Error {
Error::IndexDescription => fmt.write_str("Invalid index description"),
Error::BadFilePath => fmt.write_str("Invalid file path"),
Error::ParameterName => fmt.write_str("Invalid parameter name of index"),
Error::GpuResourcesMatch => fmt.write_str("Number of GPU resources and devices do not match"),
}
}
}
Expand Down
139 changes: 138 additions & 1 deletion src/index/gpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,43 @@ where
})
}
}

/// Build a GPU in from the given CPU native index.
/// Users will indirectly use this through [`to_gpu`] or [`into_gpu`].
///
/// # Error
///
/// The operation fails if the number of GPU resources and number of
/// devices do not match, or the index does not provide GPU support.
///
/// [`to_gpu`]: ../struct.IndexImpl.html#method.to_gpu
/// [`into_gpu`]: ../struct.IndexImpl.html#method.into_gpu
pub(crate) fn from_cpu_multiple<G>(index: &I, gpu_res: &[G], devices: &[i32]) -> Result<Self>
where
I: NativeIndex,
I: CpuIndex,
G: GpuResourcesProvider,
{
if gpu_res.len() != devices.len() {
return Err(crate::error::Error::GpuResourcesMatch);
}

let res_ptr: Vec<*mut _> = gpu_res.into_iter().map(|r| r.inner_ptr()).collect();
unsafe {
let mut gpuindex_ptr = ptr::null_mut();
faiss_try(faiss_index_cpu_to_gpu_multiple(
res_ptr.as_slice().as_ptr(),
devices.as_ptr(),
devices.len(),
index.inner_ptr(),
&mut gpuindex_ptr
))?;
Ok(GpuIndexImpl {
inner: gpuindex_ptr,
phantom: PhantomData,
})
}
}
}

impl IndexImpl {
Expand Down Expand Up @@ -125,6 +162,42 @@ impl IndexImpl {
self.to_gpu(gpu_res, device)
// let the CPU index drop naturally
}

/// Build a GPU index from the given CPU native index.
///
/// # Errors
///
/// The operation fails if the number of GPU resources and number of
/// devices do not match, or the index does not provide GPU support.
pub fn to_gpu_multiple<'gpu, G: 'gpu>(
&self,
gpu_res: &'gpu [G],
devices: &[i32],
) -> Result<GpuIndexImpl<'gpu, IndexImpl>>
where
G: GpuResourcesProvider,
{
GpuIndexImpl::from_cpu_multiple(&self, gpu_res, devices)
}

/// Build a GPU index from the given CPU native index. The index residing
/// in CPU memory is discarded in the process.
///
/// # Errors
///
/// The operation fails if the number of GPU resources and number of
/// devices do not match, or the index does not provide GPU support.
pub fn into_gpu_multiple<'gpu, G: 'gpu>(
self,
gpu_res: &'gpu [G],
devices: &[i32],
) -> Result<GpuIndexImpl<'gpu, IndexImpl>>
where
G: GpuResourcesProvider,
{
self.to_gpu_multiple(gpu_res, devices)
// let the CPU index drop naturally
}
}

impl<'gpu, I> GpuIndexImpl<'gpu, I>
Expand Down Expand Up @@ -303,6 +376,42 @@ impl FlatIndexImpl {
{
self.to_gpu(gpu_res, device)
}

/// Build a GPU index from the given CPU native index.
///
/// # Errors
///
/// The operation fails if the number of GPU resources and number of
/// devices do not match, or the index does not provide GPU support.
pub fn to_gpu_multiple<'gpu, G: 'gpu>(
&self,
gpu_res: &'gpu [G],
devices: &[i32],
) -> Result<GpuIndexImpl<'gpu, FlatIndexImpl>>
where
G: GpuResourcesProvider,
{
GpuIndexImpl::from_cpu_multiple(&self, gpu_res, devices)
}

/// Build a GPU index from the given CPU native index. The index residing
/// in CPU memory is discarded in the process.
///
/// # Errors
///
/// The operation fails if the number of GPU resources and number of
/// devices do not match, or the index does not provide GPU support.
pub fn into_gpu_multiple<'gpu, G: 'gpu>(
self,
gpu_res: &'gpu [G],
devices: &[i32],
) -> Result<GpuIndexImpl<'gpu, FlatIndexImpl>>
where
G: GpuResourcesProvider,
{
self.to_gpu_multiple(gpu_res, devices)
// let the CPU index drop naturally
}
}

#[cfg(test)]
Expand Down Expand Up @@ -334,14 +443,42 @@ mod tests {
let mut gpu_index = index.into_gpu(&res, 0).unwrap();
is_in_gpu(&gpu_index);
for _ in 0..3 {
let index = gpu_index.into_cpu().unwrap();
let index: FlatIndex = gpu_index.into_cpu().unwrap();
is_in_cpu(&index);
gpu_index = index.into_gpu(&res, 0).unwrap();
is_in_gpu(&gpu_index);
}
assert_eq!(gpu_index.ntotal(), 5); // indexed vectors should be retained
}


#[test]
fn flat_in_and_out_multiple() {
let mut res = [StandardGpuResources::new().unwrap()];
res[0].set_temp_memory(10).unwrap();
let devices = [0];

let mut index = index_factory(8, "Flat", MetricType::L2).unwrap();
assert_eq!(index.d(), 8);
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 120., 100.,
100., 105., -100., 100., 100., 105.,
];
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 5);

let mut gpu_index = index.into_gpu_multiple(&res, &devices).unwrap();
is_in_gpu(&gpu_index);
for _ in 0..3 {
let index = gpu_index.into_cpu().unwrap();
is_in_cpu(&index);
gpu_index = index.into_gpu_multiple(&res, &devices).unwrap();
is_in_gpu(&gpu_index);
}
assert_eq!(gpu_index.ntotal(), 5); // indexed vectors should be retained
}

#[test]
fn flat_index_search_into_gpu() {
let res = StandardGpuResources::new().unwrap();
Expand Down

0 comments on commit c872866

Please sign in to comment.