Skip to content

Commit

Permalink
diff_traj: add dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jul 19, 2024
1 parent 1924db9 commit 6db10ef
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
15 changes: 8 additions & 7 deletions torchdrive/tasks/diff_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from diffusers import EulerDiscreteScheduler
from safetensors.torch import load_model
from torch import nn
from torchvision.transforms.functional import InterpolationMode
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import v2
from torchdrive.amp import autocast
Expand Down Expand Up @@ -621,7 +622,7 @@ def loss(self, emb: torch.Tensor, target: torch.Tensor) -> torch.Tensor:


class ConvNextPathPred(nn.Module):
def __init__(self, dim: int = 256, max_seq_len: int = 18, pool_size: int = 4):
def __init__(self, dim: int = 256, max_seq_len: int = 18, pool_size: int = 4, num_traj: int = 6, dropout: float = 0.1):
super().__init__()

from torchvision.models.convnext import convnext_base, ConvNeXt_Base_Weights
Expand All @@ -630,7 +631,7 @@ def __init__(self, dim: int = 256, max_seq_len: int = 18, pool_size: int = 4):
self.max_seq_len = max_seq_len
# [x, y, log_std1, log_std2, rho]
self.traj_size = 5
self.num_traj = 5
self.num_traj = num_traj

self.encoder = convnext_base(
weights=ConvNeXt_Base_Weights.IMAGENET1K_V1,
Expand All @@ -641,14 +642,17 @@ def __init__(self, dim: int = 256, max_seq_len: int = 18, pool_size: int = 4):
self.static_features_encoder = nn.Sequential(
nn.Linear(1, dim),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout),
nn.Linear(dim, dim),
)

self.decoder = nn.Sequential(
nn.Linear(enc_dim * pool_size * pool_size + dim, max_seq_len * dim),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout),
nn.Linear(max_seq_len * dim, max_seq_len * dim),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout),
nn.Linear(max_seq_len * dim, max_seq_len * self.num_traj * self.traj_size + self.num_traj),
)

Expand Down Expand Up @@ -775,7 +779,7 @@ def __init__(
self.batch_transform = Compose(
NormalizeCarPosition(start_frame=0),
ImageTransform(
v2.RandomRotation(15),
v2.RandomRotation(15, InterpolationMode.BILINEAR),
v2.RandomErasing(),
),
)
Expand Down Expand Up @@ -854,7 +858,7 @@ def forward(
writer=writer,
output=output,
start_frame=0,
weights=1,
weights=None,
scaler=None,
)

Expand Down Expand Up @@ -1088,7 +1092,6 @@ def forward(
fig = plt.figure()

# generate prediction
self.train()

"""
pred_traj = torch.randn_like(noise[:1]) / self.noise_scale
Expand Down Expand Up @@ -1140,8 +1143,6 @@ def forward(
pred_len,
)

self.eval()

fig.legend()
plt.gca().set_aspect("equal")
ctx.add_figure(
Expand Down
7 changes: 3 additions & 4 deletions torchdrive/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ class DatasetConfig:

# dataset only params
dataset: Datasets
train_dataset_path: str
test_dataset_path: str
dataset_path: str
mask_path: str
batch_size: int
num_workers: int
Expand All @@ -52,13 +51,13 @@ def create_dataset(self, smoke: bool = False) -> Tuple[Dataset, Optional[Dataset
from torchdrive.datasets.nuscenes_dataset import NuscenesDataset

dataset = NuscenesDataset(
data_dir=self.train_dataset_path,
data_dir=self.dataset_path,
version="v1.0-mini" if smoke else "v1.0-trainval",
lidar=False,
num_frames=self.num_frames,
)
test_dataset = NuscenesDataset(
data_dir=self.train_dataset_path,
data_dir=self.dataset_path,
version="v1.0-mini" if smoke else "v1.0-test",
lidar=False,
num_frames=self.num_frames,
Expand Down

0 comments on commit 6db10ef

Please sign in to comment.