diff --git a/advanced/main_t2i_highres.py b/advanced/main_t2i_highres.py index 918fd71..c6bad8c 100644 --- a/advanced/main_t2i_highres.py +++ b/advanced/main_t2i_highres.py @@ -137,7 +137,7 @@ def forward(self, x, cond, randomly_augment_x_latent=False): if randomly_augment_x_latent: # this will take B, C, H, W latent and crop it so they are ~ 33% of the original size. b, c, h, w = x.size() - if random.random() < 0.6: + if random.random() < -1: new_w = random.randint(int(w * 0.3333), w) new_h = random.randint(int(h * 0.3333), h) # We dont want very small spatiality. We priotize uniform distibution on w, but h should be large enough. @@ -180,7 +180,7 @@ def forward(self, x, cond, randomly_augment_x_latent=False): else: t = torch.rand((b,)).to(x.device) texp = t.view([b, *([1] * len(x.shape[1:]))]) - #texp = self.t_transform(texp) + texp = self.t_transform(texp) z1 = torch.randn_like(x) zt = (1 - texp) * x + texp * z1 @@ -443,12 +443,14 @@ def main( n_layers=n_layers, n_heads=8, cond_seq_dim=cond_seq_dim, + max_seq= 96 * 96 ), True, ).cuda() if True: statedict = torch.load( "/home/ubuntu/ckpts_36L_2/model_102401/ema1.pt", + #"/home/ubuntu/ckpts_36L_2_highres_freezemost/model_12288/ema1.pt", map_location="cpu", ) # remove model.layers.23.modC.1.weight @@ -590,14 +592,14 @@ def dequantize_t5(tensor): final_optimizer_settings = {} # requires grad for first 2 and last 2 layer - for n, p in rf.named_parameters(): - if "layers" in n: - if any(layername in n for layername in ["layers.0.", "layers.1.", "layers.34.", "layers.35."]): - p.requires_grad = True - else: - p.requires_grad = False - else: - p.requires_grad = True + # for n, p in rf.named_parameters(): + # if "layers" in n: + # if any(layername in n for layername in ["layers.0.", "layers.1.", "layers.34.", "layers.35."]): + # p.requires_grad = True + # else: + # p.requires_grad = False + # else: + # p.requires_grad = True for n, p in rf.named_parameters(): group_parameters = {} @@ -644,7 +646,9 @@ def dequantize_t5(tensor): AdamOptimizer = torch.optim.AdamW optimizer = AdamOptimizer( - optimizer_grouped_parameters, betas=(0.9, 0.95) + #optimizer_grouped_parameters, + rf.parameters(), lr=learning_rate * (32 / hidden_dim), + betas=(0.9, 0.95) ) lr_scheduler = get_scheduler( diff --git a/advanced/mmdit.py b/advanced/mmdit.py index 25278fa..49f2ec7 100644 --- a/advanced/mmdit.py +++ b/advanced/mmdit.py @@ -80,16 +80,16 @@ def __init__(self, dim, n_heads, mh_qknorm=False): self.head_dim = dim // n_heads # this is for cond - self.w1q = nn.Linear(dim, n_heads * self.head_dim, bias=False) - self.w1k = nn.Linear(dim, self.n_heads * self.head_dim, bias=False) - self.w1v = nn.Linear(dim, self.n_heads * self.head_dim, bias=False) - self.w1o = nn.Linear(n_heads * self.head_dim, dim, bias=False) + self.w1q = nn.Linear(dim, dim, bias=False) + self.w1k = nn.Linear(dim, dim, bias=False) + self.w1v = nn.Linear(dim, dim, bias=False) + self.w1o = nn.Linear(dim, dim, bias=False) # this is for x - self.w2q = nn.Linear(dim, n_heads * self.head_dim, bias=False) - self.w2k = nn.Linear(dim, self.n_heads * self.head_dim, bias=False) - self.w2v = nn.Linear(dim, self.n_heads * self.head_dim, bias=False) - self.w2o = nn.Linear(n_heads * self.head_dim, dim, bias=False) + self.w2q = nn.Linear(dim, dim, bias=False) + self.w2k = nn.Linear(dim, dim, bias=False) + self.w2v = nn.Linear(dim, dim, bias=False) + self.w2o = nn.Linear(dim, dim, bias=False) self.q_norm1 = ( MultiHeadLayerNorm((self.n_heads, self.head_dim)) @@ -267,7 +267,7 @@ def __init__( super().__init__() self.t_embedder = TimestepEmbedder(global_conddim) - # self.c_vec_embedder = MLP(cond_vector_dim, global_conddim) + self.cond_seq_linear = nn.Linear( cond_seq_dim, dim, bias=False ) # linear for something like text sequence. @@ -292,18 +292,12 @@ def __init__( nn.SiLU(), nn.Linear(global_conddim, 2 * dim, bias=False), ) - # # init zero nn.init.constant_(self.final_linear.weight, 0) - # nn.init.constant_(self.final_linear.bias, 0) - + self.out_channels = out_channels self.patch_size = patch_size for pn, p in self.named_parameters(): - # if pn.endswith("w1o.weight") or pn.endswith("w2o.weight"): - # # this is muP - # nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * n_layers * dim)) - # if its modulation if "mod" in pn: nn.init.constant_(p, 0) @@ -320,7 +314,6 @@ def extend_pe(self, init_dim=(16, 16), target_dim=(64, 64)): pe_as_2d = pe_data.view(init_dim[0], init_dim[1], -1).permute(2, 0, 1) # now we need to extend this to target_dim. for this we will use interpolation. - # we will use bilinear interpolation. # we will use torch.nn.functional.interpolate pe_as_2d = F.interpolate( pe_as_2d.unsqueeze(0), size=target_dim, mode="bilinear" @@ -375,13 +368,11 @@ def forward(self, x, t, conds, **kwargs): # process conditions for MMDiT Blocks c_seq = conds["c_seq"][0:b] # B, T_c, D_c t = t[0:b] - # c_vec = conds["c_vec"] # B, D_gc + c = self.cond_seq_linear(c_seq) # B, T_c, D c = torch.cat([self.register_tokens.repeat(c.size(0), 1, 1), c], dim=1) - - t_emb = self.t_embedder(t) # B, D - - global_cond = t_emb # B, D + + global_cond = self.t_embedder(t) # B, D for layer in self.layers: c, x = layer(c, x, global_cond, **kwargs) diff --git a/advanced/run_multi_node_resize.sh b/advanced/run_multi_node_resize.sh index 9f2c35a..3462af8 100644 --- a/advanced/run_multi_node_resize.sh +++ b/advanced/run_multi_node_resize.sh @@ -24,11 +24,11 @@ done deepspeed --hostfile=./hostfiles \ main_t2i_highres.py \ - --learning_rate 0.006 \ + --learning_rate 0.018 \ --hidden_dim 2560 \ --n_layers 36 \ - --run_name node-2-highres \ - --save_dir "/home/ubuntu/ckpts_36L_2_highres_freezemost" \ + --run_name node-2-highres-0.18 \ + --save_dir "/home/ubuntu/ckpts_36L_2_highres_lr_0.006" \ --num_train_epochs 200 \ --train_batch_size 256 \ --per_device_train_batch_size 4 \