Skip to content

Commit

Permalink
[ONNX] Introduce 'diagnostics' to 'dynamo_export' api (pytorch#99668)
Browse files Browse the repository at this point in the history
Summary
* Introduce `DiagnosticContext` to `torch.onnx.dynamo_export`.
* Remove `DiagnosticEngine` in preparations to update 'diagnostics' in `dynamo_export` to drop dependencies on global diagnostic context. No plans to update `torch.onnx.export` diagnostics.

Next steps
* Separate `torch.onnx.export` diagnostics and `torch.onnx.dynamo_export` diagnostics.
* Drop dependencies on global diagnostic context. pytorch#100219
* Replace 'print's with 'logger.log'.
Pull Request resolved: pytorch#99668
Approved by: https://github.com/justinchuby, https://github.com/abock
  • Loading branch information
BowenBao authored and valentinandrei committed May 2, 2023
1 parent f7070b8 commit 13cfdcc
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 181 deletions.
3 changes: 0 additions & 3 deletions docs/source/onnx_diagnostics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,3 @@ API Reference

.. autoclass:: torch.onnx._internal.diagnostics.ExportDiagnostic
:members:

.. autoclass:: torch.onnx._internal.diagnostics.infra.DiagnosticEngine
:members:
24 changes: 22 additions & 2 deletions test/onnx/dynamo/test_exporter_api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Owner(s): ["module: onnx"]
import io
import logging
import os

import onnx

import torch
from beartype import roar
from torch.onnx import dynamo_export, ExportOptions, ExportOutput
from torch.onnx._internal import exporter
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.exporter import (
_DEFAULT_OPSET_VERSION,
ExportOutputSerializer,
Expand Down Expand Up @@ -132,11 +133,30 @@ def serialize(
with open(path, "r") as fp:
self.assertEquals(fp.read(), expected_buffer)

def test_save_sarif_log_to_file_with_successful_export(self):
with common_utils.TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).diagnostic_context.dump(
path
)
self.assertTrue(os.path.exists(path))

def test_save_sarif_log_to_file_with_failed_export(self):
class ModelWithExportError(torch.nn.Module):
def forward(self, x):
raise RuntimeError("Export error")

with self.assertRaises(RuntimeError):
dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2))
self.assertTrue(os.path.exists(exporter._DEFAULT_FAILED_EXPORT_SARIF_LOG_PATH))

def test_raise_on_invalid_save_argument_type(self):
with self.assertRaises(roar.BeartypeException):
ExportOutput(torch.nn.Linear(2, 3)) # type: ignore[arg-type]
export_output = ExportOutput(
onnx.ModelProto(), exporter.InputAdapter(), exporter.OutputAdapter()
onnx.ModelProto(),
exporter.InputAdapter(),
exporter.OutputAdapter(),
infra.DiagnosticContext("test", "1.0"),
)
with self.assertRaises(roar.BeartypeException):
export_output.save(None) # type: ignore[arg-type]
Expand Down
75 changes: 30 additions & 45 deletions test/onnx/internal/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,26 @@
import io
import typing
import unittest
from typing import AbstractSet, Tuple
from typing import AbstractSet, Protocol, Tuple

import torch
from torch.onnx import errors
from torch.onnx._internal import diagnostics
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import sarif
from torch.testing._internal import common_utils


class _SarifLogBuilder(Protocol):
def sarif_log(self) -> sarif.SarifLog:
...


def _assert_has_diagnostics(
engine: infra.DiagnosticEngine,
sarif_log_builder: _SarifLogBuilder,
rule_level_pairs: AbstractSet[Tuple[infra.Rule, infra.Level]],
):
sarif_log = engine.sarif_log()
sarif_log = sarif_log_builder.sarif_log()
unseen_pairs = {(rule.id, level.name.lower()) for rule, level in rule_level_pairs}
actual_results = []
for run in sarif_log.runs:
Expand All @@ -40,7 +46,7 @@ def _assert_has_diagnostics(
@contextlib.contextmanager
def assert_all_diagnostics(
test_suite: unittest.TestCase,
engine: infra.DiagnosticEngine,
sarif_log_builder: _SarifLogBuilder,
rule_level_pairs: AbstractSet[Tuple[infra.Rule, infra.Level]],
):
"""Context manager to assert that all diagnostics are emitted.
Expand All @@ -55,7 +61,7 @@ def assert_all_diagnostics(
Args:
test_suite: The test suite instance.
engine: The diagnostic engine.
sarif_log_builder: The SARIF log builder.
rule_level_pairs: A set of rule and level pairs to assert.
Returns:
Expand All @@ -70,12 +76,12 @@ def assert_all_diagnostics(
except errors.OnnxExporterError:
test_suite.assertIn(infra.Level.ERROR, {level for _, level in rule_level_pairs})
finally:
_assert_has_diagnostics(engine, rule_level_pairs)
_assert_has_diagnostics(sarif_log_builder, rule_level_pairs)


def assert_diagnostic(
test_suite: unittest.TestCase,
engine: infra.DiagnosticEngine,
sarif_log_builder: _SarifLogBuilder,
rule: infra.Rule,
level: infra.Level,
):
Expand All @@ -92,7 +98,7 @@ def assert_diagnostic(
Args:
test_suite: The test suite instance.
engine: The diagnostic engine.
sarif_log_builder: The SARIF log builder.
rule: The rule to assert.
level: The level to assert.
Expand All @@ -103,7 +109,7 @@ def assert_diagnostic(
AssertionError: If the diagnostic is not emitted.
"""

return assert_all_diagnostics(test_suite, engine, {(rule, level)})
return assert_all_diagnostics(test_suite, sarif_log_builder, {(rule, level)})


class TestOnnxDiagnostics(common_utils.TestCase):
Expand Down Expand Up @@ -240,30 +246,12 @@ class TestDiagnosticsInfra(common_utils.TestCase):
"""Test cases for diagnostics infra."""

def setUp(self):
self.engine = infra.DiagnosticEngine()
self.rules = _RuleCollectionForTest()
with contextlib.ExitStack() as stack:
self.context = stack.enter_context(
self.engine.create_diagnostic_context("test", "1.0.0")
)
self.context = stack.enter_context(infra.DiagnosticContext("test", "1.0.0"))
self.addCleanup(stack.pop_all().close)
return super().setUp()

def test_diagnostics_engine_records_diagnosis_reported_in_nested_contexts(
self,
):
with self.engine.create_diagnostic_context("inner_test", "1.0.1") as context:
context.diagnose(self.rules.rule_without_message_args, infra.Level.WARNING)
sarif_log = self.engine.sarif_log()
self.assertEqual(len(sarif_log.runs), 2)
self.assertEqual(len(sarif_log.runs[0].results), 0)
self.assertEqual(len(sarif_log.runs[1].results), 1)
self.context.diagnose(self.rules.rule_without_message_args, infra.Level.ERROR)
sarif_log = self.engine.sarif_log()
self.assertEqual(len(sarif_log.runs), 2)
self.assertEqual(len(sarif_log.runs[0].results), 1)
self.assertEqual(len(sarif_log.runs[1].results), 1)

def test_diagnostics_engine_records_diagnosis_with_custom_rules(self):
custom_rules = infra.RuleCollection.custom_collection_from_list(
"CustomRuleCollection",
Expand All @@ -281,23 +269,20 @@ def test_diagnostics_engine_records_diagnosis_with_custom_rules(self):
],
)

with self.engine.create_diagnostic_context(
"custom_rules", "1.0"
) as diagnostic_context:
with assert_all_diagnostics(
self,
self.engine,
{
(custom_rules.custom_rule, infra.Level.WARNING), # type: ignore[attr-defined]
(custom_rules.custom_rule_2, infra.Level.ERROR), # type: ignore[attr-defined]
},
):
diagnostic_context.diagnose(
custom_rules.custom_rule, infra.Level.WARNING # type: ignore[attr-defined]
)
diagnostic_context.diagnose(
custom_rules.custom_rule_2, infra.Level.ERROR # type: ignore[attr-defined]
)
with assert_all_diagnostics(
self,
self.context,
{
(custom_rules.custom_rule, infra.Level.WARNING), # type: ignore[attr-defined]
(custom_rules.custom_rule_2, infra.Level.ERROR), # type: ignore[attr-defined]
},
):
self.context.diagnose(
custom_rules.custom_rule, infra.Level.WARNING # type: ignore[attr-defined]
)
self.context.diagnose(
custom_rules.custom_rule_2, infra.Level.ERROR # type: ignore[attr-defined]
)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/test_fx_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def assert_has_diagnostics(
engine: infra.DiagnosticEngine,
engine: diagnostics.ExportDiagnosticEngine,
rule: infra.Rule,
level: infra.Level,
expected_error_node: str,
Expand Down
2 changes: 2 additions & 0 deletions torch/onnx/_internal/diagnostics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
engine,
export_context,
ExportDiagnostic,
ExportDiagnosticEngine,
)
from ._rules import rules
from .infra import levels

