Skip to content

Commit

Permalink
fix: str and int in arg parsing (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored Aug 6, 2022
1 parent 1831e37 commit 8219965
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
2 changes: 1 addition & 1 deletion discoart/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def load_config(
int_keys.add('seed')

for k, v in cfg.items():
if k in int_keys and v is not None and not isinstance(v, int):
if k in int_keys and v is not None and not isinstance(v, (int, str)):
cfg[k] = int(v)
if k == 'width_height':
cfg[k] = [int(vv) for vv in v]
Expand Down
39 changes: 39 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,42 @@ def test_eval_schedule_string():
)
def test_chec_schedule_str(val, expected):
assert _is_valid_schedule_str(val) == expected


@pytest.mark.parametrize(
'field',
[
'cut_overview',
'cut_innercut',
'cut_icgray_p',
'cut_ic_pow',
'use_secondary_model',
'cutn_batches',
'clip_guidance_scale',
'tv_scale',
'range_scale',
'sat_scale',
'init_scale',
'clamp_grad',
'clamp_max',
],
)
@pytest.mark.parametrize(
'val',
[
True,
False,
1,
0.5,
'True',
'False',
'1',
'0.5',
'[100]*600+[200]*400',
'[True, False]*1000',
],
)
def test_eval_config(field, val):
cfg = load_config(default_args)
cfg[field] = val
assert load_config(cfg)[field] is not None

0 comments on commit 8219965

Please sign in to comment.