Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training with MPS - ARM64 Mac #93

Closed
stenczelt opened this issue Mar 26, 2023 · 5 comments
Closed

training with MPS - ARM64 Mac #93

stenczelt opened this issue Mar 26, 2023 · 5 comments

Comments

@stenczelt
Copy link
Contributor

Describe the bug
Training of model cannot be initialised on M2 Mac, using MPS acceleration. Since apple GPUs don't support 64 bit floats, so one needs to set default_dtype=float32 which is likely the issue.

To Reproduce
Steps to reproduce the behavior:

  1. Try training a model with --device=mps --default_dtype=float32

Expected behavior
The training should "just work" like elsewhere or on CPU.

Desktop (please complete the following information):

  • OS: MacOS, 13.2.1
  • M2 chip
  • Torch 2.0.0,
  • Python 3.9

Additional context

training args used:

python ../scripts/run_train.py \
    --name="MACE_model" \
    --train_file="Al2O3_train.xyz" \
    --valid_fraction=0.05 \
    --test_file="Al2O3_test.xyz" \
    --config_type_weights='{"Default":1.0}' \
    --model="MACE" \
    --hidden_irreps='16x0e + 16x1o' \
    --r_max=5.0 \
    --batch_size=10 \
    --max_num_epochs=1500 \
    --swa \
    --start_swa=1200 \
    --ema \
    --ema_decay=0.99 \
    --amsgrad \
    --restart_latest \
    --device=mps \
    --default_dtype=float32

output:

