From 3d7361c2d028473b601cc04f5eecd019e14eb4eb Mon Sep 17 00:00:00 2001 From: janner Date: Mon, 17 Oct 2022 13:41:29 -0700 Subject: [PATCH] use variance in gradient guide instead of std --- diffuser/sampling/functions.py | 3 ++- slurm/plan.sh | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/diffuser/sampling/functions.py b/diffuser/sampling/functions.py index 33f67b3a..d09eca9a 100644 --- a/diffuser/sampling/functions.py +++ b/diffuser/sampling/functions.py @@ -12,13 +12,14 @@ def n_step_guided_p_sample( ): model_log_variance = extract(model.posterior_log_variance_clipped, t, x.shape) model_std = torch.exp(0.5 * model_log_variance) + model_var = torch.exp(model_log_variance) for _ in range(n_guide_steps): with torch.enable_grad(): y, grad = guide.gradients(x, cond, t) if scale_grad_by_std: - grad = model_std * grad + grad = model_var * grad grad[t < t_stopgrad] = 0 diff --git a/slurm/plan.sh b/slurm/plan.sh index f73151ac..c9672a4d 100755 --- a/slurm/plan.sh +++ b/slurm/plan.sh @@ -14,7 +14,7 @@ do python -u scripts/plan_guided.py \ --logbase logs/pretrained \ --dataset $env-$buffer-v2 \ - --prefix plans/reference \ + --prefix plans/reference_var \ --vis_freq 500 \ --verbose False \ --suffix {1} \