Skip to content

Commit

Permalink
det,path: add more dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Nov 14, 2023
1 parent b5d8962 commit 36413e9
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 10 deletions.
13 changes: 10 additions & 3 deletions torchdrive/models/det_deform.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ def __init__(
self.num_queries = num_queries

self.query_embed = nn.Embedding(num_queries, dim)
self.reference_points_project = nn.Linear(dim, 2)
self.reference_points_project = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(dim, 2),
)

decoder_layer = DeformableTransformerDecoderLayer(
d_model=dim,
Expand All @@ -58,8 +61,11 @@ def __init__(
num_layers=num_layers,
)

self.bbox_decoder = MLP(dim, dim, 9, num_layers=4)
self.class_decoder = nn.Conv1d(dim, num_classes + 1, 1)
self.bbox_decoder = MLP(dim, dim, 9, num_layers=4, dropout=dropout)
self.class_decoder = nn.Sequential(
nn.Dropout(dropout),
nn.Conv1d(dim, num_classes + 1, 1),
)

bev_encoders = []
for i in range(num_levels - 1, -1, -1):
Expand All @@ -78,6 +84,7 @@ def __init__(
group_width=dim,
bottleneck_multiplier=1.0,
),
nn.Dropout(dropout),
nn.Conv2d(dim, dim, 1),
LearnedPositionalEncoding2d((h * 2**i, w * 2**i), dim),
)
Expand Down
13 changes: 10 additions & 3 deletions torchdrive/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ class MLP(nn.Module):
"""

def __init__(
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
dropout: float = 0.0,
) -> None:
super().__init__()

Expand All @@ -21,13 +26,15 @@ def __init__(

for i in range(num_layers - 2):
layers += [
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(inplace=True),
]

layers.append(
layers += [
nn.Dropout(dropout),
nn.Linear(hidden_dim, output_dim),
)
]

self.decoder = nn.Sequential(*layers)

Expand Down
18 changes: 16 additions & 2 deletions torchdrive/models/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
group_width=bev_dim,
bottleneck_multiplier=1.0,
),
nn.Dropout(dropout),
nn.Conv2d(bev_dim, dim, 1),
LearnedPositionalEncoding2d(bev_shape, dim),
)
Expand All @@ -147,22 +148,27 @@ def __init__(
nn.Sequential(
nn.Linear(pos_dim, dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(dim, dim),
LearnedPositionalEncodingSeq(max_seq_len, dim),
)
)
# pyre-fixme[4]: Attribute must be annotated.
self.pos_decoder = compile_fn(
nn.Sequential(
nn.Dropout(dropout),
nn.Linear(dim, dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(dim, pos_dim),
)
)

static_features = 2 * pos_dim
# pyre-fixme[4]: Attribute must be annotated.
self.static_encoder = compile_fn(MLP(static_features, dim, dim, num_layers=3))
self.static_encoder = compile_fn(
MLP(static_features, dim, dim, num_layers=3, dropout=dropout)
)

self.transformer = StockTransformerDecoder(
dim=dim,
Expand Down Expand Up @@ -267,6 +273,7 @@ def __init__(
self.query_embed = nn.Embedding(num_queries, dim)

self.bev_encoder = nn.Sequential(
nn.Dropout(dropout),
models.regnet.AnyStage(
bev_dim,
bev_dim,
Expand All @@ -278,6 +285,7 @@ def __init__(
group_width=bev_dim,
bottleneck_multiplier=1.0,
),
nn.Dropout(dropout),
nn.Conv2d(bev_dim, dim, 1),
LearnedPositionalEncoding2d(bev_shape, dim),
)
Expand All @@ -288,10 +296,13 @@ def __init__(
# pyre-fixme[4]: Attribute must be annotated.
self.pos_decoder = compile_fn(
nn.Sequential(
nn.Dropout(dropout),
nn.Linear(decoder_dim, inter_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(inter_dim, inter_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(inter_dim, out_dim),
)
)
Expand All @@ -301,11 +312,14 @@ def __init__(
dim_feedforward=dim * 4,
nhead=num_heads,
batch_first=True,
dropout=dropout,
)
self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

# pyre-fixme[4]: Attribute must be annotated.
self.static_encoder = compile_fn(MLP(static_features, dim, dim, num_layers=3))
self.static_encoder = compile_fn(
MLP(static_features, dim, dim, num_layers=3, dropout=dropout)
)

transformer_init(self)

Expand Down
4 changes: 2 additions & 2 deletions torchdrive/tasks/det.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def forward(

if len(unmatched_classes) > 0:
losses["unmatched"] = (
F.cross_entropy(unmatched_classes, target_classes) * 40
F.cross_entropy(unmatched_classes, target_classes) * 20
)

if ctx.log_img:
Expand All @@ -366,7 +366,7 @@ def forward(
ctx.add_scalar("classes/accuracy", self.accuracy.compute())
self.accuracy.reset()

losses = {k: v for k, v in losses.items()}
losses = {k: v * 5 for k, v in losses.items()}

return losses

Expand Down

0 comments on commit 36413e9

Please sign in to comment.