This repository has been archived by the owner on Jan 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 63
/
teaching_task.h
165 lines (126 loc) · 4.74 KB
/
teaching_task.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
// Copyright (c) 2017 Baidu Inc. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <unordered_map>
#include "simulator.h"
namespace simulator {
struct BenchmarkRes {
int successes;
int failures;
int success_steps;
BenchmarkRes()
: successes(0), failures(0), success_steps(0) {}
BenchmarkRes(int succ, int fail, int succ_steps)
: successes(succ), failures(fail), success_steps(succ_steps) {}
BenchmarkRes& operator+=(const BenchmarkRes& br) {
successes += br.successes;
failures += br.failures;
success_steps += br.success_steps;
return *this;
}
};
/**
A task is a Finite State Machine (FSM) that has several stages (states).
Each stage can be programmed to jump to another stage.
At each stage, the teacher performs an action, whether it being
a sentence sent to the agent, or some changes to the environment.
Each stage corresponds to a time step.
A wrapper class that use the embedded Python class for defining tasks
The python users are responsible for implementing
1. all the stage functions
2. the grammar for generating the sentences
3. the reward of each stage
**/
class Task {
public:
Task(std::string name,
TeachingEnvPtr game)
: game_(game),
name_(name),
current_stage_("idle") {
init_py_task();
register_stages();
}
const std::string& name() const { return name_; }
void reset() {
py_task_.attr("reset")();
current_stage_ = "idle";
}
bool is_idle() const { return current_stage_ == "idle"; }
std::string current_stage() const { return name_ + ": " + current_stage_; }
// run the current stage
void run_stage();
void obtain_performance(BenchmarkRes& br);
size_t total_possible_sentences();
void teacher_speak(const std::string& sentence);
protected:
// teacher calls this function to give the agent a reward
void give_reward(double reward) { game_->add_teacher_reward(reward); }
TeachingEnvPtr game_; // The environment the teacher and the agent reside in
std::string name_; // A string that identifies the task
// stores all the stage functions
std::unordered_map<std::string, std::function<std::string()>> stages_;
private:
void init_py_task();
// decide which stages to register
void register_stages();
// Call a python stage function and process its outputs
std::string py_stage(const std::string& stage_name);
// Convert a simulator_entity to a dictionary in Python
boost::python::dict convert_entity_to_py_entity(const Entity& e);
// This object holds the Python task class
boost::python::object py_task_;
std::string current_stage_; // Maintains the current stage of the task
};
typedef std::shared_ptr<Task> TaskPtr;
/**
A task group is a collection of tasks that have the same group id.
There can only be one busy task in a group at any moment
**/
class TaskGroup {
public:
TaskGroup(std::string name,
std::string schedule,
TeachingEnvPtr game)
: name_(name),
schedule_(schedule),
game_(game),
busy_task_(nullptr) { // points to a task that is not in idle stage
}
// add a task with a sampling weight
// when schedule="weighted", the tasks in a group are sampled in prop to
// their weights
void add_task(const std::string& task, double weight);
bool is_idle();
void reset();
std::string name() { return name_; }
// call run_stage of the busy task
void run_stage();
size_t total_possible_sentences();
std::string current_stage();
// fill in benchmark with task performance
void report_task_performance(
std::unordered_map<std::string, BenchmarkRes>& benchmark);
static std::unordered_map<
std::string,
std::function<TaskPtr(TeachingEnvPtr, const std::vector<std::string>&)>>
create_tasks_;
private:
std::vector<TaskPtr> task_list_;
std::vector<double> task_weights_;
std::string name_;
std::string schedule_;
TeachingEnvPtr game_;
TaskPtr busy_task_;
};
typedef std::shared_ptr<TaskGroup> TaskGroupPtr;
} // namespace simulator