Skip to content

Commit

Permalink
refactor(rust): Eliminate some uses of deprecated raw_entry (#19102)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Oct 5, 2024
1 parent 5e47a91 commit 60a6465
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 211 deletions.
95 changes: 30 additions & 65 deletions crates/polars-arrow/src/array/dictionary/value_map.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::borrow::Borrow;
use std::fmt::{self, Debug};
use std::hash::{BuildHasherDefault, Hash, Hasher};
use std::hash::Hash;

use hashbrown::hash_map::RawEntryMut;
use hashbrown::HashMap;
use hashbrown::hash_table::Entry;
use hashbrown::HashTable;
use polars_error::{polars_bail, polars_err, PolarsResult};
use polars_utils::aliases::PlRandomState;

Expand All @@ -12,47 +12,10 @@ use crate::array::indexable::{AsIndexed, Indexable};
use crate::array::{Array, MutableArray};
use crate::datatypes::ArrowDataType;

/// Hasher for pre-hashed values; similar to `hash_hasher` but with native endianness.
///
/// We know that we'll only use it for `u64` values, so we can avoid endian conversion.
///
/// Invariant: hash of a u64 value is always equal to itself.
#[derive(Copy, Clone, Default)]
pub struct PassthroughHasher(u64);

impl Hasher for PassthroughHasher {
#[inline]
fn write_u64(&mut self, value: u64) {
self.0 = value;
}

fn write(&mut self, _: &[u8]) {
unreachable!();
}

#[inline]
fn finish(&self) -> u64 {
self.0
}
}

#[derive(Clone)]
pub struct Hashed<K> {
hash: u64,
key: K,
}

impl<K> Hash for Hashed<K> {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.hash.hash(state)
}
}

#[derive(Clone)]
pub struct ValueMap<K: DictionaryKey, M: MutableArray> {
pub values: M,
pub map: HashMap<Hashed<K>, (), BuildHasherDefault<PassthroughHasher>>, // NB: *only* use insert_hashed_nocheck() and no other hashmap API
pub map: HashTable<(u64, K)>,
random_state: PlRandomState,
}

