Skip to content

Commit

Permalink
[RLlib] Issue 28849: DT fails with num_gpus=1. (#31297)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored and AmeerHajAli committed Jan 12, 2023
1 parent b0d6b92 commit 422e636
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions rllib/models/torch/mingpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,17 @@ def __init__(self, config: GPTConfig):
dropout=nn.Dropout(config.resid_pdrop),
)
)
m = self.mlp
# MLP forward
self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x))))

def forward(self, x, attention_masks=None):
# Multi-head attention sub-layer.
x_att, att = self.attn(self.ln_1(x), attention_masks=attention_masks)
# Residual of multi-head attention sub-layer.
x = x + x_att
x = x + self.mlpf(self.ln_2(x))

# Position-wise FFN sub-layer: fc + activation + fc + dropout
x_ffn = self.mlp.dropout(self.mlp.c_proj(self.mlp.act(self.mlp.c_fc(x))))
# Residual of position-wise FFN sub-layer.
x = x + x_ffn
return x, att


Expand Down

0 comments on commit 422e636

Please sign in to comment.