Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/brainpy/BrainPy into numb…
Browse files Browse the repository at this point in the history
…a-mlir
  • Loading branch information
Routhleck committed May 14, 2024
2 parents e55a9b0 + 4d4eea5 commit 78c239c
Show file tree
Hide file tree
Showing 21 changed files with 181 additions and 116 deletions.
2 changes: 1 addition & 1 deletion brainpy/_src/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def find_fps_with_gd_method(
"""
# optimization settings
if optimizer is None:
optimizer = optim.Adam(lr=optim.ExponentialDecay(0.2, 1, 0.9999),
optimizer = optim.Adam(lr=optim.ExponentialDecayLR(0.2, 1, 0.9999),
beta1=0.9, beta2=0.999, eps=1e-8)
else:
if not isinstance(optimizer, optim.Optimizer):
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dyn/rates/tests/test_nvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Test_NVAR(parameterized.TestCase):
def test_NVAR(self,mode):
bm.random.seed()
input=bm.random.randn(1,5)
layer=bp.dnn.NVAR(num_in=5,
layer=bp.dyn.NVAR(num_in=5,
delay=10,
mode=mode)
if mode in [bm.NonBatchingMode()]:
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/initialize/tests/test_decay_inits.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# visualization
def mat_visualize(matrix, cmap=None):
if cmap is None:
cmap = plt.cm.get_cmap('coolwarm')
plt.cm.get_cmap('coolwarm')
cmap = plt.colormaps.get_cmap('coolwarm')
plt.colormaps.get_cmap('coolwarm')
im = plt.matshow(matrix, cmap=cmap)
plt.colorbar(mappable=im, shrink=0.8, aspect=15)
plt.show()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def dV(self, V, t, h, n, Iext):

return dVdt

def update(self, tdi):
t, dt = tdi.t, tdi.dt
def update(self):
t, dt = bp.share['t'], bp.share['dt']
V, h, n = self.integral(self.V, self.h, self.n, t, self.input, dt=dt)
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
Expand Down
3 changes: 1 addition & 2 deletions brainpy/_src/integrators/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import jax.numpy as jnp
import numpy as np
import tqdm.auto
from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_flatten

from brainpy import math as bm
Expand Down Expand Up @@ -245,7 +244,7 @@ def _step_fun_integrator(self, static_args, dyn_args, t, i):

# progress bar
if self.progress_bar:
id_tap(lambda *args: self._pbar.update(), ())
jax.pure_callback(lambda *args: self._pbar.update(), ())

# return of function monitors
shared = dict(t=t + self.dt, dt=self.dt, i=i)
Expand Down
12 changes: 5 additions & 7 deletions brainpy/_src/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,27 +660,25 @@ def searchsorted(self, v, side='left', sorter=None):
"""
return _return(self.value.searchsorted(v=_as_jax_array_(v), side=side, sorter=sorter))

def sort(self, axis=-1, kind='quicksort', order=None):
def sort(self, axis=-1, stable=True, order=None):
"""Sort an array in-place.
Parameters
----------
axis : int, optional
Axis along which to sort. Default is -1, which means sort along the
last axis.
kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}
Sorting algorithm. The default is 'quicksort'. Note that both 'stable'
and 'mergesort' use timsort under the covers and, in general, the
actual implementation will vary with datatype. The 'mergesort' option
is retained for backwards compatibility.
stable : bool, optional
Whether to use a stable sorting algorithm. The default is True.
order : str or list of str, optional
When `a` is an array with fields defined, this argument specifies
which fields to compare first, second, etc. A single field can
be specified as a string, and not all fields need be specified,
but unspecified fields will still be used, in the order in which
they come up in the dtype, to break ties.
"""
self.value = self.value.sort(axis=axis, kind=kind, order=order)
self.value = self.value.sort(axis=axis, stable=stable, order=order)


def squeeze(self, axis=None):
"""Remove axes of length one from ``a``."""
Expand Down
13 changes: 6 additions & 7 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import jax
import jax.numpy as jnp
from jax.errors import UnexpectedTracerError
from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_flatten, tree_unflatten
from tqdm.auto import tqdm

Expand Down Expand Up @@ -421,14 +420,14 @@ def call(pred, x=None):
def _warp(f):
@functools.wraps(f)
def new_f(*args, **kwargs):
return jax.tree_map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array))
return jax.tree.map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array))

return new_f


def _warp_data(data):
def new_f(*args, **kwargs):
return jax.tree_map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array))
return jax.tree.map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array))

return new_f

Expand Down Expand Up @@ -727,7 +726,7 @@ def fun2scan(carry, x):
dyn_vars[k]._value = carry[k]
results = body_fun(*x, **unroll_kwargs)
if progress_bar:
id_tap(lambda *arg: bar.update(), ())
jax.pure_callback(lambda *arg: bar.update(), ())
return dyn_vars.dict_data(), results

if remat:
Expand Down Expand Up @@ -916,15 +915,15 @@ def fun2scan(carry, x):
dyn_vars[k]._value = dyn_vars_data[k]
carry, results = body_fun(carry, x)
if progress_bar:
id_tap(lambda *arg: bar.update(), ())
carry = jax.tree_map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array))
jax.pure_callback(lambda *arg: bar.update(), ())
carry = jax.tree.map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array))
return (dyn_vars.dict_data(), carry), results

if remat:
fun2scan = jax.checkpoint(fun2scan)

