From 6982e88a963ccc8517867b8acb1948e49cf60063 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jan 2020 11:14:34 -0800 Subject: [PATCH] torch.autograd.profiler.emit_nvtx to show operators --- .gitignore | 2 ++ tools/training-benchmark-nsys-profile.py | 27 ++++++++++++++---------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 6c3e2fc81..527c05ed2 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,5 @@ dist *.swo /download /download.tar.xz +*.qdrep +*.qdstrm diff --git a/tools/training-benchmark-nsys-profile.py b/tools/training-benchmark-nsys-profile.py index a4fc363de..6456689a9 100644 --- a/tools/training-benchmark-nsys-profile.py +++ b/tools/training-benchmark-nsys-profile.py @@ -147,7 +147,9 @@ def enable_timers(model): enable_timers(model) torch.cuda.cudart().cudaProfilerStart() - if total_batch_counter >= WARM_UP_BATCHES: + PROFILING_STARTED = (total_batch_counter >= WARM_UP_BATCHES) + + if PROFILING_STARTED: torch.cuda.nvtx.range_push("batch{}".format(total_batch_counter)) true_energies = batch_y['energies'].to(parser.device) @@ -155,14 +157,15 @@ def enable_timers(model): num_atoms = [] for j, (chunk_species, chunk_coordinates) in enumerate(batch_x): - if total_batch_counter >= WARM_UP_BATCHES: + if PROFILING_STARTED: torch.cuda.nvtx.range_push("chunk{}".format(j)) chunk_species = chunk_species.to(parser.device) chunk_coordinates = chunk_coordinates.to(parser.device) num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1)) - _, chunk_energies = model((chunk_species, chunk_coordinates)) + with torch.autograd.profiler.emit_nvtx(enabled=PROFILING_STARTED, record_shapes=True): + _, chunk_energies = model((chunk_species, chunk_coordinates)) predicted_energies.append(chunk_energies) - if total_batch_counter >= WARM_UP_BATCHES: + if PROFILING_STARTED: torch.cuda.nvtx.range_pop() num_atoms = torch.cat(num_atoms) @@ -170,21 +173,23 @@ def enable_timers(model): loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() rmse = hartree2kcal((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy() - if total_batch_counter >= WARM_UP_BATCHES: + if PROFILING_STARTED: torch.cuda.nvtx.range_push("backward") - loss.backward() - if total_batch_counter >= WARM_UP_BATCHES: + with torch.autograd.profiler.emit_nvtx(enabled=PROFILING_STARTED, record_shapes=True): + loss.backward() + if PROFILING_STARTED: torch.cuda.nvtx.range_pop() - if total_batch_counter >= WARM_UP_BATCHES: + if PROFILING_STARTED: torch.cuda.nvtx.range_push("optimizer.step()") - optimizer.step() - if total_batch_counter >= WARM_UP_BATCHES: + with torch.autograd.profiler.emit_nvtx(enabled=PROFILING_STARTED, record_shapes=True): + optimizer.step() + if PROFILING_STARTED: torch.cuda.nvtx.range_pop() progbar.update(i, values=[("rmse", rmse)]) - if total_batch_counter >= WARM_UP_BATCHES: + if PROFILING_STARTED: torch.cuda.nvtx.range_pop() total_batch_counter += 1