Skip to content

Commit

Permalink
[math] flow control updates (#396)
Browse files Browse the repository at this point in the history
[math] flow control updates
  • Loading branch information
ztqakita authored Jun 20, 2023
2 parents d4e8da5 + 1d5c899 commit bb19b94
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 96 deletions.
44 changes: 25 additions & 19 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,34 +96,30 @@
# Part 4: Training #
# ------------------ #

from ._src.train.base import (DSTrainer as DSTrainer, )
from ._src.train.back_propagation import (BPTT as BPTT,
BPFF as BPFF, )
from ._src.train.online import (OnlineTrainer as OnlineTrainer,
ForceTrainer as ForceTrainer, )
from ._src.train.offline import (OfflineTrainer as OfflineTrainer,
RidgeTrainer as RidgeTrainer, )
from brainpy._src.train.base import (DSTrainer as DSTrainer, )
from brainpy._src.train.back_propagation import (BPTT as BPTT,
BPFF as BPFF, )
from brainpy._src.train.online import (OnlineTrainer as OnlineTrainer,
ForceTrainer as ForceTrainer, )
from brainpy._src.train.offline import (OfflineTrainer as OfflineTrainer,
RidgeTrainer as RidgeTrainer, )

# Part 6: Others #
# ------------------ #

from . import running, testing, analysis
from ._src.visualization import (visualize as visualize)
from ._src import base, modes, train, dyn
from brainpy import running, testing, analysis
from brainpy._src.visualization import (visualize as visualize)
from brainpy._src import base, train, dyn

# Part 7: Deprecations #
# ---------------------- #

ode.__dict__['odeint'] = odeint
sde.__dict__['sdeint'] = sdeint
fde.__dict__['fdeint'] = fdeint

# deprecated
from brainpy._src import modes
from brainpy._src.math.object_transform.base import (Base as Base,
ArrayCollector,
Collector as Collector, )


# deprecated
from brainpy._src import checking
from brainpy._src.synapses import compat
from brainpy._src.deprecations import deprecation_getattr2
Expand Down Expand Up @@ -189,9 +185,11 @@
# neurons
'HH': ('brainpy.dyn.HH', 'brainpy.neurons.HH', neurons.HH),
'MorrisLecar': ('brainpy.dyn.MorrisLecar', 'brainpy.neurons.MorrisLecar', neurons.MorrisLecar),
'PinskyRinzelModel': ('brainpy.dyn.PinskyRinzelModel', 'brainpy.neurons.PinskyRinzelModel', neurons.PinskyRinzelModel),
'PinskyRinzelModel': ('brainpy.dyn.PinskyRinzelModel', 'brainpy.neurons.PinskyRinzelModel',
neurons.PinskyRinzelModel),
'FractionalFHR': ('brainpy.dyn.FractionalFHR', 'brainpy.neurons.FractionalFHR', neurons.FractionalFHR),
'FractionalIzhikevich': ('brainpy.dyn.FractionalIzhikevich', 'brainpy.neurons.FractionalIzhikevich', neurons.FractionalIzhikevich),
'FractionalIzhikevich': ('brainpy.dyn.FractionalIzhikevich', 'brainpy.neurons.FractionalIzhikevich',
neurons.FractionalIzhikevich),
'LIF': ('brainpy.dyn.LIF', 'brainpy.neurons.LIF', neurons.LIF),
'ExpIF': ('brainpy.dyn.ExpIF', 'brainpy.neurons.ExpIF', neurons.ExpIF),
'AdExIF': ('brainpy.dyn.AdExIF', 'brainpy.neurons.AdExIF', neurons.AdExIF),
Expand All @@ -217,5 +215,13 @@
}
dyn.__getattr__ = deprecation_getattr2('brainpy.dyn', dyn.__deprecations)

del deprecation_getattr2, checking
ode.__deprecations = {'odeint': ('brainpy.ode.odeint', 'brainpy.odeint', odeint)}
ode.__getattr__ = deprecation_getattr2('brainpy.ode', ode.__deprecations)

sde.__deprecations = {'sdeint': ('brainpy.sde.sdeint', 'brainpy.sdeint', sdeint)}
sde.__getattr__ = deprecation_getattr2('brainpy.sde', sde.__deprecations)

fde.__deprecations = {'fdeint': ('brainpy.fde.fdeint', 'brainpy.fdeint', fdeint)}
fde.__getattr__ = deprecation_getattr2('brainpy.fde', sde.__deprecations)

