Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【GAN不适配PaddleScience的封装训练】 #989

Closed
zhudequan9 opened this issue Sep 19, 2024 · 2 comments
Closed

【GAN不适配PaddleScience的封装训练】 #989

zhudequan9 opened this issue Sep 19, 2024 · 2 comments

Comments

@zhudequan9
Copy link
Contributor

需求描述 Feature Description

GAN需要在单个epoch的train中,先更新判别器,每隔几个批次更新一次生成器。但是,solver.train()方法封装了全部epochs的训练,无法在内部设置。

替代实现 Alternatives

替代实现:

  1. 设置Solver的参数epochs =1iters_per_epoch = 1;这样solver.train()就只训练了一个batch。
  2. 通过交替调用solver_dis.train()sovler_gen.train(),可以实现GAN的训练。

缺点:

  1. 由于每个epoch训练后都会强制保存checkpoint。这导致每训练一个batch就会保存一次,很影响训练性能。
  2. GAN训练两个模型,需要两个Optimizer,但是一个solver只有一个Optimizer,所以需要定义两个solver,同时也会有两个DataLoader。(内存占用大,重复读取数据,并且数据shuffle后,同一批次内,两个模型训练的数据不同)
@HydrogenSulfate
Copy link
Collaborator

感谢反馈!近期我会调研下GAN的通用训练流程,预期会在train.py里加一个def train_GAN_epoch_func函数来完成GAN的训练。

@HydrogenSulfate
Copy link
Collaborator

HydrogenSulfate commented Sep 25, 2024

需求描述 Feature Description

GAN需要在单个epoch的train中,先更新判别器,每隔几个批次更新一次生成器。但是,solver.train()方法封装了全部epochs的训练,无法在内部设置。

替代实现 Alternatives

替代实现:

  1. 设置Solver的参数epochs =1iters_per_epoch = 1;这样solver.train()就只训练了一个batch。
  2. 通过交替调用solver_dis.train()sovler_gen.train(),可以实现GAN的训练。

缺点:

  1. 由于每个epoch训练后都会强制保存checkpoint。这导致每训练一个batch就会保存一次,很影响训练性能。
  2. GAN训练两个模型,需要两个Optimizer,但是一个solver只有一个Optimizer,所以需要定义两个solver,同时也会有两个DataLoader。(内存占用大,重复读取数据,并且数据shuffle后,同一批次内,两个模型训练的数据不同)

@9zdq 你好,我根据已有的tempoGAN、标准的GAN训练流程

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义生成器模型
class Generator(nn.Module):
    def __init__(self, z_dim=100, img_dim=1*28*28):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, img_dim),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.gen(x)

# 定义判别器模型
class Discriminator(nn.Module):
    def __init__(self, img_dim=1*28*28):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)

# 设置一些超参数
z_dim = 100
lr = 0.0002
batch_size = 64
num_epochs = 50
img_dim = 1*28*28

# 准备数据集和数据加载器
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root='dataset/', transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 初始化生成器和判别器
generator = Generator(z_dim, img_dim).to('cuda')
discriminator = Discriminator(img_dim).to('cuda')

# 定义损失函数和优化器
criterion = nn.BCELoss()
opt_gen = optim.Adam(generator.parameters(), lr=lr)
opt_disc = optim.Adam(discriminator.parameters(), lr=lr)

# 训练GAN
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.view(-1, img_dim).to('cuda')
        batch_size = real.size(0)

        # 训练判别器
        noise = torch.randn(batch_size, z_dim).to('cuda')
        fake = generator(noise)
        disc_real = discriminator(real).view(-1)
        disc_fake = discriminator(fake).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        discriminator.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        # 训练生成器
        output = discriminator(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        generator.zero_grad()
        lossG.backward()
        opt_gen.step()

    print(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {lossD:.4f}, loss G: {lossG:.4f}")

以及pytorch-lighning对GAN的支持方式,个人认为GAN的训练形式非常灵活,如果封装到PaddleScience中让用户去填写对应的代码、配置内容,反而会增加开发成本,考虑到ppsci下的API,除了ppsci.solver,其余的模块基本都能单独使用,所以建议GAN这里一类的模型直接用ppsci.*+paddle.*两类API在案例文件里手写训练代码,这样是最快的方式。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants