-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ValueError: expected sequence of length 361 at dim 1 (got 208)[BUG] #190
Comments
Hi @nomawni, this tutorial has not been maintained since years. Nevertheless, your bug should be fixed easily: Here in line 60, you currently have You should have instead: You should follow the same pattern in line 70. I would recommend you to rather look into the training script of PPO since this is maintained and working and it is basically the same code as in the tutorial :) see here: https://github.com/ConvLab/ConvLab-3/blob/master/convlab/policy/ppo/train.py |
@ChrisGeishauser Thanks. |
Hi @ChrisGeishauser,
|
@ChrisGeishauser could you help me with the error that I mentioned previously? |
Hi, I am facing this following error, while optimizing the PPO policy: ValueError: expected sequence of length 361 at dim 1 (got 208)
I first simulated a dialogue between the user and system. Which then I gathered and tried to train the policy using the code provided in the following link: https://github.com/ConvLab/ConvLab-3/blob/master/tutorials/Train_RL_Policies/example_train.py. Here is how to reproduce the error:
user_nlu = BERTNLU()
not use dst
user_dst = RuleDST()
rule policy
user_policy = RulePolicy(character='usr')
template NLG
user_nlg = TemplateNLG(is_user=True)
assemble
user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
sys_nlu = BERTNLU()
sys_nlg = TemplateNLG(is_user=False)
sys_dst = RuleDST()
#sys_policy = RulePolicy(character="sys")
sys_policy = PPO(is_train=True, seed=0, dataset_name='multiwoz21')
sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name="sys")
evaluator = MultiWozEvaluator()
session = BiSession(sys_agent, user_agent, kb_query=None, evaluator=evaluator)
simulator = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
env = Environment(sys_nlg, simulator, sys_nlu, sys_dst, evaluator=evaluator)
def set_seed(r_seed):
random.seed(r_seed)
np.random.seed(r_seed)
torch.manual_seed(r_seed)
set_seed(20200131)
num_dialogues = 20
dialog_data = []
def determine_success(final_reward, success_threshold=0):
return 1 if final_reward >= success_threshold else 0
sys_response = ''
success_threshold = 0
success_labels = []
for i in range(num_dialogues):
session.init_session()
dialog = []
session_over = False
cumulative_reward = 0
print(f"Turn round {i}")
print()
while not session_over:
response = session.next_turn(sys_response)
system_response, user_response, session_over, reward = response
dialog.append((system_response, user_response))
cumulative_reward += reward
print(f"User: {user_response}")
print(f"System: {system_response}")
print(f" Reward: {reward}")
print()
success_label = determine_success(cumulative_reward, success_threshold)
dialog_data.append((dialog, success_label))
Get the success of the evaluator
success = session.evaluator.task_success()
success_labels.append(success)
print(f"Task success: {session.evaluator.task_success()}")
print(f"Book rate: {session.evaluator.book_rate()}")
print('inform precision/recall/f1:', session.evaluator.inform_F1())
print('-'*50)
print('final goal:')
pprint(session.evaluator.goal)
print('='*100)
And then use the code in the following link to try to train the policy and reproduce the error : https://github.com/ConvLab/ConvLab-3/blob/master/tutorials/Train_RL_Policies/example_train.py (for i in range(epoch):
print(f'We are currently in epoch {i} in the training of PPO process')
update(env, sys_policy, batchsz, i, process_num))
Here is the full error that I am facing:
ValueError Traceback (most recent call last)
Cell In[50], line 3
1 if name == "main":
----> 3 main()
5 policy.save("trained_dialogue_policy")
Cell In[49], line 49
47 for i in range(epoch):
48 print(f'We are currently in epoch {i} in the training of PPO process')
---> 49 update(env, sys_policy, batchsz, i, process_num)
Cell In[10], line 189
186 def update(env, policy, batchsz, epoch, process_num):
187 # sample data asynchronously
188 print(f'In the update method, before the sampling')
--> 189 batch = sample(env, policy, batchsz, process_num)
191 print(f'In the update method after the sampling method')
193 # data in batch is : batch.state: ([1, s_dim], [1, s_dim]...)
194 # batch.action: ([1, a_dim], [1, a_dim]...)
195 # batch.reward/ batch.mask: ([1], [1]...)
Cell In[10], line 174
171 evt = Event() # This can be a simple threading event.
173 # Call the sampler function directly instead of starting new processes.
...
---> 68 s_vec = torch.Tensor(policy.vector.state_vectorize(s))
69 a = policy.predict(s)
71 # interact with env
ValueError: expected sequence of length 361 at dim 1 (got 208)
The text was updated successfully, but these errors were encountered: