From 80bec0bd6f314ca591290b89308602122d779fe7 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 28 Nov 2022 21:32:04 +0000 Subject: [PATCH 1/5] [Train] Strip "module." from state dict Signed-off-by: Antoni Baum --- ...ert_existing_pytorch_code_to_ray_air.ipynb | 12 +++++------ .../examples/torch_image_example.ipynb | 10 ++++----- .../examples/torch_incremental_learning.ipynb | 3 --- .../datasets_train/datasets_train.py | 4 +--- doc/source/train/dl_guide.rst | 8 ------- python/ray/train/tests/test_torch_trainer.py | 18 ++++++++++++---- python/ray/train/torch/torch_checkpoint.py | 21 +++++++++++++++++-- 7 files changed, 44 insertions(+), 32 deletions(-) diff --git a/doc/source/ray-air/examples/convert_existing_pytorch_code_to_ray_air.ipynb b/doc/source/ray-air/examples/convert_existing_pytorch_code_to_ray_air.ipynb index effd3302b67b..e458fd3f53c0 100644 --- a/doc/source/ray-air/examples/convert_existing_pytorch_code_to_ray_air.ipynb +++ b/doc/source/ray-air/examples/convert_existing_pytorch_code_to_ray_air.ipynb @@ -791,13 +791,11 @@ " from ray.air import Checkpoint\n", "\n", " checkpoint = Checkpoint.from_dict(\n", - " dict(epoch=t, model=model.module.state_dict())\n", + " dict(epoch=t, model=model.state_dict())\n", " )\n", " session.report(dict(loss=test_loss), checkpoint=checkpoint)\n", "```\n", "\n", - "Note that the `model.module` part is needed because the model gets wrapped in `torch.nn.DistributedDataParallel` by `train.torch.prepare_model`.\n", - "\n", "### Move the data loader to the training function\n", "\n", "You may have noticed a warning: `Warning: The actor TrainTrainable is very large (52 MiB). Check that its definition is not implicitly capturing a large array or other object in scope. Tip: use ray.put() to put large objects in the Ray object store.`.\n", @@ -861,7 +859,7 @@ " train_epoch(train_dataloader, model, loss_fn, optimizer)\n", " test_loss = test_epoch(test_dataloader, model, loss_fn)\n", " checkpoint = Checkpoint.from_dict(\n", - " dict(epoch=t, model=model.module.state_dict())\n", + " dict(epoch=t, model=model.state_dict())\n", " )\n", " session.report(dict(loss=test_loss), checkpoint=checkpoint)\n", "\n", @@ -1302,7 +1300,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.12 ('.venv': venv)", + "display_name": "Python 3.8.10 ('venv': venv)", "language": "python", "name": "python3" }, @@ -1316,11 +1314,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.8.10" }, "vscode": { "interpreter": { - "hash": "a658351b4133f922c5967ed6133cfc05c9f16c53a5161e5843ace3f528fccaf5" + "hash": "3c0d54d489a08ae47a06eae2fd00ff032d6cddb527c382959b7b2575f6a8167f" } } }, diff --git a/doc/source/ray-air/examples/torch_image_example.ipynb b/doc/source/ray-air/examples/torch_image_example.ipynb index 7c8039b539cb..64f1ba04ca59 100644 --- a/doc/source/ray-air/examples/torch_image_example.ipynb +++ b/doc/source/ray-air/examples/torch_image_example.ipynb @@ -200,7 +200,7 @@ "\n", "train_dataset = train_dataset.map_batches(convert_batch_to_numpy)\n", "test_dataset = test_dataset.map_batches(convert_batch_to_numpy)" - ] + ] }, { "cell_type": "code", @@ -328,7 +328,7 @@ " running_loss = 0.0\n", "\n", " metrics = dict(running_loss=running_loss)\n", - " checkpoint = TorchCheckpoint.from_state_dict(model.module.state_dict())\n", + " checkpoint = TorchCheckpoint.from_state_dict(model.state_dict())\n", " session.report(metrics, checkpoint=checkpoint)" ] }, @@ -810,7 +810,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.10.8 ('.venv': venv)", + "display_name": "Python 3.8.10 ('venv': venv)", "language": "python", "name": "python3" }, @@ -824,11 +824,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.8.10" }, "vscode": { "interpreter": { - "hash": "c704e19737f24b51bc631dadcac7a7e356bb35d1c5cd7766248d8a6946059909" + "hash": "3c0d54d489a08ae47a06eae2fd00ff032d6cddb527c382959b7b2575f6a8167f" } } }, diff --git a/doc/source/ray-air/examples/torch_incremental_learning.ipynb b/doc/source/ray-air/examples/torch_incremental_learning.ipynb index 03d5ab4f06f6..502d60954bcf 100644 --- a/doc/source/ray-air/examples/torch_incremental_learning.ipynb +++ b/doc/source/ray-air/examples/torch_incremental_learning.ipynb @@ -444,8 +444,6 @@ "from torch.optim import SGD\n", "from torch.nn import CrossEntropyLoss\n", "\n", - "from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present\n", - "\n", "def train_loop_per_worker(config: dict):\n", " num_epochs = config[\"num_epochs\"]\n", " learning_rate = config[\"learning_rate\"]\n", @@ -493,7 +491,6 @@ "\n", " # Checkpoint model after every epoch.\n", " state_dict = model.state_dict()\n", - " consume_prefix_in_state_dict_if_present(state_dict, \"module.\")\n", " checkpoint = Checkpoint.from_dict(dict(model=state_dict))\n", " session.report({\"loss\": running_loss}, checkpoint=checkpoint)" ] diff --git a/doc/source/ray-core/_examples/datasets_train/datasets_train.py b/doc/source/ray-core/_examples/datasets_train/datasets_train.py index 0f195612b65d..c2512f2383a5 100644 --- a/doc/source/ray-core/_examples/datasets_train/datasets_train.py +++ b/doc/source/ray-core/_examples/datasets_train/datasets_train.py @@ -23,7 +23,6 @@ import torch import torch.nn as nn import torch.optim as optim -from torch.nn.parallel import DistributedDataParallel import ray from ray import train @@ -455,8 +454,7 @@ def train_func(config): ) # Checkpoint model. - module = net.module if isinstance(net, DistributedDataParallel) else net - checkpoint = Checkpoint.from_dict(dict(model=module.state_dict())) + checkpoint = Checkpoint.from_dict(dict(model=model.state_dict())) # Record and log stats. print(f"session report on {session.get_world_rank()}") diff --git a/doc/source/train/dl_guide.rst b/doc/source/train/dl_guide.rst index 00ca8a8df6a1..ef31e82acc1b 100644 --- a/doc/source/train/dl_guide.rst +++ b/doc/source/train/dl_guide.rst @@ -527,7 +527,6 @@ appropriately in distributed training. import torch import torch.nn as nn - from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present from torch.optim import Adam import numpy as np @@ -552,12 +551,7 @@ appropriately in distributed training. optimizer.zero_grad() loss.backward() optimizer.step() - # To fetch non-DDP state_dict - # w/o DDP: model.state_dict() - # w/ DDP: model.module.state_dict() - # See: https://github.com/ray-project/ray/issues/20915 state_dict = model.state_dict() - consume_prefix_in_state_dict_if_present(state_dict, "module.") checkpoint = Checkpoint.from_dict( dict(epoch=epoch, model_weights=state_dict) ) @@ -714,7 +708,6 @@ Checkpoints can be loaded into the training function in 2 steps: import torch import torch.nn as nn - from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present from torch.optim import Adam import numpy as np @@ -751,7 +744,6 @@ Checkpoints can be loaded into the training function in 2 steps: loss.backward() optimizer.step() state_dict = model.state_dict() - consume_prefix_in_state_dict_if_present(state_dict, "module.") checkpoint = Checkpoint.from_dict( dict(epoch=epoch, model_weights=state_dict) ) diff --git a/python/ray/train/tests/test_torch_trainer.py b/python/ray/train/tests/test_torch_trainer.py index 02c08ae80fe8..318bfebc291a 100644 --- a/python/ray/train/tests/test_torch_trainer.py +++ b/python/ray/train/tests/test_torch_trainer.py @@ -62,9 +62,12 @@ def train_func(config): trainer.fit() -def test_torch_e2e(ray_start_4_cpus): +@pytest.mark.parametrize("prepare_model", (True, False)) +def test_torch_e2e(ray_start_4_cpus, prepare_model): def train_func(): model = torch.nn.Linear(3, 1) + if prepare_model: + model = train.torch.prepare_model(model) session.report({}, checkpoint=TorchCheckpoint.from_model(model)) scaling_config = ScalingConfig(num_workers=2) @@ -84,10 +87,15 @@ def train_func(): assert predictions.count() == 3 -def test_torch_e2e_state_dict(ray_start_4_cpus): +@pytest.mark.parametrize("prepare_model", (True, False)) +def test_torch_e2e_state_dict(ray_start_4_cpus, prepare_model): def train_func(): - model = torch.nn.Linear(3, 1).state_dict() - session.report({}, checkpoint=TorchCheckpoint.from_state_dict(model)) + model = torch.nn.Linear(3, 1) + if prepare_model: + model = train.torch.prepare_model(model) + session.report( + {}, checkpoint=TorchCheckpoint.from_state_dict(model.state_dict()) + ) scaling_config = ScalingConfig(num_workers=2) trainer = TorchTrainer( @@ -112,6 +120,8 @@ def train_func(): assert predictions.count() == 3 +# We can't really test for prepare_model here as we can't detect what the user +# has saved without loading (and thus triggering the exception anyway) def test_torch_e2e_dir(ray_start_4_cpus, tmpdir): def train_func(): model = torch.nn.Linear(3, 1) diff --git a/python/ray/train/torch/torch_checkpoint.py b/python/ray/train/torch/torch_checkpoint.py index 9d040d99060c..8963362043fc 100644 --- a/python/ray/train/torch/torch_checkpoint.py +++ b/python/ray/train/torch/torch_checkpoint.py @@ -28,11 +28,28 @@ class TorchCheckpoint(Checkpoint): # Special encoding logic to avoid serialization errors with torch. def _encode_data_dict(self, data_dict: dict) -> dict: """Encode data_dict using torch.save.""" - from torch.nn.parallel import DistributedDataParallel + from torch.nn import Module + from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present for k, v in data_dict.items(): - if isinstance(v, DistributedDataParallel) and hasattr(v, "module"): + # Only check for attribute as we want to support + # DDP, FSDP and any future approaches + if isinstance(v, Module) and hasattr(v, "module"): data_dict[k] = v.module + elif isinstance(v, dict): + # We could limit this only to the MODEL_KEY, but we'd + # miss any extra user-specified keys. This should be a + # noop with anything but DDP/FSDP module state dicts. + + # This modifies in-place the first level of the dict + # and the _metadata nested dict. + # We are doing shallow copies here, so the performance + # impact should be negligible. + state_dict = v.copy() + if "_metadata" in state_dict: + state_dict["_metadata"] = state_dict["_metadata"].copy() + consume_prefix_in_state_dict_if_present(state_dict, "module.") + data_dict[k] = state_dict # Convert the checkpoint dict to bytes, so that any GPU tensors that # are in the checkpoint dict can be properly deserialized on the From 6fdb6333d1ec9d628999a5acdd05e01992dc3618 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 28 Nov 2022 23:18:15 +0000 Subject: [PATCH 2/5] Fixes Signed-off-by: Antoni Baum --- .../datasets_train/datasets_train.py | 2 +- python/ray/air/_internal/torch_utils.py | 106 +++++++++++++++++- python/ray/train/torch/torch_checkpoint.py | 9 +- 3 files changed, 112 insertions(+), 5 deletions(-) diff --git a/doc/source/ray-core/_examples/datasets_train/datasets_train.py b/doc/source/ray-core/_examples/datasets_train/datasets_train.py index c2512f2383a5..111e4d81a8dd 100644 --- a/doc/source/ray-core/_examples/datasets_train/datasets_train.py +++ b/doc/source/ray-core/_examples/datasets_train/datasets_train.py @@ -454,7 +454,7 @@ def train_func(config): ) # Checkpoint model. - checkpoint = Checkpoint.from_dict(dict(model=model.state_dict())) + checkpoint = Checkpoint.from_dict(dict(model=net.state_dict())) # Record and log stats. print(f"session report on {session.get_world_rank()}") diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index 0b3a0ea87f56..dc391820348a 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -1,5 +1,5 @@ import warnings -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Any import numpy as np import pandas as pd @@ -228,3 +228,107 @@ def contains_tensor(obj): if contains_tensor(v): return True return False + + +# From PyTorch: + +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, +# Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, +# Samy Bengio, Johnny Mariethoz) + +# From Caffe2: + +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. + +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. + +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. + +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain + +# All contributions by Cruise LLC: +# Copyright (c) 2022 Cruise LLC. +# All rights reserved. + +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. + +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. + +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. + +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +try: + from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present # noqa +except ImportError: + # Not present in torch<=1.7.0 + # Copied from https://github.com/pytorch/pytorch/blob/\ + # c18da597e0bb1c1aecc97c77a73fed1849057fa4/torch/nn/modules/utils.py + def consume_prefix_in_state_dict_if_present( + state_dict: Dict[str, Any], prefix: str + ) -> None: + keys = sorted(state_dict.keys()) + for key in keys: + if key.startswith(prefix): + newkey = key[len(prefix) :] + state_dict[newkey] = state_dict.pop(key) + + if "_metadata" in state_dict: + metadata = state_dict["_metadata"] + for key in list(metadata.keys()): + if len(key) == 0: + continue + newkey = key[len(prefix) :] + metadata[newkey] = metadata.pop(key) diff --git a/python/ray/train/torch/torch_checkpoint.py b/python/ray/train/torch/torch_checkpoint.py index 8963362043fc..5d28b77efc6e 100644 --- a/python/ray/train/torch/torch_checkpoint.py +++ b/python/ray/train/torch/torch_checkpoint.py @@ -4,11 +4,16 @@ import torch import warnings +from torch.nn import Module + import ray.cloudpickle from ray.air.checkpoint import Checkpoint from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY from ray.train.data_parallel_trainer import _load_checkpoint_dict -from ray.air._internal.torch_utils import load_torch_model +from ray.air._internal.torch_utils import ( + load_torch_model, + consume_prefix_in_state_dict_if_present, +) from ray.util.annotations import PublicAPI if TYPE_CHECKING: @@ -28,8 +33,6 @@ class TorchCheckpoint(Checkpoint): # Special encoding logic to avoid serialization errors with torch. def _encode_data_dict(self, data_dict: dict) -> dict: """Encode data_dict using torch.save.""" - from torch.nn import Module - from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present for k, v in data_dict.items(): # Only check for attribute as we want to support From d2bc55dbb556804421b420f7f8dcbf7c9c1b2598 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 7 Dec 2022 17:43:10 +0000 Subject: [PATCH 3/5] Implement feedback Signed-off-by: Antoni Baum --- python/ray/air/_internal/torch_utils.py | 68 +++++++++++++++------- python/ray/train/torch/torch_checkpoint.py | 15 ++--- 2 files changed, 50 insertions(+), 33 deletions(-) diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index dc391820348a..0453002c5280 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -310,25 +310,49 @@ def contains_tensor(obj): # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -try: - from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present # noqa -except ImportError: - # Not present in torch<=1.7.0 - # Copied from https://github.com/pytorch/pytorch/blob/\ - # c18da597e0bb1c1aecc97c77a73fed1849057fa4/torch/nn/modules/utils.py - def consume_prefix_in_state_dict_if_present( - state_dict: Dict[str, Any], prefix: str - ) -> None: - keys = sorted(state_dict.keys()) - for key in keys: - if key.startswith(prefix): - newkey = key[len(prefix) :] - state_dict[newkey] = state_dict.pop(key) - - if "_metadata" in state_dict: - metadata = state_dict["_metadata"] - for key in list(metadata.keys()): - if len(key) == 0: - continue - newkey = key[len(prefix) :] - metadata[newkey] = metadata.pop(key) +# Not present in torch<=1.7.0 +# Adapted from https://github.com/pytorch/pytorch/blob/\ +# c18da597e0bb1c1aecc97c77a73fed1849057fa4/torch/nn/modules/utils.py +def consume_prefix_in_state_dict_if_present_not_in_place( + state_dict: Dict[str, Any], prefix: str +) -> Dict[str, Any]: + """Strip the prefix in state_dict, if any and return a new dict. + + Adapted from https://github.com/pytorch/pytorch/blob/\ +c18da597e0bb1c1aecc97c77a73fed1849057fa4/torch/nn/modules/utils.py + The original method modified the dict in-place. + + ..note:: + Given a `state_dict` from a DP/DDP model, a local model can load it by applying + `consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling + :meth:`torch.nn.Module.load_state_dict`. + + Args: + state_dict: a state-dict to be loaded to the model. + prefix: prefix. + + """ + copied = False + + keys = sorted(state_dict.keys()) + for key in keys: + if key.startswith(prefix): + newkey = key[len(prefix) :] + if not copied: + # We are doing shallow copies here, so the performance + # impact should be negligible anyway, but this is + # a simple optimization. + state_dict = state_dict.copy() + copied = True + state_dict[newkey] = state_dict.pop(key) + + if "_metadata" in state_dict: + state_dict["_metadata"] = state_dict["_metadata"].copy() + metadata = state_dict["_metadata"] + for key in list(metadata.keys()): + if len(key) == 0: + continue + newkey = key[len(prefix) :] + metadata[newkey] = metadata.pop(key) + + return state_dict diff --git a/python/ray/train/torch/torch_checkpoint.py b/python/ray/train/torch/torch_checkpoint.py index 5d28b77efc6e..3e43d142ebbb 100644 --- a/python/ray/train/torch/torch_checkpoint.py +++ b/python/ray/train/torch/torch_checkpoint.py @@ -12,7 +12,7 @@ from ray.train.data_parallel_trainer import _load_checkpoint_dict from ray.air._internal.torch_utils import ( load_torch_model, - consume_prefix_in_state_dict_if_present, + consume_prefix_in_state_dict_if_present_not_in_place, ) from ray.util.annotations import PublicAPI @@ -43,16 +43,9 @@ def _encode_data_dict(self, data_dict: dict) -> dict: # We could limit this only to the MODEL_KEY, but we'd # miss any extra user-specified keys. This should be a # noop with anything but DDP/FSDP module state dicts. - - # This modifies in-place the first level of the dict - # and the _metadata nested dict. - # We are doing shallow copies here, so the performance - # impact should be negligible. - state_dict = v.copy() - if "_metadata" in state_dict: - state_dict["_metadata"] = state_dict["_metadata"].copy() - consume_prefix_in_state_dict_if_present(state_dict, "module.") - data_dict[k] = state_dict + data_dict[k] = consume_prefix_in_state_dict_if_present_not_in_place( + v, "module." + ) # Convert the checkpoint dict to bytes, so that any GPU tensors that # are in the checkpoint dict can be properly deserialized on the From b931c2cf30e039bcce2d733a4a1c3e876aea7cce Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 7 Dec 2022 22:49:01 +0000 Subject: [PATCH 4/5] Trim Signed-off-by: Antoni Baum --- python/ray/air/_internal/torch_utils.py | 85 ------------------------- 1 file changed, 85 deletions(-) diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index ee148f464f29..f07d2dab9d82 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -229,86 +229,6 @@ def contains_tensor(obj): return False -# From PyTorch: - -# Copyright (c) 2016- Facebook, Inc (Adam Paszke) -# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) -# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) -# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) -# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) -# Copyright (c) 2011-2013 NYU (Clement Farabet) -# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, -# Leon Bottou, Iain Melvin, Jason Weston) -# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) -# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, -# Samy Bengio, Johnny Mariethoz) - -# From Caffe2: - -# Copyright (c) 2016-present, Facebook Inc. All rights reserved. - -# All contributions by Facebook: -# Copyright (c) 2016 Facebook Inc. - -# All contributions by Google: -# Copyright (c) 2015 Google Inc. -# All rights reserved. - -# All contributions by Yangqing Jia: -# Copyright (c) 2015 Yangqing Jia -# All rights reserved. - -# All contributions by Kakao Brain: -# Copyright 2019-2020 Kakao Brain - -# All contributions by Cruise LLC: -# Copyright (c) 2022 Cruise LLC. -# All rights reserved. - -# All contributions from Caffe: -# Copyright(c) 2013, 2014, 2015, the respective contributors -# All rights reserved. - -# All other contributions: -# Copyright(c) 2015, 2016 the respective contributors -# All rights reserved. - -# Caffe2 uses a copyright model similar to Caffe: each contributor holds -# copyright over their contributions to Caffe2. The project versioning records -# all such contribution and copyright details. If a contributor wants to further -# mark their specific copyright on a particular contribution, they should -# indicate their copyright solely in the commit message of the change when it is -# committed. - -# All rights reserved. - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: - -# 1. Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. - -# 2. Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. - -# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America -# and IDIAP Research Institute nor the names of its contributors may be -# used to endorse or promote products derived from this software without -# specific prior written permission. - -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -# POSSIBILITY OF SUCH DAMAGE. - # Not present in torch<=1.7.0 # Adapted from https://github.com/pytorch/pytorch/blob/\ # c18da597e0bb1c1aecc97c77a73fed1849057fa4/torch/nn/modules/utils.py @@ -321,11 +241,6 @@ def consume_prefix_in_state_dict_if_present_not_in_place( c18da597e0bb1c1aecc97c77a73fed1849057fa4/torch/nn/modules/utils.py The original method modified the dict in-place. - ..note:: - Given a `state_dict` from a DP/DDP model, a local model can load it by applying - `consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling - :meth:`torch.nn.Module.load_state_dict`. - Args: state_dict: a state-dict to be loaded to the model. prefix: prefix. From 91496776448fb94cccb601566e855ea6b8dbc070 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 7 Dec 2022 23:00:47 +0000 Subject: [PATCH 5/5] Tweak Signed-off-by: Antoni Baum --- python/ray/air/_internal/torch_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index f07d2dab9d82..15be5b1d5d34 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -248,8 +248,7 @@ def consume_prefix_in_state_dict_if_present_not_in_place( """ copied = False - keys = sorted(state_dict.keys()) - for key in keys: + for key in state_dict: if key.startswith(prefix): newkey = key[len(prefix) :] if not copied: @@ -263,7 +262,7 @@ def consume_prefix_in_state_dict_if_present_not_in_place( if "_metadata" in state_dict: state_dict["_metadata"] = state_dict["_metadata"].copy() metadata = state_dict["_metadata"] - for key in list(metadata.keys()): + for key in metadata: if len(key) == 0: continue newkey = key[len(prefix) :]