Skip to content

Commit

Permalink
Release pantheon (PaddlePaddle#56)
Browse files Browse the repository at this point in the history
* Init pantheon

* Update README

* Fix pantheon import

* Update README

* Fix the possible bug when del student

* Format docs of public methods

* Add api guide & docs for pantheon

* Use str2bool instead of bool
  • Loading branch information
Yibing Liu authored Feb 4, 2020
1 parent 30fe5c4 commit 42bbcb1
Show file tree
Hide file tree
Showing 15 changed files with 1,926 additions and 4 deletions.
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,17 @@ We encapsulate each compression and search method to a compression strategy clas

### Knowledge Distillation

- PaddleSlim supports the following losses added on any paired layers between teacher and student models:
- Flow of the solution procedure (FSP) loss.
- L2 loss.
- **Naive knowledge distillation**: transfers dark knowledge by merging the teacher and student model into the same Program, and supports the following losses added on any paired layers between teacher and student models:
- Flow of the solution procedure (FSP) loss;
- L2 loss;
- Softmax with cross-entropy loss.

- **Paddle large-scale scalable knowledge distillation framework [Pantheon](paddleslim/pantheon)**: a universal solution for knowledge distillation, more flexible than the naive knowledge distillation, and easier to scale to the large-scale applications.
- Decouple the teacher and student models --- they run in different processes in the same or different nodes, and transfer knowledge via TCP/IP ports or local files;
- Friendly to assemble multiple teacher models and each of them can work in either online or offline mode independently;
- Merge knowledge from different teachers and make batch data for the student model automatically;
- Support the large-scale knowledge prediction of teacher models on multiple devices.

### Lightweight Network Architecture Search (Light-NAS)

- PaddleSlim provides Simulated Annealing (SA)-based lightweight network architecture search method.
Expand Down
2 changes: 2 additions & 0 deletions demo/pantheon/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

The toy examples for Pantheon, see details in [PaddleSlim/Pantheon](../../paddleslim/pantheon).
103 changes: 103 additions & 0 deletions demo/pantheon/run_student.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
from paddleslim.pantheon import Student

from utils import str2bool


def parse_args():
parser = argparse.ArgumentParser(__doc__)
parser.add_argument(
"--in_address0",
type=str,
default=None,
help="Input address for teacher 0. (default: %(default)s)")
parser.add_argument(
"--in_path0",
type=str,
default=None,
help="Input file path for teacher 0. (default: %(default)s)")
parser.add_argument(
"--in_address1",
type=str,
default=None,
help="Input address for teacher 1. (default: %(default)s)")
parser.add_argument(
"--in_path1",
type=str,
default=None,
help="Input file path for teacher 1. (default: %(default)s)")
parser.add_argument(
"--test_send_recv",
type=str2bool,
default=False,
help="Whether to test send/recv interfaces. (default: %(default)s)")
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="The batch size of student model. (default: %(default)s)")
args = parser.parse_args()
return args


def run(args):
if args.in_address0 and args.in_path0:
raise ValueError(
"args.in_address0 and args.in_path0 should not be valid "
"at the same time!")
if not args.in_address0 and not args.in_path0:
raise ValueError(
"One of args.in_address0 and args.in_path0 must be valid!")

if args.in_address1 and args.in_path1:
raise ValueError(
"args.in_address1 and args.in_path1 should not be valid "
"at the same time!")
if not args.in_address1 and not args.in_path1:
raise ValueError(
"One of args.in_address1 and args.in_path1 must be valid")

student = Student(merge_strategy={"result": "sum"})

student.register_teacher(
in_address=args.in_address0, in_path=args.in_path0)
student.register_teacher(
in_address=args.in_address1, in_path=args.in_path1)
student.start()

if args.test_send_recv:
for t in xrange(2):
for i in xrange(3):
print(student.recv(t))
student.send("message from student!")

knowledge_desc = student.get_knowledge_desc()
data_generator = student.get_knowledge_generator(
batch_size=args.batch_size, drop_last=False)
for batch_data in data_generator():
batch_size = list(batch_data.values())[0].shape[0]
keys = batch_data.keys()
for i in range(batch_size):
data = {}
for key in keys:
data[key] = batch_data[key][i]
print(data)


if __name__ == '__main__':
args = parse_args()
run(args)
80 changes: 80 additions & 0 deletions demo/pantheon/run_teacher1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle.fluid as fluid

from utils import parse_args, sample_generator, sample_list_generator, batch_generator
from paddleslim.pantheon import Teacher


def run(args):
if args.out_path and args.out_port:
raise ValueError("args.out_path and args.out_port should not be valid "
"at the same time")
if not args.out_path and not args.out_port:
raise ValueError("One of args.out_path and args.out_port be valid")

# user-defined program: y = 2*x - 1
startup = fluid.Program()
program = fluid.Program()
with fluid.program_guard(program, startup):
inp_x = fluid.layers.data(name='x', shape=[-1, 1], dtype="int64")
y = inp_x * 2 - 1
result = fluid.layers.assign(y)

place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)

teacher = Teacher(out_path=args.out_path, out_port=args.out_port)
teacher.start()