Expand All @@ -63,7 +26,7 @@ impl<K: DictionaryKey, M: MutableArray> ValueMap<K, M> {
}
Ok(Self {
values,
map: HashMap::default(),
map: HashTable::default(),
random_state: PlRandomState::default(),
})
}
Expand All @@ -73,29 +36,29 @@ impl<K: DictionaryKey, M: MutableArray> ValueMap<K, M> {
M: Indexable,
M::Type: Eq + Hash,
{
let mut map = HashMap::<Hashed<K>, _, _>::with_capacity_and_hasher(
values.len(),
BuildHasherDefault::<PassthroughHasher>::default(),
);
let mut map: HashTable<(u64, K)> = HashTable::with_capacity(values.len());
let random_state = PlRandomState::default();
for index in 0..values.len() {
let key = K::try_from(index).map_err(|_| polars_err!(ComputeError: "overflow"))?;
// SAFETY: we only iterate within bounds
let value = unsafe { values.value_unchecked_at(index) };
let hash = random_state.hash_one(value.borrow());

let entry = map.raw_entry_mut().from_hash(hash, |item| {
// SAFETY: invariant of the struct, it's always in bounds since we maintain it
let stored_value = unsafe { values.value_unchecked_at(item.key.as_usize()) };
stored_value.borrow() == value.borrow()
});
let entry = map.entry(
hash,
|(_h, key)| {
// SAFETY: invariant of the struct, it's always in bounds.
let stored_value = unsafe { values.value_unchecked_at(key.as_usize()) };
stored_value.borrow() == value.borrow()
},
|(h, _key)| *h,
);
match entry {
RawEntryMut::Occupied(_) => {
Entry::Occupied(_) => {
polars_bail!(InvalidOperation: "duplicate value in dictionary values array")
},
RawEntryMut::Vacant(entry) => {
// NB: don't use .insert() here!
entry.insert_hashed_nocheck(hash, Hashed { hash, key }, ());
Entry::Vacant(entry) => {
entry.insert((hash, key));
},
}
}
Expand Down Expand Up @@ -137,19 +100,21 @@ impl<K: DictionaryKey, M: MutableArray> ValueMap<K, M> {
M::Type: Eq + Hash,
{
let hash = self.random_state.hash_one(value.as_indexed());
let entry = self.map.raw_entry_mut().from_hash(hash, |item| {
// SAFETY: we've already checked (the inverse) when we pushed it, so it should be ok?
let index = unsafe { item.key.as_usize() };
// SAFETY: invariant of the struct, it's always in bounds since we maintain it
let stored_value = unsafe { self.values.value_unchecked_at(index) };
stored_value.borrow() == value.as_indexed()
});
let entry = self.map.entry(
hash,
|(_h, key)| {
// SAFETY: invariant of the struct, it's always in bounds.
let stored_value = unsafe { self.values.value_unchecked_at(key.as_usize()) };
stored_value.borrow() == value.as_indexed()
},
|(h, _key)| *h,
);
let out = match entry {
RawEntryMut::Occupied(entry) => entry.key().key,
RawEntryMut::Vacant(entry) => {
Entry::Occupied(entry) => entry.get().1,
Entry::Vacant(entry) => {
let index = self.values.len();
let key = K::try_from(index).map_err(|_| polars_err!(ComputeError: "overflow"))?;
entry.insert_hashed_nocheck(hash, Hashed { hash, key }, ()); // NB: don't use .insert() here!
entry.insert((hash, key));
push(&mut self.values, value)?;
debug_assert_eq!(self.values.len(), index + 1);
key
Expand Down
142 changes: 56 additions & 86 deletions crates/polars-core/src/frame/group_by/hashing.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use std::hash::{BuildHasher, Hash, Hasher};

use hashbrown::hash_map::RawEntryMut;
use hashbrown::hash_map::Entry;
use polars_utils::hashing::{hash_to_partition, DirtyHash};
use polars_utils::idx_vec::IdxVec;
use polars_utils::itertools::Itertools;
use polars_utils::sync::SyncPtr;
use polars_utils::total_ord::{ToTotalOrd, TotalHash};
use polars_utils::total_ord::{ToTotalOrd, TotalHash, TotalOrdWrap};
use polars_utils::unitvec;
use rayon::prelude::*;

Expand Down Expand Up @@ -73,50 +72,42 @@ fn finish_group_order(mut out: Vec<Vec<IdxItem>>, sorted: bool) -> GroupsProxy {
}
}

pub(crate) fn group_by<T>(a: impl Iterator<Item = T>, sorted: bool) -> GroupsProxy
pub(crate) fn group_by<K>(keys: impl Iterator<Item = K>, sorted: bool) -> GroupsProxy
where
T: TotalHash + TotalEq,
K: TotalHash + TotalEq,
{
let init_size = get_init_size();
let mut hash_tbl: PlHashMap<T, (IdxSize, IdxVec)> = PlHashMap::with_capacity(init_size);
let hasher = hash_tbl.hasher().clone();
let mut cnt = 0;
a.for_each(|k| {
let idx = cnt;
cnt += 1;

let mut state = hasher.build_hasher();
k.tot_hash(&mut state);
let h = state.finish();
let entry = hash_tbl.raw_entry_mut().from_hash(h, |k_| k.tot_eq(k_));

match entry {
RawEntryMut::Vacant(entry) => {
let tuples = unitvec![idx];
entry.insert_with_hasher(h, k, (idx, tuples), |k| {
let mut state = hasher.build_hasher();
k.tot_hash(&mut state);
state.finish()
});
},
RawEntryMut::Occupied(mut entry) => {
let v = entry.get_mut();
v.1.push(idx);
},
}
});
let (mut first, mut groups);
if sorted {
let mut groups = hash_tbl
.into_iter()
.map(|(_k, v)| v)
.collect_trusted::<Vec<_>>();
groups.sort_unstable_by_key(|g| g.0);
let mut idx: GroupsIdx = groups.into_iter().collect();
idx.sorted = true;
GroupsProxy::Idx(idx)
groups = Vec::with_capacity(get_init_size());
first = Vec::with_capacity(get_init_size());
let mut hash_tbl = PlHashMap::with_capacity(init_size);
for (idx, k) in keys.enumerate_idx() {
match hash_tbl.entry(TotalOrdWrap(k)) {
Entry::Vacant(entry) => {
let group_idx = groups.len() as IdxSize;
entry.insert(group_idx);
groups.push(unitvec![idx]);
first.push(idx);
},
Entry::Occupied(entry) => unsafe {
groups.get_unchecked_mut(*entry.get() as usize).push(idx)
},
}
}
} else {
GroupsProxy::Idx(hash_tbl.into_values().collect())
let mut hash_tbl = PlHashMap::with_capacity(init_size);
for (idx, k) in keys.enumerate_idx() {
match hash_tbl.entry(TotalOrdWrap(k)) {
Entry::Vacant(entry) => {
entry.insert((idx, unitvec![idx]));
},
Entry::Occupied(mut entry) => entry.get_mut().1.push(idx),
}
}
(first, groups) = hash_tbl.into_values().unzip();
}
GroupsProxy::Idx(GroupsIdx::new(first, groups, sorted))
}

// giving the slice info to the compiler is much
Expand All @@ -128,8 +119,8 @@ pub(crate) fn group_by_threaded_slice<T, IntoSlice>(
sorted: bool,
) -> GroupsProxy
where
T: TotalHash + TotalEq + ToTotalOrd,
<T as ToTotalOrd>::TotalOrdItem: Send + Hash + Eq + Sync + Copy + DirtyHash,
T: ToTotalOrd,
<T as ToTotalOrd>::TotalOrdItem: Send + Sync + Copy + DirtyHash,
IntoSlice: AsRef<[T]> + Send + Sync,
{
let init_size = get_init_size();
Expand All @@ -141,39 +132,28 @@ where
(0..n_partitions)
.into_par_iter()
.map(|thread_no| {
let mut hash_tbl: PlHashMap<T::TotalOrdItem, (IdxSize, IdxVec)> =
PlHashMap::with_capacity(init_size);
let mut hash_tbl = PlHashMap::with_capacity(init_size);

let mut offset = 0;
for keys in &keys {
let keys = keys.as_ref();
let len = keys.len() as IdxSize;
let hasher = hash_tbl.hasher().clone();

let mut cnt = 0;
keys.iter().for_each(|k| {
for (key_idx, k) in keys.iter().enumerate_idx() {
let k = k.to_total_ord();
let idx = cnt + offset;
cnt += 1;
let idx = key_idx + offset;

if thread_no == hash_to_partition(k.dirty_hash(), n_partitions) {
let hash = hasher.hash_one(k);
let entry = hash_tbl.raw_entry_mut().from_key_hashed_nocheck(hash, &k);

match entry {
RawEntryMut::Vacant(entry) => {
let tuples = unitvec![idx];
entry.insert_with_hasher(hash, k, (idx, tuples), |k| {
hasher.hash_one(*k)
});
match hash_tbl.entry(k) {
Entry::Vacant(entry) => {
entry.insert((idx, unitvec![idx]));
},
RawEntryMut::Occupied(mut entry) => {
let v = entry.get_mut();
v.1.push(idx);
Entry::Occupied(mut entry) => {
entry.get_mut().1.push(idx);
},
}
}
});
}
offset += len;
}
hash_tbl
Expand All @@ -194,8 +174,8 @@ pub(crate) fn group_by_threaded_iter<T, I>(
where
I: IntoIterator<Item = T> + Send + Sync + Clone,
I::IntoIter: ExactSizeIterator,
T: TotalHash + TotalEq + DirtyHash + ToTotalOrd,
<T as ToTotalOrd>::TotalOrdItem: Send + Hash + Eq + Sync + Copy + DirtyHash,
T: ToTotalOrd,
<T as ToTotalOrd>::TotalOrdItem: Send + Sync + Copy + DirtyHash,
{
let init_size = get_init_size();

Expand All @@ -206,39 +186,29 @@ where
(0..n_partitions)
.into_par_iter()
.map(|thread_no| {
let mut hash_tbl: PlHashMap<T::TotalOrdItem, (IdxSize, IdxVec)> =
let mut hash_tbl: PlHashMap<T::TotalOrdItem, IdxVec> =
PlHashMap::with_capacity(init_size);

let mut offset = 0;
for keys in keys {
let keys = keys.clone().into_iter();
let len = keys.len() as IdxSize;
let hasher = hash_tbl.hasher().clone();

let mut cnt = 0;
keys.for_each(|k| {
for (key_idx, k) in keys.into_iter().enumerate_idx() {
let k = k.to_total_ord();
let idx = cnt + offset;
cnt += 1;
let idx = key_idx + offset;

if thread_no == hash_to_partition(k.dirty_hash(), n_partitions) {
let hash = hasher.hash_one(k);
let entry = hash_tbl.raw_entry_mut().from_key_hashed_nocheck(hash, &k);

match entry {
RawEntryMut::Vacant(entry) => {
let tuples = unitvec![idx];
entry.insert_with_hasher(hash, k, (idx, tuples), |k| {
hasher.hash_one(*k)
});
match hash_tbl.entry(k) {
Entry::Vacant(entry) => {
entry.insert(unitvec![idx]);
},
RawEntryMut::Occupied(mut entry) => {
let v = entry.get_mut();
v.1.push(idx);
Entry::Occupied(mut entry) => {
entry.get_mut().push(idx);
},
}
}
});
}
offset += len;
}
// iterating the hash tables locally
Expand All @@ -252,7 +222,7 @@ where
// indirection
hash_tbl
.into_iter()
.map(|(_k, v)| v)
.map(|(_k, v)| (unsafe { *v.first().unwrap_unchecked() }, v))
.collect_trusted::<Vec<_>>()
})
.collect::<Vec<_>>()
Expand Down
Loading

0 comments on commit 60a6465

Please sign in to comment.