diff --git a/docs/en_US/NAS/QuickStart.rst b/docs/en_US/NAS/QuickStart.rst index 2aa8492036..5c4538907d 100644 --- a/docs/en_US/NAS/QuickStart.rst +++ b/docs/en_US/NAS/QuickStart.rst @@ -182,6 +182,8 @@ Visualize the Experiment Users can visualize their experiment in the same way as visualizing a normal hyper-parameter tuning experiment. For example, open ``localhost::8081`` in your browser, 8081 is the port that you set in ``exp.run``. Please refer to `here <../../Tutorial/WebUI.rst>`__ for details. +We support visualizing models with 3rd-party visualization engines (like `Netron `__). This can be used by clicking ``Visualization`` in detail panel for each trial. Note that current visualization is based on `onnx `__ . Built-in evaluators (e.g., Classification) will automatically export the model into a file, for your own evaluator, you need to save your file into ``$NNI_OUTPUT_DIR/model.onnx`` to make this work. + Export Top Models ----------------- diff --git a/docs/en_US/NAS/WriteTrainer.rst b/docs/en_US/NAS/WriteTrainer.rst index 8c11cadca9..47ce011a39 100644 --- a/docs/en_US/NAS/WriteTrainer.rst +++ b/docs/en_US/NAS/WriteTrainer.rst @@ -24,6 +24,8 @@ The simplest way to customize a new evaluator is with functional APIs, which is .. note:: Due to our current implementation limitation, the ``fit`` function should be put in another python file instead of putting it in the main file. This limitation will be fixed in future release. +.. note:: When using customized evaluators, if you want to visualize models, you need to export your model and save it into ``$NNI_OUTPUT_DIR/model.onnx`` in your evaluator. + With PyTorch-Lightning ---------------------- diff --git a/nni/experiment/launcher.py b/nni/experiment/launcher.py index 7c6ed5a490..92fbf0ae5d 100644 --- a/nni/experiment/launcher.py +++ b/nni/experiment/launcher.py @@ -8,7 +8,7 @@ from subprocess import Popen import sys import time -from typing import Optional, Tuple +from typing import Optional, Tuple, List, Any import colorama @@ -43,7 +43,7 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo _check_rest_server(port) platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform _save_experiment_information(exp_id, port, start_time, platform, - config.experiment_name, proc.pid, str(config.experiment_working_directory)) + config.experiment_name, proc.pid, str(config.experiment_working_directory), []) _logger.info('Setting up...') rest.post(port, '/experiment', config.json()) return proc @@ -78,7 +78,7 @@ def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int, _check_rest_server(port) platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform _save_experiment_information(exp_id, port, start_time, platform, - config.experiment_name, proc.pid, config.experiment_working_directory) + config.experiment_name, proc.pid, config.experiment_working_directory, ['retiarii']) _logger.info('Setting up...') rest.post(port, '/experiment', config.json()) return proc, pipe @@ -156,9 +156,10 @@ def _check_rest_server(port: int, retry: int = 3) -> None: rest.get(port, '/check-status') -def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, name: str, pid: int, logDir: str) -> None: +def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, + name: str, pid: int, logDir: str, tag: List[Any]) -> None: experiments_config = Experiments() - experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir) + experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir, tag=tag) def get_stopped_experiment_config(exp_id: str, mode: str) -> None: diff --git a/nni/retiarii/evaluator/pytorch/lightning.py b/nni/retiarii/evaluator/pytorch/lightning.py index 4399844ac6..6dd83f5d00 100644 --- a/nni/retiarii/evaluator/pytorch/lightning.py +++ b/nni/retiarii/evaluator/pytorch/lightning.py @@ -1,8 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import os import warnings -from typing import Dict, Union, Optional, List +from pathlib import Path +from typing import Dict, NoReturn, Union, Optional, List, Type import pytorch_lightning as pl import torch.nn as nn @@ -18,7 +20,13 @@ class LightningModule(pl.LightningModule): - def set_model(self, model): + """ + Basic wrapper of generated model. + + Lightning modules used in NNI should inherit this class. + """ + + def set_model(self, model: Union[Type[nn.Module], nn.Module]) -> NoReturn: if isinstance(model, type): self.model = model() else: @@ -112,13 +120,23 @@ class _SupervisedLearningModule(LightningModule): def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric], learning_rate: float = 0.001, weight_decay: float = 0., - optimizer: optim.Optimizer = optim.Adam): + optimizer: optim.Optimizer = optim.Adam, + export_onnx: Union[Path, str, bool, None] = None): super().__init__() self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay') self.criterion = criterion() self.optimizer = optimizer self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()}) + if export_onnx is None or export_onnx is True: + self.export_onnx = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx' + self.export_onnx.parent.mkdir(exist_ok=True) + elif export_onnx: + self.export_onnx = Path(export_onnx) + else: + self.export_onnx = None + self._already_exported = False + def forward(self, x): y_hat = self.model(x) return y_hat @@ -135,6 +153,11 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) + + if not self._already_exported: + self.to_onnx(self.export_onnx, x, export_params=True) + self._already_exported = True + self.log('val_loss', self.criterion(y_hat, y), prog_bar=True) for name, metric in self.metrics.items(): self.log('val_' + name, metric(y_hat, y), prog_bar=True) @@ -152,9 +175,8 @@ def configure_optimizers(self): def on_validation_epoch_end(self): nni.report_intermediate_result(self._get_validation_metrics()) - def teardown(self, stage): - if stage == 'fit': - nni.report_final_result(self._get_validation_metrics()) + def on_fit_end(self): + nni.report_final_result(self._get_validation_metrics()) def _get_validation_metrics(self): if len(self.metrics) == 1: @@ -175,9 +197,11 @@ class _ClassificationModule(_SupervisedLearningModule): def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss, learning_rate: float = 0.001, weight_decay: float = 0., - optimizer: optim.Optimizer = optim.Adam): + optimizer: optim.Optimizer = optim.Adam, + export_onnx: bool = True): super().__init__(criterion, {'acc': _AccuracyWithLogits}, - learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer) + learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer, + export_onnx=export_onnx) class Classification(Lightning): @@ -200,6 +224,8 @@ class Classification(Lightning): val_dataloaders : DataLoader or List of DataLoader Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples. If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped. + export_onnx : bool + If true, model will be exported to ``model.onnx`` before training starts. default true trainer_kwargs : dict Optional keyword arguments passed to trainer. See `Lightning documentation `__ for details. @@ -211,9 +237,10 @@ def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss, optimizer: optim.Optimizer = optim.Adam, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Union[DataLoader, List[DataLoader], None] = None, + export_onnx: bool = True, **trainer_kwargs): module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate, - weight_decay=weight_decay, optimizer=optimizer) + weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx) super().__init__(module, Trainer(**trainer_kwargs), train_dataloader=train_dataloader, val_dataloaders=val_dataloaders) @@ -223,9 +250,11 @@ class _RegressionModule(_SupervisedLearningModule): def __init__(self, criterion: nn.Module = nn.MSELoss, learning_rate: float = 0.001, weight_decay: float = 0., - optimizer: optim.Optimizer = optim.Adam): + optimizer: optim.Optimizer = optim.Adam, + export_onnx: bool = True): super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError}, - learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer) + learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer, + export_onnx=export_onnx) class Regression(Lightning): @@ -248,6 +277,8 @@ class Regression(Lightning): val_dataloaders : DataLoader or List of DataLoader Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples. If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped. + export_onnx : bool + If true, model will be exported to ``model.onnx`` before training starts. default: true trainer_kwargs : dict Optional keyword arguments passed to trainer. See `Lightning documentation `__ for details. @@ -259,8 +290,9 @@ def __init__(self, criterion: nn.Module = nn.MSELoss, optimizer: optim.Optimizer = optim.Adam, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Union[DataLoader, List[DataLoader], None] = None, + export_onnx: bool = True, **trainer_kwargs): module = _RegressionModule(criterion=criterion, learning_rate=learning_rate, - weight_decay=weight_decay, optimizer=optimizer) + weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx) super().__init__(module, Trainer(**trainer_kwargs), train_dataloader=train_dataloader, val_dataloaders=val_dataloaders) diff --git a/test/.gitignore b/test/.gitignore index 1065b0ee85..35133a8063 100644 --- a/test/.gitignore +++ b/test/.gitignore @@ -10,3 +10,4 @@ _generated_model data generated lightning_logs +model.onnx diff --git a/test/ut/tools/nnictl/mock/restful_server.py b/test/ut/tools/nnictl/mock/restful_server.py index bf7c760f0c..c104d51b0c 100644 --- a/test/ut/tools/nnictl/mock/restful_server.py +++ b/test/ut/tools/nnictl/mock/restful_server.py @@ -156,7 +156,7 @@ def mock_get_latest_metric_data(): def mock_get_trial_log(): responses.add( - responses.DELETE, 'http://localhost:8080/api/v1/nni/trial-log/:id/:type', + responses.DELETE, 'http://localhost:8080/api/v1/nni/trial-file/:id/:filename', json={"status":"RUNNING","errors":[]}, status=200, content_type='application/json', diff --git a/ts/nni_manager/common/manager.ts b/ts/nni_manager/common/manager.ts index f94bf57d8b..718832e787 100644 --- a/ts/nni_manager/common/manager.ts +++ b/ts/nni_manager/common/manager.ts @@ -4,7 +4,7 @@ 'use strict'; import { MetricDataRecord, MetricType, TrialJobInfo } from './datastore'; -import { TrialJobStatus, LogType } from './trainingService'; +import { TrialJobStatus } from './trainingService'; import { ExperimentConfig } from './experimentConfig'; type ProfileUpdateType = 'TRIAL_CONCURRENCY' | 'MAX_EXEC_DURATION' | 'SEARCH_SPACE' | 'MAX_TRIAL_NUM'; @@ -59,7 +59,7 @@ abstract class Manager { public abstract getMetricDataByRange(minSeqId: number, maxSeqId: number): Promise; public abstract getLatestMetricData(): Promise; - public abstract getTrialLog(trialJobId: string, logType: LogType): Promise; + public abstract getTrialFile(trialJobId: string, fileName: string): Promise; public abstract getTrialJobStatistics(): Promise; public abstract getStatus(): NNIManagerStatus; diff --git a/ts/nni_manager/common/trainingService.ts b/ts/nni_manager/common/trainingService.ts index 2e460bd1db..c59b4bc42f 100644 --- a/ts/nni_manager/common/trainingService.ts +++ b/ts/nni_manager/common/trainingService.ts @@ -8,8 +8,6 @@ */ type TrialJobStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED' | 'SYS_CANCELED' | 'EARLY_STOPPED'; -type LogType = 'TRIAL_LOG' | 'TRIAL_STDOUT' | 'TRIAL_ERROR'; - interface TrainingServiceMetadata { readonly key: string; readonly value: string; @@ -81,7 +79,7 @@ abstract class TrainingService { public abstract submitTrialJob(form: TrialJobApplicationForm): Promise; public abstract updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise; public abstract cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean): Promise; - public abstract getTrialLog(trialJobId: string, logType: LogType): Promise; + public abstract getTrialFile(trialJobId: string, fileName: string): Promise; public abstract setClusterMetadata(key: string, value: string): Promise; public abstract getClusterMetadata(key: string): Promise; public abstract getTrialOutputLocalPath(trialJobId: string): Promise; @@ -103,5 +101,5 @@ class NNIManagerIpConfig { export { TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm, TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters, - NNIManagerIpConfig, LogType + NNIManagerIpConfig }; diff --git a/ts/nni_manager/core/nnimanager.ts b/ts/nni_manager/core/nnimanager.ts index 959371d09a..1e49109202 100644 --- a/ts/nni_manager/core/nnimanager.ts +++ b/ts/nni_manager/core/nnimanager.ts @@ -19,7 +19,7 @@ import { ExperimentConfig, toSeconds, toCudaVisibleDevices } from '../common/exp import { ExperimentManager } from '../common/experimentManager'; import { TensorboardManager } from '../common/tensorboardManager'; import { - TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus, LogType + TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus } from '../common/trainingService'; import { delay, getCheckpointDir, getExperimentRootDir, getLogDir, getMsgDispatcherCommand, mkDirP, getTunerProc, getLogLevel, isAlive, killPid } from '../common/utils'; import { @@ -403,8 +403,8 @@ class NNIManager implements Manager { // FIXME: unit test } - public async getTrialLog(trialJobId: string, logType: LogType): Promise { - return this.trainingService.getTrialLog(trialJobId, logType); + public async getTrialFile(trialJobId: string, fileName: string): Promise { + return this.trainingService.getTrialFile(trialJobId, fileName); } public getExperimentProfile(): Promise { diff --git a/ts/nni_manager/core/test/mockedTrainingService.ts b/ts/nni_manager/core/test/mockedTrainingService.ts index 0f16c95317..68b83a95f6 100644 --- a/ts/nni_manager/core/test/mockedTrainingService.ts +++ b/ts/nni_manager/core/test/mockedTrainingService.ts @@ -7,7 +7,7 @@ import { Deferred } from 'ts-deferred'; import { Provider } from 'typescript-ioc'; import { MethodNotImplementedError } from '../../common/errors'; -import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType } from '../../common/trainingService'; +import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric } from '../../common/trainingService'; const testTrainingServiceProvider: Provider = { get: () => { return new MockedTrainingService(); } @@ -63,7 +63,7 @@ class MockedTrainingService extends TrainingService { return deferred.promise; } - public getTrialLog(trialJobId: string, logType: LogType): Promise { + public getTrialFile(trialJobId: string, fileName: string): Promise { throw new MethodNotImplementedError(); } diff --git a/ts/nni_manager/package.json b/ts/nni_manager/package.json index 0ad8f5f78b..3aba4a4d47 100644 --- a/ts/nni_manager/package.json +++ b/ts/nni_manager/package.json @@ -15,6 +15,7 @@ "child-process-promise": "^2.2.1", "express": "^4.17.1", "express-joi-validator": "^2.0.1", + "http-proxy": "^1.18.1", "ignore": "^5.1.8", "js-base64": "^3.6.1", "kubernetes-client": "^6.12.1", @@ -37,6 +38,7 @@ "@types/chai-as-promised": "^7.1.0", "@types/express": "^4.17.2", "@types/glob": "^7.1.3", + "@types/http-proxy": "^1.17.7", "@types/js-base64": "^3.3.1", "@types/js-yaml": "^4.0.1", "@types/lockfile": "^1.0.0", diff --git a/ts/nni_manager/rest_server/nniRestServer.ts b/ts/nni_manager/rest_server/nniRestServer.ts index cc8c016c94..de335af1e1 100644 --- a/ts/nni_manager/rest_server/nniRestServer.ts +++ b/ts/nni_manager/rest_server/nniRestServer.ts @@ -5,6 +5,7 @@ import * as bodyParser from 'body-parser'; import * as express from 'express'; +import * as httpProxy from 'http-proxy'; import * as path from 'path'; import * as component from '../common/component'; import { RestServer } from '../common/restServer' @@ -21,6 +22,7 @@ import { getAPIRootUrl } from '../common/experimentStartupInfo'; @component.Singleton export class NNIRestServer extends RestServer { private readonly LOGS_ROOT_URL: string = '/logs'; + protected netronProxy: any = null; protected API_ROOT_URL: string = '/api/v1/nni'; /** @@ -29,6 +31,7 @@ export class NNIRestServer extends RestServer { constructor() { super(); this.API_ROOT_URL = getAPIRootUrl(); + this.netronProxy = httpProxy.createProxyServer(); } /** @@ -39,6 +42,14 @@ export class NNIRestServer extends RestServer { this.app.use(bodyParser.json({limit: '50mb'})); this.app.use(this.API_ROOT_URL, createRestHandler(this)); this.app.use(this.LOGS_ROOT_URL, express.static(getLogDir())); + this.app.all('/netron/*', (req: express.Request, res: express.Response) => { + delete req.headers.host; + req.url = req.url.replace('/netron', '/'); + this.netronProxy.web(req, res, { + changeOrigin: true, + target: 'https://netron.app' + }); + }); this.app.get('*', (req: express.Request, res: express.Response) => { res.sendFile(path.resolve('static/index.html')); }); diff --git a/ts/nni_manager/rest_server/restHandler.ts b/ts/nni_manager/rest_server/restHandler.ts index 851276c745..302a78e8a6 100644 --- a/ts/nni_manager/rest_server/restHandler.ts +++ b/ts/nni_manager/rest_server/restHandler.ts @@ -19,7 +19,7 @@ import { NNIRestServer } from './nniRestServer'; import { getVersion } from '../common/utils'; import { MetricType } from '../common/datastore'; import { ProfileUpdateType } from '../common/manager'; -import { LogType, TrialJobStatus } from '../common/trainingService'; +import { TrialJobStatus } from '../common/trainingService'; const expressJoi = require('express-joi-validator'); @@ -53,6 +53,7 @@ class NNIRestHandler { this.version(router); this.checkStatus(router); this.getExperimentProfile(router); + this.getExperimentMetadata(router); this.updateExperimentProfile(router); this.importData(router); this.getImportedData(router); @@ -66,7 +67,7 @@ class NNIRestHandler { this.getMetricData(router); this.getMetricDataByRange(router); this.getLatestMetricData(router); - this.getTrialLog(router); + this.getTrialFile(router); this.exportData(router); this.getExperimentsInfo(router); this.startTensorboardTask(router); @@ -296,13 +297,20 @@ class NNIRestHandler { }); } - private getTrialLog(router: Router): void { - router.get('/trial-log/:id/:type', async(req: Request, res: Response) => { - this.nniManager.getTrialLog(req.params.id, req.params.type as LogType).then((log: string) => { - if (log === '') { - log = 'No logs available.' + private getTrialFile(router: Router): void { + router.get('/trial-file/:id/:filename', async(req: Request, res: Response) => { + let encoding: string | null = null; + const filename = req.params.filename; + if (!filename.includes('.') || filename.match(/.*\.(txt|log)/g)) { + encoding = 'utf8'; + } + this.nniManager.getTrialFile(req.params.id, filename).then((content: Buffer | string) => { + if (content instanceof Buffer) { + res.header('Content-Type', 'application/octet-stream'); + } else if (content === '') { + content = `${filename} is empty.`; } - res.send(log); + res.send(content); }).catch((err: Error) => { this.handleError(err, res); }); @@ -319,6 +327,24 @@ class NNIRestHandler { }); } + private getExperimentMetadata(router: Router): void { + router.get('/experiment-metadata', (req: Request, res: Response) => { + Promise.all([ + this.nniManager.getExperimentProfile(), + this.experimentsManager.getExperimentsInfo() + ]).then(([profile, experimentInfo]) => { + for (const info of experimentInfo as any) { + if (info.id === profile.id) { + res.send(info); + break; + } + } + }).catch((err: Error) => { + this.handleError(err, res); + }); + }); + } + private getExperimentsInfo(router: Router): void { router.get('/experiments-info', (req: Request, res: Response) => { this.experimentsManager.getExperimentsInfo().then((experimentInfo: JSON) => { diff --git a/ts/nni_manager/rest_server/test/mockedNNIManager.ts b/ts/nni_manager/rest_server/test/mockedNNIManager.ts index 78b58cee51..27946a6001 100644 --- a/ts/nni_manager/rest_server/test/mockedNNIManager.ts +++ b/ts/nni_manager/rest_server/test/mockedNNIManager.ts @@ -13,7 +13,7 @@ import { TrialJobStatistics, NNIManagerStatus } from '../../common/manager'; import { - TrialJobApplicationForm, TrialJobDetail, TrialJobStatus, LogType + TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService'; export const testManagerProvider: Provider = { @@ -129,7 +129,7 @@ export class MockedNNIManager extends Manager { public getLatestMetricData(): Promise { throw new MethodNotImplementedError(); } - public getTrialLog(trialJobId: string, logType: LogType): Promise { + public getTrialFile(trialJobId: string, fileName: string): Promise { throw new MethodNotImplementedError(); } public getExperimentProfile(): Promise { diff --git a/ts/nni_manager/training_service/kubernetes/kubernetesTrainingService.ts b/ts/nni_manager/training_service/kubernetes/kubernetesTrainingService.ts index b747ad84fc..ad36fa8c1d 100644 --- a/ts/nni_manager/training_service/kubernetes/kubernetesTrainingService.ts +++ b/ts/nni_manager/training_service/kubernetes/kubernetesTrainingService.ts @@ -14,7 +14,7 @@ import {getExperimentId} from '../../common/experimentStartupInfo'; import {getLogger, Logger} from '../../common/log'; import {MethodNotImplementedError} from '../../common/errors'; import { - NNIManagerIpConfig, TrialJobDetail, TrialJobMetric, LogType + NNIManagerIpConfig, TrialJobDetail, TrialJobMetric } from '../../common/trainingService'; import {delay, getExperimentRootDir, getIPV4Address, getJobCancelStatus, getVersion, uniqueString} from '../../common/utils'; import {AzureStorageClientUtility} from './azureStorageClientUtils'; @@ -99,7 +99,7 @@ abstract class KubernetesTrainingService { return Promise.resolve(kubernetesTrialJob); } - public async getTrialLog(_trialJobId: string, _logType: LogType): Promise { + public async getTrialFile(_trialJobId: string, _filename: string): Promise { throw new MethodNotImplementedError(); } diff --git a/ts/nni_manager/training_service/local/localTrainingService.ts b/ts/nni_manager/training_service/local/localTrainingService.ts index 1bc33a8340..32b529caa7 100644 --- a/ts/nni_manager/training_service/local/localTrainingService.ts +++ b/ts/nni_manager/training_service/local/localTrainingService.ts @@ -13,7 +13,7 @@ import { getExperimentId } from '../../common/experimentStartupInfo'; import { getLogger, Logger } from '../../common/log'; import { HyperParameters, TrainingService, TrialJobApplicationForm, - TrialJobDetail, TrialJobMetric, TrialJobStatus, LogType + TrialJobDetail, TrialJobMetric, TrialJobStatus } from '../../common/trainingService'; import { delay, generateParamFileName, getExperimentRootDir, getJobCancelStatus, getNewLine, isAlive, uniqueString @@ -170,18 +170,20 @@ class LocalTrainingService implements TrainingService { return trialJob; } - public async getTrialLog(trialJobId: string, logType: LogType): Promise { - let logPath: string; - if (logType === 'TRIAL_LOG') { - logPath = path.join(this.rootDir, 'trials', trialJobId, 'trial.log'); - } else if (logType === 'TRIAL_STDOUT'){ - logPath = path.join(this.rootDir, 'trials', trialJobId, 'stdout'); - } else if (logType === 'TRIAL_ERROR') { - logPath = path.join(this.rootDir, 'trials', trialJobId, 'stderr'); - } else { - throw new Error('unexpected log type'); + public async getTrialFile(trialJobId: string, fileName: string): Promise { + // check filename here for security + if (!['trial.log', 'stderr', 'model.onnx', 'stdout'].includes(fileName)) { + throw new Error(`File unaccessible: ${fileName}`); + } + let encoding: string | null = null; + if (!fileName.includes('.') || fileName.match(/.*\.(txt|log)/g)) { + encoding = 'utf8'; + } + const logPath = path.join(this.rootDir, 'trials', trialJobId, fileName); + if (!fs.existsSync(logPath)) { + throw new Error(`File not found: ${logPath}`); } - return fs.promises.readFile(logPath, 'utf8'); + return fs.promises.readFile(logPath, {encoding: encoding as any}); } public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void { diff --git a/ts/nni_manager/training_service/pai/paiTrainingService.ts b/ts/nni_manager/training_service/pai/paiTrainingService.ts index 36cf78835a..31c976cc87 100644 --- a/ts/nni_manager/training_service/pai/paiTrainingService.ts +++ b/ts/nni_manager/training_service/pai/paiTrainingService.ts @@ -15,7 +15,7 @@ import { getLogger, Logger } from '../../common/log'; import { MethodNotImplementedError } from '../../common/errors'; import { HyperParameters, NNIManagerIpConfig, TrainingService, - TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType + TrialJobApplicationForm, TrialJobDetail, TrialJobMetric } from '../../common/trainingService'; import { delay } from '../../common/utils'; import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from '../../common/experimentConfig'; @@ -127,7 +127,7 @@ class PAITrainingService implements TrainingService { return jobs; } - public async getTrialLog(_trialJobId: string, _logType: LogType): Promise { + public async getTrialFile(_trialJobId: string, _fileName: string): Promise { throw new MethodNotImplementedError(); } diff --git a/ts/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts b/ts/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts index da6c5551bb..80dd02127b 100644 --- a/ts/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts +++ b/ts/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts @@ -16,7 +16,7 @@ import { getLogger, Logger } from '../../common/log'; import { ObservableTimer } from '../../common/observableTimer'; import { HyperParameters, TrainingService, TrialJobApplicationForm, - TrialJobDetail, TrialJobMetric, LogType + TrialJobDetail, TrialJobMetric } from '../../common/trainingService'; import { delay, generateParamFileName, getExperimentRootDir, getIPV4Address, getJobCancelStatus, @@ -204,7 +204,7 @@ class RemoteMachineTrainingService implements TrainingService { * @param _trialJobId ID of trial job * @param _logType 'TRIAL_LOG' | 'TRIAL_STDERR' */ - public async getTrialLog(_trialJobId: string, _logType: LogType): Promise { + public async getTrialFile(_trialJobId: string, _fileName: string): Promise { throw new MethodNotImplementedError(); } diff --git a/ts/nni_manager/training_service/reusable/routerTrainingService.ts b/ts/nni_manager/training_service/reusable/routerTrainingService.ts index bc9f413d05..4ecd1594b5 100644 --- a/ts/nni_manager/training_service/reusable/routerTrainingService.ts +++ b/ts/nni_manager/training_service/reusable/routerTrainingService.ts @@ -6,7 +6,7 @@ import { getLogger, Logger } from '../../common/log'; import { MethodNotImplementedError } from '../../common/errors'; import { ExperimentConfig, RemoteConfig, OpenpaiConfig } from '../../common/experimentConfig'; -import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType } from '../../common/trainingService'; +import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric } from '../../common/trainingService'; import { delay } from '../../common/utils'; import { PAITrainingService } from '../pai/paiTrainingService'; import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService'; @@ -52,7 +52,7 @@ class RouterTrainingService implements TrainingService { return await this.internalTrainingService.getTrialJob(trialJobId); } - public async getTrialLog(_trialJobId: string, _logType: LogType): Promise { + public async getTrialFile(_trialJobId: string, _fileName: string): Promise { throw new MethodNotImplementedError(); } diff --git a/ts/nni_manager/training_service/reusable/trialDispatcher.ts b/ts/nni_manager/training_service/reusable/trialDispatcher.ts index 576e6f0786..96082cb750 100644 --- a/ts/nni_manager/training_service/reusable/trialDispatcher.ts +++ b/ts/nni_manager/training_service/reusable/trialDispatcher.ts @@ -13,7 +13,7 @@ import * as component from '../../common/component'; import { NNIError, NNIErrorNames, MethodNotImplementedError } from '../../common/errors'; import { getBasePort, getExperimentId } from '../../common/experimentStartupInfo'; import { getLogger, Logger } from '../../common/log'; -import { TrainingService, TrialJobApplicationForm, TrialJobMetric, TrialJobStatus, LogType } from '../../common/trainingService'; +import { TrainingService, TrialJobApplicationForm, TrialJobMetric, TrialJobStatus } from '../../common/trainingService'; import { delay, getExperimentRootDir, getIPV4Address, getLogLevel, getVersion, mkDirPSync, randomSelect, uniqueString } from '../../common/utils'; import { ExperimentConfig, SharedStorageConfig } from '../../common/experimentConfig'; import { GPU_INFO, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, REPORT_METRIC_DATA, SEND_TRIAL_JOB_PARAMETER, STDOUT, TRIAL_END, VERSION_CHECK } from '../../core/commands'; @@ -157,7 +157,7 @@ class TrialDispatcher implements TrainingService { return trial; } - public async getTrialLog(_trialJobId: string, _logType: LogType): Promise { + public async getTrialFile(_trialJobId: string, _fileName: string): Promise { throw new MethodNotImplementedError(); } diff --git a/ts/nni_manager/training_service/test/localTrainingService.test.ts b/ts/nni_manager/training_service/test/localTrainingService.test.ts index f1b664870b..e8d5073a76 100644 --- a/ts/nni_manager/training_service/test/localTrainingService.test.ts +++ b/ts/nni_manager/training_service/test/localTrainingService.test.ts @@ -100,8 +100,8 @@ describe('Unit Test for LocalTrainingService', () => { fs.mkdirSync(jobDetail.workingDirectory) fs.writeFileSync(path.join(jobDetail.workingDirectory, 'trial.log'), 'trial log') fs.writeFileSync(path.join(jobDetail.workingDirectory, 'stderr'), 'trial stderr') - chai.expect(await localTrainingService.getTrialLog(jobDetail.id, 'TRIAL_LOG')).to.be.equals('trial log'); - chai.expect(await localTrainingService.getTrialLog(jobDetail.id, 'TRIAL_ERROR')).to.be.equals('trial stderr'); + chai.expect(await localTrainingService.getTrialFile(jobDetail.id, 'trial.log')).to.be.equals('trial log'); + chai.expect(await localTrainingService.getTrialFile(jobDetail.id, 'stderr')).to.be.equals('trial stderr'); fs.unlinkSync(path.join(jobDetail.workingDirectory, 'trial.log')) fs.unlinkSync(path.join(jobDetail.workingDirectory, 'stderr')) fs.rmdirSync(jobDetail.workingDirectory) diff --git a/ts/nni_manager/yarn.lock b/ts/nni_manager/yarn.lock index b805155705..ade43c2c41 100644 --- a/ts/nni_manager/yarn.lock +++ b/ts/nni_manager/yarn.lock @@ -503,6 +503,13 @@ "@types/minimatch" "*" "@types/node" "*" +"@types/http-proxy@^1.17.7": + version "1.17.7" + resolved "https://registry.yarnpkg.com/@types/http-proxy/-/http-proxy-1.17.7.tgz#30ea85cc2c868368352a37f0d0d3581e24834c6f" + integrity sha512-9hdj6iXH64tHSLTY+Vt2eYOGzSogC+JQ2H7bdPWkuh7KXP5qLllWx++t+K9Wk556c3dkDdPws/SpMRi0sdCT1w== + dependencies: + "@types/node" "*" + "@types/js-base64@^3.3.1": version "3.3.1" resolved "https://registry.yarnpkg.com/@types/js-base64/-/js-base64-3.3.1.tgz#36c2d6dc126277ea28a4d0599d0cafbf547b51e6" @@ -2071,6 +2078,11 @@ etag@~1.8.1: resolved "https://registry.yarnpkg.com/etag/-/etag-1.8.1.tgz#41ae2eeb65efa62268aebfea83ac7d79299b0887" integrity sha1-Qa4u62XvpiJorr/qg6x9eSmbCIc= +eventemitter3@^4.0.0: + version "4.0.7" + resolved "https://registry.yarnpkg.com/eventemitter3/-/eventemitter3-4.0.7.tgz#2de9b68f6528d5644ef5c59526a1b4a07306169f" + integrity sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw== + execa@^0.7.0: version "0.7.0" resolved "https://registry.yarnpkg.com/execa/-/execa-0.7.0.tgz#944becd34cc41ee32a63a9faf27ad5a65fc59777" @@ -2273,6 +2285,11 @@ flatted@^3.1.0: resolved "https://registry.yarnpkg.com/flatted/-/flatted-3.1.1.tgz#c4b489e80096d9df1dfc97c79871aea7c617c469" integrity sha512-zAoAQiudy+r5SvnSw3KJy5os/oRJYHzrzja/tBDqrZtNhUw8bt6y8OBzMWcjWr+8liV8Eb6yOhw8WZ7VFZ5ZzA== +follow-redirects@^1.0.0: + version "1.14.1" + resolved "https://registry.yarnpkg.com/follow-redirects/-/follow-redirects-1.14.1.tgz#d9114ded0a1cfdd334e164e6662ad02bfd91ff43" + integrity sha512-HWqDgT7ZEkqRzBvc2s64vSZ/hfOceEol3ac/7tKwzuvEyWx3/4UegXh5oBOIotkGsObyk3xznnSRVADBgWSQVg== + for-in@^0.1.3: version "0.1.8" resolved "https://registry.yarnpkg.com/for-in/-/for-in-0.1.8.tgz#d8773908e31256109952b1fdb9b3fa867d2775e1" @@ -2723,6 +2740,15 @@ http-proxy-agent@^4.0.1: agent-base "6" debug "4" +http-proxy@^1.18.1: + version "1.18.1" + resolved "https://registry.yarnpkg.com/http-proxy/-/http-proxy-1.18.1.tgz#401541f0534884bbf95260334e72f88ee3976549" + integrity sha512-7mz/721AbnJwIVbnaSv1Cz3Am0ZLT/UBwkC92VlxhXv/k/BBQfM2fXElQNC27BVGr0uwUpplYPQM9LnaBMR5NQ== + dependencies: + eventemitter3 "^4.0.0" + follow-redirects "^1.0.0" + requires-port "^1.0.0" + http-signature@~1.2.0: version "1.2.0" resolved "https://registry.yarnpkg.com/http-signature/-/http-signature-1.2.0.tgz#9aecd925114772f3d95b65a60abb8f7c18fbace1" @@ -5030,6 +5056,11 @@ require-main-filename@^2.0.0: resolved "https://registry.yarnpkg.com/require-main-filename/-/require-main-filename-2.0.0.tgz#d0b329ecc7cc0f61649f62215be69af54aa8989b" integrity sha512-NKN5kMDylKuldxYLSUfrbo5Tuzh4hd+2E8NPPX02mZtn1VuREQToYe/ZdlJy+J3uCpfaiGF05e7B8W0iXbQHmg== +requires-port@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/requires-port/-/requires-port-1.0.0.tgz#925d2601d39ac485e091cf0da5c6e694dc3dcaff" + integrity sha1-kl0mAdOaxIXgkc8NpcbmlNw9yv8= + resolve-from@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/resolve-from/-/resolve-from-4.0.0.tgz#4abcd852ad32dd7baabfe9b40e00a36db5f392e6" diff --git a/ts/webui/src/components/public-child/OpenRow.tsx b/ts/webui/src/components/public-child/OpenRow.tsx index 52224f5134..1bde938de6 100644 --- a/ts/webui/src/components/public-child/OpenRow.tsx +++ b/ts/webui/src/components/public-child/OpenRow.tsx @@ -56,8 +56,13 @@ class OpenRow extends React.Component { } }; - openTrialLog = (type: string): void => { - window.open(`${MANAGER_IP}/trial-log/${this.props.trialId}/${type}`); + openTrialLog = (filename: string): void => { + window.open(`${MANAGER_IP}/trial-file/${this.props.trialId}/${filename}`); + }; + + openModelOnnx = (): void => { + // TODO: netron might need prefix. + window.open(`/netron/index.html?url=${MANAGER_IP}/trial-file/${this.props.trialId}/model.onnx`); }; render(): React.ReactNode { @@ -113,16 +118,16 @@ class OpenRow extends React.Component {
@@ -132,6 +137,18 @@ class OpenRow extends React.Component { ) } + {EXPERIMENT.metadata.tag.includes('retiarii') ? ( + +
+
Visualize models with 3rd-party tools.
+ +
+
+ ) : null} diff --git a/ts/webui/src/static/interface.ts b/ts/webui/src/static/interface.ts index e1c54eb9ae..561fddbbb2 100644 --- a/ts/webui/src/static/interface.ts +++ b/ts/webui/src/static/interface.ts @@ -165,6 +165,21 @@ interface ExperimentProfile { revision: number; } +interface ExperimentMetadata { + id: string; + port: number; + startTime: number | string; + endTime: number | string; + status: string; + platform: string; + experimentName: string; + tag: any[]; + pid: number; + webuiUrl: any[]; + logDir: string; + prefixUrl: string | null; +} + interface NNIManagerStatus { status: string; errors: string[]; @@ -230,6 +245,7 @@ export { MetricDataRecord, TrialJobInfo, ExperimentProfile, + ExperimentMetadata, NNIManagerStatus, EventMap, SingleAxis, diff --git a/ts/webui/src/static/model/experiment.ts b/ts/webui/src/static/model/experiment.ts index 9c8b12fa60..0bf940c6bc 100644 --- a/ts/webui/src/static/model/experiment.ts +++ b/ts/webui/src/static/model/experiment.ts @@ -1,6 +1,6 @@ import { MANAGER_IP } from '../const'; import { ExperimentConfig, toSeconds } from '../experimentConfig'; -import { ExperimentProfile, NNIManagerStatus } from '../interface'; +import { ExperimentProfile, ExperimentMetadata, NNIManagerStatus } from '../interface'; import { requestAxios } from '../function'; import { SearchSpace } from './searchspace'; @@ -32,8 +32,24 @@ const emptyProfile: ExperimentProfile = { revision: 0 }; +const emptyMetadata: ExperimentMetadata = { + id: '', + port: 0, + startTime: '', + endTime: '', + status: '', + platform: '', + experimentName: '', + tag: [], + pid: 0, + webuiUrl: [], + logDir: '', + prefixUrl: null +}; + class Experiment { private profileField?: ExperimentProfile; + private metadataField?: ExperimentMetadata = undefined; private statusField?: NNIManagerStatus = undefined; private isNestedExperiment: boolean = false; private isexperimentError: boolean = false; @@ -82,10 +98,14 @@ class Experiment { public async update(): Promise { let updated = false; - await requestAxios(`${MANAGER_IP}/experiment`) - .then(data => { - updated = updated || !compareProfiles(this.profileField, data); - this.profileField = data; + await Promise.all([requestAxios(`${MANAGER_IP}/experiment`), requestAxios(`${MANAGER_IP}/experiment-metadata`)]) + .then(([profile, metadata]) => { + updated ||= !compareProfiles(this.profileField, profile); + this.profileField = profile; + + if (JSON.stringify(this.metadataField) !== JSON.stringify(metadata)) { + this.metadataField = metadata; + } }) .catch(error => { this.isexperimentError = true; @@ -111,6 +131,10 @@ class Experiment { return this.profileField === undefined ? emptyProfile : this.profileField; } + get metadata(): ExperimentMetadata { + return this.metadataField === undefined ? emptyMetadata : this.metadataField; + } + get config(): ExperimentConfig { return this.profile.params; } diff --git a/ts/webui/src/static/style/openRow.scss b/ts/webui/src/static/style/openRow.scss index ba49acba06..52a2b8a7ec 100644 --- a/ts/webui/src/static/style/openRow.scss +++ b/ts/webui/src/static/style/openRow.scss @@ -55,3 +55,8 @@ $bgColor: #f2f2f2; } } } + +#visualizationText { + margin: 5px 0 10px 15px; + font-size: 14px; +}