diff --git a/imagenet.py b/imagenet.py index 012b321..d88431a 100644 --- a/imagenet.py +++ b/imagenet.py @@ -482,11 +482,11 @@ def main(): '{:.2f}' .format((epoch + 1), train_loss_ema, test_loss, 100. * test_acc1)) - corruption_accs = test_c(net, test_transform) - for c in CORRUPTIONS: - print('\t'.join(map(str, [c] + corruption_accs[c]))) + corruption_accs = test_c(net, test_transform) + for c in CORRUPTIONS: + print('\t'.join(map(str, [c] + corruption_accs[c]))) - print('mCE (normalized by AlexNet):', compute_mce(corruption_accs)) + print('mCE (normalized by AlexNet):', compute_mce(corruption_accs)) if __name__ == '__main__':