Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDXL training bug: AttributeError: module 'jax.random' has no attribute 'KeyArray' #40

Open
Chadius opened this issue Apr 14, 2024 · 0 comments

Comments

@Chadius
Copy link

Chadius commented Apr 14, 2024

As I mentioned in #39, the jax library has deprecated jax.random.keyArray. It was removed entirely in 0.4.24, so LECO has a broken dependency.

/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:311: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
2024-04-14 00:57:27.323177: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-14 00:57:27.323222: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-14 00:57:27.324499: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-14 00:57:28.392824: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:311: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
Traceback (most recent call last):
  File "/content/LECO/train_lora_xl.py", line 15, in <module>
    from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
  File "/content/LECO/lora.py", line 11, in <module>
    from diffusers import UNet2DConditionModel
  File "/usr/local/lib/python3.10/dist-packages/diffusers/__init__.py", line 38, in <module>
    from .models import (
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/__init__.py", line 36, in <module>
    from .controlnet_flax import FlaxControlNetModel
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/controlnet_flax.py", line 25, in <module>
    from .modeling_flax_utils import FlaxModelMixin
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 46, in <module>
    class FlaxModelMixin(PushToHubMixin):
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 195, in FlaxModelMixin
    def init_weights(self, rng: jax.random.KeyArray) -> Dict:
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py", line 54, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.random' has no attribute 'KeyArray'```

I tried downgrading jax to 0.4.23 but now it talks about a circular import and fails to run.

```bash
/usr/local/lib/python3.10/dist-packages/torch/distributed/_functional_collectives.py:28: UserWarning: Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly
  warnings.warn(
Traceback (most recent call last):
  File "/content/LECO/train_lora_xl.py", line 15, in <module>
    from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
  File "/content/LECO/lora.py", line 11, in <module>
    from diffusers import UNet2DConditionModel
  File "/usr/local/lib/python3.10/dist-packages/diffusers/__init__.py", line 3, in <module>
    from .configuration_utils import ConfigMixin
  File "/usr/local/lib/python3.10/dist-packages/diffusers/configuration_utils.py", line 34, in <module>
    from .utils import (
  File "/usr/local/lib/python3.10/dist-packages/diffusers/utils/__init__.py", line 21, in <module>
    from .accelerate_utils import apply_forward_hook
  File "/usr/local/lib/python3.10/dist-packages/diffusers/utils/accelerate_utils.py", line 24, in <module>
    import accelerate
  File "/usr/local/lib/python3.10/dist-packages/accelerate/__init__.py", line 3, in <module>
    from .accelerator import Accelerator
  File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 35, in <module>
    from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
  File "/usr/local/lib/python3.10/dist-packages/accelerate/checkpointing.py", line 24, in <module>
    from .utils import (
  File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/__init__.py", line 135, in <module>
    from .fsdp_utils import load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, save_fsdp_optimizer
  File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/fsdp_utils.py", line 25, in <module>
    import torch.distributed.checkpoint as dist_cp
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/__init__.py", line 7, in <module>
    from .state_dict_loader import load_state_dict, load
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/state_dict_loader.py", line 12, in <module>
    from .default_planner import DefaultLoadPlanner
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/default_planner.py", line 14, in <module>
    from torch.distributed._tensor import DTensor
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_tensor/__init__.py", line 346, in <module>
    import torch.distributed._tensor._dynamo_utils
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_tensor/_dynamo_utils.py", line 1, in <module>
    from torch._dynamo import allow_in_graph
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/__init__.py", line 2, in <module>
    from . import allowed_functions, convert_frame, eval_frame, resume_execution
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 62, in <module>
    from .output_graph import OutputGraph
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 39, in <module>
    from . import config, logging as torchdynamo_logging, variables
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/__init__.py", line 26, in <module>
    from .higher_order_ops import TorchHigherOrderOperatorVariable
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/higher_order_ops.py", line 11, in <module>
    import torch.onnx.operators
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/__init__.py", line 46, in <module>
    from ._internal.exporter import (  # usort:skip. needs to be last to avoid circular import
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/exporter.py", line 44, in <module>
    from torch.onnx._internal.fx import (
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/__init__.py", line 1, in <module>
    from .patcher import ONNXTorchPatcher
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/patcher.py", line 11, in <module>
    import transformers  # type: ignore[import]
  File "/usr/local/lib/python3.10/dist-packages/transformers/__init__.py", line 26, in <module>
    from . import dependency_versions_check
  File "/usr/local/lib/python3.10/dist-packages/transformers/dependency_versions_check.py", line 16, in <module>
    from .utils.versions import require_version, require_version_core
  File "/usr/local/lib/python3.10/dist-packages/transformers/utils/__init__.py", line 31, in <module>
    from .generic import (
  File "/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py", line 33, in <module>
    import jax.numpy as jnp
  File "/usr/local/lib/python3.10/dist-packages/jax/__init__.py", line 39, in <module>
    from jax import config as _config_module
  File "/usr/local/lib/python3.10/dist-packages/jax/config.py", line 15, in <module>
    from jax._src.config import config as _deprecated_config  # noqa: F401
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/config.py", line 28, in <module>
    from jax._src import lib
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/lib/__init__.py", line 75, in <module>
    jax_version=jax.version.__version__,
AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant