diff --git a/docs/Changelog.md b/docs/Changelog.md
index f72b17d8c1b..65ca002f298 100644
--- a/docs/Changelog.md
+++ b/docs/Changelog.md
@@ -134,45 +134,6 @@ This version of the operator has been available since version 1 of the default O
Constrain input and output types to float tensors.
-### **Affine-1**
-
- Affine takes one input data (Tensor) and produces one output data
- (Tensor) where the affine function, y = alpha * x + beta,
- is applied to the tensor elementwise.
-
-#### Version
-
-No versioning maintained for experimental ops.
-#### Attributes
-
-
-- alpha : float (default is 1.0)
-- Value of alpha
-- beta : float (default is 0.0)
-- Value of beta
-
-
-#### Inputs
-
-
-- X : T
-- 1D input tensor
-
-
-#### Outputs
-
-
-- Y : T
-- 1D output tensor
-
-
-#### Type Constraints
-
-
-- T : tensor(float16), tensor(float), tensor(double)
-- Constrain input and output types to float tensors.
-
-
### **And-1**
Returns the tensor resulted from performing the `and` logical operation
@@ -739,45 +700,6 @@ This version of the operator has been available since version 1 of the default O
Constrain input and output types to float tensors.
-### **Crop-1**
-
- Crop and image to the specified spatial dimensions. If scale is given,
- then optionally start the crop offset by the left/top border amounts.
- If scale is not provided, crop the borders as provided.
-
-#### Version
-
-No versioning maintained for experimental ops.
-#### Attributes
-
-
-- border : list of ints
-- A 1-D values of (leftBorder, topBorder, rightBorder, bottomBorder).
-- scale : list of ints
-- A 1-D values of (height, width).
-
-
-#### Inputs
-
-
-- input : T
-- Input tensor of shape [N,C,H,W]
-
-
-#### Outputs
-
-
-- output : T
-- Result, has same type as input, with H and W dimensions reduced.
-
-
-#### Type Constraints
-
-
-- T : tensor(float16), tensor(float), tensor(double)
-- Constrain input and output types to float tensors.
-
-
### **DepthToSpace-1**
DepthToSpace rearranges (permutes) data from depth into blocks of spatial data.
@@ -1850,44 +1772,6 @@ This version of the operator has been available since version 1 of the default O
Only bool
-### **ImageScaler-1**
-
- Scale and bias the input image. Bias values are stored in
- the same ordering as the image pixel format.
-
-#### Version
-
-No versioning maintained for experimental ops.
-#### Attributes
-
-
-- bias : list of floats
-- Bias applied to each channel, same size as C.
-- scale : float (default is 1.0)
-- The scale to apply.
-
-
-#### Inputs
-
-
-- input : T
-- Input tensor of shape [N,C,H,W]
-
-
-#### Outputs
-
-
-- output : T
-- Result, has same shape and type as input
-
-
-#### Type Constraints
-
-
-- T : tensor(float16), tensor(float), tensor(double)
-- Constrain input and output types to float tensors.
-
-
### **InstanceNormalization-1**
Carries out instance normalization as described in the paper
@@ -3071,45 +2955,6 @@ This version of the operator has been available since version 1 of the default O
Constrain input and output types to float tensors.
-### **ParametricSoftplus-1**
-
- ParametricSoftplus takes one input data (Tensor) and produces one output data
- (Tensor) where the softplus function, y = alpha * ln(exp(beta * x) + 1), is applied to
- the tensor elementwise.
-
-#### Version
-
-No versioning maintained for experimental ops.
-#### Attributes
-
-
-- alpha : float
-- Value of alpha
-- beta : float
-- Value of beta
-
-
-#### Inputs
-
-
-- X : T
-- 1D input tensor
-
-
-#### Outputs
-
-
-- Y : T
-- 1D input tensor
-
-
-#### Type Constraints
-
-
-- T : tensor(float16), tensor(float), tensor(double)
-- Constrain input and output types to float tensors.
-
-
### **Pow-1**
Pow takes input data (Tensor) and exponent Tensor, and
@@ -4803,7 +4648,6 @@ This version of the operator has been available since version 1 of the default O
-Index tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which
contains the indices of the top k elements (original indices from the input
tensor).
-
Given two equivalent values, this operator uses the indices along the axis as
a tiebreaker. That is, the element with the lower index will appear first.
@@ -6976,9 +6820,9 @@ This version of the operator has been available since version 7 of the default O
### **Dropout-7**
- Dropout takes one input floating tensor and produces two tensor outputs,
- output (floating tensor) and mask (`Tensor`). Depending on whether it is
- in test mode or not, the output Y will either be a random dropout, or a simple
+ Dropout takes one input data (Tensor) and produces two Tensor outputs,
+ output (Tensor) and mask (Tensor). Depending on whether it is in
+ test mode or not, the output Y will either be a random dropout, or a simple
copy of the input. Note that our implementation of Dropout does scaling in
the training phase, so during testing nothing needs to be done.
This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted.
@@ -7006,7 +6850,7 @@ This version of the operator has been available since version 7 of the default O
- output : T
- The output.
-- mask (optional) : T1
+- mask (optional) : T
- The output mask.
@@ -7015,8 +6859,6 @@ This version of the operator has been available since version 7 of the default O
- T : tensor(float16), tensor(float), tensor(double)
- Constrain input and output types to float tensors.
-- T1 : tensor(bool)
-- Constrain output mask types to boolean tensors.
### **Equal-7**
@@ -9049,6 +8891,47 @@ This version of the operator has been available since version 9 of the default O
Constrain index tensor to int64
+### **MeanVarianceNormalization-9**
+
+ A MeanVarianceNormalization Function: Perform mean variance normalization
+ on the input tensor X using formula:
``` (X-EX)/sqrt(E(X-EX)^2) ```
+
+#### Version
+
+This version of the operator has been available since version 9 of the default ONNX operator set.
+
+#### Attributes
+
+
+- axes : list of ints
+- A list of integers, along which to reduce. The default is to reduce over all the dimensions of the input tensor. Use [0,2,3] (without C axis for N-D cases) for calculating means and variances along channels. Two variables with the same C-coordinate are associated with the same mean and variance.
+
+
+#### Inputs
+
+
+- X : T
+- Input tensor
+
+
+#### Outputs
+
+
+- Y : T
+- Output tensor
+
+
+#### Type Constraints
+
+
+- T : tensor(float16), tensor(float), tensor(double)
+- Constrain input and output types to all numeric tensors.
+
+
+#### Function
+
+The Function can be represented as a function.
+
### **NonZero-9**
Returns the indices of the elements that are non-zero
@@ -9717,3 +9600,54 @@ This version of the operator has been available since version 10 of the default
#### Type Constraints
+### **TopK-10**
+
+ Retrieve the top-K elements along a specified axis. Given an input tensor of
+ shape [a_1, a_2, ..., a_n, r] and integer argument k, return two outputs:
+ -Value tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n]
+ which contains the values of the top k elements along the specified axis
+ -Index tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which
+ contains the indices of the top k elements (original indices from the input
+ tensor).
+
+ Given two equivalent values, this operator uses the indices along the axis as
+ a tiebreaker. That is, the element with the lower index will appear first.
+
+#### Version
+
+This version of the operator has been available since version 10 of the default ONNX operator set.
+
+#### Attributes
+
+
+- axis : int (default is -1)
+- Dimension on which to do the sort.
+
+
+#### Inputs
+
+
+- X : T
+- Tensor of shape [a_1, a_2, ..., a_n, r]
+- K : tensor(int64)
+- A 1-D tensor containing a single positive value corresponding to the number of top elements to retrieve
+
+
+#### Outputs
+
+
+- Values : T
+- Tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] containing top K values from the input tensor
+- Indices : I
+- Tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] containing the corresponding input tensor indices for the top K values.
+
+
+#### Type Constraints
+
+
+- T : tensor(float16), tensor(float), tensor(double)
+- Constrain input and output types to float tensors.
+- I : tensor(int64)
+- Constrain index tensor to int64
+
+
diff --git a/docs/Functions-ml.md b/docs/Functions-ml.md
deleted file mode 100644
index 4598d0d88aa..00000000000
--- a/docs/Functions-ml.md
+++ /dev/null
@@ -1,4 +0,0 @@
-## Functions
-*This file is automatically generated from the
- [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).
- Do not modify directly and instead edit function definitions.*
diff --git a/docs/Functions.md b/docs/Functions.md
deleted file mode 100644
index 3eb6f06934f..00000000000
--- a/docs/Functions.md
+++ /dev/null
@@ -1,36 +0,0 @@
-## Functions
-*This file is automatically generated from the
- [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).
- Do not modify directly and instead edit function definitions.*
-## ai.onnx (default)
- * MeanVarianceNormalization
-
-
-
-### **MeanVarianceNormalization**
-
- A MeanVarianceNormalization Function: Perform mean variance normalization on the input tensor X using formula:
``` (X-EX)/sqrt(E(X-EX)^2) ```
INPUT: X(float/float16/double) with shape [N,C,W,H] or N-D shape
ATTRIBUTE:
axes: will be passed to ReducedMean Ops. Use [0,2,3] (without C axis for N-D cases) for calculating means and variances along channels. Two variables with the same C-coordinate are associated with the same mean and variance. Use [0,1,2,3] (with C axis) to calculate global mean and global variance with all variables sharing the same mean/variance.
(The KeepDims attribute in ReducedMean is set to true for calculation)
OUTPUT: X_MVN(float/float16/double) with the same shape as input X
-
-#### Version
-
-This version of the function has been available since version 9 of the default ONNX operator set.
-
-#### Inputs
-
-
-- X;
-
-
-#### Outputs
-
-
-- X_MVN;
-
-
-#### Attributes
-
-
-- axes;
-
-
-
diff --git a/docs/FunctionsChangelog-ml.md b/docs/FunctionsChangelog-ml.md
deleted file mode 100644
index 8ecc1586ed9..00000000000
--- a/docs/FunctionsChangelog-ml.md
+++ /dev/null
@@ -1,5 +0,0 @@
-## Function Changelog
-*This file is automatically generated from the
- [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).
- Do not modify directly and instead edit function definitions.*
-## ai.onnx.ml
diff --git a/docs/FunctionsChangelog.md b/docs/FunctionsChangelog.md
deleted file mode 100644
index b134e1430ad..00000000000
--- a/docs/FunctionsChangelog.md
+++ /dev/null
@@ -1,32 +0,0 @@
-## Function Changelog
-*This file is automatically generated from the
- [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).
- Do not modify directly and instead edit function definitions.*
-# ai.onnx (default)
-## Version 9 of domain ai.onnx (default)
-### **MeanVarianceNormalization-9**
-
- A MeanVarianceNormalization Function: Perform mean variance normalization on the input tensor X using formula:
``` (X-EX)/sqrt(E(X-EX)^2) ```
INPUT: X(float/float16/double) with shape [N,C,W,H] or N-D shape
ATTRIBUTE:
axes: will be passed to ReducedMean Ops. Use [0,2,3] (without C axis for N-D cases) for calculating means and variances along channels. Two variables with the same C-coordinate are associated with the same mean and variance. Use [0,1,2,3] (with C axis) to calculate global mean and global variance with all variables sharing the same mean/variance.
(The KeepDims attribute in ReducedMean is set to true for calculation)
OUTPUT: X_MVN(float/float16/double) with the same shape as input X
-
-#### Version
-
-This version of the function has been available since version 9 of the default ONNX operator set.
-
-#### Inputs
-
-
-- X;
-
-
-#### Outputs
-
-
-- X_MVN;
-
-
-#### Attributes
-
-
-- axes;
-
-
diff --git a/docs/Operators.md b/docs/Operators.md
index 4084ed201f3..4342a28f14f 100644
--- a/docs/Operators.md
+++ b/docs/Operators.md
@@ -128,17 +128,16 @@
* Where
* Xor
* experimental ATen
- * experimental Affine
- * experimental Crop
* experimental DynamicSlice
* experimental GRUUnit
* experimental GivenTensorFill
- * experimental ImageScaler
- * experimental ParametricSoftplus
* experimental Scale
* experimental ScaledTanh
* experimental ThresholdedRelu
+ **Operators with function registered:**
+ * MeanVarianceNormalization
+
## ai.onnx (default)
### **Abs**
@@ -3054,9 +3053,9 @@ expect(node, inputs=[x, y], outputs=[z],
### **Dropout**
- Dropout takes one input floating tensor and produces two tensor outputs,
- output (floating tensor) and mask (`Tensor`). Depending on whether it is
- in test mode or not, the output Y will either be a random dropout, or a simple
+ Dropout takes one input data (Tensor) and produces two Tensor outputs,
+ output (Tensor) and mask (Tensor). Depending on whether it is in
+ test mode or not, the output Y will either be a random dropout, or a simple
copy of the input. Note that our implementation of Dropout does scaling in
the training phase, so during testing nothing needs to be done.
This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted.
@@ -3086,7 +3085,7 @@ Other versions of this operator: Dropout-1,
- output : T
- The output.
-- mask (optional) : T1
+- mask (optional) : T
- The output mask.
@@ -3095,8 +3094,6 @@ Other versions of this operator: Dropout-1,
- T : tensor(float16), tensor(float), tensor(double)
- Constrain input and output types to float tensors.
-- T1 : tensor(bool)
-- Constrain output mask types to boolean tensors.
@@ -6819,6 +6816,85 @@ expect(node, inputs=[data_0, data_1], outputs=[result],
+### **MeanVarianceNormalization**
+
+ A MeanVarianceNormalization Function: Perform mean variance normalization
+ on the input tensor X using formula:
``` (X-EX)/sqrt(E(X-EX)^2) ```
+
+#### Version
+
+This version of the operator has been available since version 9 of the default ONNX operator set.
+
+#### Attributes
+
+
+- axes : list of ints
+- A list of integers, along which to reduce. The default is to reduce over all the dimensions of the input tensor. Use [0,2,3] (without C axis for N-D cases) for calculating means and variances along channels. Two variables with the same C-coordinate are associated with the same mean and variance.
+
+
+#### Inputs
+
+
+- X : T
+- Input tensor
+
+
+#### Outputs
+
+
+- Y : T
+- Output tensor
+
+
+#### Type Constraints
+
+
+- T : tensor(float16), tensor(float), tensor(double)
+- Constrain input and output types to all numeric tensors.
+
+
+#### Function
+
+The Function can be represented as a function.
+
+
+#### Examples
+
+
+meanvariancenormalization
+
+```python
+node = onnx.helper.make_node(
+ 'MeanVarianceNormalization',
+ inputs=['X'],
+ outputs=['Y']
+)
+
+input_data = np.array([[[[0.8439683], [0.5665144], [0.05836735]],
+ [[0.02916367], [0.12964272], [0.5060197]],
+ [[0.79538304], [0.9411346], [0.9546573]]],
+ [[[0.17730942], [0.46192095], [0.26480448]],
+ [[0.6746842], [0.01665257], [0.62473077]],
+ [[0.9240844], [0.9722341], [0.11965699]]],
+ [[[0.41356155], [0.9129373], [0.59330076]],
+ [[0.81929934], [0.7862604], [0.11799799]],
+ [[0.69248444], [0.54119414], [0.07513223]]]], dtype=np.float32)
+
+# Calculate expected output data
+data_mean = np.mean(input_data, axis=(0, 2, 3), keepdims=1)
+data_mean_squared = np.power(data_mean, 2)
+data_squared = np.power(input_data, 2)
+data_squared_mean = np.mean(data_squared, axis=(0, 2, 3), keepdims=1)
+std = np.sqrt(data_squared_mean - data_mean_squared)
+expected_output = (input_data - data_mean) / (std + 1e-9)
+
+expect(node, inputs=[input_data], outputs=[expected_output],
+ name='test_mvn')
+```
+
+
+
+
### **Min**
Element-wise min of each of the input tensors (with Numpy-style broadcasting support).
@@ -12287,21 +12363,21 @@ expect(node,
-Index tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which
contains the indices of the top k elements (original indices from the input
tensor).
-
+
Given two equivalent values, this operator uses the indices along the axis as
a tiebreaker. That is, the element with the lower index will appear first.
#### Version
-This version of the operator has been available since version 1 of the default ONNX operator set.
+This version of the operator has been available since version 10 of the default ONNX operator set.
+
+Other versions of this operator: TopK-1
#### Attributes
- axis : int (default is -1)
- Dimension on which to do the sort.
-- k : int (required)
-- Number of top elements to retrieve
#### Inputs
@@ -12309,6 +12385,8 @@ This version of the operator has been available since version 1 of the default O
- X : T
- Tensor of shape [a_1, a_2, ..., a_n, r]
+- K : tensor(int64)
+- A 1-D tensor containing a single positive value corresponding to the number of top elements to retrieve
#### Outputs
@@ -12338,15 +12416,15 @@ This version of the operator has been available since version 1 of the default O
```python
node = onnx.helper.make_node(
'TopK',
- inputs=['x'],
+ inputs=['x', 'k'],
outputs=['values', 'indices'],
- k=3
)
X = np.array([
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
], dtype=np.float32)
+K = np.array([3], dtype=np.int64)
values_ref = np.array([
[3, 2, 1],
[7, 6, 5],
@@ -12358,7 +12436,7 @@ indices_ref = np.array([
[3, 2, 1],
], dtype=np.int64)
-expect(node, inputs=[X], outputs=[values_ref, indices_ref],
+expect(node, inputs=[X, K], outputs=[values_ref, indices_ref],
name='test_top_k')
```
@@ -12808,86 +12886,6 @@ No versioning maintained for experimental ops.
-### experimental **Affine**
-
- Affine takes one input data (Tensor) and produces one output data
- (Tensor) where the affine function, y = alpha * x + beta,
- is applied to the tensor elementwise.
-
-#### Version
-
-No versioning maintained for experimental ops.
-#### Attributes
-
-
-- alpha : float (default is 1.0)
-- Value of alpha
-- beta : float (default is 0.0)
-- Value of beta
-
-
-#### Inputs
-
-
-- X : T
-- 1D input tensor
-
-
-#### Outputs
-
-
-- Y : T
-- 1D output tensor
-
-
-#### Type Constraints
-
-
-- T : tensor(float16), tensor(float), tensor(double)
-- Constrain input and output types to float tensors.
-
-
-
-### experimental **Crop**
-
- Crop and image to the specified spatial dimensions. If scale is given,
- then optionally start the crop offset by the left/top border amounts.
- If scale is not provided, crop the borders as provided.
-
-#### Version
-
-No versioning maintained for experimental ops.
-#### Attributes
-
-
-- border : list of ints
-- A 1-D values of (leftBorder, topBorder, rightBorder, bottomBorder).
-- scale : list of ints
-- A 1-D values of (height, width).
-
-
-#### Inputs
-
-
-- input : T
-- Input tensor of shape [N,C,H,W]
-
-
-#### Outputs
-
-
-- output : T
-- Result, has same type as input, with H and W dimensions reduced.
-
-
-#### Type Constraints
-
-
-- T : tensor(float16), tensor(float), tensor(double)
-- Constrain input and output types to float tensors.
-
-
-
### experimental **DynamicSlice**
Produces a slice of the input tensor along multiple axes. Similar to numpy:
@@ -13158,85 +13156,6 @@ No versioning maintained for experimental ops.
-### experimental **ImageScaler**
-
- Scale and bias the input image. Bias values are stored in
- the same ordering as the image pixel format.
-
-#### Version
-
-No versioning maintained for experimental ops.
-#### Attributes
-
-
-- bias : list of floats
-- Bias applied to each channel, same size as C.
-- scale : float (default is 1.0)
-- The scale to apply.
-
-
-#### Inputs
-
-
-- input : T
-- Input tensor of shape [N,C,H,W]
-
-
-#### Outputs
-
-
-- output : T
-- Result, has same shape and type as input
-
-
-#### Type Constraints
-
-
-- T : tensor(float16), tensor(float), tensor(double)
-- Constrain input and output types to float tensors.
-
-
-
-### experimental **ParametricSoftplus**
-
- ParametricSoftplus takes one input data (Tensor) and produces one output data
- (Tensor) where the softplus function, y = alpha * ln(exp(beta * x) + 1), is applied to
- the tensor elementwise.
-
-#### Version
-
-No versioning maintained for experimental ops.
-#### Attributes
-
-
-- alpha : float
-- Value of alpha
-- beta : float
-- Value of beta
-
-
-#### Inputs
-
-
-- X : T
-- 1D input tensor
-
-
-#### Outputs
-
-
-- Y : T
-- 1D input tensor
-
-
-#### Type Constraints
-
-
-- T : tensor(float16), tensor(float), tensor(double)
-- Constrain input and output types to float tensors.
-
-
-
### experimental **Scale**
Scale takes one input data (Tensor) and produces one output data
diff --git a/docs/TestCoverage.md b/docs/TestCoverage.md
index be47332a285..5cb0e1efe06 100644
--- a/docs/TestCoverage.md
+++ b/docs/TestCoverage.md
@@ -5,9 +5,9 @@
* [Overall Test Coverage](#overall-test-coverage)
# Node Test Coverage
## Summary
-Node tests have covered 111/118 (94.07%, 5 generators excluded) common operators.
+Node tests have covered 112/119 (94.12%, 5 generators excluded) common operators.
-Node tests have covered 2/11 (18.18%, 0 generators excluded) experimental operators.
+Node tests have covered 2/7 (28.57%, 0 generators excluded) experimental operators.
* [Covered Common Operators](#covered-common-operators)
* [No Cover Common Operators](#no-cover-common-operators)
@@ -3578,6 +3578,43 @@ expect(node, inputs=[data_0, data_1], outputs=[result],
+### MeanVarianceNormalization
+There are 1 test cases, listed as following:
+
+meanvariancenormalization
+
+```python
+node = onnx.helper.make_node(
+ 'MeanVarianceNormalization',
+ inputs=['X'],
+ outputs=['Y']
+)
+
+input_data = np.array([[[[0.8439683], [0.5665144], [0.05836735]],
+ [[0.02916367], [0.12964272], [0.5060197]],
+ [[0.79538304], [0.9411346], [0.9546573]]],
+ [[[0.17730942], [0.46192095], [0.26480448]],
+ [[0.6746842], [0.01665257], [0.62473077]],
+ [[0.9240844], [0.9722341], [0.11965699]]],
+ [[[0.41356155], [0.9129373], [0.59330076]],
+ [[0.81929934], [0.7862604], [0.11799799]],
+ [[0.69248444], [0.54119414], [0.07513223]]]], dtype=np.float32)
+
+# Calculate expected output data
+data_mean = np.mean(input_data, axis=(0, 2, 3), keepdims=1)
+data_mean_squared = np.power(data_mean, 2)
+data_squared = np.power(input_data, 2)
+data_squared_mean = np.mean(data_squared, axis=(0, 2, 3), keepdims=1)
+std = np.sqrt(data_squared_mean - data_mean_squared)
+expected_output = (input_data - data_mean) / (std + 1e-9)
+
+expect(node, inputs=[input_data], outputs=[expected_output],
+ name='test_mvn')
+```
+
+
+
+
### Min
There are 1 test cases, listed as following:
@@ -6481,15 +6518,15 @@ There are 1 test cases, listed as following:
```python
node = onnx.helper.make_node(
'TopK',
- inputs=['x'],
+ inputs=['x', 'k'],
outputs=['values', 'indices'],
- k=3
)
X = np.array([
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
], dtype=np.float32)
+K = np.array([3], dtype=np.int64)
values_ref = np.array([
[3, 2, 1],
[7, 6, 5],
@@ -6501,7 +6538,7 @@ indices_ref = np.array([
[3, 2, 1],
], dtype=np.int64)
-expect(node, inputs=[X], outputs=[values_ref, indices_ref],
+expect(node, inputs=[X, K], outputs=[values_ref, indices_ref],
name='test_top_k')
```
@@ -6922,24 +6959,12 @@ expect(node, inputs=[x], outputs=[y],
### ATen (call for test cases)
-### Affine (call for test cases)
-
-
-### Crop (call for test cases)
-
-
### GRUUnit (call for test cases)
### GivenTensorFill (call for test cases)
-### ImageScaler (call for test cases)
-
-
-### ParametricSoftplus (call for test cases)
-
-
### Scale (call for test cases)
diff --git a/onnx/backend/test/case/node/mvn.py b/onnx/backend/test/case/node/meanvariancenormalization.py
similarity index 97%
rename from onnx/backend/test/case/node/mvn.py
rename to onnx/backend/test/case/node/meanvariancenormalization.py
index e4b2cd64382..a2cc4575a55 100644
--- a/onnx/backend/test/case/node/mvn.py
+++ b/onnx/backend/test/case/node/meanvariancenormalization.py
@@ -10,7 +10,7 @@
from . import expect
-class MVN(Base):
+class MeanVarianceNormalization(Base):
@staticmethod
def export(): # type: () -> None
diff --git a/onnx/backend/test/case/node/topk.py b/onnx/backend/test/case/node/topk.py
index 6f5d60b452f..50dbbc29cee 100644
--- a/onnx/backend/test/case/node/topk.py
+++ b/onnx/backend/test/case/node/topk.py
@@ -16,15 +16,15 @@ class TopK(Base):
def export_top_k(): # type: () -> None
node = onnx.helper.make_node(
'TopK',
- inputs=['x'],
+ inputs=['x', 'k'],
outputs=['values', 'indices'],
- k=3
)
X = np.array([
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
], dtype=np.float32)
+ K = np.array([3], dtype=np.int64)
values_ref = np.array([
[3, 2, 1],
[7, 6, 5],
@@ -36,5 +36,5 @@ def export_top_k(): # type: () -> None
[3, 2, 1],
], dtype=np.int64)
- expect(node, inputs=[X], outputs=[values_ref, indices_ref],
+ expect(node, inputs=[X, K], outputs=[values_ref, indices_ref],
name='test_top_k')
diff --git a/onnx/backend/test/data/node/test_top_k/model.onnx b/onnx/backend/test/data/node/test_top_k/model.onnx
index b13bfe8f801..b6f9c258a02 100644
--- a/onnx/backend/test/data/node/test_top_k/model.onnx
+++ b/onnx/backend/test/data/node/test_top_k/model.onnx
@@ -1,12 +1,16 @@
-backend-test:|
-$
-xvaluesindices"TopK*
-k
+backend-test:†
+
+x
+kvaluesindices"TopK
test_top_kZ
x
-b
+Z
+k
+
+
+b
values
@@ -14,4 +18,4 @@ test_top_kZ
indices
-B
\ No newline at end of file
+B
diff --git a/onnx/backend/test/data/node/test_top_k/test_data_set_0/input_1.pb b/onnx/backend/test/data/node/test_top_k/test_data_set_0/input_1.pb
new file mode 100644
index 00000000000..1e2082fc759
Binary files /dev/null and b/onnx/backend/test/data/node/test_top_k/test_data_set_0/input_1.pb differ
diff --git a/onnx/backend/test/runner/__init__.py b/onnx/backend/test/runner/__init__.py
index dce33a05d2c..a3d9433ab8f 100644
--- a/onnx/backend/test/runner/__init__.py
+++ b/onnx/backend/test/runner/__init__.py
@@ -163,20 +163,23 @@ def tests(self): # type: () -> Type[unittest.TestCase]
setattr(tests, name, item.func)
return tests
- @staticmethod
- def _assert_similar_outputs(ref_outputs, outputs, rtol, atol): # type: (Sequence[Any], Sequence[Any], float, float) -> None
+ @classmethod
+ def assert_similar_outputs(cls, ref_outputs, outputs, rtol, atol): # type: (Sequence[Any], Sequence[Any], float, float) -> None
np.testing.assert_equal(len(ref_outputs), len(outputs))
for i in range(len(outputs)):
np.testing.assert_equal(ref_outputs[i].dtype, outputs[i].dtype)
- np.testing.assert_allclose(
- ref_outputs[i],
- outputs[i],
- rtol=rtol,
- atol=atol)
+ if ref_outputs[i].dtype == np.object:
+ np.testing.assert_array_equal(ref_outputs[i], outputs[i])
+ else:
+ np.testing.assert_allclose(
+ ref_outputs[i],
+ outputs[i],
+ rtol=rtol,
+ atol=atol)
- @staticmethod
+ @classmethod
@retry_excute(3)
- def _download_model(model_test, model_dir, models_dir): # type: (TestCase, Text, Text) -> None
+ def download_model(cls, model_test, model_dir, models_dir): # type: (TestCase, Text, Text) -> None
# On Windows, NamedTemporaryFile can not be opened for a
# second time
download_file = tempfile.NamedTemporaryFile(delete=False)
@@ -196,8 +199,8 @@ def _download_model(model_test, model_dir, models_dir): # type: (TestCase, Text
finally:
os.remove(download_file.name)
- @staticmethod
- def _prepare_model_data(model_test): # type: (TestCase) -> Text
+ @classmethod
+ def prepare_model_data(cls, model_test): # type: (TestCase) -> Text
onnx_home = os.path.expanduser(os.getenv('ONNX_HOME', os.path.join('~', '.onnx')))
models_dir = os.getenv('ONNX_MODELS',
os.path.join(onnx_home, 'models'))
@@ -214,7 +217,7 @@ def _prepare_model_data(model_test): # type: (TestCase) -> Text
break
os.makedirs(model_dir)
- Runner._download_model(model_test=model_test, model_dir=model_dir, models_dir=models_dir)
+ cls.download_model(model_test=model_test, model_dir=model_dir, models_dir=models_dir)
return model_dir
def _add_test(self,
@@ -262,7 +265,7 @@ def _add_model_test(self, model_test, kind): # type: (TestCase, Text) -> None
def run(test_self, device): # type: (Any, Text) -> None
if model_test.model_dir is None:
- model_dir = Runner._prepare_model_data(model_test)
+ model_dir = self.prepare_model_data(model_test)
else:
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
@@ -282,9 +285,9 @@ def run(test_self, device): # type: (Any, Text) -> None
inputs = list(test_data['inputs'])
outputs = list(prepared_model.run(inputs))
ref_outputs = test_data['outputs']
- self._assert_similar_outputs(ref_outputs, outputs,
- rtol=model_test.rtol,
- atol=model_test.atol)
+ self.assert_similar_outputs(ref_outputs, outputs,
+ rtol=model_test.rtol,
+ atol=model_test.atol)
for test_data_dir in glob.glob(
os.path.join(model_dir, "test_data_set*")):
@@ -305,8 +308,8 @@ def run(test_self, device): # type: (Any, Text) -> None
tensor.ParseFromString(f.read())
ref_outputs.append(numpy_helper.to_array(tensor))
outputs = list(prepared_model.run(inputs))
- self._assert_similar_outputs(ref_outputs, outputs,
- rtol=model_test.rtol,
- atol=model_test.atol)
+ self.assert_similar_outputs(ref_outputs, outputs,
+ rtol=model_test.rtol,
+ atol=model_test.atol)
self._add_test(kind + 'Model', model_test.name, run, model_marker)
diff --git a/onnx/backend/test/stat_coverage.py b/onnx/backend/test/stat_coverage.py
index 28f25efa64b..25c2534e902 100644
--- a/onnx/backend/test/stat_coverage.py
+++ b/onnx/backend/test/stat_coverage.py
@@ -129,12 +129,12 @@ def gen_model_test_coverage(schemas, f, ml):
schema_dict = dict()
for schema in schemas:
schema_dict[schema.name] = schema
- # Load models from each model test using Runner._prepare_model_data
+ # Load models from each model test using Runner.prepare_model_data
# Need to grab associated nodes
attrs = dict() # type: Dict[Text, Dict[Text, List[Any]]]
model_paths = [] # type: List[Any]
for rt in load_model_tests(kind='real'):
- model_dir = Runner._prepare_model_data(rt)
+ model_dir = Runner.prepare_model_data(rt)
model_paths.append(os.path.join(model_dir, 'model.onnx'))
model_paths.sort()
model_written = False
diff --git a/onnx/checker.cc b/onnx/checker.cc
index d20dd23b709..cfbacb1676f 100644
--- a/onnx/checker.cc
+++ b/onnx/checker.cc
@@ -342,25 +342,15 @@ void check_node(
const auto* schema = ctx.get_schema_registry()->GetSchema(
node.op_type(), domain_version, node.domain());
- if (!schema || schema->Deprecated()) {
- // There's no primitive operator for the node.
- // Check whether it's referring to a function.
- auto func_registry = ctx.get_func_registry();
- if (nullptr == func_registry) {
+ if (!schema) {
fail_check(
- "No Op or Function registered for " + node.op_type() +
+ "No Op registered for " + node.op_type() +
" with domain_version of " +
ONNX_NAMESPACE::to_string(domain_version));
- }
- auto func = func_registry->GetFunction(
- node.op_type(), domain_version, node.domain());
- if (nullptr == func) {
- fail_check(
- "No Op or Function registered for " + node.op_type() +
- " with domain_version of " +
- ONNX_NAMESPACE::to_string(domain_version));
- }
- VerifyFunctionNode(node, *func, ctx, lex_ctx);
+ } else if (schema->Deprecated()) {
+ fail_check(
+ "Op registered for " + node.op_type() + " is depracted in domain_version of " +
+ ONNX_NAMESPACE::to_string(domain_version));
} else {
schema->Verify(node);
}
@@ -597,20 +587,6 @@ void check_model(const ModelProto& model) {
check_model(model, ctx);
}
-void VerifyFunctionNode(
- const NodeProto& node,
- const FunctionProto& func,
- const CheckerContext& ctx,
- const LexicalScopeContext& lex_ctx) {
- // Create a temporary graphproto to hold the expanded subgraph
- GraphProto g;
- g.set_name("func_" + func.name() + "_expanded_subgraph");
- // To Generate unique internal tensor names
- // while preserving node's input/output names
- FunctionExpandHelper(node, func, g);
- check_graph(g, ctx, lex_ctx);
-}
-
#undef fail_check
#undef enforce_has_field
#undef enforce_has_repeated_field
diff --git a/onnx/checker.h b/onnx/checker.h
index e80d4df8832..121066cc88b 100644
--- a/onnx/checker.h
+++ b/onnx/checker.h
@@ -62,14 +62,6 @@ class CheckerContext final {
return schema_registry_;
}
- void set_func_registry(const IFunctionBuilderRegistry* func_registry) {
- func_registry_ = func_registry;
- }
-
- const IFunctionBuilderRegistry* get_func_registry() const {
- return func_registry_;
- }
-
void set_model_dir(const std::string& model_dir){
model_dir_ = model_dir;
}
@@ -85,8 +77,6 @@ class CheckerContext final {
std::unordered_map opset_imports_;
bool is_main_graph_ = true;
const ISchemaRegistry* schema_registry_ = OpSchemaRegistry::Instance();
- const IFunctionBuilderRegistry* func_registry_ =
- &FunctionBuilderRegistry::OnnxInstance();
std::string model_dir_;
};
@@ -117,11 +107,5 @@ void check_function(
void check_model(const ModelProto& model);
void check_model(const std::string& model_path);
-void VerifyFunctionNode(
- const NodeProto&,
- const FunctionProto&,
- const CheckerContext&,
- const LexicalScopeContext&);
-
} // namespace checker
} // namespace ONNX_NAMESPACE
diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc
index f8fed8888eb..c4021f9c48b 100644
--- a/onnx/cpp2py_export.cc
+++ b/onnx/cpp2py_export.cc
@@ -45,9 +45,15 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
&OpSchema::has_type_and_shape_inference_function)
.def_property_readonly(
"type_constraints", &OpSchema::typeConstraintParams)
- .def_static("is_infinite", [](int v) {
- return v == std::numeric_limits::max();
- });
+ .def_static(
+ "is_infinite",
+ [](int v) { return v == std::numeric_limits::max(); })
+ .def_property_readonly("has_function", &OpSchema::HasFunction)
+ .def_property_readonly("_function_body", [](OpSchema* op) -> py::bytes {
+ std::string bytes = "";
+ if (op->HasFunction())
+ op->GetFunction()->SerializeToString(&bytes);
+ return py::bytes(bytes);});
py::class_(op_schema, "Attribute")
.def_readonly("name", &OpSchema::Attribute::name)
@@ -151,34 +157,6 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
return OpSchemaRegistry::get_all_schemas_with_history();
});
- defs.def(
- "get_all_functions",
- [](const std::string& domain)
- -> std::unordered_map> {
- std::multimap temp_ptr_map;
- std::unordered_map> temp_map;
- FunctionBuilderRegistry& function_registry =
- FunctionBuilderRegistry::OnnxInstance();
-
- Common::Status status =
- function_registry.GetFunctions(domain, &temp_ptr_map);
- if (!status.IsOK()) {
- throw std::runtime_error(
- "Failed to retrieve function list for domain '" + domain + "'!");
- }
- for (auto iter = temp_ptr_map.begin(); iter != temp_ptr_map.end();
- ++iter) {
- std::string bytes;
- if (!iter->second->SerializeToString(&bytes)) {
- throw std::runtime_error(
- "Failed to serialize registered function for '" + iter->first +
- "'!");
- }
- temp_map[iter->first].emplace_back(py::bytes(bytes));
- }
- return temp_map;
- });
-
// Submodule `checker`
auto checker = onnx_cpp2py_export.def_submodule("checker");
checker.doc() = "Checker submodule";
@@ -245,7 +223,9 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
checker::check_model(proto);
});
- checker.def("check_model_path", (void (*) (const std::string&)) &checker::check_model);
+ checker.def(
+ "check_model_path",
+ (void (*)(const std::string&)) & checker::check_model);
// Submodule `optimizer`
auto optimizer = onnx_cpp2py_export.def_submodule("optimizer");
@@ -285,8 +265,7 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
ModelProto proto{};
ParseProtoFromPyBytes(&proto, bytes);
shape_inference::InferShapes(proto);
- auto result =
- version_conversion::ConvertVersion(proto, target);
+ auto result = version_conversion::ConvertVersion(proto, target);
std::string out;
result.SerializeToString(&out);
return py::bytes(out);
diff --git a/onnx/defs/__init__.py b/onnx/defs/__init__.py
index 1a7ce63f837..5dd3b1ebf5f 100644
--- a/onnx/defs/__init__.py
+++ b/onnx/defs/__init__.py
@@ -23,7 +23,15 @@ def onnx_opset_version(): # type: () -> int
return C.schema_version_map()[ONNX_DOMAIN][1]
-OpSchema = C.OpSchema
+@property # type: ignore
+def _Function_proto(self): # type: ignore
+ func_proto = FunctionProto()
+ func_proto.ParseFromString(self._function_body)
+ return func_proto
+
+
+OpSchema = C.OpSchema # type: ignore
+C.OpSchema.function_body = _Function_proto # type: ignore
@property # type: ignore
@@ -36,12 +44,6 @@ def _Attribute_default_value(self): # type: ignore
OpSchema.Attribute.default_value = _Attribute_default_value # type: ignore
-def get_functions(domain=ONNX_DOMAIN): # type: ignore
- function_map = defaultdict(list) # type: Dict[int, List[FunctionProto]]
- function_byte_map = C.get_all_functions(domain) # type: ignore
- for function_name, raw_functions in function_byte_map.items():
- for function_bytes in raw_functions:
- function_proto = FunctionProto()
- function_proto.ParseFromString(function_bytes)
- function_map[function_name].append(function_proto)
- return function_map
+def get_function_ops(): # type: () -> List[OpSchema]
+ schemas = C.get_all_schemas()
+ return [schema for schema in schemas if schema.has_function] # type: ignore
diff --git a/onnx/defs/experiments/defs.cc b/onnx/defs/experiments/defs.cc
index efaa5fc6a51..94d27023022 100644
--- a/onnx/defs/experiments/defs.cc
+++ b/onnx/defs/experiments/defs.cc
@@ -12,28 +12,6 @@ using SupportType = ONNX_NAMESPACE::OpSchema::SupportType;
// do not need to implement these ops. An experimental op should be either removed
// or promoted after a while. In this file, a default since_version "1" is used for all exp ops.
-static const char* Affine_ver1_doc = R"DOC(
-Affine takes one input data (Tensor) and produces one output data
-(Tensor) where the affine function, y = alpha * x + beta,
-is applied to the tensor elementwise.
-)DOC";
-
-ONNX_OPERATOR_SET_SCHEMA(
- Affine,
- 1,
- OpSchema()
- .SetSupportLevel(SupportType::EXPERIMENTAL)
- .SetDoc(Affine_ver1_doc)
- .Attr("alpha", "Value of alpha", AttributeProto::FLOAT, 1.0f)
- .Attr("beta", "Value of beta", AttributeProto::FLOAT, 0.0f)
- .Input(0, "X", "1D input tensor", "T")
- .Output(0, "Y", "1D output tensor", "T")
- .TypeConstraint(
- "T",
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
- "Constrain input and output types to float tensors.")
- .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
-
static const char* ThresholdedRelu_ver1_doc = R"DOC(
ThresholdedRelu takes one input data (Tensor) and produces one output data
(Tensor) where the rectified linear function, y = x for x > alpha, y = 0 otherwise,
@@ -81,28 +59,6 @@ ONNX_OPERATOR_SET_SCHEMA(
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
-static const char* ParametricSoftplus_ver1_doc = R"DOC(
-ParametricSoftplus takes one input data (Tensor) and produces one output data
-(Tensor) where the softplus function, y = alpha * ln(exp(beta * x) + 1), is applied to
-the tensor elementwise.
-)DOC";
-
-ONNX_OPERATOR_SET_SCHEMA(
- ParametricSoftplus,
- 1,
- OpSchema()
- .SetSupportLevel(SupportType::EXPERIMENTAL)
- .SetDoc(ParametricSoftplus_ver1_doc)
- .Attr("alpha", "Value of alpha", AttributeProto::FLOAT, OPTIONAL)
- .Attr("beta", "Value of beta", AttributeProto::FLOAT, OPTIONAL)
- .Input(0, "X", "1D input tensor", "T")
- .Output(0, "Y", "1D input tensor", "T")
- .TypeConstraint(
- "T",
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
- "Constrain input and output types to float tensors.")
- .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
-
ONNX_OPERATOR_SET_SCHEMA(
GivenTensorFill,
1,
@@ -240,67 +196,6 @@ ONNX_OPERATOR_SET_SCHEMA(
"tensor(double)"},
"Constrain output types to bool, int32, int64, float16, float, double tensors."));
-static const char* ImageScaler_ver1_doc =
- R"DOC(Scale and bias the input image. Bias values are stored in
-the same ordering as the image pixel format.)DOC";
-
-ONNX_OPERATOR_SET_SCHEMA(
- ImageScaler,
- 1,
- OpSchema()
- .SetSupportLevel(SupportType::EXPERIMENTAL)
- .SetDoc(ImageScaler_ver1_doc)
- .Attr(
- "bias",
- "Bias applied to each channel, same size as C.",
- AttributeProto::FLOATS,
- OPTIONAL)
- .Attr(
- "scale",
- "The scale to apply.",
- AttributeProto::FLOAT,
- 1.0f)
- .Input(0, "input", "Input tensor of shape [N,C,H,W]", "T")
- .Output(0, "output", "Result, has same shape and type as input", "T")
- .TypeConstraint(
- "T",
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
- "Constrain input and output types to float tensors.")
- .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
-
-static const char* Crop_ver1_doc =
- R"DOC(Crop and image to the specified spatial dimensions. If scale is given,
-then optionally start the crop offset by the left/top border amounts.
-If scale is not provided, crop the borders as provided.)DOC";
-
-ONNX_OPERATOR_SET_SCHEMA(
- Crop,
- 1,
- OpSchema()
- .SetSupportLevel(SupportType::EXPERIMENTAL)
- .SetDoc(Crop_ver1_doc)
- .Attr(
- "border",
- "A 1-D values of (leftBorder, topBorder, rightBorder, bottomBorder).",
- AttributeProto::INTS,
- OPTIONAL)
- .Attr(
- "scale",
- "A 1-D values of (height, width).",
- AttributeProto::INTS,
- OPTIONAL)
- .Input(0, "input", "Input tensor of shape [N,C,H,W]", "T")
- .Output(
- 0,
- "output",
- "Result, has same type as input, with H and W dimensions reduced.",
- "T")
- .TypeConstraint(
- "T",
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
- "Constrain input and output types to float tensors."));
-
-
static const char* DynamicSlice_ver1_doc = R"DOC(
Produces a slice of the input tensor along multiple axes. Similar to numpy:
https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
diff --git a/onnx/defs/experiments/experiments_functions.cc b/onnx/defs/experiments/functions.cc
similarity index 58%
rename from onnx/defs/experiments/experiments_functions.cc
rename to onnx/defs/experiments/functions.cc
index bc07f3d4fe7..92205771f44 100644
--- a/onnx/defs/experiments/experiments_functions.cc
+++ b/onnx/defs/experiments/functions.cc
@@ -3,38 +3,16 @@
#include "onnx/common/constants.h"
#include "onnx/common/model_helpers.h"
-#include "onnx/defs/function.h"
+#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
-static Common::Status BuildMVN(std::unique_ptr* func_proto) {
- if (nullptr == func_proto) {
- return Common::Status(
- Common::CHECKER,
- Common::INVALID_ARGUMENT,
- "func_proto should not be nullptr.");
- }
-
- func_proto->reset(new FunctionProto);
- auto& func = **func_proto;
- func.set_name("MeanVarianceNormalization");
- func.set_doc_string(
- "A MeanVarianceNormalization Function: Perform mean variance normalization "
- "on the input tensor X using formula:
``` (X-EX)/sqrt(E(X-EX)^2) ```
"
- "INPUT: X(float/float16/double) with shape [N,C,W,H] or N-D shape
"
- "ATTRIBUTE:
axes: will be passed to ReducedMean "
- "Ops. Use [0,2,3] (without C axis for N-D cases) for calculating means and variances "
- "along channels. Two variables with the same C-coordinate are associated "
- "with the same mean and variance. Use [0,1,2,3] (with C axis) to calculate "
- "global mean and global variance with all variables sharing the same mean/variance.
"
- " (The KeepDims attribute in ReducedMean is set to true for calculation)
"
- "
OUTPUT: X_MVN(float/float16/double) with the same shape as input X
");
- func.set_since_version(9);
- func.add_input("X");
- func.add_output("X_MVN");
- func.add_attribute("axes");
- func.set_status(OperatorStatus::STABLE);
-
- NodeProto* initial_node0 = func.add_node();
+using SupportType = OpSchema::SupportType;
+using SupportType = ONNX_NAMESPACE::OpSchema::SupportType;
+
+static std::vector BuildMVNFunctionBody() {
+ std::vector function_nodes;
+
+ NodeProto initial_node0;
BuildNode(
"Pow_exponent_0",
ONNX_DOMAIN,
@@ -42,8 +20,8 @@ static Common::Status BuildMVN(std::unique_ptr* func_proto) {
"Constant",
std::vector{},
std::vector{"Exponent"},
- initial_node0);
- AttributeProto* value_attr_0 = initial_node0->add_attribute();
+ &initial_node0);
+ AttributeProto* value_attr_0 = initial_node0.add_attribute();
value_attr_0->set_name("value");
value_attr_0->set_doc_string(
"Exponent (default to 2.0) to element-wisely calculate the square of a tensor");
@@ -51,8 +29,9 @@ static Common::Status BuildMVN(std::unique_ptr* func_proto) {
TensorProto* tensor_proto_0 = value_attr_0->mutable_t();
tensor_proto_0->set_data_type(TensorProto_DataType_FLOAT);
tensor_proto_0->add_float_data(2.0); // [2.0]
+ function_nodes.emplace_back(initial_node0);
- NodeProto* initial_node1 = func.add_node();
+ NodeProto initial_node1;
BuildNode(
"Div_epsilon_0",
ONNX_DOMAIN,
@@ -60,8 +39,8 @@ static Common::Status BuildMVN(std::unique_ptr* func_proto) {
"Constant",
std::vector{},
std::vector{"Epsilon"},
- initial_node1);
- AttributeProto* value_attr_1 = initial_node1->add_attribute();
+ &initial_node1);
+ AttributeProto* value_attr_1 = initial_node1.add_attribute();
value_attr_1->set_name("value");
value_attr_1->set_doc_string(
"Epsilon (default to 1e-9) to element-wisely add to the divisor tensor");
@@ -69,8 +48,9 @@ static Common::Status BuildMVN(std::unique_ptr* func_proto) {
TensorProto* tensor_proto_1 = value_attr_1->mutable_t();
tensor_proto_1->set_data_type(TensorProto_DataType_FLOAT);
tensor_proto_1->add_float_data((float)1e-9); // [1e-9]
+ function_nodes.emplace_back(initial_node1);
- NodeProto* node0 = func.add_node();
+ NodeProto node0;
BuildNode(
"Reduced_Mean_0",
ONNX_DOMAIN,
@@ -78,13 +58,14 @@ static Common::Status BuildMVN(std::unique_ptr* func_proto) {
"ReduceMean",
std::vector{"X"},
std::vector{"X_RM"},
- node0);
- AttributeProto* attr0 = node0->add_attribute();
+ &node0);
+ AttributeProto* attr0 = node0.add_attribute();
attr0->set_ref_attr_name("axes");
attr0->set_name("axes");
attr0->set_type(AttributeProto_AttributeType_INTS);
+ function_nodes.emplace_back(node0);
- NodeProto* node1 = func.add_node();
+ NodeProto node1;
BuildNode(
"Pow_0",
ONNX_DOMAIN,
@@ -92,9 +73,10 @@ static Common::Status BuildMVN(std::unique_ptr* func_proto) {
"Pow",
std::vector{"X_RM", "Exponent"},
std::vector{"EX_squared"},
- node1);
+ &node1);
+ function_nodes.emplace_back(node1);
- NodeProto* node2 = func.add_node();
+ NodeProto node2;
BuildNode(
"Pow_1",
ONNX_DOMAIN,
@@ -102,9 +84,10 @@ static Common::Status BuildMVN(std::unique_ptr* func_proto) {
"Pow",
std::vector{"X", "Exponent"},
std::vector{"X_squared"},
- node2);
+ &node2);
+ function_nodes.emplace_back(node2);
- NodeProto* node3 = func.add_node();
+ NodeProto node3;
BuildNode(
"Reduced_Mean_1",
ONNX_DOMAIN,
@@ -112,13 +95,14 @@ static Common::Status BuildMVN(std::unique_ptr* func_proto) {
"ReduceMean",
std::vector{"X_squared"},
std::vector{"E_Xsquared"},
- node3);
- AttributeProto* attr1 = node3->add_attribute();
+ &node3);
+ AttributeProto* attr1 = node3.add_attribute();
attr1->set_ref_attr_name("axes");
attr1->set_name("axes");
attr1->set_type(AttributeProto_AttributeType_INTS);
+ function_nodes.emplace_back(node3);
- NodeProto* node4 = func.add_node();
+ NodeProto node4;
BuildNode(
"SUB_0",
ONNX_DOMAIN,
@@ -126,9 +110,10 @@ static Common::Status BuildMVN(std::unique_ptr* func_proto) {
"Sub",
std::vector{"E_Xsquared", "EX_squared"},
std::vector{"Variance"},
- node4);
+ &node4);
+ function_nodes.emplace_back(node4);
- NodeProto* node5 = func.add_node();
+ NodeProto node5;
BuildNode(
"SQRT_0",
ONNX_DOMAIN,
@@ -136,9 +121,10 @@ static Common::Status BuildMVN(std::unique_ptr* func_proto) {
"Sqrt",
std::vector{"Variance"},
std::vector{"STD"},
- node5);
+ &node5);
+ function_nodes.emplace_back(node5);
- NodeProto* node6 = func.add_node();
+ NodeProto node6;
BuildNode(
"SUB_1",
ONNX_DOMAIN,
@@ -146,9 +132,10 @@ static Common::Status BuildMVN(std::unique_ptr* func_proto) {
"Sub",
std::vector{"X", "X_RM"},
std::vector{"X_variance"},
- node6);
+ &node6);
+ function_nodes.emplace_back(node6);
- NodeProto* node7 = func.add_node();
+ NodeProto node7;
BuildNode(
"ADD_0",
ONNX_DOMAIN,
@@ -156,24 +143,49 @@ static Common::Status BuildMVN(std::unique_ptr* func_proto) {
"Add",
std::vector{"STD", "Epsilon"},
std::vector{"Processed_STD"},
- node7);
+ &node7);
+ function_nodes.emplace_back(node7);
- NodeProto* node8 = func.add_node();
+ NodeProto node8;
BuildNode(
"DIV_0",
ONNX_DOMAIN,
"Calculate MVN-ed tensor for output",
"Div",
std::vector{"X_variance", "Processed_STD"},
- std::vector{"X_MVN"},
- node8);
+ std::vector{"Y"},
+ &node8);
+ function_nodes.emplace_back(node8);
- return Common::Status::OK();
+ return function_nodes;
}
-ONNX_FUNCTION_BUILD(
+static const char* mvn_ver9_doc = R"DOC(
+ A MeanVarianceNormalization Function: Perform mean variance normalization
+ on the input tensor X using formula:
``` (X-EX)/sqrt(E(X-EX)^2) ```
+)DOC";
+
+ONNX_OPERATOR_SET_SCHEMA(
MeanVarianceNormalization,
9,
- FunctionBuilder().SetDomain(ONNX_DOMAIN).SetBuildFunction(BuildMVN));
+ OpSchema()
+ .SetSupportLevel(SupportType::COMMON)
+ .SetDoc(mvn_ver9_doc)
+ .Input(0, "X", "Input tensor", "T")
+ .Output(0, "Y", "Output tensor", "T")
+ .Attr(
+ "axes",
+ "A list of integers, along which to reduce. The default is to reduce over "
+ "all the dimensions of the input tensor. Use [0,2,3] (without C axis for "
+ "N-D cases) for calculating means and variances along channels. Two "
+ "variables with the same C-coordinate are associated "
+ "with the same mean and variance.",
+ AttributeProto::INTS,
+ OPTIONAL)
+ .TypeConstraint(
+ "T",
+ {"tensor(float16)", "tensor(float)", "tensor(double)"},
+ "Constrain input and output types to all numeric tensors.")
+ .FunctionBody(BuildMVNFunctionBody()));
} // namespace ONNX_NAMESPACE
diff --git a/onnx/defs/function.cc b/onnx/defs/function.cc
index 4dca85b33c8..6e1c114148c 100644
--- a/onnx/defs/function.cc
+++ b/onnx/defs/function.cc
@@ -2,148 +2,9 @@
// Licensed under the MIT license.
#include "onnx/defs/function.h"
-#include "onnx/checker.h"
-#include "onnx/defs/operator_sets.h"
-#include "onnx/defs/schema.h"
#include "onnx/string_utils.h"
namespace ONNX_NAMESPACE {
-using namespace checker;
-FunctionBuilder& FunctionBuilder::SetDomain(const std::string& domain) {
- domain_ = domain;
- return *this;
-}
-
-const std::string& FunctionBuilder::GetDomain() const {
- return domain_;
-}
-
-FunctionBuilder& FunctionBuilder::SetBuildFunction(BuildFunction build_func) {
- build_func_ = build_func;
- return *this;
-}
-
-BuildFunction FunctionBuilder::GetBuildFunction() const {
- return build_func_;
-}
-
-Common::Status FunctionBuilderRegistry::Register(
- const FunctionBuilder& function_builder) {
- std::lock_guard lock(mutex_);
- function_builders.push_back(function_builder);
- std::unique_ptr function_proto;
- auto status = function_builder.GetBuildFunction()(&function_proto);
- if (!status.IsOK()) {
- return status;
- }
-
- CheckerContext ctx;
- std::unordered_map op_set;
- auto version_range =
- OpSchemaRegistry::DomainToVersionRange::Instance().Map().at(
- function_builder.GetDomain());
- if (function_proto->since_version() > version_range.second ||
- function_proto->since_version() < version_range.first) {
- fail_check("Invalid function version in '", function_proto->name(), "'");
- }
- op_set.insert(
- {function_builder.GetDomain(), (int)function_proto->since_version()});
- ctx.set_opset_imports(op_set);
- ctx.set_is_main_graph(false);
- LexicalScopeContext lex_ctx;
- try {
- check_function(*function_proto, ctx, lex_ctx);
- } catch (ValidationError& ex) {
- return Common::Status(Common::CHECKER, Common::INVALID_PROTOBUF, ex.what());
- }
-
- auto& func_name = function_proto->name();
- // Check no op version conflicts.
- auto range =
- domain_functions_map[function_builder.GetDomain()].equal_range(func_name);
- for (auto i = range.first; i != range.second; ++i) {
- auto version = i->second->since_version();
- if (function_proto->since_version() == version) {
- return Common::Status(
- Common::CHECKER,
- Common::FAIL,
- ONNX_NAMESPACE::MakeString(
- "A function (",
- func_name,
- ") with version (",
- version,
- ") has already been registered."));
- }
- }
- domain_functions_map[function_builder.GetDomain()].emplace(
- func_name, std::move(function_proto));
- return Common::Status::OK();
-}
-
-// Get functions for specific domain.
-Common::Status FunctionBuilderRegistry::GetFunctions(
- const std::string& domain,
- /*out*/
- std::multimap* function_set) const {
- if (nullptr == function_set) {
- return Common::Status(
- Common::CHECKER,
- Common::INVALID_ARGUMENT,
- "function_set should not be nullptr.");
- }
-
-#ifndef __ONNX_DISABLE_STATIC_REGISTRATION
- static bool ONNX_UNUSED functionBuilder_registerer =
- (RegisterOnnxFunctionBuilder(), false);
-#endif
-
- auto function_name_map_iter = domain_functions_map.find(domain);
- if (function_name_map_iter != domain_functions_map.end()) {
- for (auto iter = function_name_map_iter->second.begin();
- iter != function_name_map_iter->second.end();
- ++iter) {
- function_set->emplace(iter->first, iter->second.get());
- }
- }
- return Common::Status::OK();
-}
-
-const FunctionProto* FunctionBuilderRegistry::GetFunction(
- const std::string& func_name,
- const int maxInclusiveVersion,
- const std::string& domain) const {
- std::multimap funcs;
- auto status = GetFunctions(domain, &funcs);
- if (!status.IsOK()) {
- return nullptr;
- }
- std::map version_to_func;
- auto range = funcs.equal_range(func_name);
- for (auto i = range.first; i != range.second; ++i) {
- version_to_func[static_cast(i->second->since_version())] =
- std::move(i->second);
- }
-
- if (version_to_func.empty()) {
- return nullptr;
- }
- auto pos = version_to_func.lower_bound(maxInclusiveVersion);
- if (version_to_func.begin() == pos && pos->first > maxInclusiveVersion) {
- return nullptr;
- }
- if (version_to_func.end() == pos || pos->first > maxInclusiveVersion) {
- // All versions are less than specified version, or,
- // The version is greater than specified version.
- pos--;
- }
- return pos->second;
-}
-
-FunctionBuilderRegistry& FunctionBuilderRegistry::OnnxInstance() {
- static FunctionBuilderRegistry func_builder_registry;
- return func_builder_registry;
-}
-
std::string InteralTensorNameGenerator(
const std::string& node_name,
const std::string& internal_name) {
diff --git a/onnx/defs/function.h b/onnx/defs/function.h
index 285636066a2..121a6607e58 100644
--- a/onnx/defs/function.h
+++ b/onnx/defs/function.h
@@ -13,125 +13,10 @@
#include "onnx/onnx-operators_pb.h"
namespace ONNX_NAMESPACE {
-
-typedef Common::Status (*BuildFunction)(std::unique_ptr*);
-
-class FunctionBuilder {
- public:
- FunctionBuilder& SetDomain(const std::string& domain);
- const std::string& GetDomain() const;
- FunctionBuilder& SetBuildFunction(BuildFunction build_func);
- BuildFunction GetBuildFunction() const;
-
- private:
- std::string domain_;
- BuildFunction build_func_;
-};
-
-class IFunctionBuilderRegistry {
- public:
- virtual ~IFunctionBuilderRegistry() = default;
-
- virtual const FunctionProto* GetFunction(
- const std::string& func_name,
- const int maxInclusiveVersion,
- const std::string& domain = ONNX_DOMAIN) const = 0;
-};
-
-class FunctionBuilderRegistry : public IFunctionBuilderRegistry {
- public:
- FunctionBuilderRegistry() = default;
-
- Common::Status Register(const FunctionBuilder& function_builder);
-
- // Get functions for specific domain.
- Common::Status GetFunctions(
- const std::string& domain,
- /*out*/
- std::multimap* function_set) const;
-
- const FunctionProto* GetFunction(
- const std::string& func_name,
- const int maxInclusiveVersion,
- const std::string& domain = ONNX_DOMAIN) const override;
-
- static FunctionBuilderRegistry& OnnxInstance();
-
- private:
- std::vector function_builders;
- std::unordered_map<
- std::string,
- std::multimap>>
- domain_functions_map;
- std::mutex mutex_;
-};
-
-template
-FunctionBuilder GetFunctionBuilder();
-
-#define ONNX_FUNCTION_BUILDER_CLASS_NAME(domain, ver, name) \
- name##_##domain##_ver##ver
-
-#define ONNX_FUNCTION_BUILD(name, ver, build_func) \
- ONNX_FUNCTION_BUILD_HELPER(name, Onnx, ONNX_DOMAIN, ver, build_func)
-
-#define ONNX_FUNCTION_BUILD_HELPER(name, domain, domain_str, ver, build_func) \
- class ONNX_FUNCTION_BUILDER_CLASS_NAME(domain, ver, name); \
- template <> \
- FunctionBuilder \
- GetFunctionBuilder() { \
- return build_func; \
- }
-
-#define ONNX_FUNCTION(function_builder) \
- ONNX_FUNCTION_UNIQ_HELPER(__COUNTER__, function_builder)
-
-#define ONNX_FUNCTION_UNIQ_HELPER(counter, function_builder) \
- ONNX_FUNCTION_UNIQ(counter, function_builder)
-
-#define ONNX_FUNCTION_UNIQ(counter, function_builder) \
- static Common::Status function_builder_##counter##_status = \
- FunctionBuilderRegistry::OnnxInstance().Register(function_builder);
-
-inline void RegisterOneFunctionBuilder(FunctionBuilder&& func_builder) {
- ONNX_FUNCTION(func_builder);
-}
-
-// Registers all function builder of a given operator set
-template
-void RegisterFunctionBuilder() {
- T::ForEachFunctionBuilder(RegisterOneFunctionBuilder);
-};
-
// Helper function to expand a function node given the function proto
void FunctionExpandHelper(
const NodeProto& node,
const FunctionProto& func,
GraphProto& g,
const std::string& node_prefix = "");
-
-// Example to register a function.
-// Common::Status BuildFc(std::unique_ptr* func_proto) {
-// if (nullptr == func_proto) {
-// return Status(
-// Common::CHECKER,
-// Common::INVALID_ARGUMENT,
-// "func_proto should not be nullptr.");
-// }
-//
-// func_proto->reset(new FunctionProto);
-// auto& func = **func_proto;
-// func.set_name("FC");
-// set function inputs.
-// set function outputs.
-// set function attributes.
-// set function description.
-// set function body (nodes).
-//
-// return Status::OK();
-//}
-//
-// ONNX_FUNCTION_BUILD(Name, Ver,
-// FunctionBuilder().SetDomain("").SetBuildFunction(BuildFc));
-
} // namespace ONNX_NAMESPACE
diff --git a/onnx/defs/gen_doc.py b/onnx/defs/gen_doc.py
index 5ea1a973cdd..bc30dad7f8b 100644
--- a/onnx/defs/gen_doc.py
+++ b/onnx/defs/gen_doc.py
@@ -79,12 +79,6 @@ def display_version_link(name, version): # type: (Text, int) -> Text
return '{}'.format(changelog_md, name_with_ver, name_with_ver)
-def display_function_version_link(name, version): # type: (Text, int) -> Text
- changelog_md = 'FunctionsChangelog' + ext
- name_with_ver = '{}-{}'.format(name, version)
- return '{}'.format(changelog_md, name_with_ver, name_with_ver)
-
-
def display_schema(schema, versions): # type: (OpSchema, Sequence[OpSchema]) -> Text
s = ''
@@ -203,57 +197,10 @@ def format_value(value): # type: (Any) -> Text
s += '{}\n'.format(type_constraint.description)
s += '\n'
- return s
-
-
-def display_function(function, versions, domain=ONNX_DOMAIN): # type: (FunctionProto, List[int], Text) -> Text
- s = ''
-
- if domain:
- domain_prefix = '{}.'.format(ONNX_ML_DOMAIN)
- else:
- domain_prefix = ''
-
- # doc
- if function.doc_string:
- s += '\n'
- s += '\n'.join(' ' + line
- for line in function.doc_string.lstrip().splitlines())
- s += '\n'
-
- # since version
- s += '\n#### Version\n'
- s += '\nThis version of the function has been available since version {}'.format(function.since_version)
- s += ' of {}.\n'.format(display_domain(domain_prefix))
- if len(versions) > 1:
- s += '\nOther versions of this function: {}\n'.format(
- ', '.join(display_function_version_link(domain_prefix + function.name, v) for v in versions if v != function.since_version))
-
- # inputs
- s += '\n#### Inputs'
- s += '\n\n'
- if function.input:
- s += '\n'
- for input in function.input:
- s += '- {};
\n'.format(input)
- s += '
\n'
-
- # outputs
- s += '\n#### Outputs'
- s += '\n\n'
- if function.output:
- s += '\n'
- for output in function.output:
- s += '- {};
\n'.format(output)
- s += '
\n'
-
- # attributes
- if function.attribute:
- s += '\n#### Attributes\n\n'
- s += '\n'
- for attr in function.attribute:
- s += '- {};
\n'.format(attr)
- s += '
\n'
+ # Function Body
+ if schema.has_function: # type: ignore
+ s += '\n#### Function\n'
+ s += '\nThe Function can be represented as a function.\n'
return s
@@ -263,11 +210,6 @@ def support_level_str(level): # type: (OpSchema.SupportType) -> Text
"experimental " if level == OpSchema.SupportType.EXPERIMENTAL else ""
-def function_status_str(status=OperatorStatus.Value("EXPERIMENTAL")): # type: ignore
- return \
- "experimental " if status == OperatorStatus.Value('EXPERIMENTAL') else "" # type: ignore
-
-
def main(args): # type: (Type[Args]) -> None
with io.open(args.changelog, 'w', newline='') as fout:
fout.write('## Operator Changelog\n')
@@ -300,45 +242,6 @@ def main(args): # type: (Type[Args]) -> None
fout.write(s)
- with io.open(args.fn_changelog, 'w', newline='') as fout:
- fout.write('## Function Changelog\n')
- fout.write(
- "*This file is automatically generated from the\n"
- " [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
- " Do not modify directly and instead edit function definitions.*\n")
-
- if ONNX_ML:
- all_functions = defs.get_functions(ONNX_ML_DOMAIN)
- else:
- all_functions = defs.get_functions('')
-
- changelog_versionmap = defaultdict(list) # type: Dict[int, List[FunctionProto]]
- for fn_name, functions in sorted(all_functions.items()):
- for func in functions:
- changelog_versionmap[func.since_version].append(func)
-
- if ONNX_ML:
- s = '## {}\n'.format(ONNX_ML_DOMAIN)
- domain_display_name = ONNX_ML_DOMAIN
- domain_prefix = '{}.'.format(ONNX_ML_DOMAIN)
- else:
- s = '# ai.onnx (default)\n'
- domain_display_name = 'ai.onnx (default)'
- domain_prefix = ''
- fout.write(s)
-
- for version, function_list in sorted(changelog_versionmap.items()):
- s = ""
- for function in function_list:
- s += '## Version {} of domain {}\n'.format(version, domain_display_name)
- name_with_ver = '{}-{}'.format(domain_prefix
- + fn_name, function.since_version)
- s += '### **{}**\n'.format(name_with_ver, name_with_ver)
- available_versions = [func.since_version for func in all_functions[function.name]]
- s += display_function(function, available_versions, domain_prefix)
- s += '\n'
- fout.write(s)
-
with io.open(args.output, 'w', newline='', encoding="utf-8") as fout:
fout.write('## Operator Schemas\n')
fout.write(
@@ -378,9 +281,21 @@ def main(args): # type: (Type[Args]) -> None
for domain, supportmap in operator_schemas:
s = '* {}\n'.format(display_domain_short(domain))
fout.write(s)
-
+ function_ops = list()
for _, namemap in supportmap:
for n, schema, versions in namemap:
+ if schema.has_function: # type: ignore
+ function_ops.append((n, schema, versions))
+ continue
+ s = ' * {}{}\n'.format(
+ support_level_str(schema.support_level),
+ format_name_with_domain(domain, n),
+ format_name_with_domain(domain, n))
+ fout.write(s)
+ if len(function_ops):
+ fout.write('\n')
+ fout.write(' **Operators with function registered:**\n')
+ for n, schema, versions in function_ops:
s = ' * {}{}\n'.format(
support_level_str(schema.support_level),
format_name_with_domain(domain, n),
@@ -424,54 +339,6 @@ def main(args): # type: (Type[Args]) -> None
fout.write(s)
- with io.open(args.function_output, 'w', newline='') as fout:
- fout.write('## Functions\n')
- fout.write(
- "*This file is automatically generated from the\n"
- " [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
- " Do not modify directly and instead edit function definitions.*\n")
-
- if ONNX_ML:
- all_functions = defs.get_functions(ONNX_ML_DOMAIN)
- else:
- all_functions = defs.get_functions('')
-
- if all_functions:
- if ONNX_ML:
- s = '## {}\n'.format(ONNX_ML_DOMAIN)
- domain_prefix = '{}.'.format(ONNX_ML_DOMAIN)
- else:
- s = '## ai.onnx (default)\n'
- domain_prefix = ''
- fout.write(s)
-
- existing_functions = set() # type: Set[Text]
- for function_name, functions in sorted(all_functions.items()):
- for function in sorted(functions, key=lambda s: s.since_version, reverse=True):
- if function.name in existing_functions:
- continue
- existing_functions.add(function.name)
- s = ' * {}{}\n'.format(
- function_status_str(function.status),
- domain_prefix + function.name, domain_prefix + function.name)
- fout.write(s)
-
- fout.write('\n')
-
- fout.write('\n\n')
-
- for function_name, functions in sorted(all_functions.items()):
- available_versions = [func.since_version for func in functions]
- function = sorted(functions, key=lambda s: s.since_version, reverse=True)[0]
- s = '### {}**{}**\n'.format(
- function_status_str(function.status),
- domain_prefix + function.name, domain_prefix + function.name.lower(),
- domain_prefix + function.name)
-
- s += display_function(function, available_versions, domain_prefix)
- s += '\n\n'
- fout.write(s)
-
if __name__ == '__main__':
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
@@ -479,7 +346,5 @@ def main(args): # type: (Type[Args]) -> None
class Args(object):
output = os.path.join(docs_dir, 'Operators' + ext)
- function_output = os.path.join(docs_dir, 'Functions' + ext)
changelog = os.path.join(docs_dir, 'Changelog' + ext)
- fn_changelog = os.path.join(docs_dir, 'FunctionsChangelog' + ext)
main(Args)
diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc
index c843f226ba6..cd0c7fcb8a6 100644
--- a/onnx/defs/math/defs.cc
+++ b/onnx/defs/math/defs.cc
@@ -823,7 +823,7 @@ ONNX_OPERATOR_SET_SCHEMA(
resultShape;
}));
-static const char* TopK_ver1_doc = R"DOC(
+static const char* TopK_ver10_doc = R"DOC(
Retrieve the top-K elements along a specified axis. Given an input tensor of
shape [a_1, a_2, ..., a_n, r] and integer argument k, return two outputs:
-Value tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n]
@@ -831,17 +831,18 @@ shape [a_1, a_2, ..., a_n, r] and integer argument k, return two outputs:
-Index tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which
contains the indices of the top k elements (original indices from the input
tensor).
-
+
Given two equivalent values, this operator uses the indices along the axis as
a tiebreaker. That is, the element with the lower index will appear first.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
TopK,
- 1,
+ 10,
OpSchema()
- .SetDoc(TopK_ver1_doc)
+ .SetDoc(TopK_ver10_doc)
.Input(0, "X", "Tensor of shape [a_1, a_2, ..., a_n, r]", "T")
+ .Input(1, "K", "A 1-D tensor containing a single positive value corresponding to the number of top elements to retrieve", "tensor(int64)")
.Output(
0,
"Values",
@@ -863,40 +864,47 @@ ONNX_OPERATOR_SET_SCHEMA(
"I",
{"tensor(int64)"},
"Constrain index tensor to int64")
- .Attr(
- "k",
- "Number of top elements to retrieve",
- AttributeProto::INT,
- true)
.Attr(
"axis",
"Dimension on which to do the sort.",
AttributeProto::INT,
static_cast(-1))
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
- // Type inference:
- propagateElemTypeFromInputToOutput(ctx, 0, 0);
- updateOutputElemType(ctx, 1, TensorProto::INT64);
+ // Type inference:
+ propagateElemTypeFromInputToOutput(ctx, 0, 0);
+ updateOutputElemType(ctx, 1, TensorProto::INT64);
- // Shape inference:
- if (!hasInputShape(ctx, 0))
- return;
- auto& input_shape = getInputShape(ctx, 0);
- int64_t rank = input_shape.dim_size();
- int64_t axis = getAttribute(ctx, "axis", -1);
- if (axis < 0)
- axis += rank;
- if (axis < 0 || axis >= rank)
- fail_shape_inference("Invalid value for attribute axis");
- int64_t k = getAttribute(ctx, "k", -1);
- if (k <= 0)
- fail_shape_inference("Invalid value for attribute k");
- // TODO: unclear what results should be if axis has less than k
- // elements.
+ // Shape inference:
+ if (!hasInputShape(ctx, 0))
+ return;
+ auto& input_shape = getInputShape(ctx, 0);
+ int64_t rank = input_shape.dim_size();
+ int64_t axis = getAttribute(ctx, "axis", -1);
+ if (axis < 0)
+ axis += rank;
+ if (axis < 0 || axis >= rank)
+ fail_shape_inference("Invalid value for attribute axis");
+ // TODO: unclear what results should be if axis has less than k
+ // elements.
+ // Infer output shape if 'K' is available
+ const auto* k = ctx.getInputData(1);
+ if (nullptr != k) {
+ if (k->dims_size() != 1 || k->int64_data_size() != 1 || k->data_type() != TensorProto::INT64)
+ fail_shape_inference("K input must be a one-dimensional tensor of size 1 and of type int64.");
TensorShapeProto result_shape = input_shape;
- result_shape.mutable_dim(static_cast(axis))->set_dim_value(k);
+ result_shape.mutable_dim(static_cast(axis))->set_dim_value(k->int64_data(0));
updateOutputShape(ctx, 0, result_shape);
updateOutputShape(ctx, 1, result_shape);
+ } else {
+ // Infer output shapes' rank in any case
+ auto* output_shape_0 = getOutputShape(ctx, 0);
+ auto* output_shape_1 = getOutputShape(ctx, 1);
+ for (int i = 0; i < input_shape.dim_size(); ++i) {
+ output_shape_0->add_dim();
+ output_shape_1->add_dim();
+ }
+ }
+ return;
}));
static const char* Sin_ver7_doc = R"DOC(
diff --git a/onnx/defs/math/old.cc b/onnx/defs/math/old.cc
index a6534fa5a65..5b00a8f5236 100644
--- a/onnx/defs/math/old.cc
+++ b/onnx/defs/math/old.cc
@@ -1221,4 +1221,78 @@ ONNX_OPERATOR_SET_SCHEMA(
resultShape;
}));
+static const char* TopK_ver1_doc = R"DOC(
+Retrieve the top-K elements along a specified axis. Given an input tensor of
+shape [a_1, a_2, ..., a_n, r] and integer argument k, return two outputs:
+ -Value tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n]
+ which contains the values of the top k elements along the specified axis
+ -Index tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which
+ contains the indices of the top k elements (original indices from the input
+ tensor).
+Given two equivalent values, this operator uses the indices along the axis as
+ a tiebreaker. That is, the element with the lower index will appear first.
+)DOC";
+
+ONNX_OPERATOR_SET_SCHEMA(
+ TopK,
+ 1,
+ OpSchema()
+ .SetDoc(TopK_ver1_doc)
+ .Input(0, "X", "Tensor of shape [a_1, a_2, ..., a_n, r]", "T")
+ .Output(
+ 0,
+ "Values",
+ "Tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] "
+ "containing top K values from the input tensor",
+ "T")
+ .Output(
+ 1,
+ "Indices",
+ "Tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] "
+ "containing the corresponding input tensor indices for the top K "
+ "values.",
+ "I")
+ .TypeConstraint(
+ "T",
+ {"tensor(float16)", "tensor(float)", "tensor(double)"},
+ "Constrain input and output types to float tensors.")
+ .TypeConstraint(
+ "I",
+ {"tensor(int64)"},
+ "Constrain index tensor to int64")
+ .Attr(
+ "k",
+ "Number of top elements to retrieve",
+ AttributeProto::INT,
+ true)
+ .Attr(
+ "axis",
+ "Dimension on which to do the sort.",
+ AttributeProto::INT,
+ static_cast(-1))
+ .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
+ // Type inference:
+ propagateElemTypeFromInputToOutput(ctx, 0, 0);
+ updateOutputElemType(ctx, 1, TensorProto::INT64);
+
+ // Shape inference:
+ if (!hasInputShape(ctx, 0))
+ return;
+ auto& input_shape = getInputShape(ctx, 0);
+ int64_t rank = input_shape.dim_size();
+ int64_t axis = getAttribute(ctx, "axis", -1);
+ if (axis < 0)
+ axis += rank;
+ if (axis < 0 || axis >= rank)
+ fail_shape_inference("Invalid value for attribute axis");
+ int64_t k = getAttribute(ctx, "k", -1);
+ if (k <= 0)
+ fail_shape_inference("Invalid value for attribute k");
+ // TODO: unclear what results should be if axis has less than k
+ // elements.
+ TensorShapeProto result_shape = input_shape;
+ result_shape.mutable_dim(static_cast(axis))->set_dim_value(k);
+ updateOutputShape(ctx, 0, result_shape);
+ updateOutputShape(ctx, 1, result_shape);
+ }));
} // namespace ONNX_NAMESPACE
diff --git a/onnx/defs/nn/defs.cc b/onnx/defs/nn/defs.cc
index 16734d58c83..46577c9cbcb 100644
--- a/onnx/defs/nn/defs.cc
+++ b/onnx/defs/nn/defs.cc
@@ -52,7 +52,8 @@ void convPoolTypeAndShapeInference(
}
// don't bother with legacy auto_pad for now
- if (ctx.getAttribute("auto_pad")) {
+ const auto* auto_pad_attr = ctx.getAttribute("auto_pad");
+ if ((nullptr != auto_pad_attr) && (auto_pad_attr->s() != "NOTSET")) {
return;
}
@@ -1246,7 +1247,7 @@ ONNX_OPERATOR_SET_SCHEMA(
propagateShapeAndTypeFromFirstInput(ctx);
}));
-static const char* Dropout_ver7_doc = R"DOC(
+static const char* Dropout_ver10_doc = R"DOC(
Dropout takes one input floating tensor and produces two tensor outputs,
output (floating tensor) and mask (`Tensor`). Depending on whether it is
in test mode or not, the output Y will either be a random dropout, or a simple
@@ -1256,9 +1257,9 @@ the training phase, so during testing nothing needs to be done.
ONNX_OPERATOR_SET_SCHEMA(
Dropout,
- 7,
+ 10,
OpSchema()
- .SetDoc(Dropout_ver7_doc + GenerateOptionalArgumentsDoc())
+ .SetDoc(Dropout_ver10_doc + GenerateOptionalArgumentsDoc())
.Attr(
"ratio",
"The ratio of random dropout",
@@ -1276,13 +1277,13 @@ ONNX_OPERATOR_SET_SCHEMA(
{"tensor(bool)"},
"Constrain output mask types to boolean tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
- propagateShapeAndTypeFromFirstInput(ctx);
- if (ctx.getNumOutputs() == 2) {
- updateOutputElemType(ctx, 1, TensorProto::BOOL);
- if (hasNInputShapes(ctx, 1)) {
- propagateShapeFromInputToOutput(ctx, 0, 1);
- }
+ propagateShapeAndTypeFromFirstInput(ctx);
+ if (ctx.getNumOutputs() == 2) {
+ updateOutputElemType(ctx, 1, TensorProto::BOOL);
+ if (hasNInputShapes(ctx, 1)) {
+ propagateShapeFromInputToOutput(ctx, 0, 1);
}
+ }
}));
static const char* Shrink_ver9_doc = R"DOC(
diff --git a/onnx/defs/nn/old.cc b/onnx/defs/nn/old.cc
index 7f45609affe..91b88ec8744 100644
--- a/onnx/defs/nn/old.cc
+++ b/onnx/defs/nn/old.cc
@@ -342,6 +342,33 @@ ONNX_OPERATOR_SET_SCHEMA(
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
+static const char* Dropout_ver7_doc = R"DOC(
+Dropout takes one input data (Tensor) and produces two Tensor outputs,
+output (Tensor) and mask (Tensor). Depending on whether it is in
+test mode or not, the output Y will either be a random dropout, or a simple
+copy of the input. Note that our implementation of Dropout does scaling in
+the training phase, so during testing nothing needs to be done.
+)DOC";
+
+ONNX_OPERATOR_SET_SCHEMA(
+ Dropout,
+ 7,
+ OpSchema()
+ .SetDoc(Dropout_ver7_doc + GenerateOptionalArgumentsDoc())
+ .Attr(
+ "ratio",
+ "The ratio of random dropout",
+ AttributeProto::FLOAT,
+ 0.5f)
+ .Input(0, "data", "The input data as Tensor.", "T")
+ .Output(0, "output", "The output.", "T")
+ .Output(1, "mask", "The output mask.", "T", OpSchema::Optional)
+ .TypeConstraint(
+ "T",
+ {"tensor(float16)", "tensor(float)", "tensor(double)"},
+ "Constrain input and output types to float tensors.")
+ .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
+
static const char* BatchNorm_ver6_doc = R"DOC(
Carries out batch normalization as described in the paper
https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,
@@ -515,111 +542,107 @@ static const char* BatchNormalization_ver7_doc = R"DOC(
Output case #1: Y, mean, var, saved_mean, saved_var (training mode)
Output case #2: Y (test mode)
)DOC";
-
- ONNX_OPERATOR_SET_SCHEMA(
- BatchNormalization,
- 7,
- OpSchema()
- .NumOutputs({1, 5})
- .SetDoc(BatchNormalization_ver7_doc + GenerateOptionalArgumentsDoc())
- .Attr(
- "spatial",
- "If true, compute the mean and variance across per activation. "
- "If false, compute the mean and variance across per feature over "
- "each mini-batch.",
- AttributeProto::INT,
- static_cast(1))
- .Attr(
- "epsilon",
- "The epsilon value to use to avoid division by zero.",
- AttributeProto::FLOAT,
- 1e-5f)
- .Attr(
- "momentum",
- "Factor used in computing the running mean and variance."
- "e.g., running_mean = running_mean * momentum + mean * (1 - momentum).",
- AttributeProto::FLOAT,
- 0.9f)
- .Input(
- 0,
- "X",
- "Input data tensor from the previous operator; "
- "dimensions for image case are (N x C x H x W), "
- "where N is the batch size, C is the number of "
- "channels, and H and W are the height and the "
- "width of the data. For non image case, the "
- "dimensions are in the form of "
- "(N x C x D1 x D2 ... Dn), where N is the batch "
- "size.",
- "T")
- .Input(
- 1,
- "scale",
- "If spatial is true, the dimension of scale is (C). "
- "If spatial is false, the dimensions of scale are "
- "(C x D1 x ... x Dn)",
- "T")
- .Input(
- 2,
- "B",
- "If spatial is true, the dimension of bias is (C). "
- "If spatial is false, the dimensions of bias are "
- "(C x D1 x ... x Dn)",
- "T")
- .Input(
- 3,
- "mean",
- "If spatial is true, the dimension of the running mean "
- "(training) or the estimated mean (testing) is (C). "
- "If spatial is false, the dimensions of the running mean "
- "(training) or the estimated mean (testing) are (C x D1 x ... x Dn).",
- "T")
- .Input(
- 4,
- "var",
- "If spatial is true, the dimension of the running variance"
- "(training) or the estimated variance (testing) is (C). "
- "If spatial is false, the dimensions of the running variance"
- "(training) or the estimated variance (testing) are (C x D1 x ... x Dn).",
- "T")
- .Output(
- 0,
- "Y",
- "The output tensor of the same shape as X",
- "T")
- .Output(
- 1,
- "mean",
- "The running mean after the BatchNormalization operator.",
- "T",
- OpSchema::Optional)
- .Output(
- 2,
- "var",
- "The running variance after the BatchNormalization operator.",
- "T",
- OpSchema::Optional)
- .Output(
- 3,
- "saved_mean",
- "Saved mean used during training to speed up gradient "
- "computation.",
- "T",
- OpSchema::Optional)
- .Output(
- 4,
- "saved_var",
- "Saved variance used during training to speed up "
- "gradient computation.",
- "T",
- OpSchema::Optional)
- .TypeConstraint(
- "T",
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
- "Constrain input and output types to float tensors.")
- .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
- propagateShapeAndTypeFromFirstInput(ctx);
- // TODO in training mode, it may be possible to infer some of
- // the other outputs as well.
- }));
+
+ONNX_OPERATOR_SET_SCHEMA(
+ BatchNormalization,
+ 7,
+ OpSchema()
+ .NumOutputs({1, 5})
+ .SetDoc(BatchNormalization_ver7_doc + GenerateOptionalArgumentsDoc())
+ .Attr(
+ "spatial",
+ "If true, compute the mean and variance across per activation. "
+ "If false, compute the mean and variance across per feature over "
+ "each mini-batch.",
+ AttributeProto::INT,
+ static_cast(1))
+ .Attr(
+ "epsilon",
+ "The epsilon value to use to avoid division by zero.",
+ AttributeProto::FLOAT,
+ 1e-5f)
+ .Attr(
+ "momentum",
+ "Factor used in computing the running mean and variance."
+ "e.g., running_mean = running_mean * momentum + mean * (1 - momentum).",
+ AttributeProto::FLOAT,
+ 0.9f)
+ .Input(
+ 0,
+ "X",
+ "Input data tensor from the previous operator; "
+ "dimensions for image case are (N x C x H x W), "
+ "where N is the batch size, C is the number of "
+ "channels, and H and W are the height and the "
+ "width of the data. For non image case, the "
+ "dimensions are in the form of "
+ "(N x C x D1 x D2 ... Dn), where N is the batch "
+ "size.",
+ "T")
+ .Input(
+ 1,
+ "scale",
+ "If spatial is true, the dimension of scale is (C). "
+ "If spatial is false, the dimensions of scale are "
+ "(C x D1 x ... x Dn)",
+ "T")
+ .Input(
+ 2,
+ "B",
+ "If spatial is true, the dimension of bias is (C). "
+ "If spatial is false, the dimensions of bias are "
+ "(C x D1 x ... x Dn)",
+ "T")
+ .Input(
+ 3,
+ "mean",
+ "If spatial is true, the dimension of the running mean "
+ "(training) or the estimated mean (testing) is (C). "
+ "If spatial is false, the dimensions of the running mean "
+ "(training) or the estimated mean (testing) are (C x D1 x ... x Dn).",
+ "T")
+ .Input(
+ 4,
+ "var",
+ "If spatial is true, the dimension of the running variance"
+ "(training) or the estimated variance (testing) is (C). "
+ "If spatial is false, the dimensions of the running variance"
+ "(training) or the estimated variance (testing) are (C x D1 x ... x Dn).",
+ "T")
+ .Output(0, "Y", "The output tensor of the same shape as X", "T")
+ .Output(
+ 1,
+ "mean",
+ "The running mean after the BatchNormalization operator.",
+ "T",
+ OpSchema::Optional)
+ .Output(
+ 2,
+ "var",
+ "The running variance after the BatchNormalization operator.",
+ "T",
+ OpSchema::Optional)
+ .Output(
+ 3,
+ "saved_mean",
+ "Saved mean used during training to speed up gradient "
+ "computation.",
+ "T",
+ OpSchema::Optional)
+ .Output(
+ 4,
+ "saved_var",
+ "Saved variance used during training to speed up "
+ "gradient computation.",
+ "T",
+ OpSchema::Optional)
+ .TypeConstraint(
+ "T",
+ {"tensor(float16)", "tensor(float)", "tensor(double)"},
+ "Constrain input and output types to float tensors.")
+ .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
+ propagateShapeAndTypeFromFirstInput(ctx);
+ // TODO in training mode, it may be possible to infer some of
+ // the other outputs as well.
+ }));
} // namespace ONNX_NAMESPACE
diff --git a/onnx/defs/operator_sets.h b/onnx/defs/operator_sets.h
index bb7ed1f808c..10742a2d913 100644
--- a/onnx/defs/operator_sets.h
+++ b/onnx/defs/operator_sets.h
@@ -3,7 +3,6 @@
#pragma once
-#include "onnx/defs/function.h"
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
@@ -13,7 +12,6 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, DynamicSlice);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, ATen);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, Abs);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, Add);
-class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, Affine);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, And);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, ArgMax);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, ArgMin);
@@ -26,7 +24,6 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, Concat);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, Constant);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, Conv);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, ConvTranspose);
-class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, Crop);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, DepthToSpace);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, Div);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, Dropout);
@@ -48,7 +45,6 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, HardSigmoid);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, Hardmax);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, Identity);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, If);
-class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, ImageScaler);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, InstanceNormalization);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, LRN);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, LSTM);
@@ -71,7 +67,6 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, Not);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, Or);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, PRelu);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, Pad);
-class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, ParametricSoftplus);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, Pow);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, RNN);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 1, RandomNormal);
@@ -125,7 +120,6 @@ class OpSet_Onnx_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
- fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
@@ -141,7 +135,6 @@ class OpSet_Onnx_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
- fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
@@ -169,8 +162,6 @@ class OpSet_Onnx_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
- fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
@@ -195,8 +186,6 @@ class OpSet_Onnx_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
- fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
- }
- static void ForEachFunctionBuilder(
- std::function fn) {
- fn(GetFunctionBuilder());
}
};
// Forward declarations for ai.onnx version 10
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 10, StringNormalizer);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 10, TopK);
// Iterate over schema from ai.onnx version 10
class OpSet_Onnx_ver10 {
@@ -552,6 +539,8 @@ class OpSet_Onnx_ver10 {
static void ForEachSchema(std::function fn) {
fn(GetOpSchema());
+ fn(GetOpSchema());
}
};
@@ -568,7 +557,4 @@ inline void RegisterOnnxOperatorSetSchema() {
RegisterOpSetSchema();
}
-inline void RegisterOnnxFunctionBuilder() {
- RegisterFunctionBuilder();
-}
} // namespace ONNX_NAMESPACE
diff --git a/onnx/defs/schema.cc b/onnx/defs/schema.cc
index 6187115d9b3..0cc3c337f69 100644
--- a/onnx/defs/schema.cc
+++ b/onnx/defs/schema.cc
@@ -622,6 +622,18 @@ void OpSchema::ParseAndSetTypes(
}
}
+OpSchema& OpSchema::FunctionBody(const std::vector& func_nodes) {
+ for (const auto node : func_nodes) {
+ auto new_node = function_body_.add_node();
+ new_node->CopyFrom(node);
+ }
+ return *this;
+}
+
+const FunctionProto* OpSchema::GetFunction() const {
+ return function_body_.node_size()>0 ? &function_body_ : nullptr;
+}
+
OpSchema& OpSchema::FillUsing(const std::function& populator) {
if (populator) {
populator(*this);
@@ -629,6 +641,22 @@ OpSchema& OpSchema::FillUsing(const std::function& populator) {
return *this;
}
+void OpSchema::BuildFunction(){
+ function_body_.set_name(this->name_);
+ function_body_.set_doc_string(this->doc_);
+ function_body_.set_since_version(this->since_version_);
+ function_body_.set_status(OperatorStatus(1 - (int)this->support_));
+ for (auto& i : inputs_) {
+ function_body_.add_input(i.GetName());
+ }
+ for (auto& o : outputs_) {
+ function_body_.add_output(o.GetName());
+ }
+ for (auto& a : attributes_) {
+ function_body_.add_attribute(a.first);
+ }
+}
+
void OpSchema::Finalize() {
#define ENFORCE(x) \
do { \
@@ -692,6 +720,10 @@ void OpSchema::Finalize() {
ParseAndSetTypes(&inputs_);
ParseAndSetTypes(&outputs_);
+
+ if (this->HasFunction()) {
+ BuildFunction();
+ }
}
std::ostream& operator<<(std::ostream& out, const OpSchema& schema) {
@@ -775,6 +807,7 @@ OpName_Domain_Version_Schema_Map& OpSchemaRegistry::map() {
#endif
RegisterOnnxOperatorSetSchema();
+
#ifdef ONNX_ML
RegisterOnnxMLOperatorSetSchema();
#endif
@@ -795,8 +828,8 @@ OpName_Domain_Version_Schema_Map& OpSchemaRegistry::map() {
private:
static size_t GetRegisteredSchemaCount() {
size_t count = 0;
- for (auto x : GetMapWithoutEnsuringRegistration()) {
- for (auto y : x.second) {
+ for (auto& x : GetMapWithoutEnsuringRegistration()) {
+ for (auto& y : x.second) {
count += y.second.size();
}
}
diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h
index f6f2be1e27b..c74813cfc4d 100644
--- a/onnx/defs/schema.h
+++ b/onnx/defs/schema.h
@@ -9,6 +9,7 @@
#include
#include
#include
+#include
#include
#include
#include
@@ -20,6 +21,7 @@
#include "data_type_utils.h"
#include "onnx/common/constants.h"
#include "onnx/defs/shape_inference.h"
+#include "onnx/onnx-operators_pb.h"
namespace ONNX_NAMESPACE {
class SchemaError final : public std::runtime_error {
@@ -558,6 +560,14 @@ class OpSchema final {
return tensor_inference_function_ ? true : false;
}
+ bool HasFunction() const {
+ return function_body_.node_size() > 0;
+ }
+
+ OpSchema& FunctionBody(const std::vector& func_nodes);
+
+ const FunctionProto* GetFunction() const;
+
// Verifies that the schema is valid and all specifications are compatible.
// It will also parse all type strings specified for inputs/outputs into valid
// TypeProto and create global unique string pointer as the DataType for
@@ -568,6 +578,9 @@ class OpSchema final {
void ParseAndSetTypes(
/*out*/ std::vector* formalParameters);
+ // Build function with information stored in opschema
+ void BuildFunction();
+
std::string name_;
std::string file_;
std::string doc_;
@@ -591,6 +604,7 @@ class OpSchema final {
std::function num_inputs_allowed_ = [](int) { return true; };
std::function num_outputs_allowed_ = [](int) { return true; };
InferenceFunction tensor_inference_function_;
+ FunctionProto function_body_;
};
// Map type to store operator schemas. The format is,
@@ -700,7 +714,9 @@ class OpSchemaRegistry final : public ISchemaRegistry {
<< "in onnx/defs/schema.h)." << std::endl;
fail_schema(err.str());
}
- m[op_name][op_domain].emplace(std::make_pair(ver, op_schema));
+
+
+ m[op_name][op_domain].insert(std::pair(ver, std::move(op_schema)));
} catch (const std::exception& e) {
std::cerr << "Schema error: " << e.what() << std::endl;
@@ -778,9 +794,9 @@ class OpSchemaRegistry final : public ISchemaRegistry {
public:
static const std::vector get_all_schemas_with_history() {
std::vector r;
- for (auto x : map()) {
- for (auto y : x.second) {
- for (auto z : y.second) {
+ for (auto& x : map()) {
+ for (auto& y : x.second) {
+ for (auto& z : y.second) {
r.emplace_back(z.second);
}
}
@@ -790,8 +806,8 @@ class OpSchemaRegistry final : public ISchemaRegistry {
static const std::vector get_all_schemas() {
std::vector r;
- for (auto x : map()) {
- for (auto y : x.second) {
+ for (auto& x : map()) {
+ for (auto& y : x.second) {
auto& version2schema = y.second;
r.emplace_back(version2schema.rbegin()->second);
}
diff --git a/onnx/onnxifi_ext.h b/onnx/onnxifi_ext.h
index f3d260396fc..12c9486811d 100644
--- a/onnx/onnxifi_ext.h
+++ b/onnx/onnxifi_ext.h
@@ -7,14 +7,6 @@
extern "C" {
#endif
-/*
- * This is the super set of all extension functions we support in onnxifi.
- * All backend should support a subset of function of this list.
- */
-static const int ALL_EXT_FUNCTION_NUMBER = 2;
-static const char* ALL_EXT_FUNCTION_LIST[] = {"onnxGetExtensionFunctionAddress",
- "onnxSetIOAndRunGraph"};
-
/**
* Generic ONNXIFI extension function pointer.
*
diff --git a/onnx/shape_inference/implementation.cc b/onnx/shape_inference/implementation.cc
index 1c44ab557fb..0c2d3a5514f 100644
--- a/onnx/shape_inference/implementation.cc
+++ b/onnx/shape_inference/implementation.cc
@@ -101,14 +101,13 @@ static void InferShapesImpl(
const std::unordered_map&
outer_scope_value_types_by_name,
const std::unordered_map& opset_imports,
- const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
- const IFunctionBuilderRegistry* func_registry =
- &FunctionBuilderRegistry::OnnxInstance()) {
+ const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance()
+ ) {
std::unordered_map valueTypesByName{
outer_scope_value_types_by_name};
GraphInferenceContext graphInferenceContext{
- valueTypesByName, opset_imports, schema_registry, func_registry};
+ valueTypesByName, opset_imports, schema_registry};
for (auto& vi : *g->mutable_value_info()) {
if (vi.has_type())
@@ -154,33 +153,27 @@ static void InferShapesImpl(
InferenceContextImpl ctx(
n, valueTypesByName, inputDataByName, &graphInferenceContext);
if (!schema) {
- if (nullptr == func_registry) {
- continue;
- }
- // The node is not referring a primitive operator.
- // Check whether it's referring to a function.
- // If it's referring to a function.
- auto func =
- func_registry->GetFunction(n.op_type(), domain_version, n.domain());
- if (nullptr == func) {
- continue;
- }
+ continue;
+ } else if (schema->has_type_and_shape_inference_function()){
try {
- InferShapeForFunctionNode(*func, schema_registry, ctx);
+ schema->GetTypeAndShapeInferenceFunction()(ctx);
} catch (const ONNX_NAMESPACE::InferenceError& ex) {
(void)ex;
// Continue with inference for remaining nodes
continue;
}
- } else {
+ } else if (schema->HasFunction()) {
try {
- schema->GetTypeAndShapeInferenceFunction()(ctx);
- } catch (const ONNX_NAMESPACE::InferenceError& ex) {
- (void)ex;
- // Continue with inference for remaining nodes
+ InferShapeForFunctionNode(
+ schema->GetFunction(), schema_registry, ctx);
+ } catch (const ONNX_NAMESPACE::InferenceError& function_ex) {
+ (void)function_ex;
continue;
}
- }
+ } else {
+ // Continue with inference for remaining nodes
+ continue;
+ }
try {
for (int i = 0; i < n.output_size(); ++i) {
@@ -227,20 +220,19 @@ static void InferShapesImpl(
void InferShapes(
GraphProto* g,
const std::unordered_map& opset_imports,
- const ISchemaRegistry* schema_registry,
- const IFunctionBuilderRegistry* func_registry) {
+ const ISchemaRegistry* schema_registry
+ ) {
InferShapesImpl(
g,
std::unordered_map(0),
opset_imports,
- schema_registry,
- func_registry);
+ schema_registry);
}
void InferShapes(
ModelProto& m,
- const ISchemaRegistry* schema_registry,
- const IFunctionBuilderRegistry* func_registry) {
+ const ISchemaRegistry* schema_registry
+ ) {
std::unordered_map opset_imports;
for (const auto& opset_import : m.opset_import()) {
opset_imports[opset_import.domain()] =
@@ -251,38 +243,37 @@ void InferShapes(
g,
std::unordered_map(0),
opset_imports,
- schema_registry,
- func_registry);
+ schema_registry);
}
void InferShapeForFunctionNode(
- const FunctionProto& func,
+ const FunctionProto* func,
const ISchemaRegistry* schema_registry,
InferenceContext& ctx) {
- int domain_version = (int)func.since_version();
+ int domain_version = (int)func->since_version();
GraphProto g;
// Get a temporary tensor-shape map
std::unordered_map temp_valueTypesByName;
- std::vector temp_types_cache(func.input_size());
- for (int i = 0; i < func.input_size(); ++i) {
+ std::vector temp_types_cache(func->input_size());
+ for (int i = 0; i < func->input_size(); ++i) {
temp_types_cache[i] = *ctx.getInputType(i);
- temp_valueTypesByName[func.input().Get(i)] = &temp_types_cache.back();
+ temp_valueTypesByName[func->input().Get(i)] = &temp_types_cache.back();
}
// Get a temporary initial value map
std::unordered_map temp_initializersByName;
for (int i = 0; i < static_cast(ctx.getNumInputs()); ++i) {
- if (ctx.getInputData(i) != nullptr && i < func.input_size()) {
- temp_initializersByName[func.input().Get(i)] = ctx.getInputData(i);
+ if (ctx.getInputData(i) != nullptr && i < func->input_size()) {
+ temp_initializersByName[func->input().Get(i)] = ctx.getInputData(i);
}
}
std::unordered_map attr_map;
- for (auto& attr : func.attribute()) {
+ for (auto& attr : func->attribute()) {
if (ctx.getAttribute(attr) != nullptr) {
attr_map[attr] = ctx.getAttribute(attr);
}
}
- for (auto& n : func.node()) {
+ for (auto& n : func->node()) {
const auto schema =
schema_registry->GetSchema(n.op_type(), domain_version, n.domain());
if (!schema) {
@@ -334,8 +325,8 @@ void InferShapeForFunctionNode(
temp_valueTypesByName[copy_n.output(i)] = existingType;
}
}
- for (int i = 0; i < func.output_size(); ++i) {
- std::string output_name = func.output().Get(i);
+ for (int i = 0; i < func->output_size(); ++i) {
+ std::string output_name = func->output().Get(i);
// Skip if no type inferred for the tensor
if (!temp_valueTypesByName.count(output_name)) {
continue;
@@ -398,8 +389,7 @@ std::vector GraphInferencerImpl::doInferencing(
g_,
*context_->outer_scope_value_types_by_name, // never null
context_->opset_imports,
- context_->schema_registry,
- context_->func_registry);
+ context_->schema_registry);
std::vector graphOutputTypes;
for (const ValueInfoProto& output : g_->output()) {
diff --git a/onnx/shape_inference/implementation.h b/onnx/shape_inference/implementation.h
index b550714fc67..590120cdf40 100644
--- a/onnx/shape_inference/implementation.h
+++ b/onnx/shape_inference/implementation.h
@@ -13,19 +13,17 @@ struct GraphInferenceContext {
const std::unordered_map&
outer_scope_value_types_by_name_in,
const std::unordered_map opset_imports_in,
- const ISchemaRegistry* schema_registry_in = OpSchemaRegistry::Instance(),
- const IFunctionBuilderRegistry* func_registry_in =
- &FunctionBuilderRegistry::OnnxInstance())
+ const ISchemaRegistry* schema_registry_in = OpSchemaRegistry::Instance())
: outer_scope_value_types_by_name{&outer_scope_value_types_by_name_in},
opset_imports{opset_imports_in},
- schema_registry{schema_registry_in},
- func_registry{func_registry_in} {}
+ schema_registry{schema_registry_in} {}
+
const std::unordered_map*
outer_scope_value_types_by_name;
const std::unordered_map opset_imports;
const ISchemaRegistry* schema_registry;
- const IFunctionBuilderRegistry* func_registry;
+
};
class GraphInferencerImpl : public GraphInferencer {
@@ -170,19 +168,17 @@ void mergeShapesAndTypes(
void InferShapes(
ModelProto& m,
- const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
- const IFunctionBuilderRegistry* func_registry =
- &FunctionBuilderRegistry::OnnxInstance());
+ const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance()
+ );
void InferShapes(
GraphProto* g,
const std::unordered_map& opset_imports,
- const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
- const IFunctionBuilderRegistry* func_registry =
- &FunctionBuilderRegistry::OnnxInstance());
+ const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance()
+ );
void InferShapeForFunctionNode(
- const FunctionProto& func,
+ const FunctionProto* func,
const ISchemaRegistry* schema_registry,
InferenceContext& ctx);
diff --git a/onnx/test/cpp/function_get_test.cc b/onnx/test/cpp/function_get_test.cc
index 4023685d0c4..d013302262d 100644
--- a/onnx/test/cpp/function_get_test.cc
+++ b/onnx/test/cpp/function_get_test.cc
@@ -1,28 +1,16 @@
#include
#include "gtest/gtest.h"
#include "onnx/common/constants.h"
-#include "onnx/defs/function.h"
+#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
namespace Test {
-TEST(FunctionAPITest, Get_All_Functions) {
- std::multimap temp_map;
- FunctionBuilderRegistry& function_registry =
- FunctionBuilderRegistry::OnnxInstance();
- Common::Status status =
- function_registry.GetFunctions(ONNX_DOMAIN, &temp_map);
- size_t input_size = temp_map.size();
- EXPECT_EQ(input_size, 1);
- EXPECT_EQ(temp_map.count("MeanVarianceNormalization"), 1);
- auto temp_iter = temp_map.find("MeanVarianceNormalization");
- EXPECT_EQ(temp_iter->second->attribute_size(), 1);
-}
-TEST(FunctionAPITest, Get_Function_With_Version) {
- FunctionBuilderRegistry& function_registry =
- FunctionBuilderRegistry::OnnxInstance();
- auto func = function_registry.GetFunction(
- "MeanVarianceNormalization", 9, ONNX_DOMAIN);
+TEST(FunctionAPITest, Get_Function_op_With_Version) {
+ const auto* schema = OpSchemaRegistry::Schema("MeanVarianceNormalization", 9, "");
+ EXPECT_TRUE(schema);
+ EXPECT_TRUE(schema->HasFunction());
+ auto func = schema->GetFunction();
EXPECT_EQ(func->name(), "MeanVarianceNormalization");
}
diff --git a/onnx/test/cpp/function_verify_test.cc b/onnx/test/cpp/function_verify_test.cc
new file mode 100644
index 00000000000..a47cd87707e
--- /dev/null
+++ b/onnx/test/cpp/function_verify_test.cc
@@ -0,0 +1,159 @@
+#include
+#include
+#include "gtest/gtest.h"
+#include "onnx/checker.h"
+#include "onnx/common/constants.h"
+#include "onnx/defs/schema.h"
+#include "onnx/onnx-operators_pb.h"
+#include "onnx/onnx_pb.h"
+
+namespace ONNX_NAMESPACE {
+namespace Test {
+using namespace checker;
+using TENSOR_TYPES_MAP =
+ std::unordered_map>;
+
+void VerifyTypeConstraint(
+ const OpSchema& function_op,
+ const FunctionProto* function_proto,
+ int& counter
+) {
+ // TC for function nodes should satisfy the definition defined in the opschema
+ // This is designed to be a best-effort test
+ // TODO: Revisit to have a more consummate check on it
+ TENSOR_TYPES_MAP tc_map;
+ std::set primitive_types(
+ OpSchema::all_tensor_types().begin(), OpSchema::all_tensor_types().end());
+ for (const auto& input : function_op.inputs()) {
+ std::string name = input.GetName();
+ for (const auto& t : input.GetTypes()) {
+ if (!primitive_types.count(*t)) {
+ return; // skip variable types check for now
+ }
+ tc_map[name].emplace_back(*t);
+ }
+ }
+
+ for (const auto& output : function_op.outputs()) {
+ std::string name = output.GetName();
+ for (const auto& t : output.GetTypes()) {
+ if (!primitive_types.count(*t)) {
+ return; // skip variable types check for now
+ }
+ tc_map[name].emplace_back(*t);
+ }
+ }
+
+ for (auto& node : function_proto->node()) {
+ std::string op_type = node.op_type();
+ const OpSchema* schema = OpSchemaRegistry::Schema(
+ op_type, function_op.since_version(), function_op.domain());
+
+ std::unordered_map input_tensor_name_idx_map;
+ std::unordered_map output_tensor_name_idx_map;
+ // Enforce it on input
+ for (unsigned int i = 0; i < schema->inputs().size(); ++i) {
+ auto& input = schema->inputs().at(i);
+ input_tensor_name_idx_map[input.GetName()] = i;
+ }
+ for (auto& tensor_name_tc : tc_map) {
+ auto iter = input_tensor_name_idx_map.find(tensor_name_tc.first);
+ if (iter == input_tensor_name_idx_map.end())
+ continue;
+ const auto& types = schema->inputs().at(iter->second).GetTypes();
+ std::unordered_set allowed_types;
+ for (auto& s : types) {
+ allowed_types.insert(*s);
+ }
+ for (auto& type : tensor_name_tc.second) {
+ if (allowed_types.find(type) == allowed_types.end()) {
+ fail_check(
+ "Input type " + type + " defined in " + schema->Name() +
+ "'s function body is not allowed in node " + op_type);
+ }
+ }
+ }
+
+ // Enforce it on output
+ for (unsigned int i = 0; i < schema->outputs().size(); ++i) {
+ auto& output = schema->outputs().at(i);
+ output_tensor_name_idx_map[output.GetName()] = i;
+ }
+
+ for (auto& tensor_name_tc : tc_map) {
+ auto iter = output_tensor_name_idx_map.find(tensor_name_tc.first);
+ if (iter == output_tensor_name_idx_map.end())
+ continue;
+ const auto& types = schema->outputs().at(iter->second).GetTypes();
+ std::unordered_set allowed_types;
+ for (auto& s : types) {
+ allowed_types.insert(*s);
+ }
+ for (auto& type : tensor_name_tc.second) {
+ if (allowed_types.find(type) == allowed_types.end()) {
+ fail_check(
+ "Output type " + type + " defined in " + schema->Name() +
+ "'s function body is not allowed in node " + op_type);
+ }
+ }
+ }
+ }
+
+ ++counter;
+}
+
+void VerifyFunction(const OpSchema& op, const FunctionProto* function_proto, int& counter) {
+ // Verify function proto is valid
+ if (!function_proto) {
+ fail_check("Cannot get function body for op '", op.Name(), "'");
+ }
+ CheckerContext ctx;
+ std::unordered_map op_set;
+ if ((int)function_proto->since_version() != op.since_version()) {
+ fail_check("Unmatched since_version defined in function op '", op.Name(), "'");
+ }
+ auto version_range =
+ OpSchemaRegistry::DomainToVersionRange::Instance().Map().at(
+ op.domain());
+ if (function_proto->since_version() > version_range.second ||
+ function_proto->since_version() < version_range.first) {
+ fail_check("Invalid function version in function op '", op.Name(), "'");
+ }
+
+ op_set.insert(
+ {op.domain(), op.since_version()});
+ ctx.set_opset_imports(op_set);
+ ctx.set_is_main_graph(false);
+ LexicalScopeContext lex_ctx;
+ try {
+ check_function(*function_proto, ctx, lex_ctx);
+ } catch (ValidationError& ex) {
+ fail_check(ex.what());
+ }
+
+ // Verify function op has compatible Type constraints defined in
+ // op and function body.
+ VerifyTypeConstraint(op, function_proto, counter);
+}
+
+// Verify registered ops with function body has compatible
+// definition on TypeConstraints between ops and function body
+TEST(FunctionVerification, VerifyFunctionOps) {
+ const std::vector schemas = OpSchemaRegistry::get_all_schemas();
+ int function_counter = 0, verified_counter = 0;
+ for (const auto s : schemas) {
+ if (!s.HasFunction()) continue;
+ try{
+ ++function_counter;
+ auto function_body = s.GetFunction();
+ VerifyFunction(s, function_body, verified_counter);
+ }catch (ONNX_NAMESPACE::checker::ValidationError e){
+ FAIL() << e.what();
+ }
+ }
+ std::cerr << "[ ] Verified " << verified_counter << "/"
+ << function_counter << " Functions." << std::endl;
+}
+
+} // namespace Test
+} // namespace ONNX_NAMESPACE
diff --git a/onnx/test/cpp/op_reg_test.cc b/onnx/test/cpp/op_reg_test.cc
index 9ff2c6605ea..92eb706363f 100644
--- a/onnx/test/cpp/op_reg_test.cc
+++ b/onnx/test/cpp/op_reg_test.cc
@@ -6,15 +6,15 @@ namespace ONNX_NAMESPACE
{
namespace Test
{
- TEST(OpRegistrationTest, AffineOp)
+ TEST(OpRegistrationTest, GemmOp)
{
- auto opSchema = OpSchemaRegistry::Schema("Affine");
+ auto opSchema = OpSchemaRegistry::Schema("Gemm");
EXPECT_TRUE(nullptr != opSchema);
size_t input_size = opSchema->inputs().size();
- EXPECT_EQ(input_size, 1);
+ EXPECT_EQ(input_size, 3);
EXPECT_EQ(opSchema->inputs()[0].GetTypes(), opSchema->outputs()[0].GetTypes());
size_t attr_size = opSchema->attributes().size();
- EXPECT_EQ(attr_size, 2);
+ EXPECT_EQ(attr_size, 4);
EXPECT_NE(opSchema->attributes().count("alpha"), 0);
EXPECT_EQ(opSchema->attributes().at("alpha").type, AttributeProto_AttributeType_FLOAT);
EXPECT_NE(opSchema->attributes().count("beta"), 0);
diff --git a/onnx/test/shape_inference_test.py b/onnx/test/shape_inference_test.py
index 56f8eb9950c..820a84604a9 100644
--- a/onnx/test/shape_inference_test.py
+++ b/onnx/test/shape_inference_test.py
@@ -642,8 +642,9 @@ def test_lstm_forward(self): # type: () -> None
def test_topk_default_axis(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.FLOAT, (3, 4, 5, 10))],
- [make_node('TopK', ['x'], ['y', 'z'], k=2)],
- [])
+ [make_node('TopK', ['x', 'k'], ['y', 'z'])],
+ [],
+ initializer=[make_tensor('k', TensorProto.INT64, (1,), (2,))])
self._assert_inferred(graph,
[make_tensor_value_info('y', TensorProto.FLOAT, (3, 4, 5, 2)),
make_tensor_value_info('z', TensorProto.INT64, (3, 4, 5, 2))])
@@ -651,8 +652,9 @@ def test_topk_default_axis(self): # type: () -> None
def test_topk(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.FLOAT, (3, 4, 5, 10))],
- [make_node('TopK', ['x'], ['y', 'z'], k=2, axis=2)],
- [])
+ [make_node('TopK', ['x', 'k'], ['y', 'z'], axis=2)],
+ [],
+ initializer=[make_tensor('k', TensorProto.INT64, (1,), (2,))])
self._assert_inferred(graph,
[make_tensor_value_info('y', TensorProto.FLOAT, (3, 4, 2, 10)),
make_tensor_value_info('z', TensorProto.INT64, (3, 4, 2, 10))])