Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep model_avg on cpu #348

Merged
merged 10 commits into from
May 7, 2022
10 changes: 5 additions & 5 deletions egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def greedy_search(
context_size = model.decoder.context_size
unk_id = getattr(model, "unk_id", blank_id)

device = model.device
device = next(model.parameters()).device

decoder_input = torch.tensor(
[blank_id] * context_size, device=device, dtype=torch.int64
Expand Down Expand Up @@ -350,7 +350,7 @@ def greedy_search_batch(
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)

device = model.device
device = next(model.parameters()).device

batch_size = encoder_out.size(0)
T = encoder_out.size(1)
Expand Down Expand Up @@ -580,7 +580,7 @@ def modified_beam_search(
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
device = next(model.parameters()).device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
B[i].add(
Expand Down Expand Up @@ -705,7 +705,7 @@ def _deprecated_modified_beam_search(
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size

device = model.device
device = next(model.parameters()).device

T = encoder_out.size(1)

Expand Down Expand Up @@ -813,7 +813,7 @@ def beam_search(
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size

device = model.device
device = next(model.parameters()).device

decoder_input = torch.tensor(
[blank_id] * context_size,
Expand Down
3 changes: 1 addition & 2 deletions egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def decode_one_batch(
Return the decoding result. See above description for the format of
the returned dict.
"""
device = model.device
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3

Expand Down Expand Up @@ -560,7 +560,6 @@ def main():

model.to(device)
model.eval()
model.device = device

if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
Expand Down
26 changes: 13 additions & 13 deletions egs/librispeech/ASR/pruned_transducer_stateless4/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def get_parser():
"--start-epoch",
type=int,
default=1,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
help="""Resume training from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
Expand Down Expand Up @@ -479,7 +479,7 @@ def load_checkpoint_if_available(

def save_checkpoint(
params: AttributeDict,
model: nn.Module,
model: Union[nn.Module, DDP],
model_avg: Optional[nn.Module] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
Expand Down Expand Up @@ -529,7 +529,7 @@ def save_checkpoint(

def compute_loss(
params: AttributeDict,
model: nn.Module,
model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
Expand All @@ -553,7 +553,11 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = model.device
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
Expand Down Expand Up @@ -609,7 +613,7 @@ def compute_loss(

def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
Expand Down Expand Up @@ -643,7 +647,7 @@ def compute_validation_loss(

def train_one_epoch(
params: AttributeDict,
model: nn.Module,
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor,
Expand Down Expand Up @@ -857,6 +861,7 @@ def run(rank, world_size, args):
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model)

assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
Expand All @@ -865,11 +870,6 @@ def run(rank, world_size, args):
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank])
model.device = device

if rank == 0:
model_avg.to(device)
model_avg.device = device

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested with this change?
I wonder whether you need to do .to(torch.device('cpu')) when updating this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have tested it. model_avg is on cpu default, since it is a deepcopy of model before the line model.to(device).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK.

optimizer = Eve(model.parameters(), lr=params.initial_lr)

Expand Down Expand Up @@ -990,7 +990,7 @@ def remove_short_and_long_utt(c: Cut):


def scan_pessimistic_batches_for_oom(
model: nn.Module,
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
Expand Down
4 changes: 3 additions & 1 deletion icefall/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,5 +467,7 @@ def average_state_dict(
uniqued_names = list(uniqued.values())
for k in uniqued_names:
state_dict_1[k] *= weight_1
state_dict_1[k] += state_dict_2[k] * weight_2
state_dict_1[k] += (
state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
)
state_dict_1[k] *= scaling_factor