diff --git a/bench/large_array_vs_numpy.py b/bench/large_array_vs_numpy.py new file mode 100644 index 0000000..72219a1 --- /dev/null +++ b/bench/large_array_vs_numpy.py @@ -0,0 +1,152 @@ +################################################################################# +# To mimic the scenario that computation is i/o bound and constrained by memory +# +# It's a much simplified version that the chunk is computed in a loop, +# and expression is evaluated in a sequence, which is not true in reality. +# Neverthless, numexpr outperforms numpy. +################################################################################# +""" +Benchmarking Expression 1: +NumPy time (threaded over 32 chunks with 2 threads): 4.612313 seconds +numexpr time (threaded with re_evaluate over 32 chunks with 2 threads): 0.951172 seconds +numexpr speedup: 4.85x +---------------------------------------- +Benchmarking Expression 2: +NumPy time (threaded over 32 chunks with 2 threads): 23.862752 seconds +numexpr time (threaded with re_evaluate over 32 chunks with 2 threads): 2.182058 seconds +numexpr speedup: 10.94x +---------------------------------------- +Benchmarking Expression 3: +NumPy time (threaded over 32 chunks with 2 threads): 20.594895 seconds +numexpr time (threaded with re_evaluate over 32 chunks with 2 threads): 2.927881 seconds +numexpr speedup: 7.03x +---------------------------------------- +Benchmarking Expression 4: +NumPy time (threaded over 32 chunks with 2 threads): 12.834101 seconds +numexpr time (threaded with re_evaluate over 32 chunks with 2 threads): 5.392480 seconds +numexpr speedup: 2.38x +---------------------------------------- +""" + +import os + +os.environ["NUMEXPR_NUM_THREADS"] = "16" +import numpy as np +import numexpr as ne +import timeit +import threading + +array_size = 10**8 +num_runs = 10 +num_chunks = 32 # Number of chunks +num_threads = 2 # Number of threads constrained by how many chunks memory can hold + +a = np.random.rand(array_size).reshape(10**4, -1) +b = np.random.rand(array_size).reshape(10**4, -1) +c = np.random.rand(array_size).reshape(10**4, -1) + +chunk_size = array_size // num_chunks + +expressions_numpy = [ + lambda a, b, c: a + b * c, + lambda a, b, c: a**2 + b**2 - 2 * a * b * np.cos(c), + lambda a, b, c: np.sin(a) + np.log(b) * np.sqrt(c), + lambda a, b, c: np.exp(a) + np.tan(b) - np.sinh(c), +] + +expressions_numexpr = [ + "a + b * c", + "a**2 + b**2 - 2 * a * b * cos(c)", + "sin(a) + log(b) * sqrt(c)", + "exp(a) + tan(b) - sinh(c)", +] + + +def benchmark_numpy_chunk(func, a, b, c, results, indices): + for index in indices: + start = index * chunk_size + end = (index + 1) * chunk_size + time_taken = timeit.timeit( + lambda: func(a[start:end], b[start:end], c[start:end]), number=num_runs + ) + results.append(time_taken) + + +def benchmark_numexpr_re_evaluate(expr, a, b, c, results, indices): + for index in indices: + start = index * chunk_size + end = (index + 1) * chunk_size + if index == 0: + # Evaluate the first chunk with evaluate + time_taken = timeit.timeit( + lambda: ne.evaluate( + expr, + local_dict={ + "a": a[start:end], + "b": b[start:end], + "c": c[start:end], + }, + ), + number=num_runs, + ) + else: + # Re-evaluate subsequent chunks with re_evaluate + time_taken = timeit.timeit( + lambda: ne.re_evaluate( + local_dict={"a": a[start:end], "b": b[start:end], "c": c[start:end]} + ), + number=num_runs, + ) + results.append(time_taken) + + +def run_benchmark_threaded(): + chunk_indices = list(range(num_chunks)) + + for i in range(len(expressions_numpy)): + print(f"Benchmarking Expression {i+1}:") + + results_numpy = [] + results_numexpr = [] + + threads_numpy = [] + for j in range(num_threads): + indices = chunk_indices[j::num_threads] # Distribute chunks across threads + thread = threading.Thread( + target=benchmark_numpy_chunk, + args=(expressions_numpy[i], a, b, c, results_numpy, indices), + ) + threads_numpy.append(thread) + thread.start() + + for thread in threads_numpy: + thread.join() + + numpy_time = sum(results_numpy) + print( + f"NumPy time (threaded over {num_chunks} chunks with {num_threads} threads): {numpy_time:.6f} seconds" + ) + + threads_numexpr = [] + for j in range(num_threads): + indices = chunk_indices[j::num_threads] # Distribute chunks across threads + thread = threading.Thread( + target=benchmark_numexpr_re_evaluate, + args=(expressions_numexpr[i], a, b, c, results_numexpr, indices), + ) + threads_numexpr.append(thread) + thread.start() + + for thread in threads_numexpr: + thread.join() + + numexpr_time = sum(results_numexpr) + print( + f"numexpr time (threaded with re_evaluate over {num_chunks} chunks with {num_threads} threads): {numexpr_time:.6f} seconds" + ) + print(f"numexpr speedup: {numpy_time / numexpr_time:.2f}x") + print("-" * 40) + + +if __name__ == "__main__": + run_benchmark_threaded() diff --git a/numexpr/necompiler.py b/numexpr/necompiler.py index 1e60b5a..a693c4d 100644 --- a/numexpr/necompiler.py +++ b/numexpr/necompiler.py @@ -19,7 +19,7 @@ is_cpu_amd_intel = False # DEPRECATION WARNING: WILL BE REMOVED IN FUTURE RELEASE from numexpr import interpreter, expressions, use_vml -from numexpr.utils import CacheDict +from numexpr.utils import CacheDict, ContextDict # Declare a double type that does not exist in Python space double = numpy.double @@ -776,11 +776,9 @@ def getArguments(names, local_dict=None, global_dict=None, _frame_depth: int=2): # Dictionaries for caching variable names and compiled expressions _names_cache = CacheDict(256) _numexpr_cache = CacheDict(256) -_numexpr_last = {} +_numexpr_last = ContextDict() evaluate_lock = threading.Lock() -# MAYBE: decorate this function to add attributes instead of having the -# _numexpr_last dictionary? def validate(ex: str, local_dict: Optional[Dict] = None, global_dict: Optional[Dict] = None, @@ -887,7 +885,7 @@ def validate(ex: str, compiled_ex = _numexpr_cache[numexpr_key] = NumExpr(ex, signature, sanitize=sanitize, **context) kwargs = {'out': out, 'order': order, 'casting': casting, 'ex_uses_vml': ex_uses_vml} - _numexpr_last = dict(ex=compiled_ex, argnames=names, kwargs=kwargs) + _numexpr_last.set(ex=compiled_ex, argnames=names, kwargs=kwargs) except Exception as e: return e return None diff --git a/numexpr/tests/test_numexpr.py b/numexpr/tests/test_numexpr.py index c69e61d..2aa56ca 100644 --- a/numexpr/tests/test_numexpr.py +++ b/numexpr/tests/test_numexpr.py @@ -1201,6 +1201,7 @@ def run(self): test.join() def test_multithread(self): + import threading # Running evaluate() from multiple threads shouldn't crash @@ -1218,6 +1219,77 @@ def work(n): for t in threads: t.join() + def test_thread_safety(self): + """ + Expected output + + When not safe (before the pr this test is commited) + AssertionError: Thread-0 failed: result does not match expected + + When safe (after the pr this test is commited) + Should pass without failure + """ + import threading + import time + + barrier = threading.Barrier(4) + + # Function that each thread will run with different expressions + def thread_function(a_value, b_value, expression, expected_result, results, index): + validate(expression, local_dict={"a": a_value, "b": b_value}) + # Wait for all threads to reach this point + # such that they all set _numexpr_last + barrier.wait() + + # Simulate some work or a context switch delay + time.sleep(0.1) + + result = re_evaluate(local_dict={"a": a_value, "b": b_value}) + results[index] = np.array_equal(result, expected_result) + + def test_thread_safety_with_numexpr(): + num_threads = 4 + array_size = 1000000 + + expressions = [ + "a + b", + "a - b", + "a * b", + "a / b" + ] + + a_value = [np.full(array_size, i + 1) for i in range(num_threads)] + b_value = [np.full(array_size, (i + 1) * 2) for i in range(num_threads)] + + expected_results = [ + a_value[i] + b_value[i] if expr == "a + b" else + a_value[i] - b_value[i] if expr == "a - b" else + a_value[i] * b_value[i] if expr == "a * b" else + a_value[i] / b_value[i] if expr == "a / b" else None + for i, expr in enumerate(expressions) + ] + + results = [None] * num_threads + threads = [] + + # Create and start threads with different expressions + for i in range(num_threads): + thread = threading.Thread( + target=thread_function, + args=(a_value[i], b_value[i], expressions[i], expected_results[i], results, i) + ) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + for i in range(num_threads): + if not results[i]: + self.fail(f"Thread-{i} failed: result does not match expected") + + test_thread_safety_with_numexpr() + # The worker function for the subprocess (needs to be here because Windows # has problems pickling nested functions with the multiprocess module :-/) diff --git a/numexpr/utils.py b/numexpr/utils.py index 5d5ec9b..cc61833 100644 --- a/numexpr/utils.py +++ b/numexpr/utils.py @@ -13,6 +13,7 @@ import os import subprocess +import contextvars from numexpr.interpreter import _set_num_threads, _get_num_threads, MAX_THREADS from numexpr import use_vml @@ -226,3 +227,83 @@ def __setitem__(self, key, value): super(CacheDict, self).__delitem__(k) super(CacheDict, self).__setitem__(key, value) + +class ContextDict: + """ + A context aware version dictionary + """ + def __init__(self): + self._context_data = contextvars.ContextVar('context_data', default={}) + + def set(self, key=None, value=None, **kwargs): + data = self._context_data.get().copy() + + if key is not None: + data[key] = value + + for k, v in kwargs.items(): + data[k] = v + + self._context_data.set(data) + + def get(self, key, default=None): + data = self._context_data.get() + return data.get(key, default) + + def delete(self, key): + data = self._context_data.get().copy() + if key in data: + del data[key] + self._context_data.set(data) + + def clear(self): + self._context_data.set({}) + + def all(self): + return self._context_data.get() + + def update(self, *args, **kwargs): + data = self._context_data.get().copy() + + if args: + if len(args) > 1: + raise TypeError(f"update() takes at most 1 positional argument ({len(args)} given)") + other = args[0] + if isinstance(other, dict): + data.update(other) + else: + for k, v in other: + data[k] = v + + data.update(kwargs) + self._context_data.set(data) + + def keys(self): + return self._context_data.get().keys() + + def values(self): + return self._context_data.get().values() + + def items(self): + return self._context_data.get().items() + + def __getitem__(self, key): + return self.get(key) + + def __setitem__(self, key, value): + self.set(key, value) + + def __delitem__(self, key): + self.delete(key) + + def __contains__(self, key): + return key in self._context_data.get() + + def __len__(self): + return len(self._context_data.get()) + + def __iter__(self): + return iter(self._context_data.get()) + + def __repr__(self): + return repr(self._context_data.get())