-
Notifications
You must be signed in to change notification settings - Fork 371
/
prompt_pg.py
210 lines (189 loc) · 9.05 KB
/
prompt_pg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
from typing import List, Dict, Any, Tuple, Union
from collections import namedtuple
import torch
from ding.rl_utils import get_train_sample
from ding.torch_utils import Adam, to_device
from ding.utils import POLICY_REGISTRY, split_data_generator
from ding.utils.data import default_collate, default_decollate
from .base_policy import Policy
from ..model import model_wrap
@POLICY_REGISTRY.register('prompt_pg')
class PromptPGPolicy(Policy):
r"""
Overview:
Policy class of Prompt Policy Gradient (PromptPG) algorithm.
Link of the original paper: https://arxiv.org/abs/2209.14610
"""
config = dict(
# (string) RL policy register name (refer to function "register_policy").
type='prompt_pg',
# (bool) whether to use cuda for network.
cuda=True,
# (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
on_policy=True, # for pg strictly on policy algorithm, this line should not be modified by users
# (bool) whether to use deterministic action for evaluation.
deterministic_eval=True,
# (int) The number of actions that can be done simultaneously in one timestep.
shot_number=1,
learn=dict(
# (int) the number of samples for one update.
batch_size=64,
# (float) the step size of one gradient descend.
learning_rate=0.001,
# ==============================================================
# The following configs is algorithm-specific
# ==============================================================
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
entropy_weight=0.01,
# (float) max grad norm value.
grad_norm=5,
# (bool) whether to ignore done signal for non-termination env.
ignore_done=False,
),
collect=dict(
# (int) collect n_sample data, train model n_iteration times
# n_episode=8,
# (int) trajectory unroll length
unroll_len=1,
# ==============================================================
# The following configs is algorithm-specific
# ==============================================================
# (float) discount factor for future reward, defaults int [0, 1]
discount_factor=0,
collector=dict(get_train_sample=True),
),
eval=dict(),
)
def default_model(self) -> Tuple[str, List[str]]:
return 'language_transformer', ['ding.model.template.language_transformer']
def _init_learn(self) -> None:
r"""
Overview:
Learn mode init method. Called by ``self.__init__``.
Init the optimizer, algorithm config, main and target models.
"""
# Optimizer
self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
self._entropy_weight = self._cfg.learn.entropy_weight
self._grad_norm = self._cfg.learn.grad_norm
self._learn_model = self._model # for compatibility
def _forward_learn(self, data: dict) -> Dict[str, Any]:
r"""
Overview:
Forward and backward function of learn mode.
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward']
Returns:
- info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
"""
self._model.train()
return_infos = []
for i in range(0, len(data), self._cfg.learn.batch_size):
batch = default_collate(data[i:i + self._cfg.learn.batch_size])
if self._cuda:
batch = to_device(batch, self._device)
# Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
train_samples, cand_samples = batch["obs"]["train_sample"], batch["obs"]["candidate_samples"]
for ii in range(len(cand_samples)):
cand_samples[ii] = cand_samples[ii][0]
output = self._learn_model.forward(train_samples, cand_samples)
return_ = batch['return']
# calculate PG loss
real_act = batch['action'] # shape: (B, shot_number)
if len(real_act.shape) == 1:
real_act = real_act.unsqueeze(-1)
# Calculate loss.
total_policy_loss, total_entropy_loss = 0, 0
for ii in range(self._cfg.shot_number):
log_prob = output['dist'].log_prob(real_act[:, ii])
policy_loss = -(log_prob * return_).mean()
total_policy_loss += policy_loss
total_entropy_loss += -self._cfg.learn.entropy_weight * output['dist'].entropy().mean()
total_loss = total_entropy_loss + total_policy_loss
# update
self._optimizer.zero_grad()
total_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
list(self._learn_model.parameters()),
max_norm=self._grad_norm,
)
self._optimizer.step()
# only record last updates information in logger
return_info = {
'cur_lr': self._optimizer.param_groups[0]['lr'],
'total_loss': total_loss.item(),
'policy_loss': total_policy_loss.item(),
'entropy_loss': total_entropy_loss.item(),
'return_abs_max': return_.abs().max().item(),
'grad_norm': grad_norm,
}
return_infos.append(return_info)
return return_infos
def _init_collect(self) -> None:
self._unroll_len = self._cfg.collect.unroll_len
self._gamma = self._cfg.collect.discount_factor
self._collect_model = model_wrap(self._model, wrapper_name='combination_multinomial_sample')
def _forward_collect(self, data: dict) -> dict:
data_id = list(data.keys())
data = default_collate(list(data.values()))
self._model.eval()
with torch.no_grad():
# Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
for ii in range(len(data['candidate_samples'])):
data['candidate_samples'][ii] = data['candidate_samples'][ii][0]
output = self._collect_model.forward(self._cfg.shot_number, data['train_sample'], data['candidate_samples'])
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}
def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
r"""
Overview:
Generate dict type transition data from inputs.
Arguments:
- obs (:obj:`Any`): Env observation
- model_output (:obj:`dict`): Output of collect model, including at least ['action']
- timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
(here 'obs' indicates obs after env step).
Returns:
- transition (:obj:`dict`): Dict type transition data.
"""
return {
'obs': obs,
'action': model_output['action'],
'reward': timestep.reward,
'done': timestep.done,
}
def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
r"""
Overview:
Get the trajectory and the n step return data, then sample from the n_step return data
Arguments:
- data (:obj:`list`): The trajectory's buffer list
Returns:
- samples (:obj:`dict`): The training samples generated
"""
if self._cfg.learn.ignore_done:
raise NotImplementedError
R = 0.
for i in reversed(range(len(data))):
R = self._gamma * R + data[i]['reward']
data[i]['return'] = R
return get_train_sample(data, self._unroll_len)
def _init_eval(self) -> None:
self._eval_model = model_wrap(self._model, wrapper_name='combination_argmax_sample')
def _forward_eval(self, data: dict) -> dict:
data_id = list(data.keys())
data = default_collate(list(data.values()))
self._model.eval()
with torch.no_grad():
# Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
for ii in range(len(data['candidate_samples'])):
data['candidate_samples'][ii] = data['candidate_samples'][ii][0]
output = self._eval_model.forward(self._cfg.shot_number, data['train_sample'], data['candidate_samples'])
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}
def _monitor_vars_learn(self) -> List[str]:
return super()._monitor_vars_learn() + ['policy_loss', 'entropy_loss', 'return_abs_max', 'grad_norm']