From f163a1b68ae7263b6443581f7054a1cbd5fa18a0 Mon Sep 17 00:00:00 2001 From: Yusha Arif Date: Wed, 25 Sep 2024 08:17:08 +0000 Subject: [PATCH] fix (stateful)(utilities): adding `FlaxPreTrainedModel` as one of the bases. This is because `FlaxPreTrainedModel` is a standalone class that doesn't directly inherit from a native Flax module. --- ivy/stateful/utilities.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy/stateful/utilities.py b/ivy/stateful/utilities.py index a7260690da81..8ae563146657 100644 --- a/ivy/stateful/utilities.py +++ b/ivy/stateful/utilities.py @@ -23,7 +23,7 @@ def _is_submodule(obj, kw): "keras.src.engine.base_layer.Layer", "keras.src.layers.layer.Layer", ), - "flax": ("flax.nnx.nnx.module.Module",), + "flax": ("flax.nnx.nnx.module.Module","transformers.modeling_flax_utils.FlaxPreTrainedModel"), }[kw] try: for bc in type(obj).mro(): @@ -730,7 +730,7 @@ def forward(self, x): def _compute_module_dict_jax(model, prefix=""): _module_dict = dict() for key, value in model.__dict__.items(): - if isinstance(value, nnx.Module): + if isinstance(value, nnx.Module) and value != model: if not hasattr(value, "named_parameters"): _module_dict.update( _compute_module_dict_jax(value, prefix=f"{key}.")