if args.generator_type == "sample_generator":
reader_config = {
"sample_generator": sample_generator(max_n=1000),
"batch_size": args.batch_size,
"drop_last": False
}
elif args.generator_type == "sample_list_generator":
reader_config = {
"sample_list_generator": sample_list_generator(
max_n=1000, batch_size=args.batch_size)
}
else:
reader_config = {
"batch_generator": batch_generator(
max_n=1000, batch_size=args.batch_size)
}

if args.test_send_recv:
teacher.send("greetings from teacher1")
teacher.send({"x": 1, "y": 2})
teacher.send({3, 5})
print("recved {}".format(teacher.recv()))

teacher.start_knowledge_service(
feed_list=[inp_x.name],
schema={"x": inp_x,
"2x-1": y,
"result": result},
program=program,
reader_config=reader_config,
exe=exe,
times=args.serving_times)


if __name__ == '__main__':
args = parse_args()
run(args)
79 changes: 79 additions & 0 deletions demo/pantheon/run_teacher2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle.fluid as fluid

from utils import parse_args, sample_generator, sample_list_generator, batch_generator
from paddleslim.pantheon import Teacher


def run(args):
if args.out_path and args.out_port:
raise ValueError("args.out_path and args.out_port should not be valid "
"at the same time")
if not args.out_path and not args.out_port:
raise ValueError("One of args.out_path and args.out_port be valid")

# user-defined program: y = 2*x + 1
startup = fluid.Program()
program = fluid.Program()
with fluid.program_guard(program, startup):
inp_x = fluid.layers.data(name='x', shape=[-1, 1], dtype="int64")
y = inp_x * 2 + 1
result = fluid.layers.assign(y)

place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)

teacher = Teacher(out_path=args.out_path, out_port=args.out_port)
teacher.start()

if args.generator_type == "sample_generator":
reader_config = {
"sample_generator": sample_generator(max_n=1000),
"batch_size": args.batch_size,
"drop_last": False
}
elif args.generator_type == "sample_list_generator":
reader_config = {
"sample_list_generator": sample_list_generator(
max_n=1000, batch_size=args.batch_size)
}
else:
reader_config = {
"batch_generator": batch_generator(
max_n=1000, batch_size=args.batch_size)
}

if args.test_send_recv:
teacher.send("greetings from teacher2")
teacher.send([1])
teacher.send({1, 2, 3})
print("recved {}".format(teacher.recv()))

teacher.start_knowledge_service(
feed_list=[inp_x.name],
schema={"2x+1": y,
"result": result},
program=program,
reader_config=reader_config,
exe=exe,
times=args.serving_times)


if __name__ == '__main__':
args = parse_args()
run(args)
91 changes: 91 additions & 0 deletions demo/pantheon/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np
import argparse


def str2bool(v):
return v.lower() in ("true", "t", "1")


def parse_args():
parser = argparse.ArgumentParser(__doc__)
parser.add_argument(
"--out_port",
type=int,
default=None,
help="IP port number for sending out data. (default: %(default)s)")
parser.add_argument(
"--out_path",
type=str,
default=None,
help="The file path to dump knowledge data. (default: %(default)s)")
parser.add_argument(
"--use_cuda",
type=str2bool,
default=False,
help="Whether to use GPU for prediction. (default: %(default)s)")
parser.add_argument(
"--test_send_recv",
type=str2bool,
default=False,
help="Whether to test send/recv interfaces. (default: %(default)s)")
parser.add_argument(
"--generator_type",
type=str,
choices=[
"sample_generator", "sample_list_generator", "batch_generator"
],
default="batch_generator",
help="Which data generator to use. (default: %(default)s)")
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="The batch size per device for data generators. (default: %(default)s)"
)
parser.add_argument(
"--serving_times",
type=int,
default=1,
help="The maximum times of teacher serving knowledge. (default: %(default)s)"
)
args = parser.parse_args()
return args


def sample_generator(max_n):
def wrapper():
for i in range(max_n):
yield [i]

return wrapper


def sample_list_generator(max_n, batch_size=500):
def wrapper():
sample_list = []
for sample in sample_generator(max_n)():
if len(sample_list) < batch_size:
sample_list.append(sample)
if len(sample_list) == batch_size:
yield sample_list
sample_list = []
if len(sample_list) > 0:
yield sample_list

return wrapper


# data_generator
def batch_generator(max_n, batch_size=500):
def wrapper():
batch = []
for sample in sample_generator(max_n)():
if len(batch) < batch_size:
batch.append(sample)
if len(batch) == batch_size:
yield [np.array(batch).astype('int64').reshape((-1, 1))]
batch = []
if len(batch) > 0:
yield [np.array(batch).astype('int64').reshape((-1, 1))]

return wrapper
2 changes: 2 additions & 0 deletions docs/docs/api/api_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

- [单进程蒸馏](./single_distiller_api.md)

- [大规模可扩展知识蒸馏框架 Pantheon](./pantheon_api.md)

- [通道剪裁](./prune_api.md)

### [量化](./quantization_api.md)
Expand Down
Loading

0 comments on commit 42bbcb1

Please sign in to comment.