Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] Visualization #3878

Merged
merged 15 commits into from
Jul 12, 2021
2 changes: 2 additions & 0 deletions docs/en_US/NAS/QuickStart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ Visualize the Experiment

Users can visualize their experiment in the same way as visualizing a normal hyper-parameter tuning experiment. For example, open ``localhost::8081`` in your browser, 8081 is the port that you set in ``exp.run``. Please refer to `here <../../Tutorial/WebUI.rst>`__ for details.

We support visualizing models with 3rd-party visualization engines (like `Netron <https://netron.app/>`__). This can be used by clicking ``Visualization`` in detail panel for each trial. Note that current visualization is based on `onnx <https://onnx.ai/>`__ . Built-in evaluators (e.g., Classification) will automatically export the model into a file, for your own evaluator, you need to save your file into ``$NNI_OUTPUT_DIR/model.onnx`` to make this work.

Export Top Models
-----------------

Expand Down
2 changes: 2 additions & 0 deletions docs/en_US/NAS/WriteTrainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ The simplest way to customize a new evaluator is with functional APIs, which is

.. note:: Due to our current implementation limitation, the ``fit`` function should be put in another python file instead of putting it in the main file. This limitation will be fixed in future release.

.. note:: When using customized evaluators, if you want to visualize models, you need to export your model and save it into ``$NNI_OUTPUT_DIR/model.onnx`` in your evaluator.

With PyTorch-Lightning
----------------------

Expand Down
11 changes: 6 additions & 5 deletions nni/experiment/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from subprocess import Popen
import sys
import time
from typing import Optional, Tuple
from typing import Optional, Tuple, List, Any

import colorama

Expand Down Expand Up @@ -43,7 +43,7 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo
_check_rest_server(port)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
_save_experiment_information(exp_id, port, start_time, platform,
config.experiment_name, proc.pid, str(config.experiment_working_directory))
config.experiment_name, proc.pid, str(config.experiment_working_directory), [])
_logger.info('Setting up...')
rest.post(port, '/experiment', config.json())
return proc
Expand Down Expand Up @@ -78,7 +78,7 @@ def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int,
_check_rest_server(port)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
_save_experiment_information(exp_id, port, start_time, platform,
config.experiment_name, proc.pid, config.experiment_working_directory)
config.experiment_name, proc.pid, config.experiment_working_directory, ['retiarii'])
_logger.info('Setting up...')
rest.post(port, '/experiment', config.json())
return proc, pipe
Expand Down Expand Up @@ -156,9 +156,10 @@ def _check_rest_server(port: int, retry: int = 3) -> None:
rest.get(port, '/check-status')


def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, name: str, pid: int, logDir: str) -> None:
def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str,
name: str, pid: int, logDir: str, tag: List[Any]) -> None:
experiments_config = Experiments()
experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir)
experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir, tag=tag)


def get_stopped_experiment_config(exp_id: str, mode: str) -> None:
Expand Down
56 changes: 44 additions & 12 deletions nni/retiarii/evaluator/pytorch/lightning.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import warnings
from typing import Dict, Union, Optional, List
from pathlib import Path
from typing import Dict, NoReturn, Union, Optional, List, Type

import pytorch_lightning as pl
import torch.nn as nn
Expand All @@ -18,7 +20,13 @@


class LightningModule(pl.LightningModule):
def set_model(self, model):
"""
Basic wrapper of generated model.

Lightning modules used in NNI should inherit this class.
"""

def set_model(self, model: Union[Type[nn.Module], nn.Module]) -> NoReturn:
if isinstance(model, type):
self.model = model()
else:
Expand Down Expand Up @@ -112,13 +120,23 @@ class _SupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
optimizer: optim.Optimizer = optim.Adam,
export_onnx: Union[Path, str, bool, None] = None):
super().__init__()
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
self.criterion = criterion()
self.optimizer = optimizer
self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()})

if export_onnx is None or export_onnx is True:
self.export_onnx = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx'
self.export_onnx.parent.mkdir(exist_ok=True)
elif export_onnx:
self.export_onnx = Path(export_onnx)
else:
self.export_onnx = None
self._already_exported = False

def forward(self, x):
y_hat = self.model(x)
return y_hat
Expand All @@ -135,6 +153,11 @@ def training_step(self, batch, batch_idx):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)

if not self._already_exported:
self.to_onnx(self.export_onnx, x, export_params=True)
self._already_exported = True

self.log('val_loss', self.criterion(y_hat, y), prog_bar=True)
for name, metric in self.metrics.items():
self.log('val_' + name, metric(y_hat, y), prog_bar=True)
Expand All @@ -152,9 +175,8 @@ def configure_optimizers(self):
def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics())

def teardown(self, stage):
if stage == 'fit':
nni.report_final_result(self._get_validation_metrics())
def on_fit_end(self):
nni.report_final_result(self._get_validation_metrics())

def _get_validation_metrics(self):
if len(self.metrics) == 1:
Expand All @@ -175,9 +197,11 @@ class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
optimizer: optim.Optimizer = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'acc': _AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx)


class Classification(Lightning):
Expand All @@ -200,6 +224,8 @@ class Classification(Lightning):
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
export_onnx : bool
If true, model will be exported to ``model.onnx`` before training starts. default true
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
Expand All @@ -211,9 +237,10 @@ def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
**trainer_kwargs):
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)

Expand All @@ -223,9 +250,11 @@ class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
optimizer: optim.Optimizer = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx)


class Regression(Lightning):
Expand All @@ -248,6 +277,8 @@ class Regression(Lightning):
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
export_onnx : bool
If true, model will be exported to ``model.onnx`` before training starts. default: true
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
Expand All @@ -259,8 +290,9 @@ def __init__(self, criterion: nn.Module = nn.MSELoss,
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
**trainer_kwargs):
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
1 change: 1 addition & 0 deletions test/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ _generated_model
data
generated
lightning_logs
model.onnx
2 changes: 1 addition & 1 deletion ts/nni_manager/common/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ abstract class Manager {
public abstract getMetricDataByRange(minSeqId: number, maxSeqId: number): Promise<MetricDataRecord[]>;
public abstract getLatestMetricData(): Promise<MetricDataRecord[]>;

public abstract getTrialLog(trialJobId: string, logType: LogType): Promise<string>;
public abstract getTrialFile(trialJobId: string, fileName: string): Promise<Buffer | string>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior would be more predictable to split into get text and get binary.
If you have time, not mandatory...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having a hard time thinking of another meaningful name.
Let's do it in future when we actually need this one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a broken change


public abstract getTrialJobStatistics(): Promise<TrialJobStatistics[]>;
public abstract getStatus(): NNIManagerStatus;
Expand Down
4 changes: 2 additions & 2 deletions ts/nni_manager/common/trainingService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
*/
type TrialJobStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED' | 'SYS_CANCELED' | 'EARLY_STOPPED';

type LogType = 'TRIAL_LOG' | 'TRIAL_STDOUT' | 'TRIAL_ERROR';
type LogType = 'TRIAL_LOG' | 'TRIAL_STDOUT' | 'TRIAL_ERROR' | 'MODEL.onnx';
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove it if the parameter type is 'string' or keep this type for parameter and rename it as "FileName"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just removed LogType as I think it's no longer used.


interface TrainingServiceMetadata {
readonly key: string;
Expand Down Expand Up @@ -81,7 +81,7 @@ abstract class TrainingService {
public abstract submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail>;
public abstract updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail>;
public abstract cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean): Promise<void>;
public abstract getTrialLog(trialJobId: string, logType: LogType): Promise<string>;
public abstract getTrialFile(trialJobId: string, fileName: string): Promise<Buffer | string>;
public abstract setClusterMetadata(key: string, value: string): Promise<void>;
public abstract getClusterMetadata(key: string): Promise<string>;
public abstract getTrialOutputLocalPath(trialJobId: string): Promise<string>;
Expand Down
4 changes: 2 additions & 2 deletions ts/nni_manager/core/nnimanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,8 @@ class NNIManager implements Manager {
// FIXME: unit test
}

public async getTrialLog(trialJobId: string, logType: LogType): Promise<string> {
return this.trainingService.getTrialLog(trialJobId, logType);
public async getTrialFile(trialJobId: string, fileName: string): Promise<Buffer | string> {
return this.trainingService.getTrialFile(trialJobId, fileName);
}

public getExperimentProfile(): Promise<ExperimentProfile> {
Expand Down
2 changes: 1 addition & 1 deletion ts/nni_manager/core/test/mockedTrainingService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class MockedTrainingService extends TrainingService {
return deferred.promise;
}

public getTrialLog(trialJobId: string, logType: LogType): Promise<string> {
public getTrialFile(trialJobId: string, fileName: string): Promise<string> {
throw new MethodNotImplementedError();
}

Expand Down
1 change: 1 addition & 0 deletions ts/nni_manager/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"child-process-promise": "^2.2.1",
"express": "^4.17.1",
"express-joi-validator": "^2.0.1",
"http-proxy": "^1.18.1",
"ignore": "^5.1.8",
"js-base64": "^3.6.1",
"kubernetes-client": "^6.12.1",
Expand Down
12 changes: 12 additions & 0 deletions ts/nni_manager/rest_server/nniRestServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import { getLogDir } from '../common/utils';
import { createRestHandler } from './restHandler';
import { getAPIRootUrl } from '../common/experimentStartupInfo';

const httpProxy = require('http-proxy');
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not import? type definition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idk. Copied code. I'll try import.


/**
* NNI Main rest server, provides rest API to support
* # nnictl CLI tool
Expand All @@ -21,6 +23,7 @@ import { getAPIRootUrl } from '../common/experimentStartupInfo';
@component.Singleton
export class NNIRestServer extends RestServer {
private readonly LOGS_ROOT_URL: string = '/logs';
protected netronProxy: any = null;
protected API_ROOT_URL: string = '/api/v1/nni';

/**
Expand All @@ -29,6 +32,7 @@ export class NNIRestServer extends RestServer {
constructor() {
super();
this.API_ROOT_URL = getAPIRootUrl();
this.netronProxy = httpProxy.createProxyServer();
}

/**
Expand All @@ -39,6 +43,14 @@ export class NNIRestServer extends RestServer {
this.app.use(bodyParser.json({limit: '50mb'}));
this.app.use(this.API_ROOT_URL, createRestHandler(this));
this.app.use(this.LOGS_ROOT_URL, express.static(getLogDir()));
this.app.all('/netron/*', (req: express.Request, res: express.Response) => {
delete req.headers.host;
req.url = req.url.replace('/netron', '/');
this.netronProxy.web(req, res, {
changeOrigin: true,
target: 'https://netron.app'
});
});
this.app.get('*', (req: express.Request, res: express.Response) => {
res.sendFile(path.resolve('static/index.html'));
});
Expand Down
40 changes: 33 additions & 7 deletions ts/nni_manager/rest_server/restHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class NNIRestHandler {
this.version(router);
this.checkStatus(router);
this.getExperimentProfile(router);
this.getExperimentMetadata(router);
this.updateExperimentProfile(router);
this.importData(router);
this.getImportedData(router);
Expand All @@ -66,7 +67,7 @@ class NNIRestHandler {
this.getMetricData(router);
this.getMetricDataByRange(router);
this.getLatestMetricData(router);
this.getTrialLog(router);
this.getTrialFile(router);
this.exportData(router);
this.getExperimentsInfo(router);
this.startTensorboardTask(router);
Expand Down Expand Up @@ -296,13 +297,20 @@ class NNIRestHandler {
});
}

private getTrialLog(router: Router): void {
router.get('/trial-log/:id/:type', async(req: Request, res: Response) => {
this.nniManager.getTrialLog(req.params.id, req.params.type as LogType).then((log: string) => {
if (log === '') {
log = 'No logs available.'
private getTrialFile(router: Router): void {
router.get('/trial-file/:id/:filename', async(req: Request, res: Response) => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems ut in restful_server.py use /trial-log/:id/:type

let encoding: string | null = null;
const filename = req.params.filename;
if (!filename.includes('.') || filename.match(/.*\.(txt|log)/g)) {
encoding = 'utf8';
}
this.nniManager.getTrialFile(req.params.id, filename).then((content: Buffer | string) => {
if (content instanceof Buffer) {
res.header('Content-Type', 'application/octet-stream');
} else if (content === '') {
content = 'No logs available.'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No logs available. seems strange since it is getTrialFile() function.

}
res.send(log);
res.send(content);
}).catch((err: Error) => {
this.handleError(err, res);
});
Expand All @@ -319,6 +327,24 @@ class NNIRestHandler {
});
}

private getExperimentMetadata(router: Router): void {
router.get('/experiment-metadata', (req: Request, res: Response) => {
Promise.all([
this.nniManager.getExperimentProfile(),
this.experimentsManager.getExperimentsInfo()
]).then(([profile, experimentInfo]) => {
for (const info of experimentInfo as any) {
if (info.id === profile.id) {
res.send(info);
break;
}
}
}).catch((err: Error) => {
this.handleError(err, res);
});
});
}

private getExperimentsInfo(router: Router): void {
router.get('/experiments-info', (req: Request, res: Response) => {
this.experimentsManager.getExperimentsInfo().then((experimentInfo: JSON) => {
Expand Down
2 changes: 1 addition & 1 deletion ts/nni_manager/rest_server/test/mockedNNIManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ export class MockedNNIManager extends Manager {
public getLatestMetricData(): Promise<MetricDataRecord[]> {
throw new MethodNotImplementedError();
}
public getTrialLog(trialJobId: string, logType: LogType): Promise<string> {
public getTrialFile(trialJobId: string, fileName: string): Promise<string> {
throw new MethodNotImplementedError();
}
public getExperimentProfile(): Promise<ExperimentProfile> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ abstract class KubernetesTrainingService {
return Promise.resolve(kubernetesTrialJob);
}

public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> {
public async getTrialFile(_trialJobId: string, _filename: string): Promise<string | Buffer> {
throw new MethodNotImplementedError();
}

Expand Down
Loading