Skip to content

Commit

Permalink
fix (stateful)(utilities): adding FlaxPreTrainedModel as one of the…
Browse files Browse the repository at this point in the history
… bases. This is because `FlaxPreTrainedModel` is a standalone class that doesn't directly inherit from a native Flax module.
  • Loading branch information
YushaArif99 committed Sep 25, 2024
1 parent 9f76c2a commit f163a1b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ivy/stateful/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _is_submodule(obj, kw):
"keras.src.engine.base_layer.Layer",
"keras.src.layers.layer.Layer",
),
"flax": ("flax.nnx.nnx.module.Module",),
"flax": ("flax.nnx.nnx.module.Module","transformers.modeling_flax_utils.FlaxPreTrainedModel"),
}[kw]
try:
for bc in type(obj).mro():
Expand Down Expand Up @@ -730,7 +730,7 @@ def forward(self, x):
def _compute_module_dict_jax(model, prefix=""):
_module_dict = dict()
for key, value in model.__dict__.items():
if isinstance(value, nnx.Module):
if isinstance(value, nnx.Module) and value != model:
if not hasattr(value, "named_parameters"):
_module_dict.update(
_compute_module_dict_jax(value, prefix=f"{key}.")
Expand Down

0 comments on commit f163a1b

Please sign in to comment.