forked from open-mmlab/mmrazor
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add a demo for rewrite (open-mmlab#145)
* add a demo for rewrite * remove register symbolics * place two image in one line * width = 300 * use table to show two images * fix typo
- Loading branch information
Showing
5 changed files
with
130 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
[settings] | ||
known_third_party = h5py,m2r,mmcls,mmcv,mmdet,mmedit,mmocr,mmseg,ncnn,numpy,onnx,onnxruntime,packaging,pyppl,pytest,pytorch_sphinx_theme,recommonmark,setuptools,sphinx,tensorrt,torch | ||
known_third_party = h5py,m2r,mmcls,mmcv,mmdet,mmedit,mmocr,mmseg,ncnn,numpy,onnx,onnxruntime,packaging,pyppeteer,pyppl,pytest,pytorch_sphinx_theme,recommonmark,setuptools,sphinx,tensorrt,torch,torchvision |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
## Demo | ||
We provide a demo showing what our mmdeploy can do for general model deployment. | ||
|
||
In `demo_rewrite.py`, a resnet18 model from `torchvision` is rewritten through mmdeploy tool. In our rewritten model, the forward function of resnet gets modified to only down sample the original input to 4x. Original onnx model of resnet18 and its rewritten are visualized through [netron](https://netron.app/). | ||
|
||
### Prerequisite | ||
Before we run `demp_rewrite.py`, we need to install `pyppeteer` through: | ||
``` | ||
pip install pyppeteer | ||
``` | ||
|
||
### Demo results | ||
The original resnet18 model and its modified one are visualized as follows. The left model is the original resnet18 while the right model is exported after rewritten. | ||
Original resnet18 | Rewritten model | ||
:-------------------------:|:-------------------------: | ||
![](resources/original.png) | ![](resources/rewritten.png) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import asyncio | ||
import logging | ||
import os | ||
import shutil | ||
|
||
import torch | ||
from pyppeteer import launch | ||
from torchvision.models import resnet18 | ||
|
||
from mmdeploy.core import FUNCTION_REWRITER, RewriterContext, patch_model | ||
|
||
|
||
@FUNCTION_REWRITER.register_rewriter( | ||
func_name='torchvision.models.ResNet._forward_impl') | ||
def forward_of_resnet(ctx, self, x): | ||
"""Rewrite the forward implementation of resnet. | ||
Early return the feature map after two down-sampling steps. | ||
""" | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = self.relu(x) | ||
x = self.maxpool(x) | ||
|
||
x = self.layer1(x) | ||
return x | ||
|
||
|
||
def rewrite_resnet18(original_path: str, rewritten_path: str): | ||
# prepare inputs and original model | ||
inputs = torch.rand(1, 3, 224, 224) | ||
original_model = resnet18(pretrained=False) | ||
|
||
# export original model | ||
torch.onnx.export(original_model, inputs, original_path) | ||
|
||
# patch model | ||
patched_model = patch_model(original_model, cfg={}, backend='default') | ||
|
||
# export rewritten onnx under a rewriter context manager | ||
with RewriterContext(cfg={}, backend='default'), torch.no_grad(): | ||
torch.onnx.export(patched_model, inputs, rewritten_path) | ||
|
||
|
||
def screen_size(): | ||
"""Get windows size through tkinter.""" | ||
import tkinter | ||
tk = tkinter.Tk() | ||
width = tk.winfo_screenwidth() | ||
height = tk.winfo_screenheight() | ||
tk.quit() | ||
return width, height | ||
|
||
|
||
async def visualize(original_path: str, rewritten_path: str): | ||
# launch a web browser | ||
browser = await launch(headless=False, args=['--start-maximized']) | ||
# create two new pages | ||
page2 = await browser.newPage() | ||
page1 = await browser.newPage() | ||
# go to netron.app | ||
width, height = screen_size() | ||
await page1.setViewport({'width': width, 'height': height}) | ||
await page2.setViewport({'width': width, 'height': height}) | ||
await page1.goto('https://netron.app/') | ||
await page2.goto('https://netron.app/') | ||
await asyncio.sleep(2) | ||
|
||
# open local two onnx files | ||
mupinput1 = await page1.querySelector("input[type='file']") | ||
mupinput2 = await page2.querySelector("input[type='file']") | ||
await mupinput1.uploadFile(original_file_path) | ||
await mupinput2.uploadFile(rewritten_file_path) | ||
await asyncio.sleep(4) | ||
for _ in range(6): | ||
await page1.click('#zoom-out-button') | ||
await asyncio.sleep(0.3) | ||
await asyncio.sleep(1) | ||
await page1.screenshot({'path': original_path.replace('.onnx', '.png')}, | ||
clip={ | ||
'x': width / 4, | ||
'y': 0, | ||
'width': width / 2, | ||
'height': height | ||
}) | ||
await page2.screenshot({'path': rewritten_path.replace('.onnx', '.png')}, | ||
clip={ | ||
'x': width / 4, | ||
'y': 0, | ||
'width': width / 2, | ||
'height': height | ||
}) | ||
await browser.close() | ||
|
||
|
||
if __name__ == '__main__': | ||
tmp_dir = os.getcwd() + '/tmp' | ||
if not os.path.exists(tmp_dir): | ||
os.mkdir(tmp_dir) | ||
original_file_path = os.path.join(tmp_dir, 'original.onnx') | ||
rewritten_file_path = os.path.join(tmp_dir, 'rewritten.onnx') | ||
logging.info('Generating resnet18 and its rewritten model...') | ||
rewrite_resnet18(original_file_path, rewritten_file_path) | ||
|
||
logging.info('Visualizing models through netron...') | ||
asyncio.get_event_loop().run_until_complete( | ||
visualize(original_file_path, rewritten_file_path)) | ||
import mmcv | ||
image1 = mmcv.imread(original_file_path.replace('.onnx', '.png')) | ||
image2 = mmcv.imread(rewritten_file_path.replace('.onnx', '.png')) | ||
mmcv.imshow(image1, win_name='original') | ||
mmcv.imshow(image2, win_name='rewritten') | ||
shutil.rmtree(tmp_dir) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.