Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

__deepcopy__ for DualGraphModule #8708

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion torchvision/models/feature_extraction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import inspect
import math
import re
Expand All @@ -10,7 +11,7 @@
import torch
import torchvision
from torch import fx, nn
from torch.fx.graph_module import _copy_attr
from torch.fx.graph_module import _copy_attr, _USER_PRESERVED_ATTRIBUTES_KEY, _CodeOnlyModule


__all__ = ["create_feature_extractor", "get_graph_node_names"]
Expand Down Expand Up @@ -330,6 +331,33 @@ def train(self, mode=True):
self.graph = self.eval_graph
return super().train(mode=mode)

def _deepcopy_init(self):
return DualGraphModule.__init__

def __deepcopy__(self, memo):
res = type(self).__new__(type(self))
memo[id(self)] = res
fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["train_graph"], fake_mod.__dict__["eval_graph"])

# (borrowed from fx.GraphModule):
extra_preserved_attrs = [
"_state_dict_hooks",
"_load_state_dict_pre_hooks",
"_load_state_dict_post_hooks",
"_replace_hook",
"_create_node_hooks",
"_erase_node_hooks"
]
for attr in extra_preserved_attrs:
if attr in self.__dict__:
setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo))
res.meta = copy.deepcopy(getattr(self, "meta", {}), memo)
if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta:
for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():
setattr(res, attr_name, attr)
return res


def create_feature_extractor(
model: nn.Module,
Expand Down