Skip to content

Commit

Permalink
add pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
khizirsiddiqui committed Jun 10, 2021
1 parent 934d1d3 commit ccddfb9
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 54 deletions.
1 change: 1 addition & 0 deletions KD_Lib/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .pipeline import Pipeline
102 changes: 102 additions & 0 deletions KD_Lib/utils/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from itertools import islice
from tqdm import tqdm
import time

from KD_Lib.common import BaseClass


class Pipeline():
"""
Pipeline of knowledge distillation, pruning and quantization methods
supported by KD_Lib. Sequentially applies a list of methods on the student model.
All the elements in list must implement either train_student, prune or quantize
methods.
:param: steps (list) list of KD_Lib.KD or KD_Lib.Pruning or KD_Lib.Quantization
:param: epochs (int) number of iterations through whole batch for each method in
list
:param: plot_losses (bool) Plot a graph of losses during training
:param: save_model (bool) Save model after performing the list methods
:param: save_model_pth (str) Path where model is saved if save_model is True
:param: verbose (int) Verbose
"""
def __init__(
self,
steps,
epochs=5,
plot_losses=True,
save_model=True,
save_model_pth="./models/student.pt",
verbose=0):
self.steps = steps
self.device = device
self.verbose = verbose

self.plot_losses = plot_losses
self.save_model = save_model
self.save_model_path = save_model_pth
self._validate_steps()
self.epochs = epochs

def _validate_steps(self):
name, process = zip(*self.steps)

for t in process:
if (not hasattr(t, ('train_student', 'prune', 'quantize'))):
raise TypeError("All the steps must support at least one of "
"train_student, prune or quantize method, {} is not"
" supported yet".format(str(t)))

def get_steps(self):
return self.steps

def _iter(self, num_steps=-1):
_length = len(self.steps) if num_steps == -1 else num_steps

for idx, (name, process) in enumerate(islice(self.steps, 0, _length)):
yield idx, name, process

def _fit(self):

if self.verbose:
pbar = tqdm(total=len(self))

for idx, name, process in self._iter():
print("Starting {}".format(name))
if idx != 0:
if hasattr(process, 'train_student'):
if hasattr(self.steps[idx-1], 'train_student'):
process.student_model = self.steps[idx-1].student_model
else:
process.student_model = self.steps[idx-1].model
t1 = time.time()
if hasattr(process, 'train_student'):
process.train_student(self.epochs, self.plot_losses, self.save_model, self.save_model_path)
elif hasattr(proces, 'prune'):
process.prune()
elif hasattr(process, 'quantize'):
process.quantize()
else:
raise TypeError("{} is not supported by the pipeline yet."
.format(process))

t2 = time.time() - t1
print("{} completed in {}hr {}min {}s".format(name, t2 // (60 * 60), t2 // 60, t2 % 60)

if self.verbose:
pbar.update(1)

if self.verbose:
pbar.close()

def train(self):
"""
Train the (student) model sequentially through the list.
"""
self._validate_steps()

t1 = time.time()
self._fit()
t2 = time.time() - t1
print("Pipeline execution completed in {}hr {}min {}s".format(t2 // (60 * 60), t2 // 60, t2 % 60)
115 changes: 61 additions & 54 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,67 +13,74 @@
LONG_DESCRIPTION = f.read()

# Define the keywords
KEYWORDS = ["Knowledge Distillation", "Pruning", "Quantization", "pytorch", "machine learning", "deep learning"]
KEYWORDS = [
"Knowledge Distillation",
"Pruning",
"Quantization",
"pytorch",
"machine learning",
"deep learning",
]
REQUIRE_PATH = "requirements.txt"
PROJECT = os.path.abspath(os.path.dirname(__file__))
setup_requirements = ['pytest-runner']
setup_requirements = ["pytest-runner"]

test_requirements = ['pytest', 'pytest-cov']
test_requirements = ["pytest", "pytest-cov"]

requirements = [
'pip==19.3.1',
'transformers==4.6.1',
'sacremoses',
'tokenizers==0.10.1',
'huggingface-hub==0.0.8',
'torchtext==0.9.1',
'bumpversion==0.5.3',
'wheel==0.32.1',
'watchdog==0.9.0',
'flake8==3.5.0',
'tox==3.5.2',
'coverage==4.5.1',
'Sphinx==1.8.1',
'twine==1.12.1',
'pytest==3.8.2',
'pytest-runner==4.2',
'pytest-cov==2.6.1',
'matplotlib==3.2.1',
'torch==1.8.1',
'torchvision==0.9.1',
'tensorboard==2.2.1',
'contextlib2==0.6.0.post1',
'pandas==1.0.1',
'tqdm==4.42.1',
'numpy==1.18.1',
'sphinx-rtd-theme==0.5.0',
"pip==19.3.1",
"transformers==4.6.1",
"sacremoses",
"tokenizers==0.10.1",
"huggingface-hub==0.0.8",
"torchtext==0.9.1",
"bumpversion==0.5.3",
"wheel==0.32.1",
"watchdog==0.9.0",
"flake8==3.5.0",
"tox==3.5.2",
"coverage==4.5.1",
"Sphinx==1.8.1",
"twine==1.12.1",
"pytest==3.8.2",
"pytest-runner==4.2",
"pytest-cov==2.6.1",
"matplotlib==3.2.1",
"torch==1.8.1",
"torchvision==0.9.1",
"tensorboard==2.2.1",
"contextlib2==0.6.0.post1",
"pandas==1.0.1",
"tqdm==4.42.1",
"numpy==1.18.1",
"sphinx-rtd-theme==0.5.0",
]


if __name__ == "__main__":
setup(
author="Het Shah",
author_email='[email protected]',
classifiers=[
'Development Status :: 2 - Pre-Alpha',
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Natural Language :: English',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
],
description="A Pytorch Library to help extend all Knowledge Distillation works",
install_requires=requirements,
license="MIT license",
long_description=LONG_DESCRIPTION,
include_package_data=True,
keywords=KEYWORDS,
name='KD_Lib',
packages=find_packages(where=PROJECT),
setup_requires=setup_requirements,
test_suite="tests",
tests_require=test_requirements,
url="https://github.com/SforAiDL/KD_Lib",
version='0.0.29',
zip_safe=False,
)
author="Het Shah",
author_email="[email protected]",
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
],
description="A Pytorch Library to help extend all Knowledge Distillation works",
install_requires=requirements,
license="MIT license",
long_description=LONG_DESCRIPTION,
include_package_data=True,
keywords=KEYWORDS,
name="KD_Lib",
packages=find_packages(where=PROJECT),
setup_requires=setup_requirements,
test_suite="tests",
tests_require=test_requirements,
url="https://github.com/SforAiDL/KD_Lib",
version="0.0.29",
zip_safe=False,
)
52 changes: 52 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from KD_Lib.utils import Pipeline
from KD_Lib.KD import VanillaKD
from KD_Lib.Pruning import Lottery_Tickets_Pruner
from KD_Lib.Quantization import Dynamic_Quantizer
from KD_Lib.models import Shallow

import torch


train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)


def test_Pipeline():
teacher = Shallow(hidden_size=400)
student = Shallow(hidden_size=100)

t_optimizer = optim.SGD(teac.parameters(), 0.01)
s_optimizer = optim.SGD(stud.parameters(), 0.01)

distiller = VanillaKD(
teacher, student, train_loader, test_loader, t_optimizer, s_optimizer
)

pruner = Lottery_Tickets_Pruner(student, train_loader, test_loader)

quantizer = Dynamic_Quantizer(student, test_loader, {torch.nn.Linear})

pipe = Pipeline([distiller, pruner, quantizer], 1)
pipe.train()

0 comments on commit ccddfb9

Please sign in to comment.