def call(init, operands):
init = jax.tree_map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array))
init = jax.tree.map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array))
return jax.lax.scan(f=fun2scan,
init=(dyn_vars.dict_data(), init),
xs=operands,
Expand Down
3 changes: 1 addition & 2 deletions brainpy/_src/math/object_transform/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,8 @@ def call_fun(self, *args, **kwargs):

return call_fun


def _make_transform(fun, stack):
@wraps(fun)
# @wraps(fun)
def _transform_function(variable_data: Dict, *args, **kwargs):
for key, v in stack.items():
v._value = variable_data[key]
Expand Down
94 changes: 47 additions & 47 deletions brainpy/_src/math/object_transform/tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,52 +1172,52 @@ def f(a, b):



class TestHessian(unittest.TestCase):
def test_hessian5(self):
bm.set_mode(bm.training_mode)

class RNN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden):
super(RNN, self).__init__()
self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True)
self.out = bp.dnn.Dense(num_hidden, 1)

def update(self, x):
return self.out(self.rnn(x))

# define the loss function
def lossfunc(inputs, targets):
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs)
loss = bp.losses.mean_squared_error(predicts, targets)
return loss

model = RNN(1, 2)
data_x = bm.random.rand(1, 1000, 1)
data_y = data_x + bm.random.randn(1, 1000, 1)

bp.reset_state(model, 1)
losshess = bm.hessian(lossfunc, grad_vars=model.train_vars())
hess_matrix = losshess(data_x, data_y)

weights = model.train_vars().unique()

# define the loss function
def loss_func_for_jax(weight_vals, inputs, targets):
for k, v in weight_vals.items():
weights[k].value = v
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs)
loss = bp.losses.mean_squared_error(predicts, targets)
return loss

bp.reset_state(model, 1)
jax_hessian = jax.hessian(loss_func_for_jax, argnums=0)({k: v.value for k, v in weights.items()}, data_x, data_y)

for k, v in hess_matrix.items():
for kk, vv in v.items():
self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4))

bm.clear_buffer_memory()
# class TestHessian(unittest.TestCase):
# def test_hessian5(self):
# bm.set_mode(bm.training_mode)
#
# class RNN(bp.DynamicalSystem):
# def __init__(self, num_in, num_hidden):
# super(RNN, self).__init__()
# self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True)
# self.out = bp.dnn.Dense(num_hidden, 1)
#
# def update(self, x):
# return self.out(self.rnn(x))
#
# # define the loss function
# def lossfunc(inputs, targets):
# runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
# predicts = runner.predict(inputs)
# loss = bp.losses.mean_squared_error(predicts, targets)
# return loss
#
# model = RNN(1, 2)
# data_x = bm.random.rand(1, 1000, 1)
# data_y = data_x + bm.random.randn(1, 1000, 1)
#
# bp.reset_state(model, 1)
# losshess = bm.hessian(lossfunc, grad_vars=model.train_vars())
# hess_matrix = losshess(data_x, data_y)
#
# weights = model.train_vars().unique()
#
# # define the loss function
# def loss_func_for_jax(weight_vals, inputs, targets):
# for k, v in weight_vals.items():
# weights[k].value = v
# runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
# predicts = runner.predict(inputs)
# loss = bp.losses.mean_squared_error(predicts, targets)
# return loss
#
# bp.reset_state(model, 1)
# jax_hessian = jax.hessian(loss_func_for_jax, argnums=0)({k: v.value for k, v in weights.items()}, data_x, data_y)
#
# for k, v in hess_matrix.items():
# for kk, vv in v.items():
# self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4))
#
# bm.clear_buffer_memory()


7 changes: 7 additions & 0 deletions brainpy/_src/math/object_transform/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,13 @@ def test1(self):

tree = jax.tree.structure(hh)
leaves = jax.tree.leaves(hh)
# tree = jax.tree.structure(hh)
# leaves = jax.tree.leaves(hh)

print(tree)
print(leaves)
print(jax.tree.unflatten(tree, leaves))
# print(jax.tree.unflatten(tree, leaves))
print()


Expand Down Expand Up @@ -282,12 +285,16 @@ def all_close(x, y):
assert bm.allclose(x, y)

jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array)
# jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array)

random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array)
# random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
# jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array)

obj.load_state_dict(random_state)
jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array)
# jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array)



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_nodes():
A.pre = B
B.pre = A

net = bp.dyn.Network(A, B)
net = bp.Network(A, B)
abs_nodes = net.nodes(method='absolute')
rel_nodes = net.nodes(method='relative')
print()
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/math/object_transform/tests/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import brainpy as bp


class GABAa_without_Variable(bp.TwoEndConn):
class GABAa_without_Variable(bp.synapses.TwoEndConn):
def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75.,
alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs):
super(GABAa_without_Variable, self).__init__(pre=pre, post=post, **kwargs)
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_neu_nodes_1():
assert len(neu.nodes(method='relative', include_self=False)) == 1


class GABAa_with_Variable(bp.TwoEndConn):
class GABAa_with_Variable(bp.synapses.TwoEndConn):
def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75.,
alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs):
super(GABAa_with_Variable, self).__init__(pre=pre, post=post, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/object_transform/tests/test_controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def f1():
branches=[f1,
lambda: 2, lambda: 3,
lambda: 4, lambda: 5],
dyn_vars=var_a,
# dyn_vars=var_a,
show_code=True)

self.assertTrue(f(11) == 1)
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/object_transform/tests/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class MyObj:
def __init__(self):
self.a = bm.Variable(bm.ones(2))

@bm.cls_jit(static_argnums=1)
@bm.cls_jit(static_argnums=0)
def f(self, b, c):
self.a.value *= b
self.a.value /= c
Expand Down
Loading

0 comments on commit 78c239c

Please sign in to comment.