Skip to content

Commit

Permalink
print model components for safetensors load
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <[email protected]>
  • Loading branch information
vladmandic committed Nov 2, 2024
1 parent bada60e commit da9550d
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 4 deletions.
File renamed without changes.
83 changes: 83 additions & 0 deletions cli/model-keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/usr/bin/env python
import os
import sys
from rich import print as pprint


def has(obj, attr, *args):
import functools
if not isinstance(obj, dict):
return False
def _getattr(obj, attr):
return obj.get(attr, args) if isinstance(obj, dict) else False
return functools.reduce(_getattr, [obj] + attr.split('.'))


def remove_entries_after_depth(d, depth, current_depth=0):
if current_depth >= depth:
return None
if isinstance(d, dict):
return {k: remove_entries_after_depth(v, depth, current_depth + 1) for k, v in d.items() if remove_entries_after_depth(v, depth, current_depth + 1) is not None}
return d


def list_to_dict(flat_list):
result_dict = {}
for item in flat_list:
keys = item.split('.')
d = result_dict
for key in keys[:-1]:
d = d.setdefault(key, {})
d[keys[-1]] = None
return result_dict


def guess_dct(dct: dict):
# if has(dct, 'model.diffusion_model.input_blocks') and has(dct, 'model.diffusion_model.label_emb'):
# return 'sdxl'
if has(dct, 'model.diffusion_model.input_blocks') and len(list(has(dct, 'model.diffusion_model.input_blocks'))) == 12:
return 'sd15'
if has(dct, 'model.diffusion_model.input_blocks') and len(list(has(dct, 'model.diffusion_model.input_blocks'))) == 9:
return 'sdxl'
if has(dct, 'model.diffusion_model.joint_blocks') and len(list(has(dct, 'model.diffusion_model.joint_blocks'))) == 24:
return 'sd35-medium'
if has(dct, 'model.diffusion_model.joint_blocks') and len(list(has(dct, 'model.diffusion_model.joint_blocks'))) == 38:
return 'sd35-large'
if has(dct, 'model.diffusion_model.double_blocks') and len(list(has(dct, 'model.diffusion_model.double_blocks'))) == 19:
return 'flux-dev'
return None


def read_keys(fn):
if not fn.lower().endswith(".safetensors"):
return
from safetensors.torch import safe_open
keys = []
try:
with safe_open(fn, framework="pt", device="cpu") as f:
keys = f.keys()
except Exception as e:
pprint(e)
dct = list_to_dict(keys)
pprint(f'file: {fn}')
pprint(remove_entries_after_depth(dct, 3))
pprint(remove_entries_after_depth(dct, 6))
guess = guess_dct(dct)
pprint(f'guess: {guess}')
return keys


def main():
if len(sys.argv) == 0:
print('metadata:', 'no files specified')
for fn in sys.argv:
if os.path.isfile(fn):
read_keys(fn)
elif os.path.isdir(fn):
for root, _dirs, files in os.walk(fn):
for file in files:
read_keys(os.path.join(root, file))

if __name__ == '__main__':
sys.argv.pop(0)
main()
23 changes: 21 additions & 2 deletions modules/model_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,32 @@
from modules import shared, devices, model_quant


def remove_entries_after_depth(d, depth, current_depth=0):
if current_depth >= depth:
return None
if isinstance(d, dict):
return {k: remove_entries_after_depth(v, depth, current_depth + 1) for k, v in d.items() if remove_entries_after_depth(v, depth, current_depth + 1) is not None}
return d


def list_to_dict(flat_list):
result_dict = {}
for item in flat_list:
keys = item.split('.')
d = result_dict
for key in keys[:-1]:
d = d.setdefault(key, {})
d[keys[-1]] = None
return result_dict


def get_safetensor_keys(filename):
keys = []
try:
with safetensors.torch.safe_open(filename, framework="pt", device="cpu") as f:
keys = f.keys()
except Exception as e:
shared.log.error(f'Load dict: path="{filename}" {e}')
except Exception:
pass
return keys


Expand Down
10 changes: 9 additions & 1 deletion modules/sd_detect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import time
import torch
import diffusers
from modules import shared, shared_items, devices, errors
from modules import shared, shared_items, devices, errors, model_tools


debug_load = os.environ.get('SD_LOAD_DEBUG', None)
Expand Down Expand Up @@ -103,6 +104,13 @@ def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False):
pipeline = shared_items.get_pipelines().get(guess, None) if pipeline is None else pipeline
if not quiet:
shared.log.info(f'Autodetect {op}: detect="{guess}" class={getattr(pipeline, "__name__", None)} file="{f}" size={size}MB')
t0 = time.time()
keys = model_tools.get_safetensor_keys(f)
if keys is not None:
modules = model_tools.list_to_dict(keys)
modules = model_tools.remove_entries_after_depth(modules, 3)
t1 = time.time()
shared.log.debug(f'Autodetect {op}: modules={modules} time={t1-t0:.2f}')
except Exception as e:
shared.log.error(f'Autodetect {op}: file="{f}" {e}')
if debug_load:
Expand Down
2 changes: 1 addition & 1 deletion wiki
Submodule wiki updated from b36c2e to 2b8868

0 comments on commit da9550d

Please sign in to comment.