Skip to content

Latest commit

 

History

History
96 lines (79 loc) · 3.22 KB

ONNXTypes.md

File metadata and controls

96 lines (79 loc) · 3.22 KB

Optional Type

An optional type represents a reference to either an element (could be Tensor, Sequence, Map, or Sparse Tensor) or a null value. The optional type appears in model inputs, outputs, as well as intermediate values.

Use-cases

Optional type enables users to represent more dynamic typing senarios in ONNX. Similar to Optional[X] type hint in Python typing which is equivalent to Union[None, X], Optional types in ONNX may reference a single element, or null.

Examples in PyTorch

Optional type only appears in TorchScript graphs generated by jit script compiler. Scripting a model captures dynamic types where an optional value can be assigned either None or a value.

  • Example 1

      class Model(torch.nn.Module): 
          def forward(self, x, y:Optional[Tensor]=None): 
              if y is not None: 
                  return x + y 
              return x 
    

    Corresponding TorchScript graph:

      Graph(
          %self : __torch__.Model, 
          %x.1 : Tensor, 
          %y.1 : Tensor?
      ): 
          %11 : int = prim::Constant[value=1]() 
          %4 : None = prim::Constant() 
          %5 : bool = aten::__isnot__(%y.1, %4) 
          %6 : Tensor = prim::If(%5) 
              block0(): 
                  %y.4 : Tensor = prim::unchecked_cast(%y.1) 
                  %12 : Tensor = aten::add(%x.1, %y.4, %11) 
              -> (%12) 
              block1(): 
              -> (%x.1) 
          return (%6)
    

    ONNX graph:

      Graph(
          %x.1 : Float(2, 3), 
          %y.1 : Float(2, 3)
      ): 
          %2 : Bool(1) = onnx::OptionalHasElement(%y.1)
          %5 : Float(2, 3) = onnx::If(%2) 
              block0():
                  %3 : Float(2, 3) = onnx::OptionalGetElement(%y.1) 
                  %4 : Float(2, 3) = onnx::Add(%x.1, %3)
              -> (%4) 
              block1(): 
                  %x.2 : Float(2, 3) = onnx::Identity(%x.1) 
              -> (%x.2) 
          return (%5)
    
  • Example 2

      class Model(torch.nn.Module): 
          def forward( 
                  self, 
                  src_tokens, 
                  return_all_hiddens=torch.tensor([False]), 
          ): 
              encoder_states: Optional[Tensor] = None 
              if return_all_hiddens: 
                  encoder_states = src_tokens  
    
              return src_tokens, encoder_states
    

    Corresponding TorchScript graph:

      Graph(
          %src_tokens.1 : Float(3, 2, 4,), 
          %return_all_hiddens.1 : Bool(1)
      ): 
          %3 : None = prim::Constant() 
          %encoder_states : Tensor? = prim::If(%return_all_hiddens.1) 
              block0(): 
              -> (%src_tokens.1) 
              block1(): 
              -> (%3)
          return (%src_tokens.1, %encoder_states) 
    

    ONNX graph:

      Graph(
          %src_tokens.1 : Float(3, 2, 4), 
          %return_all_hiddens.1 : Bool(1)
      ): 
          %2 : Float(3, 2, 4) = onnx::Optional[type=tensor(float)]()
          %3 : Float(3, 2, 4) = onnx::If(%return_all_hiddens.1) 
              block0():
              -> (%src_tokens.1) 
              block1(): 
              -> (%2)
          return (%3)