__all__ = [
"ExportDiagnostic",
"ExportDiagnosticEngine",
"rules",
"levels",
"engine",
Expand Down
64 changes: 56 additions & 8 deletions torch/onnx/_internal/diagnostics/_diagnostic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""Diagnostic components for PyTorch ONNX export."""
"""Diagnostic components for TorchScript based ONNX export, i.e. `torch.onnx.export`."""
from __future__ import annotations

import contextlib
import gzip
from collections.abc import Generator
from typing import Optional
from typing import List, Optional, Type

import torch

from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, sarif
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
from torch.utils import cpp_backtrace


Expand Down Expand Up @@ -78,10 +81,10 @@ def record_fx_graphmodule(self, gm: torch.fx.GraphModule) -> None:
self.with_graph(infra.Graph(gm.print_readable(False), gm.__class__.__name__))


class ExportDiagnosticEngine(infra.DiagnosticEngine):
class ExportDiagnosticEngine:
"""PyTorch ONNX Export diagnostic engine.
The only purpose of creating this class instead of using the base class directly
The only purpose of creating this class instead of using `DiagnosticContext` directly
is to provide a background context for `diagnose` calls inside exporter.
By design, one `torch.onnx.export` call should initialize one diagnostic context.
Expand All @@ -94,10 +97,11 @@ class ExportDiagnosticEngine(infra.DiagnosticEngine):
established.
"""

contexts: List[infra.DiagnosticContext]
_background_context: infra.DiagnosticContext

def __init__(self) -> None:
super().__init__()
self.contexts = []
self._background_context = infra.DiagnosticContext(
name="torch.onnx",
version=torch.__version__,
Expand All @@ -108,12 +112,55 @@ def __init__(self) -> None:
def background_context(self) -> infra.DiagnosticContext:
return self._background_context

def create_diagnostic_context(
self,
name: str,
version: str,
options: Optional[infra.DiagnosticOptions] = None,
diagnostic_type: Type[infra.Diagnostic] = infra.Diagnostic,
) -> infra.DiagnosticContext:
"""Creates a new diagnostic context.
Args:
name: The subject name for the diagnostic context.
version: The subject version for the diagnostic context.
options: The options for the diagnostic context.
Returns:
A new diagnostic context.
"""
if options is None:
options = infra.DiagnosticOptions()
context = infra.DiagnosticContext(
name, version, options, diagnostic_type=diagnostic_type
)
self.contexts.append(context)
return context

def clear(self):
super().clear()
"""Clears all diagnostic contexts."""
self.contexts.clear()
self._background_context.diagnostics.clear()

def to_json(self) -> str:
return formatter.sarif_to_json(self.sarif_log())

def dump(self, file_path: str, compress: bool = False) -> None:
"""Dumps the SARIF log to a file."""
if compress:
with gzip.open(file_path, "wt") as f:
f.write(self.to_json())
else:
with open(file_path, "w") as f:
f.write(self.to_json())

def sarif_log(self):
log = super().sarif_log()
log = sarif.SarifLog(
version=sarif_version.SARIF_VERSION,
schema_uri=sarif_version.SARIF_SCHEMA_LINK,
runs=[context.sarif() for context in self.contexts],
)

log.runs.append(self._background_context.sarif())
return log

Expand Down Expand Up @@ -154,7 +201,8 @@ def diagnose(
) -> ExportDiagnostic:
"""Creates a diagnostic and record it in the global diagnostic context.
This is a wrapper around `context.record` that uses the global diagnostic context.
This is a wrapper around `context.add_diagnostic` that uses the global diagnostic
context.
"""
# NOTE: Cannot use `@_beartype.beartype`. It somehow erases the cpp stack frame info.
diagnostic = ExportDiagnostic(
Expand Down
3 changes: 1 addition & 2 deletions torch/onnx/_internal/diagnostics/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
Tag,
ThreadFlowLocation,
)
from .engine import Diagnostic, DiagnosticContext, DiagnosticEngine
from .context import Diagnostic, DiagnosticContext

__all__ = [
"Diagnostic",
"DiagnosticContext",
"DiagnosticEngine",
"DiagnosticOptions",
"Graph",
"Invocation",
Expand Down
Loading

0 comments on commit 13cfdcc

Please sign in to comment.