Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Remove attrdict from model hub #8554

Merged
merged 42 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
62f0cac
chore: remove attrdict from model_hub
MikhailKardash Dec 6, 2023
405d255
use setattr
MikhailKardash Dec 6, 2023
93a1dbc
namespace is not a dictionary
MikhailKardash Dec 6, 2023
59e51f6
namespace is not a dictionary part 2
MikhailKardash Dec 6, 2023
7471330
mmdet may not need attrdict at all
MikhailKardash Dec 6, 2023
8314984
this should be a dict
MikhailKardash Dec 6, 2023
f6a973f
data config is also a dict
MikhailKardash Dec 6, 2023
96c3661
linting
MikhailKardash Dec 6, 2023
cbf66ee
more linting
MikhailKardash Dec 6, 2023
1bc8de5
try this
MikhailKardash Dec 7, 2023
6da78eb
add renamed argument
MikhailKardash Dec 7, 2023
53f32a7
bad arg
MikhailKardash Dec 7, 2023
731795f
linting
MikhailKardash Dec 7, 2023
3c9d401
this can stay as a dict
MikhailKardash Dec 7, 2023
ca2fb3a
this can stay as a dict
MikhailKardash Dec 7, 2023
a8e5759
it thinks an object is a dict instead of namespace
MikhailKardash Dec 7, 2023
7f5cb57
hparams, data_config to namespace
MikhailKardash Dec 8, 2023
ca2d3fe
cast namespace to dict for a mmdet function
MikhailKardash Dec 8, 2023
75a5312
traceback for hparams
MikhailKardash Dec 8, 2023
05ba467
add from_namespace function
MikhailKardash Dec 11, 2023
8c369f8
oops
MikhailKardash Dec 11, 2023
50afdc3
add our own, better AttrDict
MikhailKardash Dec 11, 2023
0aaee7b
linting
MikhailKardash Dec 11, 2023
98d7f7d
linting and recursion
MikhailKardash Dec 11, 2023
52b5c77
bad algorithm
MikhailKardash Dec 11, 2023
28e98b1
wrong implementation of attrdict
MikhailKardash Dec 11, 2023
2c54320
ignore attr-defined mypy
MikhailKardash Dec 11, 2023
b529aeb
move ignore attr-defined to .ini
MikhailKardash Dec 12, 2023
6c47e2d
typo
MikhailKardash Dec 12, 2023
e7ec298
oops
MikhailKardash Dec 12, 2023
dd99c5d
remove attrdict from setup.py
MikhailKardash Dec 12, 2023
0381cfd
better AttrDict impl
MikhailKardash Dec 12, 2023
476b914
get around mypy
MikhailKardash Dec 12, 2023
84181a3
remove superfluous attribute setting
MikhailKardash Dec 12, 2023
624433e
linting
MikhailKardash Dec 12, 2023
d4949e6
proposed changes
MikhailKardash Dec 12, 2023
baf4c6e
review part 2
MikhailKardash Dec 12, 2023
29e05e7
linting again
MikhailKardash Dec 13, 2023
c61fcc3
a few touch ups
MikhailKardash Dec 14, 2023
ef1a14b
a few more touchups
MikhailKardash Dec 14, 2023
78c9eac
typing
MikhailKardash Dec 14, 2023
b6ba8fe
comma
MikhailKardash Dec 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion model_hub/docker/Dockerfile.transformers
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ FROM ${BASE_IMAGE}

ARG TRANSFORMERS_VERSION
ARG DATASETS_VERSION
RUN pip install transformers==${TRANSFORMERS_VERSION} datasets==${DATASETS_VERSION} attrdict
RUN pip install transformers==${TRANSFORMERS_VERSION} datasets==${DATASETS_VERSION}
RUN pip install sentencepiece!=0.1.92 protobuf scikit-learn conllu seqeval


Expand Down
22 changes: 11 additions & 11 deletions model_hub/model_hub/huggingface/_config_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
from typing import Any, Dict, Optional, Tuple, Union

import attrdict
from model_hub import utils


class FlexibleDataclass:
Expand Down Expand Up @@ -272,7 +272,7 @@ class LRSchedulerKwargs:

