diff --git a/README.md b/README.md index 87f881b1440d5..9358d3b3f825b 100644 --- a/README.md +++ b/README.md @@ -68,11 +68,17 @@ We encapsulate each compression and search method to a compression strategy clas ### Knowledge Distillation -- PaddleSlim supports the following losses added on any paired layers between teacher and student models: - - Flow of the solution procedure (FSP) loss. - - L2 loss. +- **Naive knowledge distillation**: transfers dark knowledge by merging the teacher and student model into the same Program, and supports the following losses added on any paired layers between teacher and student models: + - Flow of the solution procedure (FSP) loss; + - L2 loss; - Softmax with cross-entropy loss. +- **Paddle large-scale scalable knowledge distillation framework [Pantheon](paddleslim/pantheon)**: a universal solution for knowledge distillation, more flexible than the naive knowledge distillation, and easier to scale to the large-scale applications. + - Decouple the teacher and student models --- they run in different processes in the same or different nodes, and transfer knowledge via TCP/IP ports or local files; + - Friendly to assemble multiple teacher models and each of them can work in either online or offline mode independently; + - Merge knowledge from different teachers and make batch data for the student model automatically; + - Support the large-scale knowledge prediction of teacher models on multiple devices. + ### Lightweight Network Architecture Search (Light-NAS) - PaddleSlim provides Simulated Annealing (SA)-based lightweight network architecture search method. diff --git a/demo/pantheon/README.md b/demo/pantheon/README.md new file mode 100644 index 0000000000000..3cc55c3387673 --- /dev/null +++ b/demo/pantheon/README.md @@ -0,0 +1,2 @@ + +The toy examples for Pantheon, see details in [PaddleSlim/Pantheon](../../paddleslim/pantheon). diff --git a/demo/pantheon/run_student.py b/demo/pantheon/run_student.py new file mode 100644 index 0000000000000..19d9b20665a41 --- /dev/null +++ b/demo/pantheon/run_student.py @@ -0,0 +1,103 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from paddleslim.pantheon import Student + +from utils import str2bool + + +def parse_args(): + parser = argparse.ArgumentParser(__doc__) + parser.add_argument( + "--in_address0", + type=str, + default=None, + help="Input address for teacher 0. (default: %(default)s)") + parser.add_argument( + "--in_path0", + type=str, + default=None, + help="Input file path for teacher 0. (default: %(default)s)") + parser.add_argument( + "--in_address1", + type=str, + default=None, + help="Input address for teacher 1. (default: %(default)s)") + parser.add_argument( + "--in_path1", + type=str, + default=None, + help="Input file path for teacher 1. (default: %(default)s)") + parser.add_argument( + "--test_send_recv", + type=str2bool, + default=False, + help="Whether to test send/recv interfaces. (default: %(default)s)") + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="The batch size of student model. (default: %(default)s)") + args = parser.parse_args() + return args + + +def run(args): + if args.in_address0 and args.in_path0: + raise ValueError( + "args.in_address0 and args.in_path0 should not be valid " + "at the same time!") + if not args.in_address0 and not args.in_path0: + raise ValueError( + "One of args.in_address0 and args.in_path0 must be valid!") + + if args.in_address1 and args.in_path1: + raise ValueError( + "args.in_address1 and args.in_path1 should not be valid " + "at the same time!") + if not args.in_address1 and not args.in_path1: + raise ValueError( + "One of args.in_address1 and args.in_path1 must be valid") + + student = Student(merge_strategy={"result": "sum"}) + + student.register_teacher( + in_address=args.in_address0, in_path=args.in_path0) + student.register_teacher( + in_address=args.in_address1, in_path=args.in_path1) + student.start() + + if args.test_send_recv: + for t in xrange(2): + for i in xrange(3): + print(student.recv(t)) + student.send("message from student!") + + knowledge_desc = student.get_knowledge_desc() + data_generator = student.get_knowledge_generator( + batch_size=args.batch_size, drop_last=False) + for batch_data in data_generator(): + batch_size = list(batch_data.values())[0].shape[0] + keys = batch_data.keys() + for i in range(batch_size): + data = {} + for key in keys: + data[key] = batch_data[key][i] + print(data) + + +if __name__ == '__main__': + args = parse_args() + run(args) diff --git a/demo/pantheon/run_teacher1.py b/demo/pantheon/run_teacher1.py new file mode 100644 index 0000000000000..bbe94310b1de8 --- /dev/null +++ b/demo/pantheon/run_teacher1.py @@ -0,0 +1,80 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.fluid as fluid + +from utils import parse_args, sample_generator, sample_list_generator, batch_generator +from paddleslim.pantheon import Teacher + + +def run(args): + if args.out_path and args.out_port: + raise ValueError("args.out_path and args.out_port should not be valid " + "at the same time") + if not args.out_path and not args.out_port: + raise ValueError("One of args.out_path and args.out_port be valid") + + # user-defined program: y = 2*x - 1 + startup = fluid.Program() + program = fluid.Program() + with fluid.program_guard(program, startup): + inp_x = fluid.layers.data(name='x', shape=[-1, 1], dtype="int64") + y = inp_x * 2 - 1 + result = fluid.layers.assign(y) + + place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup) + + teacher = Teacher(out_path=args.out_path, out_port=args.out_port) + teacher.start() + + if args.generator_type == "sample_generator": + reader_config = { + "sample_generator": sample_generator(max_n=1000), + "batch_size": args.batch_size, + "drop_last": False + } + elif args.generator_type == "sample_list_generator": + reader_config = { + "sample_list_generator": sample_list_generator( + max_n=1000, batch_size=args.batch_size) + } + else: + reader_config = { + "batch_generator": batch_generator( + max_n=1000, batch_size=args.batch_size) + } + + if args.test_send_recv: + teacher.send("greetings from teacher1") + teacher.send({"x": 1, "y": 2}) + teacher.send({3, 5}) + print("recved {}".format(teacher.recv())) + + teacher.start_knowledge_service( + feed_list=[inp_x.name], + schema={"x": inp_x, + "2x-1": y, + "result": result}, + program=program, + reader_config=reader_config, + exe=exe, + times=args.serving_times) + + +if __name__ == '__main__': + args = parse_args() + run(args) diff --git a/demo/pantheon/run_teacher2.py b/demo/pantheon/run_teacher2.py new file mode 100644 index 0000000000000..5d45fec92bbce --- /dev/null +++ b/demo/pantheon/run_teacher2.py @@ -0,0 +1,79 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.fluid as fluid + +from utils import parse_args, sample_generator, sample_list_generator, batch_generator +from paddleslim.pantheon import Teacher + + +def run(args): + if args.out_path and args.out_port: + raise ValueError("args.out_path and args.out_port should not be valid " + "at the same time") + if not args.out_path and not args.out_port: + raise ValueError("One of args.out_path and args.out_port be valid") + + # user-defined program: y = 2*x + 1 + startup = fluid.Program() + program = fluid.Program() + with fluid.program_guard(program, startup): + inp_x = fluid.layers.data(name='x', shape=[-1, 1], dtype="int64") + y = inp_x * 2 + 1 + result = fluid.layers.assign(y) + + place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup) + + teacher = Teacher(out_path=args.out_path, out_port=args.out_port) + teacher.start() + + if args.generator_type == "sample_generator": + reader_config = { + "sample_generator": sample_generator(max_n=1000), + "batch_size": args.batch_size, + "drop_last": False + } + elif args.generator_type == "sample_list_generator": + reader_config = { + "sample_list_generator": sample_list_generator( + max_n=1000, batch_size=args.batch_size) + } + else: + reader_config = { + "batch_generator": batch_generator( + max_n=1000, batch_size=args.batch_size) + } + + if args.test_send_recv: + teacher.send("greetings from teacher2") + teacher.send([1]) + teacher.send({1, 2, 3}) + print("recved {}".format(teacher.recv())) + + teacher.start_knowledge_service( + feed_list=[inp_x.name], + schema={"2x+1": y, + "result": result}, + program=program, + reader_config=reader_config, + exe=exe, + times=args.serving_times) + + +if __name__ == '__main__': + args = parse_args() + run(args) diff --git a/demo/pantheon/utils.py b/demo/pantheon/utils.py new file mode 100644 index 0000000000000..af88d2a699db0 --- /dev/null +++ b/demo/pantheon/utils.py @@ -0,0 +1,91 @@ +import numpy as np +import argparse + + +def str2bool(v): + return v.lower() in ("true", "t", "1") + + +def parse_args(): + parser = argparse.ArgumentParser(__doc__) + parser.add_argument( + "--out_port", + type=int, + default=None, + help="IP port number for sending out data. (default: %(default)s)") + parser.add_argument( + "--out_path", + type=str, + default=None, + help="The file path to dump knowledge data. (default: %(default)s)") + parser.add_argument( + "--use_cuda", + type=str2bool, + default=False, + help="Whether to use GPU for prediction. (default: %(default)s)") + parser.add_argument( + "--test_send_recv", + type=str2bool, + default=False, + help="Whether to test send/recv interfaces. (default: %(default)s)") + parser.add_argument( + "--generator_type", + type=str, + choices=[ + "sample_generator", "sample_list_generator", "batch_generator" + ], + default="batch_generator", + help="Which data generator to use. (default: %(default)s)") + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="The batch size per device for data generators. (default: %(default)s)" + ) + parser.add_argument( + "--serving_times", + type=int, + default=1, + help="The maximum times of teacher serving knowledge. (default: %(default)s)" + ) + args = parser.parse_args() + return args + + +def sample_generator(max_n): + def wrapper(): + for i in range(max_n): + yield [i] + + return wrapper + + +def sample_list_generator(max_n, batch_size=500): + def wrapper(): + sample_list = [] + for sample in sample_generator(max_n)(): + if len(sample_list) < batch_size: + sample_list.append(sample) + if len(sample_list) == batch_size: + yield sample_list + sample_list = [] + if len(sample_list) > 0: + yield sample_list + + return wrapper + + +# data_generator +def batch_generator(max_n, batch_size=500): + def wrapper(): + batch = [] + for sample in sample_generator(max_n)(): + if len(batch) < batch_size: + batch.append(sample) + if len(batch) == batch_size: + yield [np.array(batch).astype('int64').reshape((-1, 1))] + batch = [] + if len(batch) > 0: + yield [np.array(batch).astype('int64').reshape((-1, 1))] + + return wrapper diff --git a/docs/docs/api/api_guide.md b/docs/docs/api/api_guide.md index 79910a06f3bc5..650bfc3bee57f 100644 --- a/docs/docs/api/api_guide.md +++ b/docs/docs/api/api_guide.md @@ -8,6 +8,8 @@ - [单进程蒸馏](./single_distiller_api.md) +- [大规模可扩展知识蒸馏框架 Pantheon](./pantheon_api.md) + - [通道剪裁](./prune_api.md) ### [量化](./quantization_api.md) diff --git a/docs/docs/api/pantheon_api.md b/docs/docs/api/pantheon_api.md new file mode 100644 index 0000000000000..b78c506987ab6 --- /dev/null +++ b/docs/docs/api/pantheon_api.md @@ -0,0 +1,256 @@ +## Teacher + +pantheon.Teacher()[source code](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/pantheon/teacher.py#L78) + +: The class defined for the teacher model. Generate knowledge data and transfer them to the student model. + +**Args:** + +- **out\_path (str|None)** - The path to dump knowledge data for offline mode. + +- **out\_port (int|None)** - The IP port number to send out knowledge for online mode, should be unique when launching multiple teachers in the same node. + +**Return:** An object of class Teacher + + +pantheon.Teacher.start()[source code](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/pantheon/teacher.py#L133) + +: Start teacher service, sychronize with student and launch the thread + to monitor commands from student. + +**Args:** None + +**Return:** None + + +pantheon.Teacher.send(data)[source code](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/pantheon/teacher.py#L181) + +: Send one data object to student. + +**Args:** + +- **data (Python data):** - The data to be sent, can be any type of Python data object. + +**Return:** None + + +pantheon.Teacher.recv()[source code](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/pantheon/teacher.py#L196) + +: Recieve one data object from student. + +**Args:** None + +**Return:** + +- The received data, can be any type of Python data object. + + +pantheon.Teacher.dump(knowledge)[source code](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/pantheon/teacher.py#L214) + +: Dump one batch knowledge data into the output file, only used in the offline mode. + +**Args:** + +- **knowledge (dict):** - The knowledge data to be dumped. + +**Return:** None + + +pantheon.Teacher.start\_knowledge\_service(feed\_list, schema, program, reader\_config, exe, buf\_size=10, times=1)[source code](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/pantheon/teacher.py#L259) + +: Start the knowledge service to generate and transfer knowledge data. In GPU mode, the devices to execute knowledge prediction will be determined by the + environment variable **FLAGS\_selected\_gpus**, or by **CUDA\_VISIBLE\_DEVICES** if it is not set, and by **CPU\_NUM** (default 1) in CPU mode. Only supported in static graph. + + **Args:** + + - **feed\_list (list):** - A list of feed Variables or their names for the + input teacher Program. + - **schema (dict):** - A dictionary to specify keys and fetched Variables + to generate knowledge. + - **program (fluid.Program):** - Inference Program of the teacher model. + - **reader\_config (dict):** - The config for data reader. Support all the three types of generators used by [fluid.io.PyReader](https://www.paddlepaddle.org.cn/documentation/docs/en/api/io/PyReader.html) and [fluid.io.DataLoader](https://www.paddlepaddle.org.cn/documentation/docs/en/api/io/DataLoader.html#dataloader), and their configs contain the key-value pair of the generator type and a generator object, plus other necessary argument pairs. See the following: + + - 1) sample generator: + + reader\_config={"sample\_generator": #some\_sample\_generator, + "batch\_size": #batch\_size, "drop\_last": #drop\_last}, + + 'drop\_last' set to True by default, + - 2) sample list generator: + + reader\_config={"sample\_list\_generator": #some\_sample\_list\_generator}, + - 3) batch generator: + + reader\_config={"batch\_generator": #some\_batch\_genrator}. + + The trial to parse config will be in the order of 1) -> 3), and any other unrelated keys in these configs will be ignored. + +- **exe (fluid.Executor):** The executor to run the input program. +- **buf\_size (int):** The size of buffers for data reader and knowledge + writer on each device. +- **times (int):** The maximum repeated serving times, default 1. Whenever + the public method **get\_knowledge\_generator()** in Student + object called once, the serving times will be added one, + until reaching the maximum and ending the service. + +**Return:** None + +**Examples:** + +Note: this example should be run with the example of class **Student**. + +```python +import paddle +import paddle.fluid as fluid +from paddleslim.pantheon import Teacher + +startup = fluid.Program() +program = fluid.Program() +with fluid.program_guard(program, startup): + images = fluid.data( + name='pixel', shape=[None, 3 * 32 * 32], dtype='float32') + labels = fluid.data(name='label', shape=[None, 1], dtype='int64') + logits = fluid.layers.fc(input=images, size=10) + loss = fluid.layers.softmax_with_cross_entropy(logits, labels) + +place = fluid.CPUPlace() +exe = fluid.Executor(place) +exe.run(startup) + +train_reader = paddle.batch( + paddle.dataset.cifar.train10(), batch_size=32) + +teacher = Teacher(out_path="example_knowledge.dat", # offline mode + #out_port=5000 # online mode + ) +teacher.start() + +teacher.start_knowledge_service( + feed_list=[images, labels], + schema={"logits": logits, + "labels": labels}, + program=program, + reader_config={"sample_list_generator": train_reader}, + exe=exe) +``` + + +## Student + +pantheon.Student(merge_strategy=None)[source code](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/pantheon/student.py#L34) + +: The class defined for the student model. Receive knowledge data from + teacher model and carry out knowledge merging. + + **Args:** + + - **merge\_strategy (dict|None):** - A dictionary whose keys are the common schemas shared by different teachers, and each corresponding value specifies the merging strategy for different schemas respectively, supporting **sum** and **mean** now. + +**Return:** An object of class Student. + + +pantheon.Student.register\_teacher(in\_path=None, in\_address=None)[source code](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/pantheon/student.py#L72) + +: Register one teacher model and assign the order number to it as its id, with the file path (offline mode) or IP address (online mode) that the teacher model writes knowledge data to. + +**Args:** + +- **in\_path (str|None):** The input file path. Default None. +- **in\_address (str|None):** The input IP address, in the format "\:\" (e.g. "127.0.0.1:8080"). Default None. + +**Return:** None + + +pantheon.Student.start()[source code](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/pantheon/student.py#L213) + +: End teachers' registration and synchronize with all of them. + +**Args:** None + +**Return:** None + +pantheon.Student.send(self, data, teacher_ids=None)[source code](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/pantheon/student.py#L240) + +: Send data to teachers. + +**Args:** + +- **data (Python data):** - A Python data object to be sent. +- **teacher_ids (list|None):** - A list of teacher ids to send data. If set to None, send the data to all teachers. Default None. + +**Return:** None + +pantheon.Student.recv(teacher_id)[source code](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/pantheon/student.py#L262) + +: Receive data from one teacher. + + **Args:** + +- **teacher\_id (int):** - The id of teacher that receives data from. + +**Return:** + +- The received data object. + +pantheon.Student.get\_knowledge\_desc()[source code](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/pantheon/student.py#L283) + + : Get description for knowledge, including shape, data type and lod level for each schema. + + **Args:** None + + **Return:** + + - Knowledge description, which is a dict. + + +pantheon.Student.get\_knowledge\_qsize()[source code](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/pantheon/student.py#L318) + + : Get the real-time size of knowledge queue. If this size is denoted as + **qsize**, it means that there are **qsize** batch knowledge data + already pushed into knowledge queue and waiting for the knowledge + generator to pop out. It's dynamic and limited up to 100, the capacity + of the knowledge queue. + + **Args:** None + + **Return:** + + - The real-time size of knowledge queue. + +pantheon.Student.get\_knowledge\_generator(batch\_size, drop\_last=False)[source code](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/pantheon/student.py#L334) + +: Get the generator for knowledge data, return None if last generator doesn't finish yet. + +**Args:** + +- **batch\_size (int):** - The batch size of returned knowledge data. +- **drop\_last (bool):** - Whether to drop the last batch if its size is less than batch size. + +**Return:** + +- The wrapper of knowledge data generator. + +**Examples:** + +Note: this example should be run with the example of class **Teacher**. + +```python +from paddleslim.pantheon import Student + +student = Student() + +student.register_teacher(in_path="example_knowledge.dat", # offline mode + #in_address="127.0.0.1:5000" # online mode + ) +student.start() + +knowledge_desc = student.get_knowledge_desc() +data_generator = student.get_knowledge_generator( + batch_size=128, drop_last=True) + +# get knowledge data +for knowledge in data_generator(): + print("knowledge queue size: {}".format(student.get_knowledge_qsize())) + + # do something else +``` diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index d5f725a8c3331..ad81dea7374bd 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -16,7 +16,8 @@ nav: - 量化: api/quantization_api.md - 剪枝与敏感度: api/prune_api.md - 模型分析: api/analysis_api.md - - 知识蒸馏: api/single_distiller_api.md + - 简单知识蒸馏: api/single_distiller_api.md + - 大规模可扩展知识蒸馏框架 Pantheon: api/pantheon_api.md - SA搜索: api/nas_api.md - One-shot搜索: api/one_shot_api.md - 搜索空间: search_space.md diff --git a/paddleslim/pantheon/README.md b/paddleslim/pantheon/README.md new file mode 100644 index 0000000000000..a9f564dc73ccd --- /dev/null +++ b/paddleslim/pantheon/README.md @@ -0,0 +1,252 @@ +# Pantheon: Paddle large-scale scalable knowledge distillation framework + +Pantheon is a universal solution for knowledge distillation in Paddle Fluid. Its design takes account of many possible behaviors of teacher models. Every teacher and student model in Pantheon works in different processes and they communicate with each other via local files or TCP/IP ports. The knowledge can be easily transferred to the student model from a single teacher model or the ensemble of multiple teacher models, in which each teacher model can work in online or offline mode independently. And Pantheon also provides a highly optimized interface for the large-scale prediction of teacher models. Beneficial from the low coupling of teachers and the student, users can allocate computation resources for different roles dependent on their computation complexity, and build a large-scale and practical knowledge distillation learning system on Pantheon. + +The illustration below shows an application of Pantheon, where the sudent model is trained with knowledge from multiple online teachers. These teachers may work on the same node but different devices, or different nodes with the student model, as long as they can communicate with each other via the Internet. The student model can send queries to teachers, and the latter take these queries as input and generate streaming knowledge data for the former. Or in a simpler way, the student model can read the training data in the **same order** with the teachers, avoiding the procedure of sending queryies. + + +
+
+ The architecture for one online knowledge distillation system based on Pantheon +
+ +## Prerequisites + +- Python 2.7.x or 3.x +- PaddlePaddle >= 1.6.0 + +## APIs + +Pantheon defines two classes **Teacher** and **Student** for the communication and knowledge transfer between teacher and student. + +- **Teacher**: used by the teacher model. Can receive queries from student and write out the knowledge from teacher model via TCP/IP port (online mode) or into a local file (offline mode). +- **Student**: used by the student model. Can receive and merge the knowledge from teachers, and feed the student model along with local data for training. + +Usually, the public methods of these two classes work in the pairwise way. Their mapping relations and suitable working modes are listed in the following table. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
TeacherStudentSupported GraphModeremarks
staticdynamiconlineoffline
__init__(
    out_path=None,
    out_port=None)
__init__(
    merge_strategy=None)
[1]
register_teacher( +
    in_path=None, +
    in_address=None) +
[2]
start()start()
[3]
send(data)recv(teacher_id)
[4]
recv()send(data,
    +  teacher_ids=None) +
[5]
dump(knowledge)
[6]
start_knowledge_service( +
    feed_list, +
    schema, +
    program, +
    reader_config, +
    exe, +
    buf_size=10, +
    times=1)
get_knowledge_desc()
[7]
get_knowledge_qsize()
get_knowledge_generator(
    batch_size, +
    drop_last=False)
+ +**Remarks:** + + - [1] Decalre the teacher object for teacher model with **out\_path** or **out\_port**, and the student for student model with **merge\_strategy** for knowledge from different teachers. + - [2] Register a teacher, and allocate an id for it which starts from zero in the order of registration. **register\_teacher()** can be called many times for multiple-teacher mode. + - [3] Estabish TCP/IP link between teachers and the student, and synchronize all of them. + - [4] Send one data from teacher to student. + - [5] Send one data from student to teacher. + - [6] Dump one batch knowledge data into the output file. + - [7] Highly optimized high-level interfaces to build service for knowledge transfer: + - **start\_knowledge\_service()** can perform large-scale prediction of teacher model on multiple devices; + - Support auto merging of knowledge from different teachers; + - Support auto reconnection of student and teachers. + +### About the data format + +- **Knowledge**: A dictionary with the keys specified by users and the values that are numpy ndarray tensors predicted by teacher models. The first dimension of tensors should be batch size and LoDTensor is not supported yet. One can call **get\_knowledge\_desc()** to get the description of knowledge, which is also a dictionary, including the shape, data type and LoD level about knowledge data. +- **Offline knowledge file**: The first line is knowledge description, and the following lines are knowledge data, one line for one batch samples, all dumped by cPickle. + + + +### Usage + +If separately runnable teacher models and the student model +have been ready, basically one can build the trainable system with knowledge +distillation by following two simple steps. + +1) Instantiate a **Teacher** object for the teacher model, and launch knowledge serving + +```python + +from paddleslim.pantheon import Teacher +... + +teacher = Teacher(out_path=args.out_path, out_port=args.out_port) +teacher.start() + +teacher.start_knowledge_service( + feed_list=[inp_x.name], + schema={"x": inp_x, + "y": y}, + program=program, + reader_config={"batch_generator": batch_generator}, + exe=exe, + buf_size=100, + times=1) +``` + +2) Instantiate a **Student** object, specify the way to merge knowledge, register teachers, + and get knowledge description and data generator for the student model + +```python +from paddleslim.pantheon import Student +... + +student = Student(merge_strategy={"result": "sum"}) + +student.register_teacher( + in_address=args.in_address0, in_path=args.in_path0) +student.register_teacher( + in_address=args.in_address1, in_path=args.in_path1) +student.start() + +knowledge_desc = student.get_knowledge_desc() +data_generator = student.get_knowledge_generator( + batch_size=32, drop_last=False) +``` + +### Example + +Here provide a toy example to show how the knowledge data is transferred from teachers to the student model and merged. + +In the directory [demo/pantheon/](../../demo/pantheon/), there implement two teacher models (not trainable, just for demo): teacher1 takes an integer **x** as input and predicts value **2x-1**, see in [run_teacher1.py](../../demo/pantheon/run_teacher1.py); teacher2 also takes **x** as input and predicts **2x+1**, see in [run_teacher2.py](../../demo/pantheon/run_teacher2.py). They two share a data reader to read a sequence of increasing natural numbers from zero to some positive inter **max_n** as input and generate different knowledge. And the schema keys for knowledge in teacher1 is [**"x", "2x-1", "result"**], and [**"2x+1", "result"**] for knowledge in teacher2, in which **"result"** is the common schema and the copy of two predictions respectively. On instantiating the **Student** object, the merging strategy for the common schema **"result"** should be specified, and the schema keys for the merged knowledge will be [**"x", "2x-1", "2x+1", "result"**], with the merged **"result"** equal to **"2x"** when the merging strategy is **"mean"** and **"4x"** when merging strategy is **"sum"**. The student model gets merged knowledge from teachers and prints them out, see in [run_student.py](../../demo/pantheon/run_student.py). + +The toy "knowledge distillation" system can be launched in three different modes, i.e., offline, online and their hybrid. All three modes should have the same outputs, and the correctness of results can be verified by checking the order and values of outputs. + +1) **Offline** + + The two teachers work in offline mode, and start them with given local file paths. + + ```shell +export PYTHONPATH=../../:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0,1 +nohup python -u run_teacher1.py --use_cuda true --out_path teacher1_offline.dat > teacher1_offline.log 2>&1& +export CUDA_VISIBLE_DEVICES=2 +nohup python -u run_teacher2.py --use_cuda true --out_path teacher2_offline.dat > teacher2_offline.log 2>&1& + ``` + After the two executions both finished, start the student model with the two generated knowledge files. + + ```shell +export PYTHONPATH=../../:$PYTHONPATH + python -u run_student.py \ + --in_path0 teacher1_offline.dat \ + --in_path1 teacher2_offline.dat + ``` + + +2) **Online** + +The two teachers work in online mode, and start them with given TCP/IP ports. Please make sure that the ICP/IP ports are available. + +```shell +export PYTHONPATH=../../:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +nohup python -u run_teacher1.py --use_cuda true --out_port 8080 > teacher1_online.log 2>&1& +export CUDA_VISIBLE_DEVICES=1,2 +nohup python -u run_teacher2.py --use_cuda true --out_port 8081 > teacher2_online.log 2>&1& +``` +Start the student model with the IP addresses that can reach the ports of the two teacher models, e.g., in the same node + +```shell +export PYTHONPATH=../../:$PYTHONPATH +python -u run_student.py \ + --in_address0 127.0.0.1:8080 \ + --in_address1 127.0.0.1:8081 \ +``` +**Note:** in online mode, the starting order of teachers and the sudent doesn't matter, and they will wait for each other to establish connection. + +3) **Hybrid of offline and online** + +One teacher works in offline mode and another one works in online mode. This time, start the offline teacher first. After the offline knowledge file gets well prepared, start the online teacher and the student at the same time. diff --git a/paddleslim/pantheon/__init__.py b/paddleslim/pantheon/__init__.py new file mode 100644 index 0000000000000..bcc99e781b6d2 --- /dev/null +++ b/paddleslim/pantheon/__init__.py @@ -0,0 +1,4 @@ +from .teacher import Teacher +from .student import Student + +__all__ = teacher.__all__ + student.__all__ diff --git a/paddleslim/pantheon/images/pantheon_arch.png b/paddleslim/pantheon/images/pantheon_arch.png new file mode 100644 index 0000000000000..d88fc11f144e2 Binary files /dev/null and b/paddleslim/pantheon/images/pantheon_arch.png differ diff --git a/paddleslim/pantheon/student.py b/paddleslim/pantheon/student.py new file mode 100644 index 0000000000000..073e186585435 --- /dev/null +++ b/paddleslim/pantheon/student.py @@ -0,0 +1,484 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six +import time +if six.PY2: + import cPickle as pickle +else: + import pickle + +import numpy as np +from collections import OrderedDict +from multiprocessing import Process, Manager +from multiprocessing.managers import BaseManager + +from threading import Thread + +from paddleslim.pantheon.utils import EndSignal, SyncSignal, StartSignal, public_authkey + +__all__ = ["Student"] + + +class Student(object): + """ + The class defined for the student model. Receive knowledge data from + teacher model and carry out knowledge merging. + + Args: + merge_strategy (dict|None): A dictionary whose keys are common + schemas shared by different teachers, and each corresponding + value specifies the merging strategy for different schemas + respectively, supporting 'sum' and 'mean' now. + """ + + def __init__(self, merge_strategy=None): + if merge_strategy: + for strategy in merge_strategy.values(): + if strategy not in ["sum", "mean"]: + raise ValueError( + "Merging strategy must be 'sum' or 'mean'!") + + self._merge_strategy = merge_strategy + self._common_schema = merge_strategy.keys() if merge_strategy else [] + + self._knowledge_desc = OrderedDict() + self._knowledge_queue = Manager().Queue(100) + self._teacher_knowledge_queues = [] + self._t2s_queues = [] + self._s2t_queues = [] + self._cmd_queues = [] + + self._num_teachers = 0 + + self._in_paths = [] + self._in_addresses = [] + + self._started = False + self._is_knowledge_desc_ready = False + self._is_knowledge_gen_locked = False + + def register_teacher(self, in_path=None, in_address=None): + """Register one teacher model and assign the order number to it as + its id, with the file path (offline mode) or IP address (online + mode) that the teacher model wrote knowledge data to. + + Args: + in_path (str|None): The input file path. Default None. + in_address (str|None): The input IP address, in the format + ":" (e.g. "127.0.0.1:8080"). Default None. + """ + if self._started: + raise ValueError( + "The student has been started and cannot register " + "teacher no longer!") + if in_path and in_address: + raise ValueError("Input path and input address should not " + "be given at the same time!") + if not in_path and not in_address: + raise ValueError("One of input path and input address should " + "be given when registering teacher!") + if in_address: + if in_address in self._in_addresses: + print("WARNING: the teacher with input address {} has been " + "registered, and ignored this time!".format(in_path)) + return + ip, port = in_address.strip().split(":") + BaseManager.register("get_knowledge_queue") + BaseManager.register("get_s2t_queue") + BaseManager.register("get_t2s_queue") + BaseManager.register("get_cmd_queue") + manager = BaseManager( + address=(ip, int(port)), authkey=public_authkey.encode()) + + # Wait for teacher model started to establish connection + print("Connecting to {}, with public key {} ...".format( + in_address, public_authkey)) + while True: + try: + manager.connect() + break + except: + time.sleep(1.0) + + knowledge_queue = manager.get_knowledge_queue() + self._t2s_queues.append(manager.get_t2s_queue()) + self._s2t_queues.append(manager.get_s2t_queue()) + self._cmd_queues.append(manager.get_cmd_queue()) + self._in_addresses.append(in_address) + self._in_paths.append(None) + print("Registered teacher {} with input address {}.".format( + self._num_teachers, in_address)) + else: + if in_path in self._in_paths: + print("WARNING: th teacher with input path {} has been " + "registered, and ignored this time!".format(in_path)) + return + + def read_offline(in_path, cmd_queue, out_queue): + end_recved = False + + def get_cmd(): + cmd, end_recved = None, False + try: + if not cmd_queue.empty(): + cmd = cmd_queue.get() + cmd_queue.task_done() + if isinstance(cmd, EndSignal): + end_recved = True + except IOError: + end_recved = True + return cmd, end_recved + + # wait for the sync in start + while not end_recved: + cmd, end_recved = get_cmd() + if isinstance(cmd, SyncSignal): + out_queue.put(SyncSignal()) + break + # for multiple-times offline serving + while not end_recved: + # wait for the sync in get_knowledge_desc() + while not end_recved: + cmd, end_recved = get_cmd() + if isinstance(cmd, SyncSignal): + out_queue.put(SyncSignal()) + break + + if end_recved: + break + with open(in_path, 'r') as fin: + # get knowledge desc + desc = pickle.load(fin) + out_queue.put(desc) + # wait for the data accessing signal + while not end_recved: + cmd, end_recved = get_cmd() + if isinstance(cmd, StartSignal): + break + # get knowledge data + while not end_recved: + try: + data = pickle.load(fin) + out_queue.put(data) + _, end_recved = get_cmd() + except EOFError: + break + if end_recved: + break + out_queue.put(EndSignal()) + out_queue.join() + + knowledge_queue = Manager().Queue(100) + cmd_queue = Manager().Queue(5) + p = Process( + target=read_offline, + args=(in_path, cmd_queue, knowledge_queue)) + p.daemon = True + p.start() + + self._t2s_queues.append(None) + self._s2t_queues.append(None) + self._cmd_queues.append(cmd_queue) + self._in_addresses.append(None) + self._in_paths.append(in_path) + print("Registered teacher {} with input path {}.".format( + self._num_teachers, in_path)) + + self._teacher_knowledge_queues.append(knowledge_queue) + self._num_teachers += 1 + + def _sync(self): + for i, queue in enumerate(self._cmd_queues): + if queue: + queue.put(SyncSignal()) + while True: + cmd = self._teacher_knowledge_queues[i].get() + self._teacher_knowledge_queues[i].task_done() + if isinstance(cmd, SyncSignal): + break + queue.join() + + def start(self): + """ + End teachers' registration and synchronize with all of them. + """ + + if self._started: + raise ValueError( + "The student cannot be started more than one time.") + self._sync() + self._started = True + + def _merge_knowledge(self, knowledge): + for k, tensors in knowledge.items(): + if len(tensors) == 0: + del knowledge[k] + elif len(tensors) == 1: + knowledge[k] = tensors[0] + else: + result = 0 + for tensor in tensors: + result += tensor + if self._merge_strategy[k] == "sum": + knowledge[k] = result + elif self._merge_strategy[k] == "mean": + knowledge[k] = result / len(tensors) + return knowledge + + def send(self, data, teacher_ids=None): + """ + Send data to teachers. + + Args: + data: A Python data object. + teacher_ids (list|None): A list of teacher ids to send data. If + set to None, send the data to all teachers. Default None. + """ + if not self._started: + raise ValueError("The method start() should be called first!") + + if teacher_ids is None: + teacher_ids = range(self._num_teachers) + + for i in teacher_ids: + if self._s2t_queues[i]: + self._s2t_queues[i].put(data) + else: + print("Warning: didn't send data to teacher {} for it is in " + "offline mode.".format(i)) + + def recv(self, teacher_id): + """ + Receive data from one teacher. + + Args: + teacher_id (int): The id of teacher that receives data from. + + Return: + The received data object. + """ + if not self._started: + raise ValueError("The method start() should be called first!") + + if self._t2s_queues[teacher_id]: + data = self._t2s_queues[teacher_id].get() + self._t2s_queues[teacher_id].task_done() + return data + else: + raise ValueError("Cannot receive data from teacher {} for it is " + "offline.".format(teacher_id)) + + def get_knowledge_desc(self): + """ + Get description for knowledge, including shape, data type and lod + level for each schema. + + Return: + dict: Knowledge description. + """ + if not self._started: + raise ValueError("The method start() should be called first!") + + if self._is_knowledge_desc_ready == False: + self._sync() + # get knowledge description + knowledge_desc = OrderedDict() + for idx, queue in enumerate(self._teacher_knowledge_queues): + desc = queue.get() + queue.task_done() + if idx > 0 and (set(knowledge_desc.keys()) & set(desc.keys()) + != set(self._common_schema)): + raise ValueError( + "Teacher {} has the same schema with other existed " + "teachers not in the merge_strategy.".format(idx)) + knowledge_desc.update(desc) + + print("Knowledge merging strategy: {}".format( + self._merge_strategy)) + print("Knowledge description after merging:") + for schema, desc in knowledge_desc.items(): + print("{}: {}".format(schema, desc)) + + self._knowledge_desc = knowledge_desc + self._is_knowledge_desc_ready = True + return self._knowledge_desc + + def get_knowledge_qsize(self): + """ + Get the real-time size of knowledge queue. If this size is denoted as + **qsize**, it means that there are **qsize** batch knowledge data + already pushed into knowledge queue and waiting for the knowledge + generator to pop out. It's dynamic and limited up to 100, the capacity + of the knowledge queue. + + Return: + int: The real-time size of knowledge queue. + """ + if not self._started: + raise ValueError("The method start() should be called first!") + + return self._knowledge_queue.qsize() + + def get_knowledge_generator(self, batch_size, drop_last=False): + """ + Get the generator for knowledge data, return None if last generator + doesn't finish yet. + + Args: + batch_size (int): The batch size of returned knowledge data. + drop_last (bool): Whether to drop the last batch if its size is less + than batch size. + + Return: + func: The wrapper of knowledge data generator. + """ + if not self._started: + raise ValueError("The method start() should be called first!") + + if batch_size <= 0: + raise ValueError("batch size must be positive!") + self._batch_size = batch_size + self._drop_last = drop_last + + # make sure only one generator is available at the same time + if self._is_knowledge_gen_locked: + print("WARNING: new knowledge generator is not available for the " + "last generator hasn't finished yielding all data yet! " + "Return None.") + return None + self._is_knowledge_gen_locked = True + + self.get_knowledge_desc() + + def split_batch(batch, num): + keys = batch.keys() + first, second = {}, {} + for key in keys: + first[key] = batch[key][0:num] + second[key] = batch[key][num:] + return first, second + + def concat_batches(batches): + keys = batches[0].keys() + ret_batch = {} + for key in keys: + ret_batch[key] = np.concatenate( + [batches[i][key] for i in range(len(batches))]) + return ret_batch + + def listen(queues, out_queue): + def data_receiver(queue, batch_size): + def wrapper(): + # The batch size of the teacher and student model may be + # not the same, make a new batch in the batch size of the + # student model. + batches, num_samples = [], 0 + while True: + batch_samples = queue.get() + queue.task_done() + if not isinstance(batch_samples, EndSignal): + cur_num_samples = list(batch_samples.values())[ + 0].shape[0] + if num_samples + cur_num_samples < batch_size: + batches.append(batch_samples) + num_samples += cur_num_samples + elif num_samples + cur_num_samples == batch_size: + batches.append(batch_samples) + yield concat_batches(batches) + batches, num_samples = [], 0 + else: + num_splited = batch_size - num_samples + first, second = split_batch(batch_samples, + num_splited) + batches.append(first) + yield concat_batches(batches) + num_left = cur_num_samples - num_splited + while num_left > batch_size: + first, second = split_batch(second, + batch_size) + yield first + num_left -= batch_size + batches, num_samples = [second], num_left + else: + if len(batches) > 0: + yield concat_batches(batches) + yield EndSignal() + break + + return wrapper + + data_receivers = [ + data_receiver(queue, self._batch_size)() for queue in queues + ] + + end_received = [0] * len(queues) + while True: + knowledge = OrderedDict( + [(k, []) for k, v in self._knowledge_desc.items()]) + for idx, receiver in enumerate(data_receivers): + if not end_received[idx]: + batch_samples = receiver.next( + ) if six.PY2 else receiver.__next__() + if not isinstance(batch_samples, EndSignal): + for k, v in batch_samples.items(): + knowledge[k].append(v) + else: + end_received[idx] = 1 + if sum(end_received) == len(queues): + break + knowledge = self._merge_knowledge(knowledge) + out_queue.put(knowledge) + out_queue.put(EndSignal()) + out_queue.join() + + # acquire data from teachers + for i, queue in enumerate(self._cmd_queues): + if queue: + queue.put(StartSignal()) + queue.join() + + self._listen_thread = Thread( + target=listen, + args=(self._teacher_knowledge_queues, self._knowledge_queue)) + self._listen_thread.dameon = True + self._listen_thread.start() + + def wrapper(): + samples = [] + + while True: + knowledge = self._knowledge_queue.get() + self._knowledge_queue.task_done() + if not isinstance(knowledge, EndSignal): + batch_size = list(knowledge.values())[0].shape[0] + if (batch_size < self._batch_size) and drop_last: + continue + yield knowledge + else: + break + # After all knowledge data yielded, make current knowledge desc invalid. + self._is_knowledge_desc_ready = False + self._is_knowledge_gen_locked = False + + return wrapper + + def __del__(self): + for i, path in enumerate(self._in_paths): + if path: + try: + self._cmd_queues[i].put(EndSignal()) + self._cmd_queues[i].join() + except: + pass diff --git a/paddleslim/pantheon/teacher.py b/paddleslim/pantheon/teacher.py new file mode 100644 index 0000000000000..6cd09efbbbc52 --- /dev/null +++ b/paddleslim/pantheon/teacher.py @@ -0,0 +1,501 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import six +if six.PY2: + import cPickle as pickle + import Queue +else: + import pickle + import queue as Queue + +from collections import OrderedDict, Iterable +import numpy as np +import copy +import multiprocessing +from multiprocessing.managers import BaseManager +from threading import Thread + +import paddle.fluid as fluid + +from paddleslim.pantheon.utils import convert_dtype, EndSignal, SyncSignal, StartSignal, public_authkey + +__all__ = ["Teacher"] + +knowledge_queue = Queue.Queue(100) +t2s_queue = Queue.Queue(100) +s2t_queue = Queue.Queue(100) +cmd_queue = Queue.Queue(5) + + +class MixedDataReader(object): + """ + The wrapper for iterable data loader, to solve the drop problem of last + batches when their number is less than the number of devices in prediction. + It implements two data generators, one for the prediction on all devices, + and another one for the prediction of remained data one single device, and + they two should be called in order. + + Args: + data_loader (fluid.io.DataLoader): The data loader. + base_number (int): The base number that the number of yielded data + batches for multiple devices should be its + multiple times. + """ + + def __init__(self, data_loader, base_number): + self._data_loader = data_loader + self._base_number = base_number + self._tail_data = [] + + def multi_dev_generator(self): + for data in self._data_loader(): + if len(self._tail_data) < self._base_number: + self._tail_data += data + if len(self._tail_data) == self._base_number: + yield self._tail_data + self._tail_data = [] + + def tail_generator(self): + for data in self._tail_data: + yield data + self._tail_data = [] + + +class Teacher(object): + """ + The class defined for the teacher model. Generate knowledge data and + transfer them to the student model. + + Args: + out_path (str|None): The path to dump knowledge for offline mode. + out_port (int|None): The IP port number to send out knowledge for + online mode, should be unique when launching multiple teachers in + the same node. + """ + + def __init__(self, out_path=None, out_port=None): + if out_path and out_port: + raise ValueError("Out path and out port should not be set at " + "the same time!") + + self._out_path = out_path + self._out_port = out_port + # knowledge description + self._knowledge_desc = {} + + self._sync_required = False + self._data_required = False + self._started = False + + def _start_manager(self): + def get_knowledge_queue(): + global knowledge_queue + return knowledge_queue + + def get_s2t_queue(): + global s2t_queue + return s2t_queue + + def get_t2s_queue(): + global t2s_queue + return t2s_queue + + def get_cmd_queue(): + global cmd_queue + return cmd_queue + + BaseManager.register( + "get_knowledge_queue", callable=get_knowledge_queue) + BaseManager.register("get_s2t_queue", callable=get_s2t_queue) + BaseManager.register("get_t2s_queue", callable=get_t2s_queue) + BaseManager.register("get_cmd_queue", callable=get_cmd_queue) + manager = BaseManager( + address=("", self._out_port), authkey=public_authkey.encode()) + manager.start() + print("listen on address: {}".format(manager._address)) + print("public authkey: {}".format(public_authkey)) + return manager + + def start(self): + """ + Start teacher service, sychronize with student and launch the thread + to monitor commands from student. + """ + if self._started: + raise ValueError( + "The teacher cannot be started more than one time.") + self._started = True + self._manager = self._start_manager() if self._out_port else None + if self._manager: + self._knowledge_queue = self._manager.get_knowledge_queue() + self._s2t_queue = self._manager.get_s2t_queue() + self._t2s_queue = self._manager.get_t2s_queue() + self._cmd_queue = self._manager.get_cmd_queue() + else: + self._knowledge_queue = None + self._s2t_queue = None + self._t2s_queue = None + self._cmd_queue = None + + self._out_file = open(self._out_path, "w") if self._out_path else None + if self._out_file: + return + + def wrapper(): + while True: + if not self._cmd_queue.empty(): + cmd = self._cmd_queue.get() + self._cmd_queue.task_done() + if isinstance(cmd, SyncSignal): + self._sync_required = True + elif isinstance(cmd, StartSignal): + self._data_required = True + else: + time.sleep(1.0) + + t = Thread(target=wrapper) + t.daemon = True + t.start() + + while True: + if self._sync_required: + self._knowledge_queue.put(SyncSignal()) + self._knowledge_queue.join() + self._sync_required = False + break + + def send(self, data): + """ + Send one data object to student. + + Args: + data (Python data): The data to be sent, can be any type of Python data object. + """ + if not self._started: + raise ValueError("The method start() should be called first!") + + if not self._t2s_queue: + raise ValueError("Cannot send data to stuent for this teacher " + "is offline!") + self._t2s_queue.put(data) + + def recv(self): + """ + Recieve one data object from student. + + Return: + The received data, can be any type of Python data object. + """ + if not self._started: + raise ValueError("The method start() should be called first!") + + if not self._s2t_queue: + raise ValueError( + "Cannot receive data from stuent for this teacher " + "is in offline mode!") + data = self._s2t_queue.get() + self._s2t_queue.task_done() + return data + + def dump(self, knowledge): + """ + Dump one batch knowledge data into output file, only used in the + offline mode. + + Args: + knowledge (dict): The knowledge data to be dumped. + """ + if not self._started: + raise ValueError("The method start() should be called first!") + + if not self._out_file: + raise ValueError("Cannot dump knowledge data in online mode!") + + if not isinstance(knowledge, dict) and not isinstance(knowledge, + OrderedDict): + raise ValueError( + "The knowledge data should be a dict or OrderedDict!") + + knowledge_desc = {} + for name, value in knowledge.items(): + knowledge_desc[name] = { + "shape": [-1] + list(value.shape[1:]), + "dtype": str(value.dtype), + "lod_level": 0 + } + if not self._knowledge_desc: + self._knowledge_desc = knowledge_desc + self._out_file.write(pickle.dumps(self._knowledge_desc)) + else: + if self._knowledge_desc != knowledge_desc: + raise ValueError( + "Current knowledge desc {} is not the same as " + "historic desc {}!".format(knowledge_desc, + self._knowledge_desc)) + + self._out_file.write(pickle.dumps(knowledge)) + + def start_knowledge_service(self, + feed_list, + schema, + program, + reader_config, + exe, + buf_size=10, + times=1): + """ + Start the knowledge service to generate and transfer knowledge data. + In GPU mode, the devices to execute knowledge prediction will be + determined by environment variable **FLAGS_selected_gpus**, or by + **CUDA_VISIBLE_DEVICES** if it is not set, and by **CPU_NUM** (default + 1) in CPU mode. Only supported in static graph. + + Args: + feed_list (list): A list of feed Variables or their names for the + input program. + schema (dict): A dictionary to specify names and fetched + Variables of knowledge. + program (fluid.Program): Inference program for the teacher model. + reader_config (dict): The config for data reader. Support all the + three types of generators used by `fluid.io.PyReader` and + `fluid.io.DataLoader`, and their configs contain the key-value + pair of the generator type and a generator object, plus + other necessary argument pairs. See the following: + + 1) sample generator: + reader_config={"sample_generator": #some_sample_generator, + "batch_size": #batch_size, "drop_last": #drop_last}, + 'drop_last' set to True by default, + 2) sample list generator: + reader_config={"sample_list_generator": + #some_sample_list_generator}, + 3) batch generator: + reader_config={"batch_generator": #some_batch_genrator}. + + The trial to parse config will be in the order of 1) -> 3), and + any other unrelated keys in these configs will be ignored. + exe (fluid.Executor): The executor to run the input program. + buf_size (int): The size of buffers for data reader and knowledge + writer on each device. + times (int): The maximum repeated serving times. Default 1. Whenever + the public method 'get_knowledge_generator()' in Student + object called once, the serving times will be added one, + until reaching the maximum and ending the service. + """ + if not self._started: + raise ValueError("The method start() should be called first!") + + if not isinstance(program, fluid.Program): + raise ValueError( + "Input argument 'program' should be a fluid Program!") + self._program = program._inference_optimize(prune_read_op=True) + + if not isinstance(feed_list, list): + raise ValueError("Input argument 'feed_list' should be a list!") + else: + self._feed_list = [] + for feed in feed_list: + if isinstance(feed, fluid.framework.Variable): + self._feed_list.append(feed) + elif isinstance(feed, str) or isinstance(feed, unicode): + self._feed_list.append(self._program.global_block().var( + feed)) + else: + raise ValueError( + "Input 'feed_list' should consist of feed " + "Variables or their names!") + + if not isinstance(schema, dict) and not isinstance(schema, + OrderedDict): + raise ValueError( + "Input argument 'schema' should be a dict or OrderedDict!") + self._schema = schema + + if not isinstance(reader_config, dict): + raise ValueError("The reader config must be a dictionary!") + + if not isinstance(exe, fluid.Executor): + raise ValueError("Input argument should be a fluid Executor!") + self._exe = exe + + if not buf_size > 0: + raise ValueError("The buffer size should be positive!") + self._buf_size = buf_size + + if not times > 0: + raise ValueError("Repeated serving times should be positive!") + self._times = times + + desc = {} + for name, var in schema.items(): + if not isinstance(var, fluid.framework.Variable): + raise ValueError( + "The member of schema must be fluid Variable.") + desc[name] = { + "shape": var.shape, + "dtype": convert_dtype(var.dtype), + "lod_level": var.lod_level + } + if not self._knowledge_desc: + self._knowledge_desc = desc + else: + if self._out_file and not self._knowledge_desc == desc: + raise ValueError("The knowledge description should be kept " + "consistent in offline mode!") + + if isinstance(self._exe.place, fluid.CUDAPlace): + places = fluid.cuda_places() + else: + places = fluid.cpu_places() + dev_count = len(places) + + data_loader = fluid.io.DataLoader.from_generator( + feed_list=self._feed_list, + capacity=self._buf_size * dev_count, + use_double_buffer=(dev_count == 1), + iterable=True) + + places = [fluid.CPUPlace()] if dev_count > 1 else [self._exe.place] + if "sample_generator" in reader_config: + if "batch_size" not in reader_config: + raise ValueError("batch size must be specified when using " + "sample generator!") + sample_generator = reader_config["sample_generator"] + batch_size = reader_config["batch_size"] + drop_last = reader_config[ + "drop_last"] if "drop_last" in reader_config else True + + data_loader.set_sample_generator( + reader=sample_generator, + batch_size=batch_size, + drop_last=drop_last, + places=places) + elif "sample_list_generator" in reader_config: + sample_list_generator = reader_config["sample_list_generator"] + data_loader.set_sample_list_generator( + reader=sample_list_generator, places=places) + elif "batch_generator" in reader_config: + batch_generator = reader_config["batch_generator"] + data_loader.set_batch_generator( + reader=batch_generator, places=places) + else: + raise ValueError( + "The reader config doesn't contain any valid " + "generator type, which should be one of 'sample_generator', " + "'sample_list_generator', and 'batch_generator'.") + + def writer(buf_queue, schema_keys): + samples_sent, batches_sent = 0, 0 + while True: + outputs = buf_queue.get() + buf_queue.task_done() + if not isinstance(outputs, EndSignal): + batch_samples = dict(zip(schema_keys, outputs)) + if self._knowledge_queue: + self._knowledge_queue.put(batch_samples) + if self._out_file: + self._out_file.write(pickle.dumps(batch_samples)) + else: + if self._knowledge_queue: + self._knowledge_queue.put(EndSignal()) + + # Asynchronous output + out_buf_queue = Queue.Queue(self._buf_size) + schema_keys, schema_vars = zip(*self._schema.items()) + out_thread = Thread(target=writer, args=(out_buf_queue, schema_keys)) + out_thread.daemon = True + out_thread.start() + + compiled_program = fluid.compiler.CompiledProgram( + self._program).with_data_parallel() + + print("Knowledge description {}".format(self._knowledge_desc)) + print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + + " Teacher begins to serve ...") + # For offline dump, write the knowledge description to the head of file + if self._out_file: + self._out_file.write(pickle.dumps(self._knowledge_desc)) + print("output path: %s" % self._out_path) + + data_reader = MixedDataReader(data_loader, dev_count) + # For online mode, send knowledge description every time + for repeated in range(self._times): + if self._knowledge_queue: + # wait for the accessing of knowledge desc and data + while True: + if self._sync_required: + self._knowledge_queue.put(SyncSignal()) + self._knowledge_queue.put(self._knowledge_desc) + self._sync_required = False + if self._data_required: + self._data_required = False + break + self._knowledge_queue.join() + + print("No.{} time serving ... ".format(repeated)) + num_batches_sent = 0 + for dev_batches in data_reader.multi_dev_generator(): + if self._sync_required: + break + outputs = self._exe.run(compiled_program, + feed=dev_batches, + fetch_list=schema_vars) + out_buf_queue.put(outputs) + num_batches_sent += dev_count + if num_batches_sent % (100 * dev_count) == 0: + log = "Processed {} batch samples.".format( + num_batches_sent) + if self._knowledge_queue: + log += " Knowledge queue size {}.".format( + self._knowledge_queue.qsize()) + print(log) + + outputs = [] + for index, batch in enumerate(data_reader.tail_generator()): + if self._sync_required: + break + output = self._exe.run(self._program, + feed=batch, + fetch_list=schema_vars) + if outputs: + outputs = [ + np.concatenate( + (outs, out), axis=0) + for (outs, out) in zip(outputs, output) + ] + else: + outputs = copy.deepcopy(output) + if outputs: + out_buf_queue.put(outputs) + num_batches_sent += (index + 1) + + print("Processed {} batch samples in total.".format( + num_batches_sent)) + + out_buf_queue.put(EndSignal()) + out_buf_queue.join() + + if self._knowledge_queue: + self._knowledge_queue.join() + print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + + " Teacher ends serving.") + + def __del__(self): + if self._manager: + self._manager.shutdown() + if self._out_file: + self._out_file.close() diff --git a/paddleslim/pantheon/utils.py b/paddleslim/pantheon/utils.py new file mode 100644 index 0000000000000..b4c8001eb6f39 --- /dev/null +++ b/paddleslim/pantheon/utils.py @@ -0,0 +1,61 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections + +public_authkey = u"aBcXyZ123" + + +class StartSignal(): + pass + + +class EndSignal(): + pass + + +class SyncSignal(): + pass + + +def convert_dtype(dtype): + import paddle.fluid as fluid + if isinstance(dtype, fluid.core.VarDesc.VarType): + if dtype == fluid.core.VarDesc.VarType.BOOL: + return 'bool' + elif dtype == fluid.core.VarDesc.VarType.FP16: + return 'float16' + elif dtype == fluid.core.VarDesc.VarType.FP32: + return 'float32' + elif dtype == fluid.core.VarDesc.VarType.FP64: + return 'float64' + elif dtype == fluid.core.VarDesc.VarType.INT8: + return 'int8' + elif dtype == fluid.core.VarDesc.VarType.INT16: + return 'int16' + elif dtype == fluid.core.VarDesc.VarType.INT32: + return 'int32' + elif dtype == fluid.core.VarDesc.VarType.INT64: + return 'int64' + elif dtype == fluid.core.VarDesc.VarType.UINT8: + return 'uint8' + + +def check_ip(address): + import IPy + try: + IPy.IP(address) + return True + except Exception as e: + return False