You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
First thank you for providing us such a nice work!
But I meet a question and really need you help:
In your MeLU.py lines 71-79:
grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
# local update
for i in range(self.weight_len):
if self.weight_name[i] in self.local_update_target_weight_name:
self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i]
else:
self.fast_weights[self.weight_name[i]] = weight_for_local_update[i]
self.model.load_state_dict(self.fast_weights)
query_set_y_pred = self.model(query_set_x)
I understand this is the standard MAML approach (inner loop).
However, the function load_state_dict() will erase (break) the gradient (https://discuss.pytorch.org/t/loading-a-state-dict-seems-to-erase-grad/56676) and thus the global update will no longer consider the local update gradient in the final optimization. So, create_graph=True may not work and the algorithm may not be standard MAML any more. I am wondering whether I lose any insight behind that.
Looking forward to your reply!
The text was updated successfully, but these errors were encountered:
Hi,
First thank you for providing us such a nice work!
But I meet a question and really need you help:
In your MeLU.py lines 71-79:
grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
# local update
for i in range(self.weight_len):
if self.weight_name[i] in self.local_update_target_weight_name:
self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i]
else:
self.fast_weights[self.weight_name[i]] = weight_for_local_update[i]
self.model.load_state_dict(self.fast_weights)
query_set_y_pred = self.model(query_set_x)
I understand this is the standard MAML approach (inner loop).
However, the function load_state_dict() will erase (break) the gradient (https://discuss.pytorch.org/t/loading-a-state-dict-seems-to-erase-grad/56676) and thus the global update will no longer consider the local update gradient in the final optimization. So, create_graph=True may not work and the algorithm may not be standard MAML any more. I am wondering whether I lose any insight behind that.
Looking forward to your reply!
I believe you are right and the original code is wrong.
Hi,
First thank you for providing us such a nice work!
But I meet a question and really need you help:
In your MeLU.py lines 71-79:
I understand this is the standard MAML approach (inner loop).
However, the function load_state_dict() will erase (break) the gradient (https://discuss.pytorch.org/t/loading-a-state-dict-seems-to-erase-grad/56676) and thus the global update will no longer consider the local update gradient in the final optimization. So, create_graph=True may not work and the algorithm may not be standard MAML any more. I am wondering whether I lose any insight behind that.
Looking forward to your reply!
The text was updated successfully, but these errors were encountered: