Skip to content

Commit

Permalink
[DAPHNE-#364] Added type inference for idxMin/idxMax.
Browse files Browse the repository at this point in the history
- Their result is always a matrix of value type Size.
- Added a respective type inference trait.
  • Loading branch information
pdamme committed Oct 14, 2022
1 parent b56a77f commit 8f028ba
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/compiler/inference/TypeInferenceUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ mlir::Type inferTypeByTraits(O * op) {
break;
}
}
else if(op->template hasTrait<ValueTypeSize>())
resVts = {IndexType::get(ctx)};

// --------------------------------------------------------------------
// Create the result type
Expand Down
6 changes: 6 additions & 0 deletions src/ir/daphneir/DaphneInferTypesOpInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ class ValueTypeFromArgsInt : public TraitBase<ConcreteOp, ValueTypeFromArgsInt>
*/
template<class ConcreteOp>
class ValueTypesConcat : public TraitBase<ConcreteOp, ValueTypesConcat> {};

/**
* @brief The value type (of the single result) is `Size`.
*/
template<class ConcreteOp>
class ValueTypeSize : public TraitBase<ConcreteOp, ValueTypeSize> {};

// ============================================================================
// Traits determining data type and value type together
Expand Down
8 changes: 4 additions & 4 deletions src/ir/daphneir/DaphneOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ class Daphne_ColAggOp<string name, Type inScalarType, Type outScalarType = inSca
def Daphne_RowAggSumOp : Daphne_RowAggOp<"sumRow" , NumScalar, NumScalar, [ValueTypeFromFirstArg, DeclareOpInterfaceMethods<VectorizableOpInterface>]>;
def Daphne_RowAggMinOp : Daphne_RowAggOp<"minRow" , AnyScalar, AnyScalar, [ValueTypeFromFirstArg, DeclareOpInterfaceMethods<VectorizableOpInterface>]>;
def Daphne_RowAggMaxOp : Daphne_RowAggOp<"maxRow" , AnyScalar, AnyScalar, [ValueTypeFromFirstArg, DeclareOpInterfaceMethods<VectorizableOpInterface>, DeclareOpInterfaceMethods<DistributableOpInterface>]>;
def Daphne_RowAggIdxMinOp : Daphne_RowAggOp<"idxminRow", NumScalar, Size>;
def Daphne_RowAggIdxMaxOp : Daphne_RowAggOp<"idxmaxRow", NumScalar, Size>;
def Daphne_RowAggIdxMinOp : Daphne_RowAggOp<"idxminRow", NumScalar, Size, [ValueTypeSize]>;
def Daphne_RowAggIdxMaxOp : Daphne_RowAggOp<"idxmaxRow", NumScalar, Size, [ValueTypeSize]>;
def Daphne_RowAggMeanOp : Daphne_RowAggOp<"meanRow" , NumScalar, NumScalar, [ValueTypeFromArgsFP]>;
def Daphne_RowAggVarOp : Daphne_RowAggOp<"varRow" , NumScalar, NumScalar, [ValueTypeFromArgsFP]>;
def Daphne_RowAggStddevOp : Daphne_RowAggOp<"stddevRow", NumScalar, NumScalar, [ValueTypeFromArgsFP]>;
Expand All @@ -352,8 +352,8 @@ def Daphne_ColAggSumOp : Daphne_ColAggOp<"sumCol" , NumScalar, NumScalar, [
DeclareOpInterfaceMethods<VectorizableOpInterface>]>;
def Daphne_ColAggMinOp : Daphne_ColAggOp<"minCol" , AnyScalar, AnyScalar, [ValueTypeFromFirstArg]>;
def Daphne_ColAggMaxOp : Daphne_ColAggOp<"maxCol" , AnyScalar, AnyScalar, [ValueTypeFromFirstArg]>;
def Daphne_ColAggIdxMinOp : Daphne_ColAggOp<"idxminCol", NumScalar, Size>;
def Daphne_ColAggIdxMaxOp : Daphne_ColAggOp<"idxmaxCol", NumScalar, Size>;
def Daphne_ColAggIdxMinOp : Daphne_ColAggOp<"idxminCol", NumScalar, Size, [ValueTypeSize]>;
def Daphne_ColAggIdxMaxOp : Daphne_ColAggOp<"idxmaxCol", NumScalar, Size, [ValueTypeSize]>;
def Daphne_ColAggMeanOp : Daphne_ColAggOp<"meanCol" , NumScalar, NumScalar, [ValueTypeFromArgsFP]>;
def Daphne_ColAggVarOp : Daphne_ColAggOp<"varCol" , NumScalar, NumScalar, [ValueTypeFromArgsFP]>;
def Daphne_ColAggStddevOp : Daphne_ColAggOp<"stddevCol", NumScalar, NumScalar, [ValueTypeFromArgsFP]>;
Expand Down
1 change: 1 addition & 0 deletions src/ir/daphneir/DaphneTypeInferenceTraits.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def ValueTypeFromArgs: NativeOpTrait<"ValueTypeFromArgs">;
def ValueTypeFromArgsFP: NativeOpTrait<"ValueTypeFromArgsFP">;
def ValueTypeFromArgsInt: NativeOpTrait<"ValueTypeFromArgsInt">;
def ValueTypesConcat: NativeOpTrait<"ValueTypesConcat">;
def ValueTypeSize: NativeOpTrait<"ValueTypeSize">;

// ****************************************************************************
// Traits determining data type and value type together
Expand Down

0 comments on commit 8f028ba

Please sign in to comment.