Skip to content

Commit

Permalink
[Doc example fix] Add zero_grad to the second-order optim example
Browse files Browse the repository at this point in the history
  • Loading branch information
fKunstner committed Aug 14, 2020
1 parent c3171ba commit f28ed2b
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions docs_src/examples/use_cases/example_diag_ggn_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import matplotlib.pyplot as plt

BATCH_SIZE = 128
STEP_SIZE = 0.01
STEP_SIZE = 0.05
DAMPING = 1.0
MAX_ITER = 200
PRINT_EVERY = 50
Expand Down Expand Up @@ -96,7 +96,7 @@ def step(self):
for group in self.param_groups:
for p in group["params"]:
step_direction = p.grad / (p.diag_ggn_mc + group["damping"])
p.data.add_(-group["step_size"], step_direction)
p.data.add_(step_direction, alpha=-group["step_size"])


# %%
Expand All @@ -117,6 +117,9 @@ def step(self):
accuracies = []
for batch_idx, (x, y) in enumerate(mnist_loader):
x, y = x.to(DEVICE), y.to(DEVICE)

model.zero_grad()

outputs = model(x)
loss = loss_function(outputs, y)

Expand Down

0 comments on commit f28ed2b

Please sign in to comment.