Everything you know about loss is a LIE! #294
Replies: 28 comments 154 replies
-
Could it be that the loss is actually "loss per timestep" or maybe the average of the loss for each timestep? I think this would be different from the loss accumulation for all timesteps (not sure if what I said even makes sense) |
Beta Was this translation helpful? Give feedback.
-
@swfsql it has to do with the signal to noise ratio, low timesteps have very little noise added into it, and when the trainer takes a sample (a single step down the timestep chain IIUC) to predict noise, it gets a lot larger error. There's a lot more math involved, but basically you can think of it as 5/4 is a lot bigger that 500/499 There's some more info here #308 because Min-SNR is designed precisely to counteract some of this high loss from low timesteps. |
Beta Was this translation helpful? Give feedback.
-
I took this idea and tinkered a bit. I set the timestep to be a fixed (1000-global_step), and it produced a very clean U-curve, bottoming out at around min_timestep 400. Observationally, setting the timestep range higher does a better job of absorbing "broad" details, but misses the fine details. Once the timestep dropped under 400, the results got rapidly garbled and nonsensical; much worse than just normal overtraining. Taking this information, I wondered how training would be affected if the lower timesteps weren't selected as frequently. I reimplemented the random timestep generation to use a standard normal distribution where -6 sigma is min_timestep and 6 sigma is max_timestep, and a mean of sigma = 6
timesteps = ((torch.randn((b_size,), device=latents.device).clip(-sigma, sigma) + sigma) / (2*sigma)) * (max_timestep - min_timestep) + min_timestep Here's a comparison of loss between the standard implementation and my normal distribution (teal is the modified routine). In both cases, MinSNR=5. Observationally, knocking out the lower timesteps results in significantly faster improvements to the samples over time. Setting min_timestep too high causes the model to not learn the finer details of the subject, so a balance is needed. A 6-sigma standard normal distribution using [100..1000] as my range (which should give me an average timestep of 550) results in the model learning significantly faster - in the dataset I'm working on, I typically get pretty decent results after 1500-2000 steps, but with this change I've been seeing it approach the same level of fidelity by ~400-500 steps, with significantly less overtraining "damage" to the underlying model. Here's a comparison of one of my training images, the standard timestep selection routine (min/max range of [0..1000]), and my timestep selection routine with bounds set at [100..1000]. The two generated images are generated after 500 steps of training. All other parameters other than the timestep range and random generation routine were held constant. For what it's worth, I'm using the Prodigy optimizer with a CosineAnnealingLR scheduler; I suppose tests should also be run with the more standard Adam8bit and a constant learning rate, but the results were a significant-enough improvement that I felt the observation bore sharing. I have no theoretical basis for any of this, I'm just kind of experimenting and found that this had a massive impact. Intuitively, one thing that might be worth trying is some kind of combination of global step, learning rate and timestep range scheduling, to tilt the timestep ranges higher early on or when the learning rate is higher, and then reducing the lower bound of the range over time to see if it can balance learning the finer details without overtraining all that extra error from the low timesteps. |
Beta Was this translation helpful? Give feedback.
-
It looks like #889 (and the linked paper) have potentially addressed this problem, as well. I'll be running some tests, but if it's really that simple, then all's the better! |
Beta Was this translation helpful? Give feedback.
-
I think I might have stumbled into something extraordinary, and want to throw this out there to get other brains on it. One of my observations in my experiments is that forward-noising applies noise as This shape has shown up in other experiments (you'll notice it in my post above in the loss explorations), which has made me wonder if there's a fundamental bias in the forward noising mechanism - there will be much more total information (sample + noise) in the latent at step 400 than there is at steps 10 or 800. The pixel values will always have their largest magnitude at the ~400ish step peak. In abstract, maybe this doesn't matter, because the unet should learn to predict noise regardless of the function used to generate the noise, so long as that function is differentiable, right? But what if it's actually a source of bias in training? At first I thought that this wouldn't change things that much, but that we might be able to debias it by dividing both noise and target by if args.apply_noise_compensation:
noise_comp = (noise_scheduler.alphas_cumprod.sqrt() + (1 - noise_scheduler.alphas_cumprod).sqrt()).to(device=accelerator.device)
noise_pred = noise_pred / noise_comp[timesteps] # .sqrt()
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") My first samples were blurry. REALLY blurry. But they "felt" much more like my subject than samples of the same timestep were with the standard routine. Furthermore, as the epochs progressed, the samples retained significantly more of my subject's ineffable quality, but they sharpened up! The longer the training ran, the clearer my samples became. Around epoch 30 or so, they were essentially back to full resolution, but without the mutations, distortions, or burnout that is characteristic of overtraining in that range prior, and the samples had astonishingly preserved most of the "shape" from the underlying model. This smells, in principle, much like what the When I then pulled the LoRA into SD and tried it against various models other than the one I trained it on (RealisticVision) and in combination with other LoRAs, it feels like it generalizes FAR better than I've ever accomplished before. This felt like a fundamentally superior output in terms of fidelity AND flexibility. Training details: AdamW8bit (modified to use the the AdaBelief term) constant LR, of 8e-6, What I'm now wondering is a) why does this work? and b) if we can bias the model towards "level of detail" by intentionally selecting the curve by which noise_pred is modified before loss calculation. What I'm envisioning is some kind of "10-band EQ" where you could drag the learning rate of certain LODs up or down. Here's a video of my training run over 39 epochs. interpolated.mp4As you can see, it learns the "large outline" of my features early on, but then improves incrementally on the level of details in the image WITHOUT losing those "large features". In much of my previous experiments, I've found that both large and small features got learned at the same time, and that frequently I'd end up with models where either there was too much "large detail" (and the shape of the outputs got distorted and nonsensical) or there wasn't enough "small detail" (and the fidelity of the subject didn't feel right). This feels like it sidestepped that issue entirely and gave me an unexpectedly flexible high-fidelity output. I would appreciate any insight into what might be happening here, and how it might be understood to improve the ability to better control training. |
Beta Was this translation helpful? Give feedback.
-
Hi @cheald , Hope you are well. Thank you for the info about the timestep. Thank you. |
Beta Was this translation helpful? Give feedback.
-
Hi @cheald, I have also been experimenting with this problem, and I believe I have the true ideal solution, which should end the need to tune timestep weightings. The traditional solution to this would be multi-task loss which is structured in such a way that timesteps would get extra weight depending on their difficulty (e.g. more difficult timesteps are weighted higher, less difficult ones are lower weight), so that all timesteps contribute the same amount towards the image. The weights would be trained as a tensor with one element for each timestep. The main problem with this, is that it is not very stable on smaller batch sizes that anyone who isn't training a foundation model would use -- instead of optimizing the parameters, what you'd end up doing with this is effectively smacking them across the room every now and then, since you're not going to touch every timestep every training step. It is not likely to converge properly unless using a batch size of at least 256 or so (very conservative estimate, it probably needs more), and is inefficient for learning since we would expect nearby timesteps to have similar difficulty and this doesn't capture that. The EDM2 paper (repo: https://github.com/NVlabs/edm2) includes a different form of this multi-task loss in the form of a single-layer MLP with no activation function that takes noise level (sigma) as input. While this was designed to be used with a continuous noise schedule like the EDM models use, I have found that it works extremely well on discrete timestep schedules, and most importantly for us, it allows multi-task loss to work on smaller batch sizes efficiently. From my testing (still ongoing), I have found that this drastically improved results over the debias schedule you noted that you used before, and interestingly, the weightings it chose did not look too much like other weightings I had been recommended before. I have also had wonderful results with scheduled pseudo huber loss in combination with this. One remaining problem with the learned timestep weightings, though, is that it most likely will take a longer time to converge than most short training runs will use. My tests so far have been with a full finetune with a virtual batch size of 64. I do get good, fast convergence with this method when I use the recently released Schedule Free optimizer (https://github.com/facebookresearch/schedule_free/tree/main), where I get fairly close to the final schedule within about 500 steps with an LR of 0.005. Regardless of whether it is viable for use on all training run durations for your use cases, I am sure that you could use it on a longer training run for the purpose of discovering better timestep weight schedules. The MLP is also formulated in a way that it accepts a "baseline" timestep weighting of sorts (noted in the paper's formulas as |
Beta Was this translation helpful? Give feedback.
-
I've developed a way to do posthoc analysis of a lora to see WHERE it improved loss on your training set. The basic idea is simple: Take a training dataset, a lora, and a model. Load each training sample, noise it by every 50th timestep, do noise prediction, and take loss as the standard Here's the samples at the 20th epoch. Observations: Fine details, very warm overall tone, blurry and low-detail backgrounds. The color depth feels a bit flat, but the textures are decent. And here's the loss ratio plot. What I've found is that the stock training regieme is good at reducing loss at higher timesteps, but has a much harder time with lower timesteps. Forgive the lack of labels; the X axis is "timestep / 50" (ranging from 0-20, which expands to 0-1000), and the Y axis is the ratio of baseline loss:lora loss (higher means the lora reduced loss): (Edit: I realized this morning that I was using baseline/lora rather than lora/baseline, so that changes my interpretations, which I've updated) Additionally, I can plot statistics PER SAMPLE to find pathological samples in my dataset which are not converging! The red line is a ratio of 1.0, and the box plot plots the loss reductions across all 20 sampled timesteps, with the typical mean, median, and 1SD. This is really useful for finding samples which the optimizer has outsized trouble converging. By comparison, here's a run where I experimented with using Observations from these samples: Much more neutral colors, better dynamic range, but the likeness isn't quite as good. The teeth are better (and this has held through my experiments; the teeth overtrain first, but with this technique they remain fine the whole time): And here's the loss plots. The loss on the high end has tangibly improved. If I let this training run go for 60 epochs, it does a GREAT job at learning structure and form, but doesn't quite get details. What's interesting here is that the tail end flipped, but the loss change as a percentage on the low end didn't change much at all. This might be due to the lower absolute values on the high end, but it's interesting that the first part of the curve didn't change much. Here's the notebook. It should go in your sd-scripts directory as it uses a few utility functions from sd-scripts to ease model loading. Right now it's just working with SD1.5 but it shouldn't be hard to extend it for SDXL or whatnot. My hope is that lessons learned in SD15 land can be applied to SDXL, since SD15 is a lot faster to run experiments with. |
Beta Was this translation helpful? Give feedback.
-
I am pretty sure that I've directly identified a cause of the original observation in this issue. Essentially: Given a static noise, and then forward noising a latent with that noise, and then predicting the noise from that noised latent, earlier timestamps consistently end up with a lower overall magnitude of noise. with torch.no_grad():
timesteps = torch.arange(0, 1000, 25, device=dev)
latent = encode_path_sd15(image, 128)
latent = latent.expand(timesteps.shape[0], *latent.shape[1:])
noise = torch.randn((1, *latent.shape[1:]), device=latent.device).expand(latent.shape)
_, text_embeddings = prompt_to_cond(["man"], latent.unsqueeze(0))
noisy_latents = noise_scheduler.add_noise(latent, noise, timesteps)
noise_pred = unet(noisy_latents, timesteps, text_embeddings.cuda()).sample
fig = plt.figure( figsize=(20, 50) )
for i, timestep in enumerate(timesteps):
ax = plt.subplot(10, 4, i + 1)
ax.set_title(f"Timestep {timestep}")
n = noise[i].flatten().cpu()
ax.hist(n, alpha=0.5, bins=100, color="red")
ax.hist(noise_pred[i].flatten().cpu(), alpha=0.5, bins=100)
plt.show() Red is the true noise (held constant), and blue is the predicted noise. You'll notice that at t=0, the predicted noise histogram has a significantly narrower distribution of noise, and the magnitude of noise increases as the timestamp increases. This will plainly lead to significant differences in loss for lower timesteps. If we plot This is essentially the difference in the magnitude of the true noise and predicted noise at each timestep for a given static noise. I regressed this curve to roughly I tried scaling noise_pred directly first, and this does something very interesting: it causes what feels like a relative "hyperfocus" on detail, resulting in over-sharpened (one might even say "overtrained") images. This is after only 4 epochs, but the pattern holds.
But, okay, if the issue is just magnitude, we can directly normalize the noise_pred to the standard deviation of noise: noise_pred = noise_pred / noise_pred.std() * noise.std() # noise.std() should be pretty close to 1 here, so maybe unnecessary? This results in a significantly higher level of subject detail (and feels perhaps the most photoreal result I've achieved yet) but the background basically entirely disappeared in all my samples. My first thought is that I wonder if it's related to masked training, but I wouldn't think so, since noising and noise_pred normalization are both applied without respect to masks. At any rate, there's a dial here to play with here. I suspect there's more to the nature of noise vs noise_pred than just the difference in noise, but the noise is definitely a good clue. Perplexingly, normalizing noise to noise_std (which should produce similar loss values, I think?) does NOT produce similar results: noise_pred = unet.call(...)
noise = noise / noise.std() * noise_pred.std() I don't have an explanation for this, so there's clearly something else in play here that I'm missing. |
Beta Was this translation helpful? Give feedback.
-
Okay, so initial runs show extreme promise. I've added multiple additional loss objectives:
(Edit: I'm doing more testing, and it might be that kl_div loss alone is sufficient for this effect; it keeps us in the right "neighborhood" but allows more flexibility") They can be tested individually, or combined. Each has a weight, and the individual objectives are summed and added to the overall loss. The results are really, REALLY good. # in train_utils.py
def noise_stats(noise):
mean = noise.mean(dim=(1,2,3)).view(-1, 1, 1, 1)
std = noise.std(dim=(1,2,3)).view(-1, 1, 1, 1)
skew = torch.sum((noise - mean)**3 / std**3) / (noise.numel() / noise.shape[0])
kurt = torch.sum((noise - mean)**4 / std**4) / (noise.numel() / noise.shape[0]) - 3
return skew, kurt
def stat_losses(noise, noise_pred, std_loss_weight=0.5, kl_loss_weight=3e-3, skew_loss_weight=0, kurtosis_loss_weight=0):
std_loss = torch.nn.functional.mse_loss(
noise_pred.std(dim=(1,2,3)),
noise.std(dim=(1,2,3)),
reduction="none") * std_loss_weight
skew_pred, kurt_pred = noise_stats(noise_pred)
skew_true, kurt_true = noise_stats(noise)
skew_loss = torch.nn.functional.mse_loss(skew_pred, skew_true, reduction="none") * skew_loss_weight
kurt_loss = torch.nn.functional.mse_loss(kurt_pred, kurt_true, reduction="none") * kurtosis_loss_weight
p1s = []
p2s = []
for i, v in enumerate(noise_pred):
n = noise[i]
p1s.append(torch.histc(v.float(), bins=500, min=n.min(), max=n.max()) + 1e-6)
p2s.append(torch.histc(n.float(), bins=500) + 1e-6)
p1 = torch.stack(p1s)
p2 = torch.stack(p2s)
kl_loss = torch.nn.functional.kl_div(p1.log(), p2, reduction="none").mean(dim=1) * kl_loss_weight
return std_loss, skew_loss, kurt_loss, kl_loss # in train_network.py
std_loss, skew_loss, kurt_loss, kl_loss = train_util.stat_losses(noise, noise_pred)
loss = loss + std_loss + kl_loss + skew_loss + kurt_loss
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし Charting all those metrics - std, skew, kurtosis, and kl_div shows that despite the classic loss objective improving, as training continues, various metrics go wonky. But, we KNOW that the desired noise target has a consistant std, skew, and kurtosis. My hunch here is that models overtrain by learning noise predictions which do not resemble IID gaussian, and using metrics like std, skew, and kurtosis as objectives keeps them from getting too far afield.
(Edit: I'm using torch.histc for divergence which isn't differentiable, so it's not working quite correctly WRT the backwards pass. I'm reimplementing with a soft histogram instead.) This DOES unfortunately add four (!) new hyperparameters - weights for each new loss type - but so far the values I've used are producing astonishingly good results. After 20 epochs, I'm getting remarkably good samples, with no sign of overtraining, yet. I'm still running a training run, and don't have the vram to run analytics on a trained lora while it's going, but I'll get some timestep error distribution graphs once this run completes. I would very much like others to give this a go and let me know what kinds of results you end up with. Edit: Here's the timestep error graph after 36 epochs. It's still a little funny at the extremes, but this is a SIGNIFICANT improvement overall. |
Beta Was this translation helpful? Give feedback.
-
Hey, you might want to check out this recent nvidia paper: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/ It looks like it could effectively be a way to handle this issue on the inference end -- or at the very least, you could gain some useful insights that are relevant to your problem by reading the paper and experimenting with the schedule they have (keeping in mind that it is dataset- and model-dependent). They unfortunately haven't released "training" code (this isn't really training, it's a zeroth-order optimization), but I've collaborated with someone to replicate what I am pretty sure is a valid and correct implementation, and I am experimenting with it. |
Beta Was this translation helpful? Give feedback.
-
@cheald thanks for your diligent work on this avenue of research
based on my empirical results and my interpretation of what is being trained in particular by LoRA fine tunings, masked loss does not ignore backgrounds. It will learn this aka for a region of noisy pixels, a function that will make that region trend towards "gray". Another interpretation is that error is increasing for the backgrounds after every iteration and simply being ignored. My hypothesis for why masked training works well is that for many of the subjects trained by the community, the number of steps needed to achieve decent results does not add "too much" error to backgrounds. Specifically, the conditional Unet LoRAs for a face that already looks like a celebrity will have small changes from identity (aka 0s) for good performance, and near small untrained / random values, the amount of "damage" done to the weights computing backgrounds is relatively small. If you use masked training for many other subjects it tends to blow up backgrounds simply because the parameters via whatever fine tuning method have to "actually" learn something. |
Beta Was this translation helpful? Give feedback.
-
Observations from tonight's tests:
I've updated my branch with most of those changes, and am very pleased with the results. As an aside: charting the effects of |
Beta Was this translation helpful? Give feedback.
-
I am trying to understand: how does the masking, which is in 2D pixel space, get expressed in the latent space? |
Beta Was this translation helpful? Give feedback.
-
Some observations from a few training runs (SDXL Lora) with very high number of steps, using the proposed changes:
At this stage I am unsure of the root causes, it could be a poor choice of hyper-parameters or these could be over-training artifacts beyond what the proposed changes are able to compensate. |
Beta Was this translation helpful? Give feedback.
-
Were is an example, taken with model RealVisXL_4.0 (I get the same artifacts with SDXL base). Same prompt, without Lora: |
Beta Was this translation helpful? Give feedback.
-
I've tried different runs with I also wonder if it would make sense to have some kind of schedule for this parameter, using very low values on low timesteps. |
Beta Was this translation helpful? Give feedback.
-
My further experimentation is turning up that:
For reasons I don't quite understand yet, increasing the variance of the noise (which is what the first loss term does) causes a loss of detail. This is somewhat confusing to me, because I'd intuitively guess that a wider range of noise would result in a more diverse set of color - and essentially detail. However, it improves - by my eye - subject fidelity. The second term there keeps the variance of any individual channel in the latent in check, which should help clamp down on "wild outlier" values. By itself, it does a really, really good job of causing the model to preserve detail, but hurts in terms of subject fidelity. (Thought: Perhaps a wider range of negative values is leading to that color flattening? It might be interesting to play with tweaking the positive and negative sides of the noise separately!) Armed with that information, it might be that we could schedule some combination or crossfade of those terms, with the model learning to push noise_pred variance towards noise.std() early on, and then switching to minimizing channel loss variance later on. Compressing the noise_pred variance too much results in "hyper-detailed" results, but extending it too far out (even towards 1, which is what true noise variance is) results in the model losing too much detail and becoming very flat. The reason for this is not immediately clear to me; intuitively, I would expect noise_pred most like true noise to produce the highest fidelity results, but it seems that this is not the case. If anyone has a theoretical explanation for that, I'd be very interested in exploring what's going on there. This does kind of explain the noise variance across timesteps, though; at high timesteps, the variance learned is closer to true, and produces "low detail" results, but as you get closer to t=0, the noise_pred variance drops, resulting in increased detail in the output. Alternately, scheduling the various loss factors by some timestep-variant scheduler might result in better overall results, too. |
Beta Was this translation helpful? Give feedback.
-
I've got an alternate loss form which is working remarkably well for me, and which I'd like some input on: alphas_cumprod = noise_scheduler.alphas_cumprod.to(accelerator.device)
ac = alphas_cumprod[timesteps]
mae_loss = F.l1_loss(noise_pred, target, reduction="none")
base_loss = 1/-mae_loss.exp() + 1
loss = base_loss.mean(dim=(2, 3), keepdims=True) * ac
loss = loss + base_loss.std(dim=(2,3), keepdims=True) * (1-ac) The basic idea here is:
The operating theories here are:
My overnight test results with this are really interesting. I tried both "variance at low t" and "raw loss at low t", and I think the "raw loss at low t" form is generally better, though both are impressive. The suggested formulation here DOES seem to be learning my actual underlying dataset more easily, which suggests better convergence (and perhaps the need to actually reduce my lora rank!) Results after 50 epochs. Normally, I'd see severe overtraining - visible in the wrinkles around the eyes, and/or the teeth - by now, or a loss of detail and dynamic range. Both seem to be preserved. |
Beta Was this translation helpful? Give feedback.
-
What I do not get is how I can have average to fantastic loss YET nothing learns (lora/locon/loha/dreambooth) the data? |
Beta Was this translation helpful? Give feedback.
-
Do you want to try this method? |
Beta Was this translation helpful? Give feedback.
-
I'm back, and I think I have some really cool stuff to share. SD 1.5 samples, 20 steps, euler a I've been tinkering with this over the last few weeks, and I think I've gained some genuine insight into the problem that has massive implications for training fidelity and convergence time. TL;DR: match the per-step standard deviations and means of your input noise per channel to the model you're trying to train, and things happen. A huge portion of the variant loss is actually legit - it's caused by the fact that we generate noise at (mean=0, std=1), but for whatever reason, the model learns to predict noise which consistently has a std not of 1, especially closer to t=0. If we actually measure the per-channel std and noise of noise predictions made by the unmodified model, some obvious patterns apply. This appears to vary per-model and per-architecture. Realistic Vision 5.1 (SD1.5): SDXL models are wildly different: DreamshaperXL It was clear that the standard training mechanism isn't fully correct - in particular, we know that Stable Diffusion, left unmodified, learns a mean of 0 for images, resulting in need for the famous "noise offset" regularization scheme. This has been corrected in downstream models (like RV), which include noise offsets, but we still train with a "blind" noise offset, which may or may not match the underlying model. Worse yet, the actual true mean and std varies per channel and per timestep. The good news is that it appears to broadly follow a curve. This means that we can measure and interpolate, and use those observed values to affect training. I'm playing with this at two locations:
Rule of thumb: Higher std weight = more detail. Too much = "oversharp". Higher mean weight = more light/shadow depth. Too much = color imbalance and contrast blowout." I've got this implemented in my autostats branch if you want to try it. The important part is the addition of the a couple of parameters, which indicates that the process should collect model noise statistics, and persist it to a file. This is done so that subsequent runs can just reuse the collected stats. (This could probably be extracted to a separate utility script with a minimum of fuss, too.)
If you specify an This currently runs 16 inferences at 64 (non-uniform) steps per inference, which is probably higher resolution than is needed, but provides robust numbers. This only takes a couple of minutes on SD1.5, but takes 15-20 minutes for SDXL models on my RTX 3090. However, it only has to be done once per model. More inferences and more varied prompts (particularly with prompting for various levels of detail and light/shadow) may help improve stats collection, but I haven't played with it too much. @recris This series of experiments did very clearly identify the cause of the halftone pattern. In SDXL models, the lower you drag the mean of channel 3 of the latent, the more detail you get, but you ALSO get the halftoning. Lower channel 3 (the last one, that is, the channels are numbered [0, 1, 2, 3]). manifests largely as a "sharpen" slider, and too sharp results in imagined detail like what you're seeing in your examples. It is easily corrected by dampening how far channel 3 is pulled from a mean of 0. Also in SDXL, channel 0 is largely "luminosity"; increasing the mean brightens the image, and reducing it darkens it. As you get more extreme, this tends to have "over-contrasting" effects. If you just want to increase color depth, jittering the mean of channel 0 noise (ala Here's an example of an SDXL training (1 epoch) with mean weights at Kazam_screencast_00007.mp4I have successfully mitigated this in SDXL by dampening the effects of the early-timestep mean offsets on channel 3: if self.is_sdxl:
ts = 500
mean_target_by_ts[:ts, 3] = mean_target_by_ts[:ts, 3] * torch.arange(0, 1.0, 1 / ts, device=mean_target_by_ts.device).view(-1, 1, 1) This very much does seem to prevent the halftone over-sharpening while maintaining the majority of the detail: However, this is SDXL-specific, and I didn't love adding it as a general term. I might still add it as a SDXL-specific term, but it feels like a hack that wallpapers over some deeper understanding of what's happening. |
Beta Was this translation helpful? Give feedback.
-
Potentially of interest here |
Beta Was this translation helpful? Give feedback.
-
I've gone through a whole host of experiments this weekend, but the most promising is this: adding a loss function for the norm of the text encoder conds REALLY helps. I'm experimenting with SD1.5 (Realistic Vision 5.1, specifically). I've been chasing a whole bunch of various manipulations of the noise, but I've been unable to come up with fixes which generalize. But I started messing with the text encoder, and things clicked. For some conceptual overview:
Okay, so with that theory understood, during training, we embed a given caption, then we get a loss value from the unet that tells us how far away we were from correctly predicting the noise. The unet's prediction is conditioned on the text encoder output, so if you're training the text encoder (or a Lora which modifies it), this causes the training loop to change the embedding that the text encoder produces to try to give the unet updated conditionals which improve its guess next time. If left unconstrained, the text encoder can learn to improve loss by pushing the embedding for a given caption out of the "highly populated" concept space into a more unique part of the embedding space which the unet can more easily learn. The problem with this, though, is that SD seems have a fairly narrow "aesthetic" range of embeddings clustered around vectors of a certain length. By learning an embedding further out of the "normal" range, we can more easily reduce loss (because it's learning an uncontested part of the embedding space with less concept bleed), but it does it by removing the embedding further from our learned concepts. Here's some examples using a simple custom ComfyUI node that I can use to extend or shrink the length of an embedding:
At "natural" embedding length: Embedding * 0.5 (same direction - pointing at the same concepts - but only half the length) Embedding * 0.75 Embedding * 1.25 Embedding * 1.5 You can see that we're keeping all the same concepts, but changing the vector length has marked impacts on both prompt cohesion and aesthetic quality. I think this is a large part of the problem with training subjects which don't look like existing subjects in the model. The trainer learns:
This is desirable behavior when training a new model, or when fine-tuning it on a lot of novel data, but this is less desirable when just trying to integrate a new subject. But, we can easily tell SD to "learn this subject, but keep it aesthetic" by just constraining the TE norm! # Prior to the training loop
def embed_caption(captions):
return get_weighted_text_embeddings(
tokenizer,
text_encoder,
captions,
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
all_caps = []
with torch.no_grad(), tqdm(total=10 * len(train_dataloader), desc="Collecting text encoder stats") as pbar:
for i in range(0, 10):
for batch in train_dataloader:
caps = embed_caption(batch["captions"])
pbar.update(1)
all_caps.append(caps)
cap_norm_mean = torch.cat(all_caps).norm(p=2, dim=(-1)).mean(dim=0).unsqueeze(0)
# Inside the training loop
deadzone = 0.0
te_loss_weight = 1.0
te_nrm_loss = (F.mse_loss(text_encoder_conds.norm(dim=(-1)), cap_norm_mean, reduction="none") - cap_norm_mean*deadzone**2).clamp(min=0).mean(dim=1).view(-1, 1, 1, 1)
loss = loss + te_nrm_loss * te_loss_weight The This graphs the length of the vector between the original embedding of "chris" and the learned embedding. This will include both angular distance (change in concept relatedness) and length. Here, the red line was a te_loss_weight=1.0, and the line indicates that the embedding stabilizes at a distance of ~8 from original. If you think of the embedding space as a sphere, this means that the "radius" of the sphere is being held constant, and the trainer is finding a new spot on that sphere. The pink line is a weight of 0.1, and you can see that this distance is divergent (and keeps diverging!) - this weight is probably too low. It's substantially improved the aesthetic quality of my training, but I can probably bring it up a bit. With a high-enough weight (keeping the embedding norm constrained) you could read that metric as a "how good your captions are" metric, too. There might also be some gains to be had by combining this with noise std prediction loss terms, but for the time being, I think this shows significantly more promise for resolving trainings for stubborn datasets. I think this might substantially change the recommendations for text encoder/unet learning rate ratios. By restraining the embedding norm, once it's at the right "angle", continued noise losses (due to the unet still learning) won't have the effect of altering the embedding length, and should result in both faster convergence and better aesthetics. |
Beta Was this translation helpful? Give feedback.
-
Until we get this in the Kohya trainer, which I highly doubt Kohya will do much with the trainer these days requiring this level of work, all moot, sadly. |
Beta Was this translation helpful? Give feedback.
-
A couple of new techniques to try: Rank estimation via SVD of the model layersdef get_rank(w, cutoff=0.3):
U, S, V = torch.svd(w.flatten(start_dim=1).to(device="cuda", dtype=torch.float32))
cumsum = S.cumsum(0) / S.sum()
rank = (cumsum > cutoff).nonzero()[0].item()
del S, V, U, cumsum
return max(rank, 4)
class LoRAModule(torch.nn.Module):
# ...
self.lora_dim = get_rank(org_module.weight, target_pct) Here, At cutoff=0.35, I'm getting (as a quick small subset of layers):
These are ~900mb float32 checkpoints (from a SD1.5 model), so obviously on the larger size for a typical LoRA, but the cutoff value could be moved up or down easily enough. Dynamic alpha and pre-computed per-layer alphasAfter the separate layer scale experiments, I've moved back to just training the alpha parameter. If we think of this in terms of a ratio of the lora_dim, then the general algorithm is:
This gives us a way to actually estimate relative importance per layer. After running for 24 epochs, I get something like this. The blue bar is the final alpha ratio for that layer (essentially, what you would multiply lora_dim by to get the actual alpha for that layer). Orange is the shift from initial (1.0, in this case). This is interesting for a couple of reasons. First, it's obvious that not all layers are contributing equally to the learning task. Layers under a given final alpha could actually probably be dropped to conserve parameters trained; some experimentation is warranted, but it's likely that we could slim down the layers trained, or reallocate their parameters to the more impactful layers. Second, and WAY more interestingly, this essentially gives me an estimate for the optimal LR per layer. Remember, my initial LRs were 1e-4 and inital alpha was 1.0. alpha * lr is approximable as the effective learning rate. What this technique does is effectively perform per layer automatic LR adjustment! For example, I can see that my TE layers broadly settle on an alpha of 0.05-0.1, which suggests that the 1e-4 learning rate is 10-20x the ideal LR for these layers relative to the learning rate of the other layers. However, I think this will let me essentially select whatever global LR I want, and to automatically scale layer LRs to each other. By running multiple 24-epoch trainings, taking the checkpoint from the 24th epoch, computing these alphas, and then using them as the initial alphas for the next run, I get dramatic improvements in training quality without adjusting any other hyperparameters. Additionally, the network gets "good" much, much faster with each iterative adjustment of alphas, and the offsets from inital alphas drop, suggesting that there is in fact an ideal set of alphas per layer. The general idea here is just: class LoRAModule:
# ...
alpha = ALPHAS.get(self.lora_name, 1.0) * self.lora_dim
self.alpha = torch.nn.Parameter(torch.tensor(alpha).float()) I'm loading I did have to crank the LR for alphas way up - 1000-5000x the base rate (which works out to 0.1-0.5 for my use case). Despite that, the alphas DO converge. After 24 epochs on the second run: It's worth noting that this is without training the LayerNorm/GroupNorm layers, or the linear/conv2d bias layers. I am training the text embedding layer here, but I'm going to run some tests without it, too. Here's some quick examples of samples from training. First, two of my ground truth images: And here are samples from 2 runs, using the learned alphas from run 1 for the initial alphas of run 2. No other parameters were changed between runs. Columns are samples at epoch 6 and 24, rows are run 1/2. As with most of these experiments, this introduces a potential confounder for all the other lessons learned; if various training problems are caused by certain layers over or under-training relative to each other, then learning alphas first might mitigate many of them. |
Beta Was this translation helpful? Give feedback.
-
I've got one particularly interesting note for people to play with here. tl;dr I think that LoRA training is fundamentally flawed and biases heavily towards doing most of its learning in the largest (by element count) layers of the network. LoRAs consist of, per layer, a project-down matrix and a project-up matrix, of shapes [in, lora_dim] and [lora_dim, out], respectively. In Kohya, the project-down matrix is initialized with However, I think we have a much larger problem: Under the current training mechanism, effective learning rates are inconsistent across layers because of disparity in the sizes of the up and down matrices between layers. This is made worse by the fact that rank selection mutates the effective learning rate of the layer relative to other layers, making it more difficult to evaluate rank selection independently. Let me explain. First, set aside Adam's learning rate adaptation for the moment (I think that's a whole 'nother mess of trouble). When we fine tune a model with full weights, we start with a weight matrix However, when we learn a decomposed approximation of The equivalent update to Here's an exaggerated example: import torch
from torch.nn import functional as F
torch.manual_seed(1)
input_t = torch.randn(960, dtype=torch.float64)
A = torch.randn(64, 960, requires_grad=True, dtype=torch.float64)
B = torch.randn(320, 64, requires_grad=True, dtype=torch.float64)
t = (B@A).clone().detach().requires_grad_(True)
lora_output = F.linear(F.linear(input_t, A), B)
full_output = t @ input_t
print("Matrices match?", torch.allclose(lora_output, full_output))
lora_output.mean().backward()
full_output.mean().backward()
with torch.no_grad():
lr = 100.0
t_u = t.grad * lr
a_u = A.grad * lr
b_u = B.grad * lr
print("t_u.norm", t_u.norm())
print("A_u @ B_u.norm", (b_u @ a_u).norm())
# Matrices match? True
# t_u.norm tensor(176.6249, dtype=torch.float64)
# A_u @ B_u.norm tensor(130243.2076, dtype=torch.float64) You can see that the implied weight update from the LoRA approach is massively larger than the fine-tune update. This is, in short, because we're stepping both A and B by In isolation, this isn't a problem - you just pick a new lr that fits your problem domain. However, this is a very big problem for practical LoRA training, because we're training a whole bunch of layers with different geometry and norms. The effect of this is that the matrices which produce gradients with larger norms will make changes to the output of the model at a significantly faster rate - orders of magnitude, perhaps - than the smaller layers. This essentially guarantees that LoRA training will concentrate most of the learning in those large layers, and will overtrain long before the small layer can begin to exert any significant influence. I'm trying to work out how to compensate for this, but I'm running short of ideas. However, my intuition is that if we can scale the grads of A and B correctly, we can help prevent large layers from dominating training. I've tried just dividing grads by grad.numel() (and grad.numel().sqrt()), but I don't think those are correct yet. Furthermore, because A and B are different parameters, Adam is going to learn different adaptive learning rates for each of them, which I suspect further muddles the problem. Ideally, we would use an Adam variant which takes its first and second moment estimations from |
Beta Was this translation helpful? Give feedback.
-
Recently I came across this approach to calculating loss with multiple objectives: https://github.com/TorchJD/torchjd I wonder if this could improve results when combining MSE loss with std loss instead of a straight sum. Reddit discussion: https://www.reddit.com/r/MachineLearning/comments/1fbvuhs/r_training_models_with_multiple_losses/ |
Beta Was this translation helpful? Give feedback.
-
I've been experimenting with different noising strategies, inspired in part by Noise Offset and Pyramid Noise.
This is the standard implementation of timesteps, which tells the noise scheduler how much of the noise to add to the latents.
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
A sampling from the Uniform distribution [0,1000)
But something very interesting happens when you replace those random timesteps with a constant value, your loss variability is almost none!
(Deterministic training at timestep intervals from [100-900], note the inverse exponential effect on loss)
Judging by our previous expectations of loss, very little training is expected to have occurred, but that is not the case.
(Timesteps 500 [center] and 600 are closest to my subject, with 200 coming in as a surprising third)
I'm still running tests to see what more I can glean from this, but in general I'm experiencing an unprecedented stability in training that I have a hard time explaining.
Beta Was this translation helpful? Give feedback.
All reactions