forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Init pantheon * Update README * Fix pantheon import * Update README * Fix the possible bug when del student * Format docs of public methods * Add api guide & docs for pantheon * Use str2bool instead of bool
- Loading branch information
Yibing Liu
authored
Feb 4, 2020
1 parent
30fe5c4
commit 42bbcb1
Showing
15 changed files
with
1,926 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|
||
The toy examples for Pantheon, see details in [PaddleSlim/Pantheon](../../paddleslim/pantheon). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
from paddleslim.pantheon import Student | ||
|
||
from utils import str2bool | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(__doc__) | ||
parser.add_argument( | ||
"--in_address0", | ||
type=str, | ||
default=None, | ||
help="Input address for teacher 0. (default: %(default)s)") | ||
parser.add_argument( | ||
"--in_path0", | ||
type=str, | ||
default=None, | ||
help="Input file path for teacher 0. (default: %(default)s)") | ||
parser.add_argument( | ||
"--in_address1", | ||
type=str, | ||
default=None, | ||
help="Input address for teacher 1. (default: %(default)s)") | ||
parser.add_argument( | ||
"--in_path1", | ||
type=str, | ||
default=None, | ||
help="Input file path for teacher 1. (default: %(default)s)") | ||
parser.add_argument( | ||
"--test_send_recv", | ||
type=str2bool, | ||
default=False, | ||
help="Whether to test send/recv interfaces. (default: %(default)s)") | ||
parser.add_argument( | ||
"--batch_size", | ||
type=int, | ||
default=32, | ||
help="The batch size of student model. (default: %(default)s)") | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def run(args): | ||
if args.in_address0 and args.in_path0: | ||
raise ValueError( | ||
"args.in_address0 and args.in_path0 should not be valid " | ||
"at the same time!") | ||
if not args.in_address0 and not args.in_path0: | ||
raise ValueError( | ||
"One of args.in_address0 and args.in_path0 must be valid!") | ||
|
||
if args.in_address1 and args.in_path1: | ||
raise ValueError( | ||
"args.in_address1 and args.in_path1 should not be valid " | ||
"at the same time!") | ||
if not args.in_address1 and not args.in_path1: | ||
raise ValueError( | ||
"One of args.in_address1 and args.in_path1 must be valid") | ||
|
||
student = Student(merge_strategy={"result": "sum"}) | ||
|
||
student.register_teacher( | ||
in_address=args.in_address0, in_path=args.in_path0) | ||
student.register_teacher( | ||
in_address=args.in_address1, in_path=args.in_path1) | ||
student.start() | ||
|
||
if args.test_send_recv: | ||
for t in xrange(2): | ||
for i in xrange(3): | ||
print(student.recv(t)) | ||
student.send("message from student!") | ||
|
||
knowledge_desc = student.get_knowledge_desc() | ||
data_generator = student.get_knowledge_generator( | ||
batch_size=args.batch_size, drop_last=False) | ||
for batch_data in data_generator(): | ||
batch_size = list(batch_data.values())[0].shape[0] | ||
keys = batch_data.keys() | ||
for i in range(batch_size): | ||
data = {} | ||
for key in keys: | ||
data[key] = batch_data[key][i] | ||
print(data) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
run(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
import paddle.fluid as fluid | ||
|
||
from utils import parse_args, sample_generator, sample_list_generator, batch_generator | ||
from paddleslim.pantheon import Teacher | ||
|
||
|
||
def run(args): | ||
if args.out_path and args.out_port: | ||
raise ValueError("args.out_path and args.out_port should not be valid " | ||
"at the same time") | ||
if not args.out_path and not args.out_port: | ||
raise ValueError("One of args.out_path and args.out_port be valid") | ||
|
||
# user-defined program: y = 2*x - 1 | ||
startup = fluid.Program() | ||
program = fluid.Program() | ||
with fluid.program_guard(program, startup): | ||
inp_x = fluid.layers.data(name='x', shape=[-1, 1], dtype="int64") | ||
y = inp_x * 2 - 1 | ||
result = fluid.layers.assign(y) | ||
|
||
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace() | ||
exe = fluid.Executor(place) | ||
exe.run(startup) | ||
|
||
teacher = Teacher(out_path=args.out_path, out_port=args.out_port) | ||
teacher.start() | ||
|
||
if args.generator_type == "sample_generator": | ||
reader_config = { | ||
"sample_generator": sample_generator(max_n=1000), | ||
"batch_size": args.batch_size, | ||
"drop_last": False | ||
} | ||
elif args.generator_type == "sample_list_generator": | ||
reader_config = { | ||
"sample_list_generator": sample_list_generator( | ||
max_n=1000, batch_size=args.batch_size) | ||
} | ||
else: | ||
reader_config = { | ||
"batch_generator": batch_generator( | ||
max_n=1000, batch_size=args.batch_size) | ||
} | ||
|
||
if args.test_send_recv: | ||
teacher.send("greetings from teacher1") | ||
teacher.send({"x": 1, "y": 2}) | ||
teacher.send({3, 5}) | ||
print("recved {}".format(teacher.recv())) | ||
|
||
teacher.start_knowledge_service( | ||
feed_list=[inp_x.name], | ||
schema={"x": inp_x, | ||
"2x-1": y, | ||
"result": result}, | ||
program=program, | ||
reader_config=reader_config, | ||
exe=exe, | ||
times=args.serving_times) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
run(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
import paddle.fluid as fluid | ||
|
||
from utils import parse_args, sample_generator, sample_list_generator, batch_generator | ||
from paddleslim.pantheon import Teacher | ||
|
||
|
||
def run(args): | ||
if args.out_path and args.out_port: | ||
raise ValueError("args.out_path and args.out_port should not be valid " | ||
"at the same time") | ||
if not args.out_path and not args.out_port: | ||
raise ValueError("One of args.out_path and args.out_port be valid") | ||
|
||
# user-defined program: y = 2*x + 1 | ||
startup = fluid.Program() | ||
program = fluid.Program() | ||
with fluid.program_guard(program, startup): | ||
inp_x = fluid.layers.data(name='x', shape=[-1, 1], dtype="int64") | ||
y = inp_x * 2 + 1 | ||
result = fluid.layers.assign(y) | ||
|
||
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace() | ||
exe = fluid.Executor(place) | ||
exe.run(startup) | ||
|
||
teacher = Teacher(out_path=args.out_path, out_port=args.out_port) | ||
teacher.start() | ||
|
||
if args.generator_type == "sample_generator": | ||
reader_config = { | ||
"sample_generator": sample_generator(max_n=1000), | ||
"batch_size": args.batch_size, | ||
"drop_last": False | ||
} | ||
elif args.generator_type == "sample_list_generator": | ||
reader_config = { | ||
"sample_list_generator": sample_list_generator( | ||
max_n=1000, batch_size=args.batch_size) | ||
} | ||
else: | ||
reader_config = { | ||
"batch_generator": batch_generator( | ||
max_n=1000, batch_size=args.batch_size) | ||
} | ||
|
||
if args.test_send_recv: | ||
teacher.send("greetings from teacher2") | ||
teacher.send([1]) | ||
teacher.send({1, 2, 3}) | ||
print("recved {}".format(teacher.recv())) | ||
|
||
teacher.start_knowledge_service( | ||
feed_list=[inp_x.name], | ||
schema={"2x+1": y, | ||
"result": result}, | ||
program=program, | ||
reader_config=reader_config, | ||
exe=exe, | ||
times=args.serving_times) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
run(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import numpy as np | ||
import argparse | ||
|
||
|
||
def str2bool(v): | ||
return v.lower() in ("true", "t", "1") | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(__doc__) | ||
parser.add_argument( | ||
"--out_port", | ||
type=int, | ||
default=None, | ||
help="IP port number for sending out data. (default: %(default)s)") | ||
parser.add_argument( | ||
"--out_path", | ||
type=str, | ||
default=None, | ||
help="The file path to dump knowledge data. (default: %(default)s)") | ||
parser.add_argument( | ||
"--use_cuda", | ||
type=str2bool, | ||
default=False, | ||
help="Whether to use GPU for prediction. (default: %(default)s)") | ||
parser.add_argument( | ||
"--test_send_recv", | ||
type=str2bool, | ||
default=False, | ||
help="Whether to test send/recv interfaces. (default: %(default)s)") | ||
parser.add_argument( | ||
"--generator_type", | ||
type=str, | ||
choices=[ | ||
"sample_generator", "sample_list_generator", "batch_generator" | ||
], | ||
default="batch_generator", | ||
help="Which data generator to use. (default: %(default)s)") | ||
parser.add_argument( | ||
"--batch_size", | ||
type=int, | ||
default=32, | ||
help="The batch size per device for data generators. (default: %(default)s)" | ||
) | ||
parser.add_argument( | ||
"--serving_times", | ||
type=int, | ||
default=1, | ||
help="The maximum times of teacher serving knowledge. (default: %(default)s)" | ||
) | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def sample_generator(max_n): | ||
def wrapper(): | ||
for i in range(max_n): | ||
yield [i] | ||
|
||
return wrapper | ||
|
||
|
||
def sample_list_generator(max_n, batch_size=500): | ||
def wrapper(): | ||
sample_list = [] | ||
for sample in sample_generator(max_n)(): | ||
if len(sample_list) < batch_size: | ||
sample_list.append(sample) | ||
if len(sample_list) == batch_size: | ||
yield sample_list | ||
sample_list = [] | ||
if len(sample_list) > 0: | ||
yield sample_list | ||
|
||
return wrapper | ||
|
||
|
||
# data_generator | ||
def batch_generator(max_n, batch_size=500): | ||
def wrapper(): | ||
batch = [] | ||
for sample in sample_generator(max_n)(): | ||
if len(batch) < batch_size: | ||
batch.append(sample) | ||
if len(batch) == batch_size: | ||
yield [np.array(batch).astype('int64').reshape((-1, 1))] | ||
batch = [] | ||
if len(batch) > 0: | ||
yield [np.array(batch).astype('int64').reshape((-1, 1))] | ||
|
||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.