Skip to content

Commit

Permalink
fixed code formatting
Browse files Browse the repository at this point in the history
Signed-off-by: Zeeshan Patel <[email protected]>
  • Loading branch information
zpx01 committed Oct 11, 2024
1 parent a1cffca commit a42f7d2
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 10 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/diffusion/data/diffusion_taskencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def encode_sample(self, sample: dict) -> dict:
t5_text_embeddings_seq_length = t5_text_embeddings.shape[0]

if t5_text_embeddings_seq_length > self.text_embedding_padding_size:
t5_text_embeddings = t5_text_embeddings[:self.text_embedding_padding_size]
t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size]
else:
t5_text_embeddings = F.pad(
t5_text_embeddings,
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/diffusion/models/dit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
2 changes: 2 additions & 0 deletions nemo/collections/diffusion/models/dit/dit_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def get_pos_emb_on_this_cp_rank(pos_emb, seq_dim):
pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :])
return pos_emb


class SinCosPosEmb3D(MegatronModule):
"""
SinCosPosEmb3D is a 3D sine-cosine positional embedding module.
Expand Down Expand Up @@ -136,6 +137,7 @@ def forward(self, pos_ids: torch.Tensor):
pos_id = pos_ids[..., 0] * self.h * self.w + pos_ids[..., 1] * self.w + pos_ids[..., 2]
return self.pos_embedding(pos_id)


class FactorizedLearnable3DEmbedding(MegatronModule):
def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/diffusion/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,14 @@ class DiTXLConfig(DiTConfig):
hidden_size: int = 1152
num_attention_heads: int = 16


@dataclass
class DiT7BConfig(DiTConfig):
num_layers: int = 32
hidden_size: int = 3072
num_attention_heads: int = 24


@dataclass
class DiTLlama30BConfig(DiTConfig):
num_layers: int = 48
Expand Down
19 changes: 11 additions & 8 deletions nemo/collections/diffusion/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
from nemo.collections.diffusion.data.diffusion_taskencoder import BasicDiffusionTaskEncoder
from nemo.collections.diffusion.models.model import (
DiT7BConfig,
DiTLlama5BConfig,
DiTLlama30BConfig,
DiTConfig,
DiTLConfig,
DiTLlama5BConfig,
DiTLlama30BConfig,
DiTModel,
DiTXLConfig,
)
Expand Down Expand Up @@ -81,10 +81,11 @@ def pretrain() -> run.Partial:
context_parallel_size=1,
sequence_parallel=False,
pipeline_dtype=torch.bfloat16,
ddp=run.Config(DistributedDataParallelConfig,
ddp=run.Config(
DistributedDataParallelConfig,
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
),
grad_reduce_in_fp32=True,
),
),
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
num_sanity_val_steps=0,
Expand Down Expand Up @@ -153,12 +154,13 @@ def pretrain_7b() -> run.Partial:
recipe.trainer.val_check_interval = 1000
recipe.log.log_dir = 'nemo_experiments/dit7b'
recipe.optim.lr_scheduler = run.Config(nl.lr_scheduler.WarmupHoldPolicyScheduler, warmup_steps=100, hold_steps=1e9)
recipe.optim.config.weight_decay=0.1
recipe.optim.config.adam_beta1=0.9
recipe.optim.config.adam_beta2=0.95
recipe.optim.config.weight_decay = 0.1
recipe.optim.config.adam_beta1 = 0.9
recipe.optim.config.adam_beta2 = 0.95

return recipe


@run.cli.factory(target=llm.train)
def pretrain_ditllama5b() -> run.Partial:
recipe = pretrain_7b()
Expand All @@ -167,6 +169,7 @@ def pretrain_ditllama5b() -> run.Partial:
recipe.log.log_dir = 'nemo_experiments/ditllama5b'
return recipe


@run.cli.factory(target=llm.train)
def pretrain_ditllama30b() -> run.Partial:
recipe = pretrain_ditllama5b()
Expand Down

0 comments on commit a42f7d2

Please sign in to comment.