Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Doubt about dis_opt and gen_opt #68

Open
songbo0925 opened this issue Feb 23, 2021 · 2 comments
Open

Doubt about dis_opt and gen_opt #68

songbo0925 opened this issue Feb 23, 2021 · 2 comments

Comments

@songbo0925
Copy link

In trainer.py, why only update the parameters of dis_a and gen_a and ignore the parameters of dis_b and gen_b?

DG-Net/trainer.py

Lines 242 to 248 in a067be1

dis_params = list(self.dis_a.parameters()) #+ list(self.dis_b.parameters())
gen_params = list(self.gen_a.parameters()) #+ list(self.gen_b.parameters())
self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
lr=lr_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])

@layumi
Copy link
Contributor

layumi commented Feb 23, 2021

Hi @songbo0925
Thanks for your attention on our paper. The weight is shared between dis_a and dis_b, so as gen_a and gen_b.
Therefore, only one optimizer is needed.

@songbo0925
Copy link
Author

Hi @layumi
Thanks for your wonderful work and reply.
In line 180-181 of trainer.py, gen_b is set to be the same as gen_a just this once. But in each iterations, if just the parameters gen_a is updated, then in next forward gen_b and gen_a may have different parameters. So I want to know how you achieve weight sharing?So as dis_a and dis_b. Maybe I ignored some code, please advise.

DG-Net/trainer.py

Lines 180 to 181 in a067be1

self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'], fp16 = False) # auto-encoder for domain a
self.gen_b = self.gen_a # auto-encoder for domain b

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants