Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scaling from 128x128, to 256x256, 512x512 and 1024x1024? #95

Open
tin-sely opened this issue Feb 5, 2024 · 4 comments
Open

Scaling from 128x128, to 256x256, 512x512 and 1024x1024? #95

tin-sely opened this issue Feb 5, 2024 · 4 comments

Comments

@tin-sely
Copy link

tin-sely commented Feb 5, 2024

hey,

loved your paper and thanks a bunch for providing the code!

i have a quick question, how do you scale and train the network (HDiT) for increased resolutions? i saw you mentioned here: #14 (comment) that you first need to build the entire network, and then skip layers but i'm not sure if this also applies to this new architecture?

many thanks!

@tin-sely
Copy link
Author

tin-sely commented Feb 6, 2024

it looks like it's not meant for progressive scaling? i guess the best option would be to train a lower resolution and then copy the relevant weights to a higher-res network

another thing i was curious about was the inputs:

def forward(self, x, sigma, aug_cond=None, class_cond=None, mapping_cond=None):

x, sigma, and class_cond are clear, but do you have any more details on aug_cond and mapping_cond?

@madebyollin
Copy link

@tin-sely I believe aug_cond is for non-leaky augmentations. When an input image is augmented during training, a description of how that image was augmented is also given to the generator (as aug_cond - augmentation conditioning), so that the generator eventually learns how to generate either augmented or non-augmented images depending the value of the aug_cond input.

I believe mapping_cond is an older name for aug_cond which is used in the non-transformer model configs (the ones that use KarrasAugmentWrapper - which takes the aug_cond tensor and gives it to the model as mapping_cond)

@tin-sely
Copy link
Author

tin-sely commented Feb 7, 2024

thanks a bunch @madebyollin! ✨

@mnslarcher
Copy link

My understanding is that you use aug_cond when you wish to provide the model with information about the augmentations using Fourier Features:

self.aug_emb = layers.FourierFeatures(9, mapping.width)

self.aug_in_proj = Linear(mapping.width, mapping.width, bias=False)

aug_emb = self.aug_in_proj(self.aug_emb(aug_cond))

On the other hand, if you use mapping_cond, the condition will be fed directly into a linear layer, as shown here:

self.mapping_cond_in_proj = Linear(mapping_cond_dim, mapping.width, bias=False) if mapping_cond_dim else None

mapping_emb = self.mapping_cond_in_proj(mapping_cond) if self.mapping_cond_in_proj is not None else 0

These embeddings are then both fed into the MappingNetwork:

cond = self.mapping(time_emb + aug_emb + class_emb + mapping_emb)

But getting more clarity on this would definitely help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants