Skip to content

Commit

Permalink
fix (stateful)(utilities): added a guard to handle cases wherein `jax…
Browse files Browse the repository at this point in the history
….Array's` might be present inside the `.v` dict. We also call `setattr` to ensure the update is reflected on the translated model as well.
  • Loading branch information
YushaArif99 committed Sep 23, 2024
1 parent 3219c1f commit d5247ae
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions ivy/stateful/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,16 @@ def _maybe_update_flax_layer_weights(
)
params2[name] = getattr(layer, flax_name)
continue

params2[name].value = jnp.asarray(
params1_np, dtype=params2[name].value.dtype
)

if isinstance(params2[name], nnx.Variable):
params2[name].value = jnp.asarray(
params1_np, dtype=params2[name].value.dtype
)
else:
params2[name] = jnp.asarray(
params1_np, dtype=params2[name].dtype
)
setattr(model2, name, params2[name])

for name in buffers1:
layer, weight_name = _retrive_layer(model2, key_mapping[name])
Expand Down Expand Up @@ -250,6 +256,7 @@ def _maybe_update_flax_layer_weights(

else:
buffers2[name] = jnp.asarray(buffers1_np, dtype=buffers2[name].dtype)
setattr(model2, name, buffers2[name])

# Check if the parameters and buffers are the same
for name in params1:
Expand Down

0 comments on commit d5247ae

Please sign in to comment.