del deprecation_getattr2, checking, compat
43 changes: 29 additions & 14 deletions brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from ._tools import (
dynvar_deprecation,
node_deprecation,
get_stack_cache,
cache_stack,
)
from .base import (
BrainPyObject,
Expand Down Expand Up @@ -201,8 +203,10 @@ def __call__(self, *args, **kwargs):
return self._return(rets)

elif not self._eval_dyn_vars: # evaluate dynamical variables
with new_transform(self):
with VariableStack() as stack:
stack = get_stack_cache(self.target)
if stack is None:
with new_transform(self):
with VariableStack() as stack:
if current_transform_number() > 1:
rets = self._transform(
[v.value for v in self._grad_vars], # variables for gradients
Expand All @@ -218,6 +222,7 @@ def __call__(self, *args, **kwargs):
*args,
**kwargs
)
cache_stack(self.target, stack)

self._dyn_vars = stack
self._dyn_vars.remove_var_by_id(*[id(v) for v in self._grad_vars])
Expand Down Expand Up @@ -266,7 +271,7 @@ def _make_grad(


def grad(
func: Callable = None,
func: Optional[Callable] = None,
grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
argnums: Optional[Union[int, Sequence[int]]] = None,
holomorphic: Optional[bool] = False,
Expand All @@ -278,7 +283,7 @@ def grad(
# deprecated
dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
) -> GradientTransform:
) -> Union[Callable, GradientTransform]:
"""Automatic gradient computation for functions or class objects.
This gradient function only support scalar return. It creates a function
Expand Down Expand Up @@ -780,7 +785,7 @@ def grad_fun(*args, **kwargs):


def vector_grad(
func: Callable,
func: Optional[Callable] = None,
grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
argnums: Optional[Union[int, Sequence[int]]] = None,
return_value: bool = False,
Expand All @@ -789,7 +794,7 @@ def vector_grad(
# deprecated
dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
) -> ObjectTransform:
) -> Union[Callable, ObjectTransform]:
"""Take vector-valued gradients for function ``func``.
Same as `brainpy.math.grad <./brainpy.math.autograd.grad.html>`_,
Expand Down Expand Up @@ -850,14 +855,24 @@ def vector_grad(
child_objs = check.is_all_objs(child_objs, out_as='dict')
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')

return GradientTransform(target=func,
transform=_vector_grad,
grad_vars=grad_vars,
dyn_vars=dyn_vars,
child_objs=child_objs,
argnums=argnums,
return_value=return_value,
has_aux=False if has_aux is None else has_aux)
if func is None:
return lambda f: GradientTransform(target=f,
transform=_vector_grad,
grad_vars=grad_vars,
dyn_vars=dyn_vars,
child_objs=child_objs,
argnums=argnums,
return_value=return_value,
has_aux=False if has_aux is None else has_aux)
else:
return GradientTransform(target=func,
transform=_vector_grad,
grad_vars=grad_vars,
dyn_vars=dyn_vars,
child_objs=child_objs,
argnums=argnums,
return_value=return_value,
has_aux=False if has_aux is None else has_aux)


def _check_callable(fun):
Expand Down
145 changes: 104 additions & 41 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-


import functools
from typing import Union, Sequence, Any, Dict, Callable, Optional
import numbers

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -426,17 +426,38 @@ def _check_f(f):
if callable(f):
return f
else:
return (lambda _: f)
return (lambda *args, **kwargs: f)


def _check_sequence(a):
return isinstance(a, (list, tuple))


def _cond_transform_fun(fun, dyn_vars):
@functools.wraps(fun)
def new_fun(dyn_vals, *static_vals):
for k, v in dyn_vars.items():
v._value = dyn_vals[k]
r = fun(*static_vals)
return {k: v.value for k, v in dyn_vars.items()}, r

return new_fun


def _get_cond_transform(dyn_vars, pred, true_fun, false_fun):
_true_fun = _cond_transform_fun(true_fun, dyn_vars)
_false_fun = _cond_transform_fun(false_fun, dyn_vars)

def call_fun(operands):
return jax.lax.cond(pred, _true_fun, _false_fun, dyn_vars.dict_data(), *operands)

return call_fun


def cond(
pred: bool,
true_fun: Union[Callable, jnp.ndarray, Array, float, int, bool],
false_fun: Union[Callable, jnp.ndarray, Array, float, int, bool],
true_fun: Union[Callable, jnp.ndarray, Array, numbers.Number],
false_fun: Union[Callable, jnp.ndarray, Array, numbers.Number],
operands: Any = (),

# deprecated
Expand Down Expand Up @@ -504,43 +525,54 @@ def cond(
dynvar_deprecation(dyn_vars)
node_deprecation(child_objs)

dyn_vars = get_stack_cache((true_fun, false_fun))
_transform = _get_cond_transform(VariableStack() if dyn_vars is None else dyn_vars,
pred,
true_fun,
false_fun)
if jax.config.jax_disable_jit:
dyn_vars = VariableStack()
dyn_values, res = _transform(operands)

else:
with new_transform('cond'):
dyn_vars, rets_true = evaluate_dyn_vars(true_fun, *operands)
dyn_vars2, rets_false = evaluate_dyn_vars(false_fun, *operands)
tree_true = jax.tree_util.tree_structure(dyn_vars)
tree_false = jax.tree_util.tree_structure(dyn_vars2)
if tree_true != tree_false:
raise TypeError('true_fun and false_fun output must have same type structure, '
f'got {tree_true} and {tree_false}.')
dyn_vars += dyn_vars2
del tree_true, tree_false, dyn_vars2

if current_transform_number() > 0:
return jax.lax.cond(pred, lambda: rets_true, lambda: rets_false)
if dyn_vars is None:
with new_transform('cond'):
dyn_vars, rets = evaluate_dyn_vars(
_transform,
operands,
use_eval_shape=current_transform_number() <= 1
)
cache_stack((true_fun, false_fun), dyn_vars)
if current_transform_number() > 0:
return rets[1]
dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands)
for k in dyn_values.keys():
dyn_vars[k]._value = dyn_values[k]
return res

# TODO: cache mechanism?
def _true_fun(dyn_vals, *static_vals):
for k, v in dyn_vars.items():
v._value = dyn_vals[k]
r = true_fun(*static_vals)
return {k: v.value for k, v in dyn_vars.items()}, r

def _false_fun(dyn_vals, *static_vals):
for k, v in dyn_vars.items():
v._value = dyn_vals[k]
r = false_fun(*static_vals)
return {k: v.value for k, v in dyn_vars.items()}, r
def _if_else_return1(conditions, branches, operands):
for i, pred in enumerate(conditions):
if pred:
return branches[i](*operands)
else:
return branches[-1](*operands)

old_values = {k: v.value for k, v in dyn_vars.items()}
dyn_values, res = jax.lax.cond(pred, _true_fun, _false_fun, old_values, *operands)
for k, v in dyn_vars.items():
v._value = dyn_values[k]

return res
def _if_else_return2(conditions, branches):
for i, pred in enumerate(conditions):
if pred:
return branches[i]
else:
return branches[-1]


def all_equal(iterator):
iterator = iter(iterator)
try:
first = next(iterator)
except StopIteration:
return True
return all(first == x for x in iterator)


def ifelse(
Expand Down Expand Up @@ -620,6 +652,10 @@ def ifelse(
raise ValueError(f'The numbers of branches and conditions do not match. '
f'Got len(conditions)={len(conditions)} and len(branches)={len(branches)}. '
f'We expect len(conditions) + 1 == len(branches). ')
if operands is None:
operands = tuple()
if not isinstance(operands, (tuple, list)):
operands = (operands,)

dynvar_deprecation(dyn_vars)
node_deprecation(child_objs)
Expand All @@ -631,21 +667,47 @@ def ifelse(
branches[1],
operands)
else:
if jax.config.jax_disable_jit:
return _if_else_return1(conditions, branches, operands)

else:
dyn_vars = get_stack_cache(tuple(branches))
if dyn_vars is None:
with new_transform('ifelse'):
with VariableStack() as dyn_vars:
if current_transform_number() > 1:
rets = [branch(*operands) for branch in branches]
else:
rets = [jax.eval_shape(branch, *operands) for branch in branches]
trees = [jax.tree_util.tree_structure(ret) for ret in rets]
if not all_equal(trees):
msg = 'All returns in branches should have the same tree structure. But we got:\n'
for tree in trees:
msg += f'- {tree}\n'
raise TypeError(msg)
cache_stack(tuple(branches), dyn_vars)
if current_transform_number():
return _if_else_return2(conditions, rets)

branches = [_cond_transform_fun(fun, dyn_vars) for fun in branches]

code_scope = {'conditions': conditions, 'branches': branches}
codes = ['def f(operands):',
codes = ['def f(dyn_vals, *operands):',
f' f0 = branches[{len(conditions)}]']
num_cond = len(conditions) - 1
code_scope['_cond'] = cond
code_scope['_cond'] = jax.lax.cond
for i in range(len(conditions) - 1):
codes.append(f' f{i + 1} = lambda r: _cond(conditions[{num_cond - i}], branches[{num_cond - i}], f{i}, r)')
codes.append(f' return _cond(conditions[0], branches[0], f{len(conditions) - 1}, operands)')
codes.append(f' f{i + 1} = lambda *r: _cond(conditions[{num_cond - i}], branches[{num_cond - i}], f{i}, *r)')
codes.append(f' return _cond(conditions[0], branches[0], f{len(conditions) - 1}, dyn_vals, *operands)')
codes = '\n'.join(codes)
if show_code:
print(codes)
exec(compile(codes.strip(), '', 'exec'), code_scope)
f = code_scope['f']
r = f(operands)
return r
dyn_values, res = f(dyn_vars.dict_data(), *operands)
for k in dyn_values.keys():
dyn_vars[k]._value = dyn_values[k]
return res


def _loop_abstractify(x):
Expand Down Expand Up @@ -764,6 +826,7 @@ def for_loop(
progress_bar: bool
Whether we use the progress bar to report the running progress.
.. versionadded:: 2.4.2
dyn_vars: Variable, sequence of Variable, dict
The instances of :py:class:`~.Variable`.
Expand Down
Loading

0 comments on commit bb19b94

Please sign in to comment.