Skip to content

Commit

Permalink
Adding tests for validate and noticed that re_evaluate tests usin…
Browse files Browse the repository at this point in the history
…g `local_dict` argument are flawed and do not actually work
  • Loading branch information
robbmcleod committed Jun 29, 2023
1 parent 0032150 commit 74d5973
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
3 changes: 2 additions & 1 deletion numexpr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
import os, os.path
import platform
from numexpr.expressions import E
from numexpr.necompiler import NumExpr, disassemble, evaluate, re_evaluate
from numexpr.necompiler import (NumExpr, disassemble, evaluate, re_evaluate,
validate)

from numexpr.utils import (_init_num_threads,
get_vml_version, set_vml_accuracy_mode, set_vml_num_threads,
Expand Down
32 changes: 30 additions & 2 deletions numexpr/tests/test_numexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from numpy import shape, allclose, array_equal, ravel, isnan, isinf

import numexpr
from numexpr import E, NumExpr, evaluate, re_evaluate, disassemble, use_vml
from numexpr import E, NumExpr, evaluate, re_evaluate, validate, disassemble, use_vml
from numexpr.expressions import ConstantNode

import unittest
Expand Down Expand Up @@ -370,10 +370,38 @@ def test_re_evaluate(self):
assert_array_equal(x, array([86., 124., 168.]))

def test_re_evaluate_dict(self):
a1 = array([1., 2., 3.])
b1 = array([4., 5., 6.])
c1 = array([7., 8., 9.])
x = evaluate("2*a + 3*b*c", local_dict={'a': a1, 'b': b1, 'c': c1})
x = re_evaluate()
assert_array_equal(x, array([86., 124., 168.]))

def test_validate(self):
a = array([1., 2., 3.])
b = array([4., 5., 6.])
c = array([7., 8., 9.])
x = evaluate("2*a + 3*b*c", local_dict={'a': a, 'b': b, 'c': c})
retval = validate("2*a + 3*b*c")
assert(retval is None)
x = re_evaluate()
assert_array_equal(x, array([86., 124., 168.]))

def test_validate_missing_var(self):
a = array([1., 2., 3.])
b = array([4., 5., 6.])
retval = validate("2*a + 3*b*c")
assert(isinstance(retval, KeyError))

def test_validate_syntax(self):
retval = validate("2+")
assert(isinstance(retval, SyntaxError))

def test_validate_dict(self):
a1 = array([1., 2., 3.])
b1 = array([4., 5., 6.])
c1 = array([7., 8., 9.])
retval = validate("2*a + 3*b*c", local_dict={'a': a1, 'b': b1, 'c': c1})
assert(retval is None)
x = re_evaluate()
assert_array_equal(x, array([86., 124., 168.]))

Expand Down

0 comments on commit 74d5973

Please sign in to comment.