Skip to content

Commit

Permalink
remove dumb things
Browse files Browse the repository at this point in the history
  • Loading branch information
cloneofsimo committed Jun 7, 2024
1 parent 8913d38 commit 261859e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 36 deletions.
26 changes: 15 additions & 11 deletions advanced/main_t2i_highres.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def forward(self, x, cond, randomly_augment_x_latent=False):
if randomly_augment_x_latent:
# this will take B, C, H, W latent and crop it so they are ~ 33% of the original size.
b, c, h, w = x.size()
if random.random() < 0.6:
if random.random() < -1:
new_w = random.randint(int(w * 0.3333), w)
new_h = random.randint(int(h * 0.3333), h)
# We dont want very small spatiality. We priotize uniform distibution on w, but h should be large enough.
Expand Down Expand Up @@ -180,7 +180,7 @@ def forward(self, x, cond, randomly_augment_x_latent=False):
else:
t = torch.rand((b,)).to(x.device)
texp = t.view([b, *([1] * len(x.shape[1:]))])
#texp = self.t_transform(texp)
texp = self.t_transform(texp)
z1 = torch.randn_like(x)
zt = (1 - texp) * x + texp * z1

Expand Down Expand Up @@ -443,12 +443,14 @@ def main(
n_layers=n_layers,
n_heads=8,
cond_seq_dim=cond_seq_dim,
max_seq= 96 * 96
),
True,
).cuda()
if True:
statedict = torch.load(
"/home/ubuntu/ckpts_36L_2/model_102401/ema1.pt",
#"/home/ubuntu/ckpts_36L_2_highres_freezemost/model_12288/ema1.pt",
map_location="cpu",
)
# remove model.layers.23.modC.1.weight
Expand Down Expand Up @@ -590,14 +592,14 @@ def dequantize_t5(tensor):
final_optimizer_settings = {}

# requires grad for first 2 and last 2 layer
for n, p in rf.named_parameters():
if "layers" in n:
if any(layername in n for layername in ["layers.0.", "layers.1.", "layers.34.", "layers.35."]):
p.requires_grad = True
else:
p.requires_grad = False
else:
p.requires_grad = True
# for n, p in rf.named_parameters():
# if "layers" in n:
# if any(layername in n for layername in ["layers.0.", "layers.1.", "layers.34.", "layers.35."]):
# p.requires_grad = True
# else:
# p.requires_grad = False
# else:
# p.requires_grad = True

for n, p in rf.named_parameters():
group_parameters = {}
Expand Down Expand Up @@ -644,7 +646,9 @@ def dequantize_t5(tensor):
AdamOptimizer = torch.optim.AdamW

optimizer = AdamOptimizer(
optimizer_grouped_parameters, betas=(0.9, 0.95)
#optimizer_grouped_parameters,
rf.parameters(), lr=learning_rate * (32 / hidden_dim),
betas=(0.9, 0.95)
)

lr_scheduler = get_scheduler(
Expand Down
35 changes: 13 additions & 22 deletions advanced/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,16 @@ def __init__(self, dim, n_heads, mh_qknorm=False):
self.head_dim = dim // n_heads

# this is for cond
self.w1q = nn.Linear(dim, n_heads * self.head_dim, bias=False)
self.w1k = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
self.w1v = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
self.w1o = nn.Linear(n_heads * self.head_dim, dim, bias=False)
self.w1q = nn.Linear(dim, dim, bias=False)
self.w1k = nn.Linear(dim, dim, bias=False)
self.w1v = nn.Linear(dim, dim, bias=False)
self.w1o = nn.Linear(dim, dim, bias=False)

# this is for x
self.w2q = nn.Linear(dim, n_heads * self.head_dim, bias=False)
self.w2k = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
self.w2v = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
self.w2o = nn.Linear(n_heads * self.head_dim, dim, bias=False)
self.w2q = nn.Linear(dim, dim, bias=False)
self.w2k = nn.Linear(dim, dim, bias=False)
self.w2v = nn.Linear(dim, dim, bias=False)
self.w2o = nn.Linear(dim, dim, bias=False)

self.q_norm1 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim))
Expand Down Expand Up @@ -267,7 +267,7 @@ def __init__(
super().__init__()

self.t_embedder = TimestepEmbedder(global_conddim)
# self.c_vec_embedder = MLP(cond_vector_dim, global_conddim)

self.cond_seq_linear = nn.Linear(
cond_seq_dim, dim, bias=False
) # linear for something like text sequence.
Expand All @@ -292,18 +292,12 @@ def __init__(
nn.SiLU(),
nn.Linear(global_conddim, 2 * dim, bias=False),
)
# # init zero
nn.init.constant_(self.final_linear.weight, 0)
# nn.init.constant_(self.final_linear.bias, 0)


self.out_channels = out_channels
self.patch_size = patch_size

for pn, p in self.named_parameters():
# if pn.endswith("w1o.weight") or pn.endswith("w2o.weight"):
# # this is muP
# nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * n_layers * dim))
# if its modulation
if "mod" in pn:
nn.init.constant_(p, 0)

Expand All @@ -320,7 +314,6 @@ def extend_pe(self, init_dim=(16, 16), target_dim=(64, 64)):
pe_as_2d = pe_data.view(init_dim[0], init_dim[1], -1).permute(2, 0, 1)

# now we need to extend this to target_dim. for this we will use interpolation.
# we will use bilinear interpolation.
# we will use torch.nn.functional.interpolate
pe_as_2d = F.interpolate(
pe_as_2d.unsqueeze(0), size=target_dim, mode="bilinear"
Expand Down Expand Up @@ -375,13 +368,11 @@ def forward(self, x, t, conds, **kwargs):
# process conditions for MMDiT Blocks
c_seq = conds["c_seq"][0:b] # B, T_c, D_c
t = t[0:b]
# c_vec = conds["c_vec"] # B, D_gc

c = self.cond_seq_linear(c_seq) # B, T_c, D
c = torch.cat([self.register_tokens.repeat(c.size(0), 1, 1), c], dim=1)

t_emb = self.t_embedder(t) # B, D

global_cond = t_emb # B, D

global_cond = self.t_embedder(t) # B, D

for layer in self.layers:
c, x = layer(c, x, global_cond, **kwargs)
Expand Down
6 changes: 3 additions & 3 deletions advanced/run_multi_node_resize.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ done

deepspeed --hostfile=./hostfiles \
main_t2i_highres.py \
--learning_rate 0.006 \
--learning_rate 0.018 \
--hidden_dim 2560 \
--n_layers 36 \
--run_name node-2-highres \
--save_dir "/home/ubuntu/ckpts_36L_2_highres_freezemost" \
--run_name node-2-highres-0.18 \
--save_dir "/home/ubuntu/ckpts_36L_2_highres_lr_0.006" \
--num_train_epochs 200 \
--train_batch_size 256 \
--per_device_train_batch_size 4 \
Expand Down

0 comments on commit 261859e

Please sign in to comment.