Skip to content

Commit

Permalink
Add utility to inspect a model's parameters (to get dtype/device)
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Dec 31, 2023
1 parent a84e842 commit 5768afc
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 7 deletions.
3 changes: 2 additions & 1 deletion modules/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
from modules import errors, shared
from modules.torch_utils import get_param

if sys.platform == "darwin":
from modules import mac_specific
Expand Down Expand Up @@ -131,7 +132,7 @@ def cond_cast_float(input):


def manual_cast_forward(self, *args, **kwargs):
org_dtype = next(self.parameters()).dtype
org_dtype = get_param(self).dtype
self.to(dtype)
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
Expand Down
3 changes: 2 additions & 1 deletion modules/interrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchvision.transforms.functional import InterpolationMode

from modules import devices, paths, shared, lowvram, modelloader, errors
from modules.torch_utils import get_param

blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
Expand Down Expand Up @@ -131,7 +132,7 @@ def load(self):

self.clip_model = self.clip_model.to(devices.device_interrogate)

self.dtype = next(self.clip_model.parameters()).dtype
self.dtype = get_param(self.clip_model).dtype

def send_clip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory:
Expand Down
3 changes: 2 additions & 1 deletion modules/sd_models_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer
from modules import devices, shared, prompt_parser
from modules.torch_utils import get_param


def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
Expand Down Expand Up @@ -90,7 +91,7 @@ def get_target_prompt_token_count(self, token_count):
def extend_sdxl(model):
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""

dtype = next(model.model.diffusion_model.parameters()).dtype
dtype = get_param(model.model.diffusion_model).dtype
model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn'
model.cond_stage_key = 'txt'
Expand Down
17 changes: 17 additions & 0 deletions modules/torch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

import torch.nn


def get_param(model) -> torch.nn.Parameter:
"""
Find the first parameter in a model or module.
"""
if hasattr(model, "model") and hasattr(model.model, "parameters"):
# Unpeel a model descriptor to get at the actual Torch module.
model = model.model

for param in model.parameters():
return param

raise ValueError(f"No parameters found in model {model!r}")
5 changes: 3 additions & 2 deletions modules/upscaler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from PIL import Image

from modules import images, shared
from modules.torch_utils import get_param

logger = logging.getLogger(__name__)

Expand All @@ -17,8 +18,8 @@ def upscale_without_tiling(model, img: Image.Image):
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()

model_weight = next(iter(model.model.parameters()))
img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype)
param = get_param(model)
img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)

with torch.no_grad():
output = model(img)
Expand Down
5 changes: 4 additions & 1 deletion modules/xlmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional

from modules.torch_utils import get_param


class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):

Expand Down Expand Up @@ -62,7 +65,7 @@ def __init__(self, config=None, **kargs):
self.post_init()

def encode(self,c):
device = next(self.parameters()).device
device = get_param(self).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
Expand Down
5 changes: 4 additions & 1 deletion modules/xlmr_m18.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional

from modules.torch_utils import get_param


class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):

Expand Down Expand Up @@ -68,7 +71,7 @@ def __init__(self, config=None, **kargs):
self.post_init()

def encode(self,c):
device = next(self.parameters()).device
device = get_param(self).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
Expand Down
19 changes: 19 additions & 0 deletions test/test_torch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import types

import pytest
import torch

from modules.torch_utils import get_param


@pytest.mark.parametrize("wrapped", [True, False])
def test_get_param(wrapped):
mod = torch.nn.Linear(1, 1)
cpu = torch.device("cpu")
mod.to(dtype=torch.float16, device=cpu)
if wrapped:
# more or less how spandrel wraps a thing
mod = types.SimpleNamespace(model=mod)
p = get_param(mod)
assert p.dtype == torch.float16
assert p.device == cpu

0 comments on commit 5768afc

Please sign in to comment.