def parse_dict_to_dataclasses(
dataclass_types: Tuple[Any, ...],
args: Union[Dict[str, Any], attrdict.AttrDict],
args: Union[Dict[str, Any], utils.AttrDict],
as_dict: bool = False,
) -> Tuple[Any, ...]:
"""
Expand All @@ -295,29 +295,29 @@ def parse_dict_to_dataclasses(
obj = dtype(**inputs)
if as_dict:
try:
obj = attrdict.AttrDict(obj.as_dict())
obj = utils.AttrDict(obj.as_dict())
except AttributeError:
obj = attrdict.AttrDict(dataclasses.asdict(obj))
obj = utils.AttrDict(dataclasses.asdict(obj))
outputs.append(obj)
return (*outputs,)


def default_parse_config_tokenizer_model_kwargs(
hparams: Union[Dict, attrdict.AttrDict]
) -> Tuple[Dict, Dict, Dict]:
hparams: Union[Dict, utils.AttrDict],
) -> Tuple[utils.AttrDict, utils.AttrDict, utils.AttrDict]:
"""
This function will provided hparams into fields for the transformers config, tokenizer,
This function converts hparams into fields for the transformers config, tokenizer,
and model. See the defined dataclasses ConfigKwargs, TokenizerKwargs, and ModelKwargs for
expected fields and defaults.

Args:
hparams: hyperparameters to parse.

Returns:
One dictionary each for the config, tokenizer, and model.
One AttrDict each for the config, tokenizer, and model.
"""
if not isinstance(hparams, attrdict.AttrDict):
hparams = attrdict.AttrDict(hparams)
if not isinstance(hparams, utils.AttrDict):
hparams = utils.AttrDict(hparams)
config_args, tokenizer_args, model_args = parse_dict_to_dataclasses(
(ConfigKwargs, TokenizerKwargs, ModelKwargs), hparams, as_dict=True
)
Expand All @@ -340,7 +340,7 @@ def default_parse_config_tokenizer_model_kwargs(


def default_parse_optimizer_lr_scheduler_kwargs(
hparams: Union[Dict, attrdict.AttrDict]
hparams: Union[Dict, utils.AttrDict]
) -> Tuple[OptimizerKwargs, LRSchedulerKwargs]:
"""
Parse hparams relevant for the optimizer and lr_scheduler and fills in with
Expand Down
29 changes: 15 additions & 14 deletions model_hub/model_hub/huggingface/_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import logging
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import attrdict
import datasets as hf_datasets
import torch
import transformers
import transformers.optimization as hf_opt

import determined.pytorch as det_torch
import model_hub.utils
from model_hub import utils
from model_hub.huggingface import _config_parser as hf_parse

MODEL_MODES = {
Expand All @@ -27,10 +26,10 @@


def build_using_auto(
config_kwargs: Union[Dict, attrdict.AttrDict],
tokenizer_kwargs: Union[Dict, attrdict.AttrDict],
config_kwargs: Union[Dict, utils.AttrDict],
tokenizer_kwargs: Union[Dict, utils.AttrDict],
model_mode: str,
model_kwargs: Union[Dict, attrdict.AttrDict],
model_kwargs: Union[Dict, utils.AttrDict],
use_pretrained_weights: bool = True,
) -> Tuple[
transformers.PretrainedConfig, # This is how it's named in transformers
Expand Down Expand Up @@ -146,7 +145,7 @@ def build_default_lr_scheduler(


def default_load_dataset(
data_config: Union[Dict, attrdict.AttrDict]
data_config_input: Union[Dict, utils.AttrDict],
) -> Union[
hf_datasets.Dataset,
hf_datasets.IterableDataset,
Expand All @@ -155,15 +154,17 @@ def default_load_dataset(
]:
"""
Creates the dataset using HuggingFace datasets' load_dataset method.
If a dataset_name is provided, we will use that long with the dataset_config_name.
If a dataset_name is provided, we will use that along with the dataset_config_name.
Otherwise, we will create the dataset using provided train_file and validation_file.

Args:
data_config: arguments for load_dataset. See DatasetKwargs for expected fields.
Returns:
Dataset returned from hf_datasets.load_dataset.
"""
(data_config,) = hf_parse.parse_dict_to_dataclasses((hf_parse.DatasetKwargs,), data_config)
(data_config,) = hf_parse.parse_dict_to_dataclasses(
(hf_parse.DatasetKwargs,), data_config_input
)
# This method is common in nearly all main HF examples.
if data_config.dataset_name is not None:
# Downloading and loading a dataset from the hub.
Expand Down Expand Up @@ -215,11 +216,11 @@ def __init__(self, context: det_torch.PyTorchTrialContext) -> None:
# A subclass of BaseTransformerTrial may have already set hparams and data_config
# attributes so we only reset them if they do not exist.
if not hasattr(self, "hparams"):
self.hparams = attrdict.AttrDict(context.get_hparams())
self.hparams = utils.AttrDict(context.get_hparams())
if not hasattr(self, "data_config"):
self.data_config = attrdict.AttrDict(context.get_data_config())
self.data_config = utils.AttrDict(context.get_data_config())
if not hasattr(self, "exp_config"):
self.exp_config = attrdict.AttrDict(context.get_experiment_config())
self.exp_config = utils.AttrDict(context.get_experiment_config())
# Check to make sure all expected hyperparameters are set.
self.check_hparams()

Expand Down Expand Up @@ -266,13 +267,13 @@ def __init__(self, context: det_torch.PyTorchTrialContext) -> None:

def check_hparams(self) -> None:
# We require hparams to be an AttrDict.
if not isinstance(self.hparams, attrdict.AttrDict):
self.hparams = attrdict.AttrDict(self.hparams)
if not isinstance(self.hparams, utils.AttrDict):
self.hparams = utils.AttrDict(self.hparams)

if "num_training_steps" not in self.hparams:
# Compute the total number of training iterations used to configure the
# learning rate scheduler.
self.hparams.num_training_steps = model_hub.utils.compute_num_training_steps(
self.hparams.num_training_steps = utils.compute_num_training_steps(
self.context.get_experiment_config(), self.context.get_global_batch_size()
)
if "use_pretrained_weights" not in self.hparams:
MikhailKardash marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
14 changes: 7 additions & 7 deletions model_hub/model_hub/mmdetection/_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import os
from typing import Any, Dict, List

import attrdict
import mmcv
import mmcv.parallel
import mmcv.runner
Expand All @@ -36,10 +35,11 @@

import determined.pytorch as det_torch
from determined.common import set_logger
from model_hub import utils
from model_hub.mmdetection import _callbacks as callbacks
from model_hub.mmdetection import _data as data
from model_hub.mmdetection import _data_backends as data_backends
from model_hub.mmdetection import utils as utils
from model_hub.mmdetection import utils as mmdetutils


class MMDetTrial(det_torch.PyTorchTrial):
Expand All @@ -55,8 +55,8 @@ class MMDetTrial(det_torch.PyTorchTrial):

def __init__(self, context: det_torch.PyTorchTrialContext) -> None:
self.context = context
self.hparams = attrdict.AttrDict(context.get_hparams())
self.data_config = attrdict.AttrDict(context.get_data_config())
self.hparams = utils.AttrDict(context.get_hparams())
self.data_config = utils.AttrDict(context.get_data_config())
self.cfg = self.build_mmdet_config()
# We will control how data is moved to GPU.
self.context.experimental.disable_auto_to_device()
Expand All @@ -69,7 +69,7 @@ def __init__(self, context: det_torch.PyTorchTrialContext) -> None:

# If use_pretrained, try loading pretrained weights for the mmcv config if available.
if self.hparams.use_pretrained:
ckpt_path, ckpt = utils.get_pretrained_ckpt_path("/tmp", self.hparams.config_file)
ckpt_path, ckpt = mmdetutils.get_pretrained_ckpt_path("/tmp", self.hparams.config_file)
if ckpt_path is not None:
logging.info("Loading from pretrained weights.")
if "state_dict" in ckpt:
Expand Down Expand Up @@ -117,7 +117,7 @@ def build_mmdet_config(self) -> mmcv.Config:
if config_dir is not None:
config_file = os.path.join(config_dir, config_file)
if config_dir is None or not os.path.exists(config_file):
raise OSError(f"Config file {self.hparams.config_file} not found.")
raise OSError(f"Config file {config_file} not found.")
cfg = mmcv.Config.fromfile(config_file)
cfg.data.val.test_mode = True

Expand Down Expand Up @@ -151,7 +151,7 @@ def setup_torch_amp(self, fp16_cfg: mmcv.Config) -> None:
to see how to configure fp16 training.
"""
mmcv.runner.wrap_fp16_model(self.model)
loss_scaler = utils.build_fp16_loss_scaler(fp16_cfg.loss_scale)
loss_scaler = mmdetutils.build_fp16_loss_scaler(fp16_cfg.loss_scale)
self.loss_scaler = self.context.wrap_scaler(loss_scaler)
self.context.experimental._auto_amp = True

Expand Down
20 changes: 19 additions & 1 deletion model_hub/model_hub/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
import os
import typing
import urllib.parse
from typing import Dict, List, Union
from typing import Any, Dict, List, Union

import filelock
import numpy as np
Expand Down Expand Up @@ -86,3 +87,20 @@ def compute_num_training_steps(experiment_config: Dict, global_batch_size: int)
)
# Otherwise, max_length_unit=='records'
return int(max_length / global_batch_size)


class AttrDict(dict):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.__dict__ = self
for key in self.keys():
if isinstance(self[key], dict):
self[key] = AttrDict(self[key])

if typing.TYPE_CHECKING:

def __getattr__(self, item: Any) -> Any:
return True

def __setattr__(self, item: Any, value: Any) -> None:
return None
1 change: 0 additions & 1 deletion model_hub/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# libraries that are guaranteed to work with our code. Other versions
# may work with model-hub as well but are not officially supported.
install_requires=[
"attrdict",
"determined>=0.13.11", # We require custom reducers for PyTorchTrial.
],
)
7 changes: 3 additions & 4 deletions model_hub/tests/test_hf.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import attrdict

import model_hub.huggingface as hf
from model_hub import utils


def test_config_parser() -> None:
args = {"pretrained_model_name_or_path": "xnli", "num_labels": 4}
config = hf.parse_dict_to_dataclasses((hf.ConfigKwargs,), args, as_dict=True)[0]
target = attrdict.AttrDict(
target = utils.AttrDict(
{
"pretrained_model_name_or_path": "xnli",
"revision": "main",
Expand All @@ -23,7 +22,7 @@ def test_nodefault_config_parser() -> None:
"pretrained_model_name_or_path": "xnli",
}
config = hf.parse_dict_to_dataclasses((hf.ConfigKwargs,), args, as_dict=True)[0]
target = attrdict.AttrDict(
target = utils.AttrDict(
{
"pretrained_model_name_or_path": "xnli",
"revision": "main",
Expand Down
5 changes: 2 additions & 3 deletions model_hub/tests/test_mmdetection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import shutil
from typing import Generator

import attrdict
import git
import pytest
import torch
Expand Down Expand Up @@ -128,7 +127,7 @@ def test_merge_config(
) -> None:
hparams = context.get_hparams()
hparams["merge_config"] = "./tests/fixtures/merge_config.py"
trial.hparams = attrdict.AttrDict(hparams)
trial.hparams = mh_utils.AttrDict(hparams)
new_cfg = trial.build_mmdet_config()
assert new_cfg.optimizer.type == "AdamW"
assert new_cfg.optimizer_config.grad_clip.max_norm == 0.1
Expand All @@ -142,7 +141,7 @@ def test_override_mmdet_config(
"optimizer_config.grad_clip.max_norm": 35,
"optimizer_config.grad_clip.norm_type": 2,
}
trial.hparams = attrdict.AttrDict(hparams)
trial.hparams = mh_utils.AttrDict(hparams)
new_cfg = trial.build_mmdet_config()
assert new_cfg.optimizer_config.grad_clip.max_norm == 35
assert new_cfg.optimizer_config.grad_clip.norm_type == 2