diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index 2e5e103c..59509c0c 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -94,7 +94,6 @@ def __init__( self.target = target # transform - self._eval_dyn_vars = False self._grad_transform = transform self._dyn_vars = VariableStack() self._transform = None @@ -198,20 +197,18 @@ def __call__(self, *args, **kwargs): ) return self._return(rets) - elif not self._eval_dyn_vars: # evaluate dynamical variables - stack = get_stack_cache(self.target) - if stack is None: - with VariableStack() as stack: - rets = eval_shape(self._transform, - [v.value for v in self._grad_vars], # variables for gradients - {}, # dynamical variables - *args, - **kwargs) - cache_stack(self.target, stack) - + # evaluate dynamical variables + stack = get_stack_cache(self.target) + if stack is None: + with VariableStack() as stack: + rets = eval_shape(self._transform, + [v.value for v in self._grad_vars], # variables for gradients + {}, # dynamical variables + *args, + **kwargs) + cache_stack(self.target, stack) self._dyn_vars = stack self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) - self._eval_dyn_vars = True # if not the outermost transformation if not stack.is_first_stack(): diff --git a/brainpy/_src/math/object_transform/tests/test_autograd.py b/brainpy/_src/math/object_transform/tests/test_autograd.py index 1cd7c7cd..bb4adf1d 100644 --- a/brainpy/_src/math/object_transform/tests/test_autograd.py +++ b/brainpy/_src/math/object_transform/tests/test_autograd.py @@ -86,6 +86,17 @@ def call(a, b, c): assert aux[1] == bm.exp(0.1) + def test_grad_jit(self): + def call(a, b, c): return bm.sum(a + b + c) + + bm.random.seed(1) + a = bm.ones(10) + b = bm.random.randn(10) + c = bm.random.uniform(size=10) + f_grad = bm.jit(bm.grad(call)) + assert (f_grad(a, b, c) == 1.).all() + + class TestObjectFuncGrad(unittest.TestCase): def test_grad_ob1(self): class Test(bp.BrainPyObject):