You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As I mentioned in #39, the jax library has deprecatedjax.random.keyArray. It was removed entirely in 0.4.24, so LECO has a broken dependency.
/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:311: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
torch.utils._pytree._register_pytree_node(
2024-04-14 00:57:27.323177: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-14 00:57:27.323222: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-14 00:57:27.324499: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-14 00:57:28.392824: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:311: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
torch.utils._pytree._register_pytree_node(
Traceback (most recent call last):
File "/content/LECO/train_lora_xl.py", line 15, in<module>
from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
File "/content/LECO/lora.py", line 11, in<module>
from diffusers import UNet2DConditionModel
File "/usr/local/lib/python3.10/dist-packages/diffusers/__init__.py", line 38, in<module>
from .models import (
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/__init__.py", line 36, in<module>
from .controlnet_flax import FlaxControlNetModel
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/controlnet_flax.py", line 25, in<module>
from .modeling_flax_utils import FlaxModelMixin
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 46, in<module>
class FlaxModelMixin(PushToHubMixin):
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 195, in FlaxModelMixin
def init_weights(self, rng: jax.random.KeyArray) -> Dict:
File "/usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py", line 54, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.random' has no attribute 'KeyArray'```I tried downgrading jax to 0.4.23 but now it talks about a circular import and fails to run.```bash
/usr/local/lib/python3.10/dist-packages/torch/distributed/_functional_collectives.py:28: UserWarning: Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly warnings.warn(Traceback (most recent call last): File "/content/LECO/train_lora_xl.py", line 15, in <module> from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV File "/content/LECO/lora.py", line 11, in <module> from diffusers import UNet2DConditionModel File "/usr/local/lib/python3.10/dist-packages/diffusers/__init__.py", line 3, in <module> from .configuration_utils import ConfigMixin File "/usr/local/lib/python3.10/dist-packages/diffusers/configuration_utils.py", line 34, in <module> from .utils import ( File "/usr/local/lib/python3.10/dist-packages/diffusers/utils/__init__.py", line 21, in <module> from .accelerate_utils import apply_forward_hook File "/usr/local/lib/python3.10/dist-packages/diffusers/utils/accelerate_utils.py", line 24, in <module> import accelerate File "/usr/local/lib/python3.10/dist-packages/accelerate/__init__.py", line 3, in <module> from .accelerator import Accelerator File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 35, in <module> from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state File "/usr/local/lib/python3.10/dist-packages/accelerate/checkpointing.py", line 24, in <module> from .utils import ( File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/__init__.py", line 135, in <module> from .fsdp_utils import load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, save_fsdp_optimizer File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/fsdp_utils.py", line 25, in <module> import torch.distributed.checkpoint as dist_cp File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/__init__.py", line 7, in <module> from .state_dict_loader import load_state_dict, load File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/state_dict_loader.py", line 12, in <module> from .default_planner import DefaultLoadPlanner File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/default_planner.py", line 14, in <module> from torch.distributed._tensor import DTensor File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_tensor/__init__.py", line 346, in <module> import torch.distributed._tensor._dynamo_utils File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_tensor/_dynamo_utils.py", line 1, in <module> from torch._dynamo import allow_in_graph File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/__init__.py", line 2, in <module> from . import allowed_functions, convert_frame, eval_frame, resume_execution File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 62, in <module> from .output_graph import OutputGraph File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 39, in <module> from . import config, logging as torchdynamo_logging, variables File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/__init__.py", line 26, in <module> from .higher_order_ops import TorchHigherOrderOperatorVariable File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/higher_order_ops.py", line 11, in <module> import torch.onnx.operators File "/usr/local/lib/python3.10/dist-packages/torch/onnx/__init__.py", line 46, in <module> from ._internal.exporter import ( # usort:skip. needs to be last to avoid circular import File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/exporter.py", line 44, in <module> from torch.onnx._internal.fx import ( File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/__init__.py", line 1, in <module> from .patcher import ONNXTorchPatcher File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/patcher.py", line 11, in <module> import transformers # type: ignore[import] File "/usr/local/lib/python3.10/dist-packages/transformers/__init__.py", line 26, in <module> from . import dependency_versions_check File "/usr/local/lib/python3.10/dist-packages/transformers/dependency_versions_check.py", line 16, in <module> from .utils.versions import require_version, require_version_core File "/usr/local/lib/python3.10/dist-packages/transformers/utils/__init__.py", line 31, in <module> from .generic import ( File "/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py", line 33, in <module> import jax.numpy as jnp File "/usr/local/lib/python3.10/dist-packages/jax/__init__.py", line 39, in <module> from jax import config as _config_module File "/usr/local/lib/python3.10/dist-packages/jax/config.py", line 15, in <module> from jax._src.config import config as _deprecated_config # noqa: F401 File "/usr/local/lib/python3.10/dist-packages/jax/_src/config.py", line 28, in <module> from jax._src import lib File "/usr/local/lib/python3.10/dist-packages/jax/_src/lib/__init__.py", line 75, in <module> jax_version=jax.version.__version__,AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)
The text was updated successfully, but these errors were encountered:
As I mentioned in #39, the jax library has deprecated
jax.random.keyArray
. It was removed entirely in 0.4.24, so LECO has a broken dependency.The text was updated successfully, but these errors were encountered: