Skip to content

Commit

Permalink
rework asym ae training step so batch works i hope
Browse files Browse the repository at this point in the history
  • Loading branch information
neggles committed Jan 24, 2024
1 parent 3327697 commit f723699
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions src/neurosis/models/autoencoder_asym.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,21 +143,32 @@ def get_loss(self, model_output: Tensor, target: Tensor) -> Tensor:
raise ValueError(f"loss type {self.loss_type} not supported")

def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor:
x = self.get_input(batch)
recon = self.forward(x)
loss = self.get_loss(x, recon.sample)
x: Tensor = self.get_input(batch)

self.loss_ema.update(loss.mean().item())
posterior = self.vae.encode(x).latent_dist
z: Tensor = posterior.mode()
kl_loss: Tensor = posterior.kl()

recon: Tensor = self.vae.decode(z).sample
ae_loss = self.get_loss(x, recon)

# update EMA loss tracker
log_loss = ae_loss.mean().detach()
self.loss_ema.update(log_loss.item())

self.log_dict(
{"train/loss": loss.mean(), "train/loss_ema": self.loss_ema.value},
{
"train/loss": log_loss,
"train/loss_ema": self.loss_ema.value,
"train/loss_kl": kl_loss.mean().detach(),
},
on_step=True,
on_epoch=False,
prog_bar=True,
logger=True,
batch_size=x.shape[0],
)
return loss.mean()
return ae_loss.mean()

def configure_optimizers(self) -> OptimizerLRScheduler:
param_groups = []
Expand Down

0 comments on commit f723699

Please sign in to comment.