Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance: Optimise HIR case folding #893

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 80 additions & 3 deletions regex-syntax/src/hir/interval.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::char;
use std::cmp;
use std::collections::hash_map::DefaultHasher;
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
use std::slice;
use std::u8;

Expand Down Expand Up @@ -32,9 +34,10 @@ use crate::unicode;
//
// Tests on this are relegated to the public API of HIR in src/hir.rs.

#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug)]
pub struct IntervalSet<I> {
ranges: Vec<I>,
folded: bool,
}

impl<I: Interval> IntervalSet<I> {
Expand All @@ -44,7 +47,10 @@ impl<I: Interval> IntervalSet<I> {
/// The given ranges do not need to be in any specific order, and ranges
/// may overlap.
pub fn new<T: IntoIterator<Item = I>>(intervals: T) -> IntervalSet<I> {
let mut set = IntervalSet { ranges: intervals.into_iter().collect() };
let mut set = IntervalSet {
ranges: intervals.into_iter().collect(),
folded: false,
};
set.canonicalize();
set
}
Expand All @@ -53,8 +59,13 @@ impl<I: Interval> IntervalSet<I> {
pub fn push(&mut self, interval: I) {
// TODO: This could be faster. e.g., Push the interval such that
// it preserves canonicalization.

// don't collect hash if we're not going to use it
let before = if self.folded { self.get_hash() } else { 0 };

self.ranges.push(interval);
self.canonicalize();
self.folded = self.folded && before == self.get_hash();
}

/// Return an iterator over all intervals in this set.
Expand All @@ -79,6 +90,9 @@ impl<I: Interval> IntervalSet<I> {
/// This returns an error if the necessary case mapping data is not
/// available.
pub fn case_fold_simple(&mut self) -> Result<(), unicode::CaseFoldError> {
if self.folded {
return Ok(());
}
let len = self.ranges.len();
for i in 0..len {
let range = self.ranges[i];
Expand All @@ -88,14 +102,28 @@ impl<I: Interval> IntervalSet<I> {
}
}
self.canonicalize();
self.folded = true;
Ok(())
}

/// Union this set with the given set, in place.
pub fn union(&mut self, other: &IntervalSet<I>) {
if other.ranges.is_empty() {
return;
}

// don't collect hash if we're not going to use it
let before_self = if self.folded { self.get_hash() } else { 0 };
let before_other = if other.folded { other.get_hash() } else { 0 };

// This could almost certainly be done more efficiently.
self.ranges.extend(&other.ranges);
self.canonicalize();
self.folded = self.folded && other.folded || {
let current_hash = self.get_hash();
self.folded && before_self == current_hash
|| other.folded && before_other == current_hash
};
}

/// Intersect this set with the given set, in place.
Expand All @@ -105,9 +133,14 @@ impl<I: Interval> IntervalSet<I> {
}
if other.ranges.is_empty() {
self.ranges.clear();
self.folded = false;
return;
}

// don't collect hash if we're not going to use it
let before_self = if self.folded { self.get_hash() } else { 0 };
let before_other = if other.folded { other.get_hash() } else { 0 };

// There should be a way to do this in-place with constant memory,
// but I couldn't figure out a simple way to do it. So just append
// the intersection to the end of this range, and then drain it before
Expand All @@ -134,6 +167,11 @@ impl<I: Interval> IntervalSet<I> {
}
}
self.ranges.drain(..drain_end);
self.folded = self.folded && other.folded || {
let current_hash = self.get_hash();
self.folded && before_self == current_hash
|| other.folded && before_other == current_hash
};
}

/// Subtract the given set from this set, in place.
Expand All @@ -142,6 +180,10 @@ impl<I: Interval> IntervalSet<I> {
return;
}

// don't collect hash if we're not going to use it
let before_self = if self.folded { self.get_hash() } else { 0 };
let before_other = if other.folded { other.get_hash() } else { 0 };

// This algorithm is (to me) surprisingly complex. A search of the
// interwebs indicate that this is a potentially interesting problem.
// Folks seem to suggest interval or segment trees, but I'd like to
Expand Down Expand Up @@ -226,6 +268,11 @@ impl<I: Interval> IntervalSet<I> {
a += 1;
}
self.ranges.drain(..drain_end);
self.folded = self.folded && other.folded || {
let current_hash = self.get_hash();
self.folded && before_self == current_hash
|| other.folded && before_other == current_hash
};
}

/// Compute the symmetric difference of the two sets, in place.
Expand Down Expand Up @@ -276,6 +323,9 @@ impl<I: Interval> IntervalSet<I> {
self.ranges.push(I::create(lower, I::Bound::max_value()));
}
self.ranges.drain(..drain_end);

// we don't need to update foldedness here stays the same because, necessarily, any set of
// matching members is entirely present or entirely not present
}

/// Converts this set into a canonical ordering.
Expand Down Expand Up @@ -318,6 +368,33 @@ impl<I: Interval> IntervalSet<I> {
}
true
}

fn get_hash(&self) -> u64 {
let mut hasher = DefaultHasher::default();
self.hash(&mut hasher);
hasher.finish()
}
}

impl<I> PartialEq for IntervalSet<I>
where
I: Interval,
{
fn eq(&self, other: &Self) -> bool {
self.ranges.eq(&other.ranges)
}
}

impl<I> Eq for IntervalSet<I> where I: Interval {}

impl<I> Hash for IntervalSet<I>
where
I: Interval,
{
fn hash<H: Hasher>(&self, state: &mut H) {
// don't hash the foldedness
self.ranges.hash(state)
}
}

/// An iterator over intervals.
Expand All @@ -333,7 +410,7 @@ impl<'a, I> Iterator for IntervalSetIter<'a, I> {
}

pub trait Interval:
Clone + Copy + Debug + Default + Eq + PartialEq + PartialOrd + Ord
Clone + Copy + Debug + Default + Eq + PartialEq + PartialOrd + Ord + Hash
{
type Bound: Bound;

Expand Down
43 changes: 28 additions & 15 deletions regex-syntax/src/hir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ impl<'a> Iterator for ClassUnicodeIter<'a> {
///
/// The range is closed. That is, the start and end of the range are included
/// in the range.
#[derive(Clone, Copy, Default, Eq, PartialEq, PartialOrd, Ord)]
#[derive(Clone, Copy, Default, Eq, PartialEq, PartialOrd, Ord, Hash)]
pub struct ClassUnicodeRange {
start: char,
end: char,
Expand Down Expand Up @@ -1028,20 +1028,33 @@ impl Interval for ClassUnicodeRange {
}
let start = self.start as u32;
let end = (self.end as u32).saturating_add(1);
let mut next_simple_cp = None;
for cp in (start..end).filter_map(char::from_u32) {
if next_simple_cp.map_or(false, |next| cp < next) {
continue;
}
let it = match unicode::simple_fold(cp)? {
Ok(it) => it,
Err(next) => {
next_simple_cp = next;
continue;
let mut range = start..end;
let mut idx = 0;
while let Some(cp) = range.next() {
if let Some(c) = char::from_u32(cp) {
let it = match unicode::optimised_fold(idx, c)? {
Ok((it, next_idx)) => {
idx = next_idx;
it
}
Err(next) => {
if let Some((next, next_idx)) = next {
let next = next as u32;
range = next..end;
idx = next_idx;
}
continue;
}
};
for cp_folded in it {
if let Some(last) = ranges.last_mut() {
if last.end as u32 + 1 == cp_folded as u32 {
last.end = cp_folded;
continue;
}
}
ranges.push(ClassUnicodeRange::new(cp_folded, cp_folded));
}
};
for cp_folded in it {
ranges.push(ClassUnicodeRange::new(cp_folded, cp_folded));
}
}
Ok(())
Expand Down Expand Up @@ -1186,7 +1199,7 @@ impl<'a> Iterator for ClassBytesIter<'a> {
///
/// The range is closed. That is, the start and end of the range are included
/// in the range.
#[derive(Clone, Copy, Default, Eq, PartialEq, PartialOrd, Ord)]
#[derive(Clone, Copy, Default, Eq, PartialEq, PartialOrd, Ord, Hash)]
pub struct ClassBytesRange {
start: u8,
end: u8,
Expand Down
61 changes: 53 additions & 8 deletions regex-syntax/src/unicode.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::error;
use std::fmt;
use std::mem::size_of;
use std::result;

use crate::hir;
Expand Down Expand Up @@ -78,38 +79,82 @@ impl fmt::Display for UnicodeWordError {
/// to, since there is some cost to fetching the equivalence class.
///
/// This returns an error if the Unicode case folding tables are not available.
#[allow(dead_code)]
pub fn simple_fold(
c: char,
) -> FoldResult<result::Result<impl Iterator<Item = char>, Option<char>>> {
match optimised_fold(0, c) {
Ok(Ok((iter, _))) => Ok(Ok(iter)),
Ok(Err(Some((c, _)))) => Ok(Err(Some(c))),
Ok(Err(None)) => Ok(Err(None)),
Err(e) => Err(e),
}
}

pub fn optimised_fold(
start: usize,
c: char,
) -> FoldResult<
result::Result<(impl Iterator<Item = char>, usize), Option<(char, usize)>>,
> {
#[cfg(not(feature = "unicode-case"))]
fn imp(
_: usize,
_: char,
) -> FoldResult<result::Result<impl Iterator<Item = char>, Option<char>>>
{
) -> FoldResult<
result::Result<
(impl Iterator<Item = char>, usize),
Option<(char, usize)>,
>,
> {
use std::option::IntoIter;
Err::<result::Result<IntoIter<char>, _>, _>(CaseFoldError(()))
Err::<result::Result<(IntoIter<char>, usize), _>, _>(CaseFoldError(()))
}

#[cfg(feature = "unicode-case")]
fn imp(
start: usize,
c: char,
) -> FoldResult<result::Result<impl Iterator<Item = char>, Option<char>>>
{
) -> FoldResult<
result::Result<
(impl Iterator<Item = char>, usize),
Option<(char, usize)>,
>,
> {
use crate::unicode_tables::case_folding_simple::CASE_FOLDING_SIMPLE;
// this is the greatest number of steps before we are guaranteed to find our value
const DEPTH_MAX: usize = size_of::<usize>() * 8
- CASE_FOLDING_SIMPLE.len().leading_zeros() as usize
+ 1;

// first, see if we can find it in less than depth; it's likely that we've recently looked
// up an adjacent value if we've provided a start
for (i, &(other, foldings)) in
CASE_FOLDING_SIMPLE[start..].iter().take(DEPTH_MAX).enumerate()
{
if other == c {
return Ok(Ok((foldings.iter().copied(), start + i + 1)));
} else if other > c {
return Ok(Err(Some((other, start + i))));
}
}
if start + DEPTH_MAX >= CASE_FOLDING_SIMPLE.len() {
return Ok(Err(None));
}

Ok(CASE_FOLDING_SIMPLE
.binary_search_by_key(&c, |&(c1, _)| c1)
.map(|i| CASE_FOLDING_SIMPLE[i].1.iter().copied())
.map(|i| (CASE_FOLDING_SIMPLE[i].1.iter().copied(), i + 1))
.map_err(|i| {
if i >= CASE_FOLDING_SIMPLE.len() {
None
} else {
Some(CASE_FOLDING_SIMPLE[i].0)
Some((CASE_FOLDING_SIMPLE[i].0, i))
}
}))
}

imp(c)
imp(start, c)
}

/// Returns true if and only if the given (inclusive) range contains at least
Expand Down