2023-03-26 08:04:10.863 INFO: MACE version: 0.2.0
2023-03-26 08:04:10.863 INFO: Configuration: Namespace(name='MACE_model', seed=123, log_dir='logs', model_dir='.', checkpoints_dir='checkpoints', results_dir='results', downloads_dir='downloads', device='mps', default_dtype='float32', log_level='INFO', error_table='PerAtomRMSE', model='MACE', r_max=5.0, num_radial_basis=8, num_cutoff_basis=5, interaction='RealAgnosticResidualInteractionBlock', interaction_first='RealAgnosticResidualInteractionBlock', max_ell=3, correlation=3, num_interactions=2, MLP_irreps='16x0e', radial_MLP='[64, 64, 64]', hidden_irreps='16x0e + 16x1o', num_channels=None, max_L=None, gate='silu', scaling='rms_forces_scaling', avg_num_neighbors=1, compute_avg_num_neighbors=True, compute_stress=False, compute_forces=True, train_file='Al2O3_train.xyz', valid_file=None, valid_fraction=0.05, test_file='Al2O3_test.xyz', E0s=None, energy_key='energy', forces_key='forces', virials_key='virials', stress_key='stress', dipole_key='dipole', charges_key='charges', loss='weighted', forces_weight=100.0, swa_forces_weight=100.0, energy_weight=1.0, swa_energy_weight=1000.0, virials_weight=1.0, swa_virials_weight=10.0, stress_weight=1.0, swa_stress_weight=10.0, dipole_weight=1.0, swa_dipole_weight=1.0, config_type_weights='{"Default":1.0}', optimizer='adam', batch_size=10, valid_batch_size=10, lr=0.01, swa_lr=0.001, weight_decay=5e-07, amsgrad=True, scheduler='ReduceLROnPlateau', lr_factor=0.8, scheduler_patience=50, lr_scheduler_gamma=0.9993, swa=True, start_swa=1200, ema=True, ema_decay=0.99, max_num_epochs=1500, patience=2048, eval_interval=2, keep_checkpoints=False, restart_latest=True, save_cpu=False, clip_grad=10.0, wandb=False, wandb_project='', wandb_entity='', wandb_name='', wandb_log_hypers=['num_channels', 'max_L', 'correlation', 'lr', 'swa_lr', 'weight_decay', 'batch_size', 'max_num_epochs', 'start_swa', 'energy_weight', 'forces_weight'])
2023-03-26 08:04:10.883 INFO: Using MPS GPU acceleration
2023-03-26 08:04:10.908 INFO: Using isolated atom energies from training file
2023-03-26 08:04:10.909 INFO: Loaded 45 training configurations from 'Al2O3_train.xyz'
2023-03-26 08:04:10.909 INFO: Using random 5.0% of training set for validation
2023-03-26 08:04:10.918 INFO: Loaded 11 test configurations from 'Al2O3_test.xyz'
2023-03-26 08:04:10.918 INFO: Total number of configurations: train=43, valid=2, tests=[Default: 11]
2023-03-26 08:04:10.919 INFO: AtomicNumberTable: (8, 13)
2023-03-26 08:04:10.919 INFO: Atomic energies: [-422.9243, -105.9163]
2023-03-26 08:04:13.990 INFO: WeightedEnergyForcesLoss(energy_weight=1.000, forces_weight=100.000)
2023-03-26 08:04:13.999 INFO: Average number of neighbors: 58.728389739990234
2023-03-26 08:04:13.999 INFO: Selected the following outputs: {'energy': True, 'forces': True, 'virials': False, 'stress': False, 'dipoles': False}
2023-03-26 08:04:13.999 INFO: Building model
2023-03-26 08:04:13.999 INFO: Hidden irreps: 16x0e + 16x1o
/Users/tks32/research/mace-tmp/venv/lib/python3.9/site-packages/torch/jit/_check.py:172: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn("The TorchScript type system doesn't support "
Traceback (most recent call last):
  File "/Users/tks32/research/mace-tmp/Al2O3/../scripts/run_train.py", line 563, in <module>
    main()
  File "/Users/tks32/research/mace-tmp/Al2O3/../scripts/run_train.py", line 324, in main
    model.to(device)
  File "/Users/tks32/research/mace-tmp/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1145, in to
    return self._apply(convert)
  File "/Users/tks32/research/mace-tmp/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 844, in _apply
    self._buffers[key] = fn(buf)
  File "/Users/tks32/research/mace-tmp/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1143, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
@stenczelt
Copy link
Contributor Author

stenczelt commented Mar 26, 2023

Having found where a float64 buffer was used explicitly, there is a further issue using the same training inputs:

2023-03-26 08:17:26.457 INFO: MACE version: 0.2.0
2023-03-26 08:17:26.457 INFO: Configuration: Namespace(name='MACE_model', seed=123, log_dir='logs', model_dir='.', checkpoints_dir='checkpoints', results_dir='results', downloads_dir='downloads', device='mps', default_dtype='float32', log_level='INFO', error_table='PerAtomRMSE', model='MACE', r_max=5.0, num_radial_basis=8, num_cutoff_basis=5, interaction='RealAgnosticResidualInteractionBlock', interaction_first='RealAgnosticResidualInteractionBlock', max_ell=3, correlation=3, num_interactions=2, MLP_irreps='16x0e', radial_MLP='[64, 64, 64]', hidden_irreps='16x0e + 16x1o', num_channels=None, max_L=None, gate='silu', scaling='rms_forces_scaling', avg_num_neighbors=1, compute_avg_num_neighbors=True, compute_stress=False, compute_forces=True, train_file='Al2O3_train.xyz', valid_file=None, valid_fraction=0.05, test_file='Al2O3_test.xyz', E0s=None, energy_key='energy', forces_key='forces', virials_key='virials', stress_key='stress', dipole_key='dipole', charges_key='charges', loss='weighted', forces_weight=100.0, swa_forces_weight=100.0, energy_weight=1.0, swa_energy_weight=1000.0, virials_weight=1.0, swa_virials_weight=10.0, stress_weight=1.0, swa_stress_weight=10.0, dipole_weight=1.0, swa_dipole_weight=1.0, config_type_weights='{"Default":1.0}', optimizer='adam', batch_size=10, valid_batch_size=10, lr=0.01, swa_lr=0.001, weight_decay=5e-07, amsgrad=True, scheduler='ReduceLROnPlateau', lr_factor=0.8, scheduler_patience=50, lr_scheduler_gamma=0.9993, swa=True, start_swa=1200, ema=True, ema_decay=0.99, max_num_epochs=1500, patience=2048, eval_interval=2, keep_checkpoints=False, restart_latest=True, save_cpu=False, clip_grad=10.0, wandb=False, wandb_project='', wandb_entity='', wandb_name='', wandb_log_hypers=['num_channels', 'max_L', 'correlation', 'lr', 'swa_lr', 'weight_decay', 'batch_size', 'max_num_epochs', 'start_swa', 'energy_weight', 'forces_weight'])
2023-03-26 08:17:26.479 INFO: Using MPS GPU acceleration
2023-03-26 08:17:26.505 INFO: Using isolated atom energies from training file
2023-03-26 08:17:26.506 INFO: Loaded 45 training configurations from 'Al2O3_train.xyz'
2023-03-26 08:17:26.506 INFO: Using random 5.0% of training set for validation
2023-03-26 08:17:26.515 INFO: Loaded 11 test configurations from 'Al2O3_test.xyz'
2023-03-26 08:17:26.515 INFO: Total number of configurations: train=43, valid=2, tests=[Default: 11]
2023-03-26 08:17:26.516 INFO: AtomicNumberTable: (8, 13)
2023-03-26 08:17:26.516 INFO: Atomic energies: [-422.9243, -105.9163]
2023-03-26 08:17:29.681 INFO: WeightedEnergyForcesLoss(energy_weight=1.000, forces_weight=100.000)
2023-03-26 08:17:29.697 INFO: Average number of neighbors: 58.728389739990234
2023-03-26 08:17:29.697 INFO: Selected the following outputs: {'energy': True, 'forces': True, 'virials': False, 'stress': False, 'dipoles': False}
2023-03-26 08:17:29.697 INFO: Building model
2023-03-26 08:17:29.697 INFO: Hidden irreps: 16x0e + 16x1o
/Users/tks32/research/mace-tmp/venv/lib/python3.9/site-packages/torch/jit/_check.py:172: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn("The TorchScript type system doesn't support "
2023-03-26 08:17:30.419 INFO: Using stochastic weight averaging (after 1200 epochs) with energy weight : 1000.0, forces weight : 100.0 and learning rate : 0.001
Traceback (most recent call last): 
  File "/Users/tks32/research/mace-tmp/Al2O3/../scripts/run_train.py", line 563, in <module>
    main()
  File "/Users/tks32/research/mace-tmp/Al2O3/../scripts/run_train.py", line 429, in main
    model=AveragedModel(model),
  File "/Users/tks32/research/mace-tmp/venv/lib/python3.9/site-packages/torch/optim/swa_utils.py", line 104, in __init__
    self.module = deepcopy(model)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 296, in _reconstruct
    value = deepcopy(value, memo)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 296, in _reconstruct
    value = deepcopy(value, memo)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 296, in _reconstruct
    value = deepcopy(value, memo)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/opt/homebrew/Cellar/[email protected]/3.9.16/Frameworks/Python.framework/Versions/3.9/lib/python3.9/copy.py", line 272, in _reconstruct
    y.__setstate__(state)
  File "/Users/tks32/research/mace-tmp/venv/lib/python3.9/site-packages/e3nn/util/codegen/_mixin.py", line 109, in __setstate__
    smod = torch.jit.load(buffer)
  File "/Users/tks32/research/mace-tmp/venv/lib/python3.9/site-packages/torch/jit/_serialization.py", line 164, in load
    cpp_module = torch._C.import_ir_module_from_buffer(
RuntimeError: supported devices include CPU, CUDA and HPU, however got MPS

@stenczelt
Copy link
Contributor Author

The above tries to make a copy of the model, which then breaks some lower level loading function.

Oddly if you construct a dummy ScaleShiftMace model, move it to the MPS device, and try making this AveragedModel then it works. So there may be something the training script is adding that breaks it?

This runs with no issues:

import numpy as np
import torch
from e3nn import o3
from torch.optim.swa_utils import AveragedModel

from mace.modules import ScaleShiftMACE, RealAgnosticInteractionBlock

mps_device = torch.device("mps")


def main_93():
    # check MPS device
    if not torch.backends.mps.is_available():
        raise
    mps_device = torch.device("mps")

    # construct model
    model = ScaleShiftMACE(
        r_max=3.0,
        num_bessel=10,
        num_polynomial_cutoff=10,
        max_ell=4,
        interaction_cls=RealAgnosticInteractionBlock,
        interaction_cls_first=RealAgnosticInteractionBlock,
        num_interactions=2,
        num_elements=2,
        hidden_irreps=o3.Irreps("16x0e"),
        MLP_irreps=o3.Irreps("16x0e"),
        atomic_energies=np.zeros(2),
        avg_num_neighbors=10.0,
        atomic_numbers=[1, 2],
        correlation=2,
        atomic_inter_scale=1.0,
        atomic_inter_shift=1.0,
        gate=None,
    )

    # move to device
    model.to(mps_device)

    # try AveragedModel
    average_model = AveragedModel(model)


if __name__ == "__main__":
    main_93()

@stenczelt
Copy link
Contributor Author

Digging in the traceback, here is the object it is crying about.

This happens here in the traceback.

File "[...]/e3nn/util/codegen/_mixin.py", line 109, in __setstate__
    smod = torch.jit.load(buffer)
{'_backward_hooks': OrderedDict(),
 '_backward_pre_hooks': OrderedDict(),
 '_buffers': OrderedDict([('weight', tensor([], device='mps:0')),
                          ('output_mask',
                           tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='mps:0'))]),
 '_forward_hooks': OrderedDict(),
 '_forward_hooks_with_kwargs': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_forward_pre_hooks_with_kwargs': OrderedDict(),
 '_in1_dim': 64,
 '_in2_dim': 16,
 '_is_full_backward_hook': None,
 '_load_state_dict_post_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_modules': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_optimize_einsums': True,
 '_parameters': OrderedDict(),
 '_profiling_str': 'TensorProduct(16x0e+16x1o x 1x0e+1x1o+1x2e+1x3o -> '
                   '32x0e+48x1o+48x2e+32x3o | 160 paths | 160 weights)',
 '_specialized_code': True,
 '_state_dict_hooks': OrderedDict(),
 '_state_dict_pre_hooks': OrderedDict(),
 'instructions': [Instruction(i_in1=0, i_in2=0, i_out=0, connection_mode='uvu', has_weight=True, path_weight=1.0, path_shape=(16, 1)),
                  Instruction(i_in1=1, i_in2=1, i_out=1, connection_mode='uvu', has_weight=True, path_weight=1.0, path_shape=(16, 1)),
                  Instruction(i_in1=0, i_in2=1, i_out=2, connection_mode='uvu', has_weight=True, path_weight=1.7320508075688772, path_shape=(16, 1)),
                  Instruction(i_in1=1, i_in2=0, i_out=3, connection_mode='uvu', has_weight=True, path_weight=1.7320508075688772, path_shape=(16, 1)),
                  Instruction(i_in1=1, i_in2=2, i_out=4, connection_mode='uvu', has_weight=True, path_weight=1.7320508075688772, path_shape=(16, 1)),
                  Instruction(i_in1=0, i_in2=2, i_out=5, connection_mode='uvu', has_weight=True, path_weight=2.23606797749979, path_shape=(16, 1)),
                  Instruction(i_in1=1, i_in2=1, i_out=6, connection_mode='uvu', has_weight=True, path_weight=2.23606797749979, path_shape=(16, 1)),
                  Instruction(i_in1=1, i_in2=3, i_out=7, connection_mode='uvu', has_weight=True, path_weight=2.23606797749979, path_shape=(16, 1)),
                  Instruction(i_in1=0, i_in2=3, i_out=8, connection_mode='uvu', has_weight=True, path_weight=2.6457513110645907, path_shape=(16, 1)),
                  Instruction(i_in1=1, i_in2=2, i_out=9, connection_mode='uvu', has_weight=True, path_weight=2.6457513110645907, path_shape=(16, 1))],
 'internal_weights': False,
 'irreps_in1': 16x0e+16x1o,
 'irreps_in2': 1x0e+1x1o+1x2e+1x3o,
 'irreps_out': 16x0e+16x0e+16x1o+16x1o+16x1o+16x2e+16x2e+16x2e+16x3o+16x3o,
 'shared_weights': False,
 'training': True,
 'weight_numel': 160}

@davkovacs
Copy link
Collaborator

davkovacs commented May 16, 2023

It is fixed in 9f403e1 as part of this PR #95

@ilyes319
Copy link
Contributor

closed as merged in PR #95

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants