Skip to content

Commit

Permalink
Merge pull request #4 from Mapleshade20/archive
Browse files Browse the repository at this point in the history
* fix for the broken run_git calls

* Use fixed size for sub-quadratic chunking on MPS

Even if this causes chunks to be much smaller, performance isn't significantly impacted. This will usually reduce memory usage but should also help with poor performance when free memory is low.

* Make sub-quadratic the default for MPS

* `torch.empty` can create issues; use `torch.zeros`

For MPS, using a tensor created with `torch.empty()` can cause `torch.baddbmm()` to include NaNs in the tensor it returns, even though `beta=0`. However, with a tensor of shape [1,1,1], there should be a negligible performance difference between `torch.empty()` and `torch.zeros()` anyway, so it's better to just use `torch.zeros()` for this and avoid unnecessarily creating issues.

* `git checkout` with commit hash

* Change the repositories origin URLs when necessary

* Mac k-diffusion workarounds are no longer needed

* Remove duplicate code for torchsde randn

* Fix DDIM and PLMS samplers on MPS

* fix broken XYZ plot seeds
add new callback for scripts to be used before processing

* Fix typo in launch_utils.py

existance -> existence

* fix 2 for git code botched by previous PRs

* also use setup callback for the refiner instead of before_process

* add res(dpmdd 2m sde heun) and reorder the sampler list

* Fix MHA updown err and support ex-bias for no-bias layer

* repair DDIM/PLMS/UniPC batches

* make on_before_component/on_after_component possible earlier

* Put frequently used sampler back

* remove "if bias exist" check

* revert to applying mask before denoising for k-diffusion, like it was before

* repair /docs page

* Fix typo in shared_options.py

unperdictable -> unpredictable

* further repair the /docs page to not break styles with the attempted fix

* return seed controls UI to how it was before

* fix API always using -1 as seed

* hires prompt timeline: merge to latests, slightly simplify diff

* add second_order to samplers that mistakenly didn't have it

* separate Extra options

* Update hash for SD XL Repo

* revert changed inpainting mask conditioning calculation after AUTOMATIC1111#12311

* when refreshing cards in extra networks UI, do not discard user's custom resolution

* correctly add Eta DDIM to infotext when it's 1.0 and do not add it when it's 0.0.

* get XYZ plot to work with recent changes to refined specified in fields of p rather than in settings

* fix processing error that happens if batch_size is not a multiple of how many prompts/negative prompts there are AUTOMATIC1111#12509

* Add extra img2img noise

* Add NoCrypt/miku gradio theme

* update changelog file

* Add PR refs to changelog

* Use the new SD VAE override setting

* Update CHANGELOG.md

* full module with ex_bias

* store patches for Lora in a specialized module

* lint

* RNG: Make all elements of shape `int`s

* Fix inpaint upload for alpha masks, create reusable function

* CSS: Remove forced visible overflow for Gradio group child divs

* Remove wrong scale

* RAM optimization round 2

* send weights to target device instead of CPU memory

* Revert "send weights to target device instead of CPU memory"

This reverts commit 0815c45.

* send weights to target device instead of CPU memory

* auto add data-dir to gradio-allowed-path

* return empty list if extensions_dir not exist

* Add extra noise callback

* resolve the issue with loading fp16 checkpoints while using --no-half

* Attempt to resolve NaN issue with unstable VAEs in fp32 mk2

* Fix img2img background color not being used

* Add option for faster live interrupt

* Make image viewer actually fit the whole page

* fix issues with model refresh

* remove unused import

* image hash

* negative_prompt full_prompt hash

* fix model override logic

do not need extra logic to unload refine model

* Make results column sticky

* Gallery: Set preview to True, allow custom height

* Exit out of hires fix if interrupted earlier

* fix typo `txt2txt` -> `txt2img`

* run python unbuffered so output shows up in docker logs

* xformers update

* more grammar fixes

* refactor: Update ui.js

* api support get image from url

* second appearance

* revert xformers back to 0.0.20

* do not assign to vae_dict

* possible fix for dictionary changed size during iteration

* implement undo hijack for SDXL

* switch to PNG when images too large

* fix xyz swap axes

make csv_string_to_list_strip function

* Fix SD VAE switch error after model reuse

* Add resize-handler extension

* catch error when loading config_states

and save config_states with indent

* assert key created_at exist in config_states

* no need to use OrderedDict

* make it obvious that a config_status is corrupted

also format HTML removing unnecessary text blocks

* make live preview display work independently from progress bar

* make live previews play nice with window/slider resizes

* make mobile built-in extension actually do something

* Store base_vae and loaded_vae_file in sd_model

* Change to access sd_model attribute with dot

* fix for small images in live previews not being scaled up

* Change where VAE state are stored in model

* fix potential ssrf attack in AUTOMATIC1111#12663

* Update torch for Navi 31 (7900 XT/XTX) 

Navi 3 needs at least 5.5 which is only on the nightly chain, previous versions are no longer online (torch==2.1.0.dev-20230614+rocm5.5 torchvision==0.16.0.dev-20230614+rocm5.5 torchaudio==2.1.0.dev-20230614+rocm5.5).
so switch to nightly rocm5.6 without explicit versions this time

* Fix for consistency with shared.opts.sd_vae of UI

* add settings for http/https URLs in source images in api

* prevent API options from being changed via API

* make resize handle available to extensions

* also prevent changing API options via override_settings

* feat: replace threading.Lock() to FIFOLock

Signed-off-by: AnyISalIn <[email protected]>

* ditch --always-batch-cond-uncond in favor of an UI setting

* add citation

* citation mk2

* forbid Full live preview method for medvram and add a setting to undo the forbidding

* add RNG source to XYZ

* Reset columns on resize handle dblclick

* Make Gradio temp directory if it doesn't exist

Gradio normally creates the temp directory in `pil_to_temp_file()` (https://github.com/gradio-app/gradio/blob/861d752a83da0f95e9f79173069b69eababeed39/gradio/components/base.py#L313) but since the Gradio implementation of `pil_to_temp_file()` is replaced with `save_pil_to_file()`, the Gradio temp directory should also be created by `save_pil_to_file()` when necessary.

* remove unneeded example_inputs from gradio config

* attemped solution to the uncommon hanging problem that is seemingly caused by live previews working on the tensor as denoising

* Fix resize handle overflowing in Safari

* Expand the hit area of resize handle

* actual solution to the uncommon hanging problem that is seemingly caused by multiple progress requests working on same tensor

* Prevent text selection and cursor changes

* dump current stack traces when exiting with SIGINT

* Update README.md with Intel install instructions

* for live previews, only hide gallery after at least one live previews pic has been received
fix blinking for live previews
fix a clientside live previews exception that happens when you kill serverside during sampling
match the size of live preview image to gallery image

* Limit mouse detection to primary button only

* make it possible to localize tooltips and placeholders

* Replace tabs with spaces

* fix broken generate button if not using live previews

* eslint

* Fix double click event not firing

* Improve integration, fix for new gradio

* use an atomic operation to replace the cache with the new version

* add --medvram-sdxl

* add type annotations for extra fields of shared.sd_model

* fix endless progress requests

* Removed the old code

* lint

* set devices.dtype_unet correctly

* tell RealESRGANer which device to run on, could be cuda, M1, or other GPU

* fix memory leak when generation fails

* update doggettx cross attention optimization to not use an unreasonable amount of memory in some edge cases -- suggestion by MorkTheOrk

* update changelog

* draw extra network buttons above description

* Fixing and improving integration

* remove console.log

* bump gradio version

* add infotext for use_old_scheduling option

* update changelog

* update info about gradio in changelog file

* Zoom and Pan: Resize handler

* fix incorrect save/display of new values in Defaults page in settings

* fix defaults settings page breaking when any of main UI tabs are hidden

* fix error that causes some extra networks to be disabled if both <lora:> and <lyco:> are present in the prompt

* update gradio to 3.41.2

* update changelog

* fix for Reload UI function: if you reload UI on one tab, other opened tabs will no longer stop working

* Merge pull request AUTOMATIC1111#12795 from catboxanon/prevent-duplicate-resize-handler-mk2

Prevent duplicate resize handler

* Merge pull request AUTOMATIC1111#12797 from Madrawn/vae_resolve_bug

Small typo: vae resolve bug

* Merge pull request AUTOMATIC1111#12792 from catboxanon/image-cropper-hide

Hide broken image crop tool

* Merge pull request AUTOMATIC1111#12780 from catboxanon/xyz-hide-samplers

Don't show hidden samplers in dropdown for XYZ script

* fix style editing dialog breaking if it's opened in both img2img and txt2img tabs

* update changelog

* lint

* hide --gradio-auth and --api-auth values from /internal/sysinfo report

* update changelog

* Merge pull request AUTOMATIC1111#12814 from AUTOMATIC1111/non-local-condition

non-local condition

* Merge pull request AUTOMATIC1111#12819 from catboxanon/fix/rng-infotext

Add missing infotext for RNG in options

* always show NV as RNG source in infotext

* Merge pull request AUTOMATIC1111#12842 from dhwz/dev

remove xformers Python version check

* Merge pull request AUTOMATIC1111#12837 from bluelovers/pr/file-metadata-break-001

style: file-metadata word-break

* Merge pull request AUTOMATIC1111#12818 from catboxanon/sgm

Add option to align with sgm repo's sampling implementation

* Merge pull request AUTOMATIC1111#12834 from catboxanon/fix/notification-tab-switch

Fix notification not playing when built-in webui tab is inactive

* Merge pull request AUTOMATIC1111#12832 from catboxanon/fix/skip-install-extensions

Honor `--skip-install` for extension installers

* Merge pull request AUTOMATIC1111#12833 from catboxanon/fix/dont-print-blank-stdout

Don't print blank stdout in extension installers

* revert SGM noise multiplier change for img2img because it breaks hires fix

* Merge pull request AUTOMATIC1111#12856 from catboxanon/extra-noise-noisy-latent

Add noisy latent to `ExtraNoiseParams` for callback

* Merge pull request AUTOMATIC1111#12855 from dhwz/dev

don't print empty lines

* Merge pull request AUTOMATIC1111#12854 from catboxanon/fix/quicksettings-dropdown-unfocus

Do not change quicksettings dropdown option when value returned is `None`

* Merge pull request AUTOMATIC1111#12839 from ibrainventures/patch-1

[RC 1.6.0 - zoom is partly hidden] Update style.css

* get progressbar to display correctly in extensions tab

* Merge pull request AUTOMATIC1111#12838 from bluelovers/pr/file-metadata-path-001

display file metadata `path` , `ss_output_name`

* go back to single path for filenames in extra networks metadata dialog

* Merge pull request AUTOMATIC1111#12851 from bluelovers/pr/extension-time-001

chore: change extension time format

* update changelog

* keep order in list of checkpoints when loading model that doesn't have a checksum

* Merge pull request AUTOMATIC1111#12864 from AUTOMATIC1111/extension-time-format-time-zone

patch Extension time format in systme time zone

* Merge pull request AUTOMATIC1111#12865 from AUTOMATIC1111/another-convert-to-system-time-zone

extension update time, convert to system time zone

* add an option to choose how to combine hires fix and refiner

* fix inpainting models in txt2img creating black pictures

* add information about Restore faces and Tiling into the changelog

* add --dump-sysinfo, a cmd arg to dump limited sysinfo file at startup

* update bug report template to include sysinfo and not include all other fields that are already covered by sysinfo

* fix an issue where VAE would remain in fp16 after an auto-switch to fp32

* fix an issue where using hires fix with refiner on first pass with medvram would cause an exception when generating

* Merge pull request AUTOMATIC1111#12876 from ljleb/fix-re

Fix generation params regex

---------

Signed-off-by: AnyISalIn <[email protected]>
Co-authored-by: AUTOMATIC1111 <[email protected]>
Co-authored-by: brkirch <[email protected]>
Co-authored-by: Ikko Eltociear Ashimine <[email protected]>
Co-authored-by: Kohaku-Blueleaf <[email protected]>
Co-authored-by: whitebell <[email protected]>
Co-authored-by: Robert Barron <[email protected]>
Co-authored-by: w-e-w <[email protected]>
Co-authored-by: catboxanon <[email protected]>
Co-authored-by: NoCrypt <[email protected]>
Co-authored-by: Cade Schlaefli <[email protected]>
Co-authored-by: S-Del <[email protected]>
Co-authored-by: Dan <[email protected]>
Co-authored-by: XDOneDude <[email protected]>
Co-authored-by: bluelovers <[email protected]>
Co-authored-by: SpenserCai <[email protected]>
Co-authored-by: Uminosachi <[email protected]>
Co-authored-by: akiba <[email protected]>
Co-authored-by: fraz0815 <[email protected]>
Co-authored-by: AnyISalIn <[email protected]>
Co-authored-by: MMP0 <[email protected]>
Co-authored-by: Ravi Panchumarthy <[email protected]>
Co-authored-by: Danil Boldyrev <[email protected]>
Co-authored-by: yajun <[email protected]>
  • Loading branch information
2 parents f1923ca + 5ef669d commit 6a0fd1f
Show file tree
Hide file tree
Showing 129 changed files with 7,116 additions and 3,751 deletions.
6 changes: 6 additions & 0 deletions .eslintrc.js
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,11 @@ module.exports = {
modalNextImage: "readonly",
// token-counters.js
setupTokenCounters: "readonly",
// localStorage.js
localSet: "readonly",
localGet: "readonly",
localRemove: "readonly",
// resizeHandle.js
setupResizeHandle: "writable"
}
};
78 changes: 7 additions & 71 deletions .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ body:
id: steps
attributes:
label: Steps to reproduce the problem
description: Please provide us with precise step by step information on how to reproduce the bug
description: Please provide us with precise step by step instructions on how to reproduce the bug
value: |
1. Go to ....
2. Press ....
Expand All @@ -37,64 +37,14 @@ body:
id: what-should
attributes:
label: What should have happened?
description: Tell what you think the normal behavior should be
description: Tell us what you think the normal behavior should be
validations:
required: true
- type: input
id: commit
attributes:
label: Version or Commit where the problem happens
description: "Which webui version or commit are you running ? (Do not write *Latest Version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Version: v1.2.3** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)"
validations:
required: true
- type: dropdown
id: py-version
attributes:
label: What Python version are you running on ?
multiple: false
options:
- Python 3.10.x
- Python 3.11.x (above, no supported yet)
- Python 3.9.x (below, no recommended)
- type: dropdown
id: platforms
attributes:
label: What platforms do you use to access the UI ?
multiple: true
options:
- Windows
- Linux
- MacOS
- iOS
- Android
- Other/Cloud
- type: dropdown
id: device
attributes:
label: What device are you running WebUI on?
multiple: true
options:
- Nvidia GPUs (RTX 20 above)
- Nvidia GPUs (GTX 16 below)
- AMD GPUs (RX 6000 above)
- AMD GPUs (RX 5000 below)
- CPU
- Other GPUs
- type: dropdown
id: cross_attention_opt
- type: textarea
id: sysinfo
attributes:
label: Cross attention optimization
description: What cross attention optimization are you using, Settings -> Optimizations -> Cross attention optimization
multiple: false
options:
- Automatic
- xformers
- sdp-no-mem
- sdp
- Doggettx
- V1
- InvokeAI
- "None "
label: Sysinfo
description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.
validations:
required: true
- type: dropdown
Expand All @@ -108,21 +58,7 @@ body:
- Brave
- Apple Safari
- Microsoft Edge
- type: textarea
id: cmdargs
attributes:
label: Command Line Arguments
description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
render: Shell
validations:
required: true
- type: textarea
id: extensions
attributes:
label: List of extensions
description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise.
validations:
required: true
- Other
- type: textarea
id: logs
attributes:
Expand Down
155 changes: 155 additions & 0 deletions CHANGELOG.md

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- given-names: AUTOMATIC1111
title: "Stable Diffusion Web UI"
date-released: 2022-08-22
url: "https://github.com/AUTOMATIC1111/stable-diffusion-webui"
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- Clip skip
- Hypernetworks
- Loras (same as Hypernetworks but more pretty)
- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt
- A separate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt
- Can select to load a different VAE from settings screen
- Estimated completion time in progress bar
- API
Expand All @@ -88,12 +88,15 @@ A browser interface based on Gradio library for Stable Diffusion.
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
- Now without any bad letters!
- Load checkpoints in safetensors format
- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64
- Eased resolution restriction: generated image's dimension must be a multiple of 8 rather than 64
- Now with a license!
- Reorder elements in the UI from settings screen

## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for:
- [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended)
- [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
- [Intel CPUs, Intel GPUs (both integrated and discrete)](https://github.com/openvinotoolkit/stable-diffusion-webui/wiki/Installation-on-Intel-Silicon) (external wiki page)

Alternatively, use online services (like Google Colab):

Expand All @@ -115,15 +118,15 @@ Alternatively, use online services (like Google Colab):
1. Install the dependencies:
```bash
# Debian-based:
sudo apt install wget git python3 python3-venv
sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0
# Red Hat-based:
sudo dnf install wget git python3
# Arch-based:
sudo pacman -S wget git python3
```
2. Navigate to the directory you would like the webui to be installed and execute the following command:
```bash
bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
wget -q https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh
```
3. Run `webui.sh`.
4. Check `webui-user.sh` for options.
Expand Down Expand Up @@ -169,5 +172,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- LyCORIS - KohakuBlueleaf
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
10 changes: 9 additions & 1 deletion extensions-builtin/Lora/extra_networks_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self):
super().__init__('lora')

self.errors = {}
"""mapping of network names to the number of errors the network had during operation"""

def activate(self, p, params_list):
additional = shared.opts.sd_lora

self.errors.clear()

if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
Expand Down Expand Up @@ -56,4 +61,7 @@ def activate(self, p, params_list):
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)

def deactivate(self, p):
pass
if self.errors:
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))

self.errors.clear()
31 changes: 31 additions & 0 deletions extensions-builtin/Lora/lora_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch

import networks
from modules import patches


class LoraPatches:
def __init__(self):
self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)

def undo(self):
self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward')
self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict')
self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward')
self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict')
self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward')
self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict')
self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward')
self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict')
self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward')
self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict')

7 changes: 5 additions & 2 deletions extensions-builtin/Lora/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def calc_scale(self):

return 1.0

def finalize_updown(self, updown, orig_weight, output_shape):
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
Expand All @@ -145,7 +145,10 @@ def finalize_updown(self, updown, orig_weight, output_shape):
if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape)

return updown * self.calc_scale() * self.multiplier()
if ex_bias is not None:
ex_bias = ex_bias * self.multiplier()

return updown * self.calc_scale() * self.multiplier(), ex_bias

def calc_updown(self, target):
raise NotImplementedError()
Expand Down
7 changes: 6 additions & 1 deletion extensions-builtin/Lora/network_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)

self.weight = weights.w.get("diff")
self.ex_bias = weights.w.get("diff_b")

def calc_updown(self, orig_weight):
output_shape = self.weight.shape
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
if self.ex_bias is not None:
ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
else:
ex_bias = None

return self.finalize_updown(updown, orig_weight, output_shape)
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
28 changes: 28 additions & 0 deletions extensions-builtin/Lora/network_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import network


class ModuleTypeNorm(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["w_norm", "b_norm"]):
return NetworkModuleNorm(net, weights)

return None


class NetworkModuleNorm(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)

self.w_norm = weights.w.get("w_norm")
self.b_norm = weights.w.get("b_norm")

def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)

if self.b_norm is not None:
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
else:
ex_bias = None

return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
Loading

0 comments on commit 6a0fd1f

Please sign in to comment.