Skip to content

Commit

Permalink
add a demo for rewrite (open-mmlab#145)
Browse files Browse the repository at this point in the history
* add a demo for rewrite

* remove register symbolics

* place two image in one line

* width = 300

* use table to show two images

* fix typo
  • Loading branch information
AllentDan authored Oct 27, 2021
1 parent 0b46c50 commit a563a03
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[settings]
known_third_party = h5py,m2r,mmcls,mmcv,mmdet,mmedit,mmocr,mmseg,ncnn,numpy,onnx,onnxruntime,packaging,pyppl,pytest,pytorch_sphinx_theme,recommonmark,setuptools,sphinx,tensorrt,torch
known_third_party = h5py,m2r,mmcls,mmcv,mmdet,mmedit,mmocr,mmseg,ncnn,numpy,onnx,onnxruntime,packaging,pyppeteer,pyppl,pytest,pytorch_sphinx_theme,recommonmark,setuptools,sphinx,tensorrt,torch,torchvision
16 changes: 16 additions & 0 deletions demo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
## Demo
We provide a demo showing what our mmdeploy can do for general model deployment.

In `demo_rewrite.py`, a resnet18 model from `torchvision` is rewritten through mmdeploy tool. In our rewritten model, the forward function of resnet gets modified to only down sample the original input to 4x. Original onnx model of resnet18 and its rewritten are visualized through [netron](https://netron.app/).

### Prerequisite
Before we run `demp_rewrite.py`, we need to install `pyppeteer` through:
```
pip install pyppeteer
```

### Demo results
The original resnet18 model and its modified one are visualized as follows. The left model is the original resnet18 while the right model is exported after rewritten.
Original resnet18 | Rewritten model
:-------------------------:|:-------------------------:
![](resources/original.png) | ![](resources/rewritten.png)
113 changes: 113 additions & 0 deletions demo/demo_rewrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import asyncio
import logging
import os
import shutil

import torch
from pyppeteer import launch
from torchvision.models import resnet18

from mmdeploy.core import FUNCTION_REWRITER, RewriterContext, patch_model


@FUNCTION_REWRITER.register_rewriter(
func_name='torchvision.models.ResNet._forward_impl')
def forward_of_resnet(ctx, self, x):
"""Rewrite the forward implementation of resnet.
Early return the feature map after two down-sampling steps.
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
return x


def rewrite_resnet18(original_path: str, rewritten_path: str):
# prepare inputs and original model
inputs = torch.rand(1, 3, 224, 224)
original_model = resnet18(pretrained=False)

# export original model
torch.onnx.export(original_model, inputs, original_path)

# patch model
patched_model = patch_model(original_model, cfg={}, backend='default')

# export rewritten onnx under a rewriter context manager
with RewriterContext(cfg={}, backend='default'), torch.no_grad():
torch.onnx.export(patched_model, inputs, rewritten_path)


def screen_size():
"""Get windows size through tkinter."""
import tkinter
tk = tkinter.Tk()
width = tk.winfo_screenwidth()
height = tk.winfo_screenheight()
tk.quit()
return width, height


async def visualize(original_path: str, rewritten_path: str):
# launch a web browser
browser = await launch(headless=False, args=['--start-maximized'])
# create two new pages
page2 = await browser.newPage()
page1 = await browser.newPage()
# go to netron.app
width, height = screen_size()
await page1.setViewport({'width': width, 'height': height})
await page2.setViewport({'width': width, 'height': height})
await page1.goto('https://netron.app/')
await page2.goto('https://netron.app/')
await asyncio.sleep(2)

# open local two onnx files
mupinput1 = await page1.querySelector("input[type='file']")
mupinput2 = await page2.querySelector("input[type='file']")
await mupinput1.uploadFile(original_file_path)
await mupinput2.uploadFile(rewritten_file_path)
await asyncio.sleep(4)
for _ in range(6):
await page1.click('#zoom-out-button')
await asyncio.sleep(0.3)
await asyncio.sleep(1)
await page1.screenshot({'path': original_path.replace('.onnx', '.png')},
clip={
'x': width / 4,
'y': 0,
'width': width / 2,
'height': height
})
await page2.screenshot({'path': rewritten_path.replace('.onnx', '.png')},
clip={
'x': width / 4,
'y': 0,
'width': width / 2,
'height': height
})
await browser.close()


if __name__ == '__main__':
tmp_dir = os.getcwd() + '/tmp'
if not os.path.exists(tmp_dir):
os.mkdir(tmp_dir)
original_file_path = os.path.join(tmp_dir, 'original.onnx')
rewritten_file_path = os.path.join(tmp_dir, 'rewritten.onnx')
logging.info('Generating resnet18 and its rewritten model...')
rewrite_resnet18(original_file_path, rewritten_file_path)

logging.info('Visualizing models through netron...')
asyncio.get_event_loop().run_until_complete(
visualize(original_file_path, rewritten_file_path))
import mmcv
image1 = mmcv.imread(original_file_path.replace('.onnx', '.png'))
image2 = mmcv.imread(rewritten_file_path.replace('.onnx', '.png'))
mmcv.imshow(image1, win_name='original')
mmcv.imshow(image2, win_name='rewritten')
shutil.rmtree(tmp_dir)
Binary file added demo/resources/original.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/resources/rewritten.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit a563a03

Please sign in to comment.