TensorIR (TIR) is TVM's intermediate representation (IR) for describing operations over tensors. In particular, TIR allows for expressing low-level operations on tensors in a manner that is amenable to automatically implementing common optimization strategies like loop tiling or unrolling, especially in the context of autotuning, wherein these transformations are applied using a learned cost function.
This document is intended to serve primarily as a high-level reference for TIR's semantics (observable behavior), aiming to describe a high-level, portable subset of TIR as it is meant to be initially passed to the compiler. The subset of TIR described in this document should be accepted by the compiler and lowered to any hardware back-end without issue; a reader of this document should be able to correctly describe how each language construct will ultimately be executed and give the final result of running the program. Thus, this specification can serve as a reference for users of TIR's front-end (who can ensure their programs will behave as intended) as well as compiler implementers (who can ensure different compiler optimizations maintain the guarantees on visible behavior provided by the specification).
There are two main reasons this document describes only a subset of TIR. The first is that, unlike many programming languages, the grammar includes lots of auxiliary information intended for the compiler's internal use. Skillful users of TIR can take advantage of the compiler implementation by specifying this additional information appropriately, but the compiler implementation is greatly subject to change and it is unclear that such behavior should be part of the contract between users and the compiler. Indeed, one purpose of this specification is to clarify the distinction between the "front-end" interface to TIR and the compiler internals, since this has previously been ambiguous due to the degree to which the TIR implementation exposes compiler internals. The second reason is that the hardware back-ends supported by TIR have greatly varying properties. While it is possible that some of these details might be specified in the future or described in separate documents, this version of the specification will not account for TIR programs that make use of certain internal details of the compiler implementation or low-level details of specific hardware back-ends. The aim, rather, is to establish simple ground rules for the language that account for its most common uses.
Note: This specification corresponds to the intended functionality of the language, which may differ from how it has been implemented. Portions of the specification that differ from the implementation will be given in «double caret marks (guillemets)» (color-coding would be preferable, but Github Markdown does not support it). These discrepancies should be corrected or addressed.
TIR is an imperative language that describes tensor operations mainly in terms of bounded loops over indices with defined iteration domains, wherein the loop bodies apply scalar operations to tensor elements. The main abstractions in TIR are values (scalars or vectors) and buffers (regions of memory, which tend to represent tensors); elements from buffers can be read and written in loop bodies in order to implement operations on tensors. In most cases, the bounded ranges for loop iterations provide the compiler with more information for optimizations and autotuning procedures.
In addition to its primary functionality in describing loops over tensor elements, TIR is also capable of interfacing with TVM's object system and can invoke arbitrary packed functions via intrinsics, though this version of the specification will not go into detail on intrinsics.
As noted in the preamble, this version of the specification covers a portable, high-level subset of TIR, excluding in particular behaviors that might make use of compiler implementation internal details or low-level properties of hardware back-ends. While such uses of TIR exist and will continue to exist, we defer specifying them until future versions of this specification in order to establish basic expectations for the language's behavior and additionally in order to avoid "committing" the compiler implementers to supporting certain behaviors unto perpetuity; these lower-level details generally reflect conditions of the deep learning stack that are highly subject to change. Instead, in this version of the specification, we account for high-level uses of TIR intended to correspond to very common applications:
- The output of the TE (Tensor Expressions) library,
- Implementations of new tensor operators intended to be used with TVM's auto-tuning libraries, and
- TIR code intended to be invoked from Relax via TVM Unity.
This document will note which features are outside the subset of TIR intended to be specified at present; any program that makes use of these features or does not abide by the restrictions described is thus considered to be unspecified: the specification makes no guarantees on that program's behavior.
Notation: [x]
means a sequence (zero or more) of x
, {x: y}
means "a map from x
to y
," and x?
means "optionally x
."
PrimFunc ::= PrimFunc(params: [Var], body: Stmt, ret_type: Type?,
buffer_map: {Var: Buffer}, attrs: Attrs)
Type ::=
PrimType(dtype: DataType)
| PointerType(element_type: Type, storage_scope: str)
| TupleType(fields: [Type])
DataType ::= DataType(code: DTTypeCode, bits: DTSize, lanes: DTLanes)
DTTypeCode ::= Int() | UInt() | Float() | BFloat() | Handle()
DTSize ::= 0 | 1 | 8 | 16 | 32 | 64
DTLanes ::= 1 | 4 | 8 | 16 | 32 | 64
Stmt ::=
LetStmt(var: Var, value: PrimExpr, body: Stmt)
| AttrStmt(node: ObjectRef**, attr_key: str, value: PrimExpr, body: Stmt)
| AssertStmt(condition: PrimExpr, message: PrimExpr, body: Stmt)
| BufferStore(buffer: Buffer, value: PrimExpr, indices: [PrimExpr])
| BufferRealize(buffer: Buffer, bounds: [Range], condition: PrimExpr, body: Stmt)
| Allocate(buffer_var: Var, dtype: DataType, extents: [PrimExpr],
condition: PrimExpr, body: Stmt, annotations: {str: Object*})
| DeclBuffer(buffer: Buffer, body: Stmt)
| SeqStmt(seq: [Stmt])
| IfThenElse(condition: PrimExpr, then_case: Stmt, else_case: Stmt?)
| Evaluate(value: PrimExpr)
| For(loop_var: Var, min: PrimExpr, extent: PrimExpr,
kind: ForKind, body: Stmt, thread_binding: IterVar?, annotations: {str: Object*})
| While(condition: PrimExpr, body: Stmt)
| Block(iter_vars: [IterVar], reads: [BufferRegion],
writes: [BufferRegion], name_hint: str, body: Stmt,
init: Stmt?, alloc_buffers: [Buffer],
match_buffers: [MatchBufferRegion],
annotations: {str: Object*})
| BlockRealize(values: [PrimExpr], predicate: PrimExpr, block: Block)
Buffer ::= Buffer(data: Var, dtype: DataType, shape: [PrimExpr],
axis_separators: [IntImm], strides: [PrimExpr],
elem_offset: PrimExpr?, name: str, data_alignment: int,
offset_factor: int, buffer_Type: BufferType)
BufferType ::=
kDefault()
| kAutoBroadcast()
BufferRegion ::= BufferRegion(buffer: Buffer, region: [Range])
MatchBufferRegion ::= MatchBufferRegion(buffer: Buffer, source: BufferRegion)
PrimExpr ::=
Var(name_hint: str, dtype: DataType, type_annotation: Type)
| IntImm(value: int, dtype: DataType)
| FloatImm(value: float, dtype: DataType)
| StringImm(value: str)
| Cast(value: PrimExpr, dtype: DataType)
| Select(condition: PrimExpr, true_value: PrimExpr, false_value: PrimExpr)
| BufferLoad(buffer: Buffer, indices: [PrimExpr])
| Ramp(base: PrimExpr, stride: PrimExpr, lanes: int)
| Broadcast(value: PrimExpr, lanes: int)
| Let(var: Var, value: PrimExpr, body: PrimExpr)
| Call(dtype: DataType, op: Op|GlobalVar, args: [PrimExpr])
| Shuffle(vectors: [PrimExpr], indices: [PrimExpr])
| BinaryOp
| CmpOp
| LogicalOp
LogicalOp ::=
And(a: PrimExpr, b: PrimExpr)
| Or(a: PrimExpr, b: PrimExpr)
| Not(a: PrimExpr)
BinaryOp ::=
Add(a: PrimExpr, b: PrimExpr)
| Sub(a: PrimExpr, b: PrimExpr)
| Mul(a: PrimExpr, b: PrimExpr)
| Div(a: PrimExpr, b: PrimExpr)
| Mod(a: PrimExpr, b: PrimExpr)
| FloorDiv(a: PrimExpr, b: PrimExpr)
| FloorMod(a: PrimExpr, b: PrimExpr)
| Min(a: PrimExpr, b: PrimExpr)
| Max(a: PrimExpr, b: PrimExpr)
CmpOp ::=
Eq(a: PrimExpr, b: PrimExpr)
| NE(a: PrimExpr, b: PrimExpr)
| LT(a: PrimExpr, b: PrimExpr)
| LE(a: PrimExpr, b: PrimExpr)
| GE(a: PrimExpr, b: PrimExpr)
| GT(a: PrimExpr, b: PrimExpr)
ForKind ::=
kSerial()
| kParallel()
| kVectorized()
| kUnrolled()
| kThreadBinding()
Range ::= Range(min: PrimExpr, extent: PrimExpr?)
IterVar ::= IterVar(dom: Range?, var: Var, iter_type: IterVarType, thread_tag: str)
IterVarType ::=
kDataPar()
| kThreadIndex()
| kCommReduce()
| kOrdered()
| kOpaque()
| kUnrolled()
| kVectorized()
| kParallelized()
| kTensorized()
Attrs ::= Attrs(contents: {str: Object*})
GlobalVar ::= GlobalVar(name_hint: str)
IRModule ::= IRModule(functions: {GlobalVar: PrimFunc|BaseFunc***})
*Note that attributes and annotations can contain arbitrary TVM objects as values. These objects are used only at compile time.
**Refers to one of the base classes in the TVM object representation. In practice, ObjectRef
s are usually TIR AST nodes. Which ones are appropriate depend on the specific attributes (only those listed under the semantics have any visible effects; the rest are used only at compile time).
***Refers to BaseFunc
s other than PrimFunc
s. They are not referenced in TIR PrimFunc
s.
Additionally, at run time, PrimFunc
s take in parameters corresponding to buffers via the DLPack
library's DLTensor
class, defined in dlpack.h:
typedef struct {
void* data; // pointer to the buffer contents
DLDevice device; // not discussed in this specification
int32_t ndim;
DLDataType dtype; // has the same fields as DataType in the above AST
int64_t* shape; // array giving the shape of the corresponding buffer
int64_t* strides; // can be null
uint64_t byte_offset;
} DLTensor;
The correspondence of the fields of DLTensor
to Buffer
will be discussed with the semantics for invoking a PrimFunc
.
Programs in TIR are provided using IRModule
s, which are collections of global functions (both in TIR and other languages like Relay and Relax) each denoted by a GlobalVar
. TIR is concerned only with PrimFunc
s: Execution begins with a particular PrimFunc
, which may have calls to other PrimFunc
s or itself (recursively).
Expressions in TIR (PrimExprs) operate on three kinds of values:
- Scalars, which are single members of TIR's numerical datatypes: Floating point (
Float
), Brain floating point (BFloat
), signed integer (Int
), unsigned integer (UInt
).Int
andUInt
values can have bitwidths of 8, 16, 32, or 64.UInt
values can also have a bitwidth of 1 (corresponding to a Boolean value). Scalar values are immutable.Float
values can have bitwidths of 16, 32, or 64.BFloat
values must have a bitwidth of 16. - Vectors, which correspond to an immutable grouping of multiple members of the above-mentioned datatypes:
Float
,BFloat
,Int
, orUInt
, with the same bitwidths permitted forFloat
,BFloat
,Int
, andUInt
scalars. Vectors may contain 4, 8, 16, 32, or 64 elements of the listed data type. Their representation at run time is back-end–specific, so we make no stipulations about how the data in a vector is represented. - Pointer values (which have the
Handle
datatype, indicating that they are "handles" to data), which are indices to memory locations that contain scalars or other data of interest. Pointers may address a value with a known datatype (Float
,BFloat
,Int
, orUInt
), or they may address values whose datatype is unknown at compile time or opaque data intended only for calls to builtins (external procedures). In principle, a pointer could be used to address a vector, but this is presently not supported. To avoid confusion withPointerType
below, we will generally refer to pointer values as "handles" in this document.
Note that a pointer is simply an index into memory; the management of the memory is part of the program state. In TIR, regions of memory that are valid to address via pointers are commonly indicated in the AST using the Buffer
construct, which defines the size and arrangement of data in some region and defines other information, such as the buffer's shape and stride size. However, "buffers" themselves in TIR are not values in the language: they are not returned by expressions and manifest at run time only as handles.
For convenience in the specification, we will use some shorthand for common concepts:
node->field
: This refers to the field namedfield
on an AST node namednode
.list[index]
: This refers to theindex
th element of a list calledlist
, using zero-based indexing.vector.index
: This refers to theindex
th element of a vector value calledvector
, using zero-based indexing. This notation is meant to distinguish vector values in TIR from lists.len(list)
: This gives the length oflist
.||vector||
: This gives the length ofvector
. (Again, this notation is meant to distinguish vector values from lists.)dtype(expr)
: This will be used to denote the datatype derived from a given expressionexpr
. The section below will describe how datatypes are derived. This "function" should be distinct from AST fields nameddtype
(e.g., forBuffer
nodes).
All TIR PrimExpr
s have an associated DataType
that describes the datatype of the result of evaluating the PrimExpr
. These are defined in include/tvm/runtime/data_type.h
, as PrimExprNode::dtype
.
DataTypes have three fields:
code
, which describes the type of elements in the datatype. The following are the type codes used in TIR (note: this document will leave off the initialk
when referring to these codes for readability):kInt
andkUInt
for signed and unsigned integer values, respectivelykFloat
for floating point values andkBFloat
for thebfloat16
format (Brain floating point).kHandle
for pointer values (handles).
bits
, which describes the bitwidth of the elements of the datatype. Common bitwidths in TIR are 1, 8, 16, 32, and 64 (for integers)lanes
, which describes the number of elements in a vector value. lanes is 1 for a scalar value and greater than 1 for a vector (4, 8, 16, 32, or 64 lanes are common in TIR).
- Scalar. A DataType with the
Int
,UInt
,Float
, orBFloat
codes is a scalar type iflanes
is 1. - Vector. A DataType with the
Int
,UInt
,Float
, orBFloat
codes is a vector type iflanes
is greater than 1. - Handle (Pointer). A DataType is a handle type if it has the
Handle
code. The value ofbits
(if nonzero) corresponds to the size of the pointer (this is 64 on most devices supported by TIR, but pointers on some lower-powered devices are 32 bits wide). (As aforementioned, we use the term "handle" to distinguish fromPointerType
in the below section.) Thelanes
field is undefined for handle values, though the implementation always sets it to 1. Note that ifbits
is 0, then it instead refers to theVoid
type. - Boolean (
Bool
). Refers to aUInt
datatype with a bitwidth of 1, since these are used to represent the results of logical operators in TIR.Bool
datatypes can be either scalars (1 lane) or vectors (4, 8, 16, 32, or 64 lanes). Void
. Refers to aHandle
datatype with a bitwidth of 0, indicating an opaque object inaccessible to TIR (but which may be used by calls to builtins).
These notations are used only sparingly in this specification, but are often used in TIR's documention:
- For scalars and pointers: The string format is
{code}{bits}
, with the code in lowercase. For example,int32
refers to 32-bit integers andfloat16
refers to 16-bit floating-point numbers. - For rectors: The string format is
{code}{bits}x{lanes}
. For example, a vector of 16-bit floating point values with 4 members isfloat16x4
and and a vector of 8-bit integers with 8 members isint8x8
.
TIR PrimExpr
s have a datatype (accessible in the implementation via the dtype
field) that indicates the datatype resulting from evaluating the expression. However, TIR variables can have a finer-grained type in their type_annotation
field. These finer-grained types are denoted in the AST under Type
(tvm::ir::Type
in the implementation).
The finer-grained types are as follows:
PrimType
indicates that theVar
does not have a more refined type than itsDataType
and provides no further information. It is required that thedtype
field of thePrimType
be equal to thedtype
field of theVar
.PointerType
indicates that the Var is bound to a pointer value.PointerType
, unlike theHandle
datatype, describes the datatype of the value being referenced by the pointer. This is most often used for thedata
field of aBuffer
, as thedata
field is a pointer to a specific region of memory on a specific device. If aVar
has atype_annotation
that is aPointerType
, itsdtype
field must have thekHandle
code. There are two fields inPointerType
:element_type
: AType
(which must bePrimType
) that describes the type of the value the pointer refers to.storage_scope
: A string that conveys device-specific information regarding the region of memory that the pointer addresses. For example, "shared
" refers to shared memory, and "shared.dyn
" refers to dynamic shared memory on CUDA GPUs.
TupleType
is used much less frequently than the previous two. Namely, it is used in only two cases in TIR, namely for theret_type
field ofPrimFunc
s or as thetype_annotation
field for aVar
that references a value with aVoid
datatype. In these cases, theTupleType
must have an empty list for itsfields
value (in the case of aPrimFunc
, it means that it does not return a value; for aVar
, it means that the value isVoid
).
For each PrimExpr
, we define the rule determining their dtype
field below:
-
Var(name_hint, dtype, type_annotation)
: There are two ways to construct a Var, either by specifying dtype or type_annotation (and not the other):- If
dtype
is specified andtype_annotation
is not specified, then the resulting datatype isdtype
. If dtype is notVoid
, thentype_annotation
should be set toPrimType(dtype)
. If dtype isVoid
, thentype_annotation
should be set toTupleType([])
. - If
type_annotation
is specified anddtype
is not, then determine the resulting datatype based ontype_annotation
as follows:- If
type_annotation
isPrimType
, then the resulting datatype istype_annotation->dtype
. - If
type_annotation
isPointerType
, then the resulting datatype isDataType(Handle, 64, 1)
. - If type_annotation is
TupleType([])
, then the resulting datatype isVoid
. - Any other
type_annotation
should be considered invalid and result in a type error.
- If
- If
-
IntImm(value, dtype)
:- The following conditions must hold or else there is a type error:
-
dtype
must be a scalar (have exactly one lane). -
dtype
must have a typecode of eitherInt
orUInt
. - If
dtype->code
isUInt
,value
must be greater than or equal to 0. Ifdtype->bits
is less than 64, value must be strictly less than$2^b$ , where$b$ isdtype->bits
. - If the
dtype->code
isInt
anddtype->bits
is greater than 1 and less than 64, value must be greater than or equal to$-(2^{b-1})$ and strictly less than$2^{b-1}$ , where$b$ isdtype->bits
. If the bitwidth is exactly 1, thenvalue
must be either 0 or 1.
-
- The resulting datatype is
dtype
.
- The following conditions must hold or else there is a type error:
-
FloatImm(value, dtype)
:- The following conditions must hold or else there is a type error:
-
dtype->code
must beFloat
orBFloat
-
value
must beNaN
,+inf
,-inf
, or between the minimum and maximum values for a floating point number of the bitwidth given: for 16-bitFloat
s:$\pm 65504$ ; for 16-bitBFloat
s:$\pm 3.38953139 \cdot 10^{38}$ ; for 32 bits:$\pm 3.402823466 \cdot 10^{38}$ ; and for 64 bits:$\pm 1.7976931348623158 \cdot 10^{308}$ .
-
- The resulting datatype is
dtype
.
- The following conditions must hold or else there is a type error:
-
StringImm(value)
: Its datatype isDataType(Handle, 64, 1)
. -
Cast(value, dtype)
: The number of lanes indtype(value)
must match the number of lanes indtype
or else there is a type error. «Ifvalue
has aHandle
datatype, thendtype
must also beHandle
or else there is a type error; ifdtype
isHandle
, thenvalue
must have a typecode ofInt
,UInt
, orHandle
or else there is a type error.» The resulting datatype isdtype
. -
Select(condition, true_value, false_value)
:- The following conditions must hold, or else there is a type error:
-
dtype(condition)
must be aBool
datatype (not necessarily a scalar). -
dtype(true_value)
anddtype(false_value)
must match. -
dtype(condition)->lanes
must either be 1 or matchdtype(true_value)->lanes
.
-
- The resulting datatype is
dtype(true_value)
.
- The following conditions must hold, or else there is a type error:
-
BufferLoad(buffer, indices)
:- Suppose
len(indices)
isn
. Ifn
is greater than 0, the firstn - 1
members of indices must have a scalar datatype (i.e., exactly one lane). That is, all indices except the last one must have scalar data types. - «All members of
indices
must have datatypes withInt
orUInt
typecodes, and they must all have the same bitwidth. In principle, hardware back-ends have some specific size of index they expect (most commonly 64-bit, but it may be 32-bit or 16-bit on lower-powered systems), but any integer width is permitted in TIR (it will be cast to the expected width at run time).» - Let
index_lanes
be 1 ifindices
is of length 0. Iflen(indices) > 0
, then letindex_lanes
bedtype(indices[len(indices)-1])->lanes
(the last member'slanes
). - Let
buffer_lanes
bebuffer->dtype->lanes
. - The resulting datatype will be
DataType(code=buffer->code, bits=buffer->bits, lanes=index_lanes*buffer_lanes)
.
- Suppose
-
Ramp(base, stride, lanes)
:- The following conditions must hold or else there is a type error:
- The value of
lanes
must be strictly greater than 1. -
dtype(base)
anddtype(stride)
must match anddtype(base)->lanes
anddtype(stride)->lanes
must both be 1. - «
dtype(base)->code
anddtype(stride)->code
must beInt
orUInt
.»
- The value of
- The resulting datatype will be
DataType(code=dtype(base)->code, base=dtype(base)->bits, lanes=lanes)
.
- The following conditions must hold or else there is a type error:
-
Broadcast(value, lanes)
:- The following conditions must hold or else there is a type error:
-
dtype(value)->lanes
must be 1. -
lanes
must be strictly greater than 1.
-
- The resulting datatype will be
DataType(code=dtype(value)->code, bits=dtype(value)->bits, lanes=lanes)
.
- The following conditions must hold or else there is a type error:
-
Let(var, value, body)
:-
dtype(var)
must matchdtype(value)
or else there is a type error. - The resulting datatype is
dtype(body)
.
-
-
Call(dtype, op, args)
: The resulting datatype isdtype
; the datatype is not otherwise checked. -
Shuffle(vectors, indices)
:- «
len(vectors)
must be at least 1 or else there is a type error.» - The datatypes of all elements of
vectors
must have the same typecode and bitwidth; it is a type error otherwise. Let the typecode bevector_code
and the bitwidth bevector_bits
. - Let
total_lanes
be the sum ofdtype(vectors[i])->lanes
over alli
from 0 tolen(vectors) - 1
, inclusive. -
len(indices)
must equaltotal_lanes
or else it is a type error. - «All members of
indices
must beInt
orUInt
scalars or else it is a type error.» - The resulting datatype will be
DataType(code=vector_code, bits=vector_bits, lanes=total_lanes).
- «
- Binary ops (with arguments
a
andb
), which areAdd
,Sub
,Mul
,Div
,Mod
,FloorDiv
,FloorMod
,Min
, andMax
:dtype(a)
anddtype(b)
must match «and the typecode must not beHandle
», or else it is a type error. The result will bedtype(a)
. «ForMod
,dtype(a)
anddtype(b)
must also both have either theInt
orUInt
typecode.» - Logical ops
And
andOr
, with argumentsa
andb
:a
andb
must both haveBool
datatypes and have the same number of lanes. The result will bedtype(a)
. -
Not(a)
:a
must have aBool
datatype (or else it is a type error). The result will have aBool
datatype with the same number of lanes asdtype(a)
. - Comparison operators (with arguments
a
andb
), which areEq
,NE
,LT
,LE
,GE
, andGT
:a
andb
must have the same datatype «and the typecode must not beHandle
», or else it is a type error. The result has aBool
datatype and the same number of lanes asdtype(a)
.
Even though statements do not produce values themselves, many contain PrimExpr
s and have requirements on the types for those PrimExpr
s. There is a type error if any condition listed below does not hold for the given statement (assuming its subfields typecheck individually). Some of these rules also include structural requirements not related to datatypes or type annotations.
LetStmt(var, value, body)
: Ifvar->type_annotation
is aPointerType
, thendtype(value)->code
must beHandle
anddtype(value)->bits
must be nonzero. (Note that there is no requirement on theelement_type
field: This allows for implicit casts of pointers.) Otherwise,dtype(var)
anddtype(value)
must match.AttrStmt(node, attr_key, value, body)
: Always valid.AssertStmt(condition, message, body)
:message
must be either aPrimExpr
whose datatype isint32
or aStringImm
node. Additionally,condition
must be aBool
scalar.BufferStore(buffer, value, indices)
:- Suppose
len(indices)
isn
. Ifn
is greater than 0, the firstn - 1
members of indices must have a scalar datatype (i.e., exactly one lane). That is, all indices except the last one must have scalar data types. - Let
index_lanes
be 1 iflen(indices)
is 0. Iflen(indices)
> 0, then letindex_lanes
bedtype(indices[len(indices)-1])->lanes
(the lanes of the last member's datatype). - «All members of
indices
must have a typecode ofInt
orUInt
; in fact, they must all have the same typecode and bitwidth.» - Let
buffer_lanes
bebuffer->dtype->lanes
. dtype(value)->lanes
must be equal toindex_lanes * buffer_lanes
.
- Suppose
BufferRealize(buffer, bounds, condition, body)
: The following conditions must hold:- «
condition
is aBool
scalar.» - «All members of bounds are
Int
orUInt
scalars; their datatypes and bitwidths match.»
- «
Allocate(buffer_var, dtype, extents, condition, body, annotations)
: The following conditions must hold:- Either
buffer_var->type_annotation
isPointerType(dtype)
ordtype
is aBool
scalar andbuffer_var->type_annotation
isPointerType(int8)
. - All members of extents must have a scalar datatype. «All members' typecodes and bitwidths must be Int or UInt and must all match.»
dtype(condition)
must be aBool
scalar.
- Either
DeclBuffer(buffer, body)
: Always valid.SeqStmt(seq)
: Always valid.IfThenElse(condition, then_case, else_case)
:dtype(condition)
must be aBool
scalar.Evaluate(value)
: Always valid.For(loop_var, min, extent, kind, body, thread_binding, annotations)
: The following conditions must hold:- The datatypes of
min
,extent
, andloop_var
must all be scalars «with anInt
orUInt
typecode». dtype(min)->bits
anddtype(extent)->bits
must be less than or equal todtype(loop_var)->bits
.- If
min
is anIntImm
node anddtype(min)->bits
<dtype(loop_var)->bits
, then considermin
to have the same datatype asloop_var
(i.e., "promote" its datatype). - If
extent
is anIntImm
node anddtype(extent)->bits < dtype(loop_var)->bits
, then "promote" its datatype as withmin
. - After performing the datatype "promotions," if necessary,
dtype(loop_var)
,dtype(min)
, anddtype(extent)
must all match exactly. - «If
kind
iskVectorized
,body
must not containWhile
statements. Additionally,min
must be anIntImm
with a value of 0 and extent must be anIntImm
with a value of at least 1.»
- The datatypes of
While(condition, body)
:condition
must have a scalar datatype with anInt
orUInt
typecode. Additionally,condition
must not be anIntImm
node.Block(iter_vars, reads, writes, name_hint, body, init, alloc_buffers, match_buffers, annotations)
: «The datatypes of thevar
fields for all members ofiter_vars
must beInt
orUInt
scalars.»BlockRealize(iter_values, predicate, block)
:len(iter_values)
must matchlen(block->iter_vars)
. Additionally,predicate
must have aBool
datatype.
Certain language constructs like IterVar
are neither PrimExpr
s nor statements but are used to construct PrimExpr
s and statements. They have some typing rules as well. The conditions listed below for each construct must hold (assuming their subfields type check) or else there is a type error.
IterVar(dom, var, iter_type, thread_tag)
: Ifdom
is specified and thedom->extent
is defined, thendom->extent
must have anInt
datatype anddtype(dom->extent)
must matchdtype(var)
.Range(min, extent)
: Always valid.Buffer(data, dtype, shape, axis_separators, strides, elem_offset, name, data_alignment, offset_factor, buffer_type)
: The following must hold:- «All members of
shape
must haveInt
orUInt
scalar datatypes.» - «All members of
strides
must haveInt
orUInt
scalar datatypes and they must all match exactly. (In this specification, we do not permit users to specify strides themselves, so all members must be freshVar
nodes.)» - «
elem_offset
must have anInt
orUInt
scalar datatype. (In this specification, we do not permit users to specifyelem_offset
themselves, so it must be a freshVar
node.)» data->type_annotation
must be aPointerType
anddata->type_annotation->element_type
must be aPrimType
.- If
buffer_type
iskAutoBroadcast
,strides
is empty, andshape
is nonempty, then treatstrides
as a list of freshVar
nodes of the same length asshape
, wheredtype(strides[i])
matchesdtype(shape[i])
for alli
from 0 tolen(shape) - 1
.
- «All members of
BufferRegion(buffer, region)
:region
must be of the same length asbuffer->shape
.MatchBufferRegion(buffer, source)
: The following must hold:- Let
source_buffer
besource->buffer
. Letregion
besource->region
. Letshape
bebuffer->shape
. - «
buffer->dtype
andsource_buffer->dtype
must match.» - «
source_buffer->data_alignment
must be divisible by buffer->data_alignment.» - «If
buffer->elem_offset
is anIntImm
with a value of 0, thensource_buffer->elem_offset
must also be anIntImm
with a value of 0. (In this specification, we do not permit users to specifyelem_offset
themselves, so this condition should not come up.)» - The
buffer->data->storage_scope
must matchsource_buffer->data->storage_scope
. buffer->buffer_type
andsource_buffer->buffer_type
must bekDefault
.- Let
offset
belen(shape) - len(region)
.offset
must be greater than or equal to 0. - For all
i
from 0 tooffset - 1
, the compiler must be able to statically prove thatregion[i]->extent
is numerically equal to 1 (including via arithmetic simplification) or else there is an error. - For all
i
from 0 tolen(shape) - 1
, ifshape[i]
is not aVar
, the compiler must be able to statically prove thatshape[i]
is numerically equivalent toregion[i + offset]->extent
(including via arithmetic simplification) or else there is an error.
- Let
PrimFunc
: Ifret_type
is not defined (note: this is usually the case in practice), treatret_type
asTupleType([])
. Moreover, if no call to thetir.ret
builtin appears in thePrimFunc
'sbody
field, thenret_type
must beTupleType([])
. If there is at least one call to thetir.ret
builtin in thePrimFunc
'sbody
, then theret_type
field must match the type of the value passed to the builtin:- If the returned value is a
Var
, thenret_type
must match thetype_annotation
field on theVar
. - If the returned value is not a
Var
and has aHandle
datatype, thenret_type
must be aPointerType
. (Note: ThePointerType
can contain more detailed information than what theHandle
states.) - If the returned value is not a
Var
and not aHandle
, thenret_type
must bePrimType
with a matching datatype. - «If there are multiple calls to
tir.ret
and the returned values and their datatypes do not all match, it is considered a type-checking error.» Otherwise, if there are multiple calls totir.ret
and the types of the returned values match, use the matching type for theret_type
.
- If the returned value is a
TIR enforces single static assignment (SSA), meaning that all variables must be unique and are bound exactly once. TIR follows lexical scoping, meaning that variables are scoped to the "block" (lexical block, not the Block
node) in which they are bound: If a variable is in scope, that means it is valid to reference it, and conversely, once it leaves scope, it may no longer be referenced.
Note that any GlobalVar
in the IRModule
that maps to a PrimFunc
may be referenced from any expression that is inside a PrimFunc
, including the GlobalVar
corresponding to the PrimFunc
currently executing (for recursive calls). That is, GlobalVar
s are always in scope.
Here is a list of binding sites:
PrimFunc
: Variables that appear in theparams
field are in scope for the entirety of thePrimFunc
body.Let
andLetStmt
: The variable invar
is in scope for the entirety ofbody
and leaves scope afterwards.BlockRealize
: Each variable contained initer_vars
is in scope whenblock
is executed and leaves scope afterwards.For
:loop_var
is in scope for the entirety ofbody
and leaves scope afterwards.Allocate
: Thevar
withinbuffer_var
is bound for the entirety ofbody
and leaves scope afterwards.BufferRealize
: Also acts as an allocation of the buffer, which means it is a binding site for the buffer'sdata
field as well as for any fresh variables in the buffer'sstride
,elem_offset
, andshape
fields.AttrStmt
: Certain attributes (thread_extent
andvirtual_thread
) act as binding sites, since they introduce a variable (in thenode
field) that is in scope when the body is executed.Block
: TheBlock
acts as a binding site for the buffers (namely theirdata
field) mentioned inalloc_buffers
andmatch_buffers
; they leave scope at the end of thebody
. However, the buffers inalloc_buffers
are not permitted to have unboundVar
s in theirshape
,stride
, orelem_offset
fields, soalloc_buffers
does not act as a binding site for those variables (any variables in those fields should already be bound).
TIR programs operate by modifying aspects of the program state. Here is the program state that may be accessed or altered from TIR:
- A TIR program begins by calling a
PrimFunc
that may take as arguments some regions of memory (which are organized in a buffer map). Any buffers in the buffer map are expected to have been allocated by the caller. TIR may access any location within those buffers and may modify the contents of those buffers as well. - TIR may allocate more memory (i.e., buffers other than those passed as arguments). Any new buffers will also be deallocated by TIR; buffers are generally associated with particular scopes and will be deallocated at the end of the scope.
- Depending on the back-end, TIR statements may also launch new threads and utilize synchronization primitives.
- External calls are capable of invoking arbitrary TVM
PackedFunc
s and can therefore alter any system state. External calls can also invoke device-specific routines that affect the device state. Any usage of external calls to modify state is fully the caller's responsibility to manage; the TIR compiler makes no assumptions about such resources. - Raising errors or exiting abnormally. This may be done by calls to intrinsics, but also by
AssertStmt
.
In terms of this specification, we consider reads and writes to memory (buffers) and any sort of abnormal exit or I/O side effects to be externally visible, so the semantics for TIR will be described in terms of these actions. For the purposes of the specification, we do not consider memory allocations/deallocations to be directly "observable" by the user, in order to give the compiler greater freedom to rearrange or consolidate memory allocations. Similarly, even though latency and other metrics of performance are very important in practice, the specification does not consider them to be "observable." This provides the compiler with the greatest freedom to make performance-related changes to the code, so long as the other observable behavior remains unchanged. (That is, the specification does not make any promises that specific optimizations are applied in specific situations. The compiler may do those things, but it must preserve the other observable behavior.)
Certain TIR constructs refer to buffers and operations on the memory that underlies them (allocations, accesses, updates, and deallocations). In this version of the specification, we will treat buffers as abstract multidimensional arrays of values of the listed datatype. The specification will thus refer only to the indices of buffers in terms of their shape. A buffer with dtype
ty
with shape (d1, d2, ..., dn)
could thus be conceived of as an array containing d1
elements, each of which is an array containing d2
elements, etc., until we finally obtain an array of dn
elements, each of which is a member of ty
.
For example, if the shape is (2, 2, 3)
, an array representation of it could be [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
; index (0, 1, 1)
would give us element 5 and index (1, 0, 2)
would give us element 9.
Note that a buffer can have ()
as a shape; such buffers will be interpreted as storing a single value of the buffer's dtype
. Each buffer should be assumed to be unique (not aliasing any other), unless it is indicated in the match_buffers
field of a Block
(see the semantics for Block
).
Even though these buffers are represented on real hardware back-ends in terms of memory and operations on buffers are implemented as reads and writes over memory, we will not specify buffers in terms of pointer or indexing arithmetic because details about memory access for hardware back-ends vary greatly (for example, some back-ends use two-dimensional physical indices). Additionally, TIR does not permit direct manipulation of pointers in the language (there is no arithmetic defined for values with the Handle
datatype). Optimizations like sharing single memory allocations between buffers (e.g., using different offsets or strides) are left to the lower levels of the compiler to implement. Hence, at the top level of TIR, we do not make any guarantees about the representation of buffers in memory, so any TIR code that makes use of such details is not specified by this document.
PrimExpr
s in TIR yield values when executed (hence the word "evaluate"). This section describes the value produced by executing each PrimExpr
.
-
Var
: If the variable is in scope, return the value bound to that variable. (It is otherwise an error to reference an unbound variable.) -
IntImm(value, dtype)
: Evaluates to an integer value equal tovalue
with datatypedtype
. -
FloatImm(value, dtype)
: Evaluates to a floating point value equal tovalue
with datatypedtype
. -
StringImm(value)
: Evaluates to aHandle
value. TheHandle
points to the start of a string with the characters ofvalue
. (The strings are null-terminated C-compatible strings and they are never deallocated.) -
Cast(value, dtype)
: Evaluatevalue
(calling the resultv
). Castv
to datatypedtype
, returning the castvalue
. Ifv
is a vector, then perform the cast element-by-element (assuming no particular ordering). For a scalar or for each element of a vector, casts behave as they would in C (see the ISO C Standard for full formal detail) or in C++'sstatic_cast
: unsigned integers are cast by truncating the most significant bits or padding with zeros, casts to signed integers involve sign extension, and casts from floating point to integer truncate. Special cases:- If
value
has aHandle
datatype,dtype
must beHandle
;Handle
values cannot be cast to other datatypes. - If
dtype
isHandle
andvalue
does not have aHandle
datatype, this is still valid. The numerical value is cast to a pointer, but the specification makes no guarantees about the result except in the case wherev
is an integer scalar with the value 0: this is treated as a null pointer, which can be used by some builtins.
- If
-
Select(condition, true_value, false_value)
:- Evaluate
condition
and call the resultc
. -
Select
is not short-circuiting: Evaluatetrue_value
and call the resultt
, and evaluatefalse_value
and call the resultf
. - If
condition
is a scalar, then ifc
is 1, then the result ist
. Ifc
is 0, the result isf
. - If
condition
is a vector, then the result will be a vector of the same width. Let us suppose the result is calledr
. Fori
between 0 and||c|| - 1
,r.i
ist.i
ifc.i
is 1 andf.i
otherwise. No specific order of execution for instantiating the elements ofr
is guaranteed.
- Evaluate
-
BufferLoad(buffer, indices)
:- If
indices
is of length 0, then this means the buffer stores only a single element. In this case, return that single element (do not perform the below steps). - If
indices
is nonempty, evaluate the members ofindices
in order, calling the list of resulting valuesindices'
. Letn
belen(indices)
. - Cast all values in
indices'
to the integer type expected for the hardware back-end (most commonly, 64-bit unsigned integers as per C'ssize_t
, but it may be smaller on some hardware back-ends). - If all members of
indices'
are scalars, then letv
be the member ofbuffer
at index(indices'[0], indices'[1], ..., indices'[n-1])
. - If
indices'[n-1]
is a vector, leti_lanes
be its number of lanes. Letelems
be a list of buffer elements (each of which has a datatype ofbuffer->dtype
), of lengthi_lanes
. For eachj
from 0 toi_lanes - 1
(inclusive),elems[j]
is the element of buffer at the indices(indices'[0], indices'[1], ..., indices'[n-1].j)
. Letv
beconcat(elems[0], elems[1], ..., elems[i_lanes-1])
, a single vector withi_lanes * buffer->dtype->lanes
lanes. - Note that if any set of buffer indices is out of bounds at run time (e.g., if any single member of
indices'
is out of bounds), there is no guarantee on what will result. By default, TIR does not check bounds at run time. - Return
v
.
- If
-
Ramp(base, stride, lanes)
:- Evaluate
base
and call itb
. Evaluatestride
and call its
. - The result is a vector with the same bitwidth and typecode as
dtype(b)
withlanes
for the number of lanes. Thei
th element of the vector is equal tob + i * s
, following the arithmetical semantics given under the rule for binary operators below (castingi
to the datatype ofs
if necessary), wherei
ranges from 0 tolanes - 1
(inclusive).
- Evaluate
-
Broadcast(value, lanes)
: Evaluatevalue
and call the resultv
(per the type system,v
must be a scalar). Return a vector with datatypeDataType(code=dtype(value)->code, bits=dtype(value)->bits, lanes=lanes)
, where all elements of the vector have the valuev
. -
Let(var, value, body)
:- Evaluate
value
. Let us call the resultv
. - Create a new scope where
v
is bound tovar
. - Next evaluate
body
in the new scope. Let us call the resultb
. - Pop the scope (i.e., remove
v
from the scope). - Return b.
- Evaluate
-
Call(dtype, op, args)
:- If
op
is not aGlobalVar
corresponding to a TIRPrimFunc
in theIRModule
, then this node calls a TIR builtin. All TIR builtins have their own semantics, so no general semantics can be given forCall(dtype, op, args)
; it is not even guaranteed that the members of args will be evaluated. Instead, see the section on builtins for a discussion of the semantics of builtins. - If
op
is aGlobalVar
corresponding to a TIRPrimFunc
(note that this can be thePrimFunc
currently exeucting, resulting in a recursive call), then this will call the denotedPrimFunc
:- Evaluate the members of
args
, from left (index 0) to right (indexlen(args) - 1
). Let these values be collectively denoted thevi
. - Push a new scope.
- Evaluate the
PrimFunc
denoted byop
according to the rules in the statement semantics, using thevi
for the parameter values. Buffer parameters should be provided usingDLTensor
s, which may necessitate conversion at run time. - Pop the scope.
- If the
PrimFunc
returns a value via thetir.ret
builtin, then return that value. Otherwise, treat the return value as aVoid
value. (The latter will usually be the case, sincePrimFunc
s typically mutate buffer inputs rather than return values.)
- Evaluate the members of
- If
op
denotes aBaseFunc
in theIRModule
other than aPrimFunc
, this is invalid. Raise a run-time error.
- If
-
Shuffle(vectors, indices)
:- Evaluate the elements of
vectors
in order, calling the list of resultsvectors'
. - Evaluate the elements of
indices
in order, calling the list of resultsindices'
. Cast the members of indices' to the indexing type expected by the hardware back-end (most commonly a 64-bit unsigned integer, but the width may be smaller on some systems). - Let
v
be the result of concatenating all members ofvectors'
together in order (note:v
does not have to be materialized in the implementation of this operator, but it allows for specifying the result concisely). - The result is a vector with datatype
DataType(code=dtype(vectors[0])->code, bits=dtype(vectors[0])->bits, lanes=len(indices))
. Let us call the resultr
. For alli
from 0 tolen(indices) - 1
(inclusive),r.i
is set tovectors'[i].indices'[i]
.
- Evaluate the elements of
- Binary ops (with arguments
a
andb
), which areAdd
,Sub
,Mul
,Div
,Mod
,FloorDiv
,FloorMod
,Min
, andMax
:- In all cases, evaluate
a
and thenb
, calling the resulting valuesv1
andv2
. Per the type system, these must have the same datatype. For all operators, we will consider a functionf
that describes the semantics of that operator for a single element. Ifv1
andv2
are scalars, then the result will bef(v1, v2)
, using the below definitions off
. Ifv1
andv2
are vectors, then the result will be a vector of the same size, where thei
th element of the result isf(v1.i, v2.i)
for alli
from 0 to||v1|| - 1
. Note that for computing elements of vectors, no particular order of execution should be assumed. The result of evaluating the expression will have the same datatype asa
andb
. - For values with a
Float
typecode, the arithmetic operators below follow the semantics supported by the hardware back-end (generally expected to be IEEE 754, but specialized devices may deviate from it). Analogously, forBFloat
values, thebfloat16
specification should be followed. For integers, the operations should be taken to act on the binary representation of the integers (two's complement for signed integers), with the according overflow and underflow behavior as a result (if the bitwidth isb
, then the max value for unsigned integers is$2^b - 1$ for unsigned integers and for signed integers, the min value is$-(2^{b-1})$ and the max value is$2^{b-1} - 1$ ). - For
Add
,$f(x, y) = x + y$ . - For
Sub
,$f(x, y) = x - y$ . - For
Mul
,$f(x, y) = x \cdot y$ . - For
Div
,$f(x, y) = x / y$ (if$x$ and$y$ haveInt
orUInt
typecodes, then the division gives an error for dividing by zero and truncates toward zero, as in C. ForFloat
orBFloat
typecodes, the division should follow the floating point or Brain float standards for division, respectively). - For
Mod
(only defined forInt
andUInt
operands),$f(x, y) = x \text{ mod } y$ (i.e.,$x - ((x / y) \cdot y)$ ). Note that, as in C, the sign of the result will be the sign of$x$ , regardless of the sign of$y$ . - For
FloorDiv
,$f(x, y) = \lfloor x / y \rfloor$ . - For
FloorMod
,$f(x, y) = x - (\lfloor x / y \rfloor \cdot y)$ , the remainder of the floor division. (See the note below comparingDiv
,Mod,
FloorDiv
, andFloorMod
.) - For
Min
,$f(x, y) = \text{min}(x, y)$ . - For
Max
,$f(x, y) = \text{max}(x, y)$ .
- In all cases, evaluate
- Logical ops
And
andOr
, with argumentsa
andb
:- If
a
andb
are scalars, then we implement short-circuiting semantics:- For
And
, evaluatea
and call the resultv1
. Ifv1
is 0, then return 0 (without evaluatingb
). Ifv1
is 1, then evaluateb
and call the resultv2
; returnv2
. - For
Or
, evaluatea
and call the resultv1
. Ifv1
is 1, then return 1 (without evaluatingb
). Ifv1
is 0, then evaluateb
and call the resultv2
; returnv2
.
- For
- If
a
andb
are vectors, then we make no guarantee as to whether the implementation is short-circuiting on a per-element level.- For safety, neither
a
norb
should contain side effects (which may happen in calls to builtins); if it is important for there to be side effects, we recommend instead decomposing the vector into scalars. - Suppose that
v1
andv2
are the result of evaluatinga
andb
, respectively (though it is not guaranteed that all elements of both will be evaluated). We return a vector of the same size as a where thei
th element of the vector isf(v1.i, v2.i)
for eachi
from 0 to||v1|| - 1
, using the below definitions off
:- For
And
,$f(x, y) = x \land y$ . - For
Or
,$f(x, y) = x \lor y$ .
- For
- For safety, neither
- If
- Logical op
Not
with a unary argumenta
: Evaluatea
(which must be aBool
value, per the type system) and call the resultv
. Ifv
is a scalar, return 0 ifv
is 1 and 1 ifv
is 0. Ifv
is a vector, return a vector of the same size where thei
th element is 1 ifv.i
is 0 and 0 ifv.i
is 1, for alli
from 0 to||v|| - 1
. - Comparison operators (with arguments
a
andb
), which areEq
,NE
,LT
,LE
,GE
, andGT
:- In all cases, evaluate
a
and thenb
, calling the resulting valuesv1
andv2
. Per the type system, these must have the same datatype. For all operators, we will consider a functionf
that describes the semantics of that operator for a single element. Ifv1
andv2
are scalars, then the result will bef(v1, v2)
, using the below definitions off
. Ifv1
andv2
are vectors, then the result will be a vector of the same size, where thei
th element of the result isf(v1.i, v2.i)
for alli
from 0 to||v1|| - 1
. Note that for computing elements of vectors, no particular order of execution should be assumed. The result of evaluating the expression will have the same datatype asa
andb
. Ifv1
andv2
have aFloat
typecode, use the semantics supported by the hardware back-end (again, generally expected to be IEEE 754) to determine the results (especially for comparisons withNaN
,+inf
, and-inf
); if they haveInt
orUInt
typecodes, interpret the comparisons mathematically. Analogously, forBFloat
values, thebfloat16
specification should be followed. The datatype of the result isBool
in all cases. - For
Eq
,$f(x, y)$ is 1 if$x = y$ (numerically equal) and 0 otherwise. - For
NE
,$f(x, y)$ is 1 if$x \neq y$ (numerically unequal) and 0 otherwise. - For
LE
,$f(x, y)$ is 1 if$x \leq y$ and 0 otherwise. - For
LT
,$f(x, y)$ is 1 if$x < y$ and 0 otherwise. - For
GE
,$f(x, y)$ is 1 if$x \geq y$ and 0 otherwise. - For
GT
,$f(x, y)$ is 1 if$x > y$ and 0 otherwise.
- In all cases, evaluate
With Div
, integer division truncates like in C, where 5/2
evaluates to 2 and -5/2
similarly evaluates to -2. Mod
follows the same sign rule as %
in C (where -5 % 2
evaluates to -1). As noted by Guido van Rossum in this article about Python's integer arithmetic, (a/b)*b + (a % b)
ends up equal to a
if a
is negative only because the remainder (a % b
) is negative. With FloorMod
and FloorDiv
rules for %
and /
, (a/b)*b + (a % b) = a
holds while a % b
is never negative.
As a result, below are some properties of FloorMod
and FloorDiv
that TIR's compiler uses for optimizations. All of them assume that N
is a positive integer and are quantified over all x
.
0 <= FloorMod(x, N) < N
.FloorDiv(x + N, N) = FloorDiv(x, N) + 1
.FloorMod(x + N, N) = FloorMod(x, N)
.
With Div
and Mod
, these rules would all depend on the sign of x
or the sign of x + N
.
Builtins are external procedures that can be invoked from TIR via the Call
PrimExpr
. Note that some TIR documentation and comments refer to builtins as "intrinsics." In this document, we use the term "builtin" to distinguish TIR's builtins from platform-specific intrinsics.
Each builtin in TIR can essentially be treated as a PrimExpr all its own, albeit one that is used too rarely or too situationally to be made a "first-class" part of the AST. TIR builtins generally fit within the following broad categories:
- Platform-specific intrinsics (especially on GPUs)
- Hints to the compiler that have no effect of their own
- Less common mathematical operations
- Interactions with TVM's object system
TIR builtins are also categorized in terms of the effects they may have:
kExprAnnotation
: Acts as an annotation for the benefit of the compiler and acts as the identity function for its inputs.kPure
: Acts as a pure function (evaluates its inputs and returns a value, having no other effects).kReadState
: May read memory other than its arguments. For example, this may be global memory or memory that results from dereferencing its arguments (which must also be constructed via builtins).kUpdateState
: May update memory.kOpaque
: Cannot make any assumptions as to whether it reads or writes to any states.kSpecialCallArg
: The intrinsic indicates that its result is a special value that is valid for certain other intrinsic calls. Namely, the result is meant to be used only in a specific context and should not appear outside of that context (e.g., producing a value meant to be used only by certain other builtins).kEmbedInfo
: Acts similarly tokExprAnnotation
, except the result of its call is removed from the final generated code (i.e., treat the argument as never being evaluated).kControlJump
: Affects control flow.
We will enumerate and describe the builtins in a separate document, as they are added and changed more frequently than other language constructs.
However, one builtin is particularly important and should be discussed directly in the core language semantics, since it is used to implement return values for PrimFunc
s: tir.ret
. This builtin takes one argument, which it evaluates. After evaluating the argument, the execution of the current PrimFunc
's body
halts (see the semantics for PrimFunc
s below) and the PrimFunc
returns the value of the passed value (the "return value").
Unlike PrimExprs
, statements do not return values. Instead, they operate by modifying the program state. These rules describe how each variety of statement in TIR affects the program state. Statements do, however, depend on values produced by evaluating PrimExprs
.
PrimFunc(params, body, ret_type, buffer_map, attrs)
: APrimFunc
is not technically a statement, but TIR execution always begins by calling aPrimFunc
.- First, the variables in
params
enter the scope with the called values (passed externally). Note that not all members ofparams
need to be passed in as an external argument. Namely, if a variable inparams
appears in theshape
,strides
, orelem_offset
field of any member ofbuffer_map
, it will be assigned below, in step ii. - If a member of
params
(let us call itv
) is a key inbuffer_map
, that means thatv
corresponds to a buffer at run time. The external caller must pass in a pointer to aDLTensor
(defined in thedlpack
library); it will correspond tobuffer_map[v]
in the program. Let us refer to theDLTensor
andbuffer_map[v]
ast
andb
, respectively.- The elements of the buffer are read from
t->data
depending on the shape and striding defined. Ift->shape
is empty, then theDLTensor
stores a single element at locationt->data
;b
will accordingly also store a single element in the same manner. Otherwise, letn
belen(t->shape)
andS
be the size of a member oft->dtype
. - If
t->strides
is null, thent
has a tightly packed, row-major representation, so the element atindices
ofb
is at addressdata + S*(indices[0]*(t->shape[1]*t->shape[2]*...*t->shape[n-1]) + indices[1]*(t->shape[2]*...*t->shape[n-1]) + ... + indices[n-2]*t->shape[n-1] + indices[n-1])
. - If
t->strides
is not null, then the elementindices
is given by the addressdata + S*(indices[0]*t->strides[0] + indices[1]*t->strides[1] + ... + indices[n-1]*t->strides[n-1])
. - Additionally, the following correspondences are checked between
b
andt
:b->dtype
andt->dtype
must match or else there is an error.- Let
elem_offset
bet->byte_offset
divided by the size of bytes of a member ofb->dtype
. Ifb->elem_offset
is a Var, then if it is currently unbound, bind it toelem_offset
. Ifb->elem_offset
is anIntImm
, its value must matchelem_offset
or else an error is raised. Ifb->elem_offset
is aVar
that is already bound, then its bound value must matchelem_offset
or else an error is raised. - If
t->strides
is null, thenb->strides
must either be empty or it must be of lengtht->shape
whereb->strides[i]
is anIntImm
or boundVar
with a value oft->shape[i+1] * t->shape[i+2] * ... * t->shape[n-1]
for alli
from 0 ton - 2
and anIntImm
or boundVar
with a value of 1 forb->strides[n - 1]
, wheren
islen(t->shape)
. (These values for strides are equivalent to having a tightly packed row-major representation.) If neither condition is met, then an error is raised. - If
t->strides
is not null, thenlen(t->strides)
must matchlen(b->strides)
or else an error is raised. For alli
from 0 tolen(b->strides) - 1
,b->strides[i]
must be anIntImm
whose value matchest->strides[i]
(or else an error is raised), an already boundVar
whose bound value ist->strides[i]
(or else an error is raised), or an unboundVar
(in which case, it is bound with the valuet->strides[i]
). len(t->shape)
andlen(b->shape)
must match or else an error is raised. For alli
from 0 tolen(b->shape) - 1
,b->shape[i]
must be anIntImm
whose value matchest->shape[i]
(or else an error is raised), an already boundVar
whose bound value ist->shape[i]
(or else an error is raised), or an unboundVar
(in which case, it is bound with the valuet->shape[i]
).
- One further condition that
PrimFunc
s expect of theirDLTensor
arguments: No twoDLTensor
arguments are permitted to alias each other. - Next,
body
is executed. ThePrimFunc
produces outputs by mutating values in buffers passed as the inputs; these changes can be observed by the caller via theDLTensor
representations passed in step i. If atir.ret
builtin was executed, thePrimFunc
returns the value that was passed totir.ret
; otherwise, thePrimFunc
returns a void value. Upon returning (whether after encountering a call totir.ret
or the end ofbody
), the current scope is popped and all values allocated within thePrimFunc
are freed.
- The elements of the buffer are read from
- First, the variables in
LetStmt(var, value, body)
:- Evaluate
value
(let us call the resultv
). Ifvar->type_annotation
is aPointerType
, then implicitly castv
to a pointer tovar->type_annotation->element_type
(as far as TIR is concerned, it is simply aHandle
value). - Push a new scope.
- Bind
v
tovar
in the new scope. - Execute
body
. - Pop the scope.
- Evaluate
AttrStmt(node, attr_key, value, body)
: For almost all values ofattr_key
, this node has no functional semantics of its own and serves only to provide additional information to the compiler; for those cases, simply evaluatebody
. However, certain values ofattr_key
do affect the semantics and will be described below:thread_extent
andvirtual_thread
: These attributes have semantics similar to a parallelFor
loop (though they are realized on hardware in different ways). For these attributes,node
must be aVar
node andvalue
must be anIntImm
giving an upper bound.- Evaluate
value
and call the resultv
. - Evaluate
body
v
times in parallel, binding node toi
in thei
th parallel execution. Any interleaving of execution is permitted; additionally, there is no guarantee about how many distinct threads will be created to execute the loop body. If an error occurs in one execution, it is guaranteed that execution will not proceed past theAttrStmt
but it is not guaranteed that all parallel executions will stop simultaneously or, in the case that multiple executions raise errors, which error will be the one displayed. node
is in scope only during the execution ofbody
and leaves scope afterwards.node
cannot be reassigned or modified during the loop body.
- Evaluate
volatile_scope
: In this case,node
must be aVar
node. The variable denoted bynode
should be bound somewhere inbody
. This attribute indicates to the compiler that the assignment is volatile in the same sense as thevolatile
keyword in C: The binding and any references tonode
inbody
must not be optimized away by the compiler under any circumstances. The only other semantics is to evaluatebody
.
AssertStmt(condition, message, body)
: Evaluatecondition
(let us call the resultv
). Ifv
is 1, then executebody
. Otherwise, raise an assertion error withmessage
as the error message.BufferStore(buffer, value, indices)
:- Let
buffer_lanes
bebuffer->dtype->lanes
. - Evaluate
value
and call the resultv
. Letvalue_lanes
bedtype(v)->lanes
. - Evaluate
indices
and call the array of resultsindices'
. Cast all values in indices' to the integer type expected for the hardware back-end (most commonly, 64-bit unsigned integers as per C'ssize_t
, but it may be smaller on some hardware back-ends). Letn
belen(indices')
. - Let the number of lanes of
indices'[n-1]
bei_lanes
. - Depending on
n
andi_lanes
:- If
n
is 0, then the shape ofbuffer
must be()
. Storev
asbuffer
's only element. - If all members of
indices'
are scalars, then storev
to index(indices'[0], indices'[1], ..., indices'[n-1])
inbuffer
. - If
i_lanes
is greater than 1, then letm
be||v||
and letW
bem / buffer->dtype->lanes
(truncating division). For eachj
from 0 toi_lanes - 1
:- Let
p
beconcat(v.(j*W), v.((j * W) + 1), ..., v.(((j + 1)* W) - 1))
, i.e., take the vector consisting of the(j*W)
th lane ofv
through the((j + 1)*W)
th lane ofv
(exclusive), withW
lanes in total. - Store
p
to element(indices'[0], indices'[1], ..., indices'[n-2], indices'[n-1].j)
ofbuffer
.
- Let
- Note that if any buffer index is out of bounds at run time, there is no guarantee on what will result. By default, TIR does not check bounds at run time.
- If
- Let
BufferRealize(buffer, bounds, condition, body)
:- Push a new scope.
- Evaluate
buffer->shape
, allocating a new buffer of that shape with the datatypebuffer->dtype
(note: in the compiler implementation, this may be implemented either as an actual allocation or by loading in external data). This acts as an assignment tobuffer->data
. - Evaluate
body
. - Pop the scope and deallocate the newly allocated buffer.
bounds
andcondition
provide additional information for the compiler for code generation, but do not affect the semantics.
Allocate(buffer_var, dtype, extents, condition, body, annotations)
:- Evaluate
condition
and call the resultc
. Ifc
is 0, then finish executing the statement. - Evaluate the members of
extents
in order, calling the list of resultsextents'
. - Push a new scope.
- Allocate a buffer of shape
extents'
whose entries are of datatypedtype
.buffer_var
will be assigned a pointer to this buffer. Note that we do not specify the layout of this buffer in memory. The resulting allocated buffer will not alias any buffer existing in the program. For the purposes of this specification, we assume each allocation to correspond to one and only one buffer; lower levels of compilation may attempt to consolidate memory allocations, but that should not be done at the front end. - Execute
body
. - Deallocate the memory (i.e., delete the buffer) allocated in step iv. Pop the scope.
- Evaluate
DeclBuffer(buffer, body)
: Indicates to the compiler that buffer will be in scope forbody
. The only semantics at run time is thatbody
is executed.SeqStmt(seq)
: Execute the statements inseq
one after the other in the order of the list.IfThenElse(condition, then_case, else_case)
: Evaluatecondition
(let us call the resultv
). Ifv
is 1, executethen_case
. Otherwise, ifelse_case
is present, executeelse_case
(ifelse_case
is not present, then do nothing further).Evaluate(value)
: Simply evaluatevalue
, which is aPrimExpr
. This is effectively a no-op unlessvalue
contains a call to a builtin, as the result of evaluatingvalue
cannot be accessed or used by any other statement.For(loop_var, min, extent, kind, body, thread_binding, annotations)
: The semantics of aFor
statement depend onkind
:- If
kind
iskSerial
:- Evaluate
min
and call the resultm
. Castm
to the bitwidth ofloop_var
. - Push a new scope.
- In the new scope bind
m
toloop_var
. - Evaluate
extent
and call the resulte
. Caste
to the bitwidth ofloop_var
. - If
loop_var
is greater than or equal toe
, then pop the scope and finish executing the statement. - Evaluate
body
. - Bind
m + 1
toloop_var
. Return to step c and resume execution from there with this new value ofloop_var
. (Note thatloop_var
cannot be mutated from within the loop body.)
- Evaluate
- If
kind
iskParallel
:- Evaluate
min
and call the resultm
and evaluateextent
and call the resulte
. Leti1
bem
,i2
bem + 1
, ..., andin
bee - 1
. - Evaluate body
e - m
times in parallel, withloop_var
bound toij
in thej
th parallel execution. Any interleaving of execution is permitted; additionally, there is no guarantee about how many distinct threads will be created to execute the loop body (or whether the executions will actually be in parallel). If an error occurs in one execution, it is guaranteed that execution will not proceed past theFor
statement, but it is not guaranteed that all parallel executions will stop simultaneously or, in the case that multiple executions raise errors, which error will be the one displayed. If any loop iteration writes to a buffer index that is read by any other loop iteration (earlier or later), there is no guarantee on the resulting semantics. loop_var
is in scope only during the execution ofbody
and leaves scope afterwards.loop_var
cannot be reassigned or modified during the loop body.
- Evaluate
- If
kind
iskVectorized
: The visible semantics are the same as those forkParallel
in that the loop body will be evaluatedextent
times (min
must be 0 ifkind
iskVectorized
); no dependencies between loop iterations are permitted, meaning namely that no loop iteration may write to a buffer index that is read by any other loop iteration (no guarantee is made on the resulting semantics if that is the case). In terms of the implementation, the loop will be implemented by combining loop iterations into single invocations of vectorized operations when this is possible. However, loads and stores to buffers and any other side effects should not be affected by this change—unlike withkParallel
, the ordering of side effects (including errors) must be preserved. - If
kind
iskUnrolled
: The semantics are the same as forkSerial
(this kind simply indicates that the compiler should generate code for the loop by unrolling it rather than including jumps, but it does not change the semantics). - If
kind
iskThreadBinding
: The semantics are the same as forkParallel
, but this indicates to the compiler that the loop iterations should be mapped to hardware threads (as in CUDA), with thethread_binding
field giving further information for the compiler to use for the mapping.
- If
While(condition, body)
:- First, evaluate
condition
(let us call the resultv
). - If
v
is 0, the statement is finished executing. - If
v
is 1, executebody
. Resume execution from step i.
- First, evaluate
Block(iter_vars, reads, writes, name_hint, body, init, alloc_buffers, match_buffers, annotations)
:- Push a new scope.
- For
i
ranging from 0 tolen(alloc_buffers) - 1
, let us consider the members ofalloc_buffers
:- Evaluate the members of
alloc_buffers[i]->shape
, calling the list of resultsshape'
. - Allocate a buffer of datatype
alloc_buffers[i]->dtype
with shapeshape'
, binding the result toalloc_buffers[i]->data
.
- Evaluate the members of
- For j ranging from 0 to len(match_buffers) - 1:
- Let
buffer
bematch_buffers[j]->buffer
,source_buffer
bematch_buffers[j]->source->buffer
, andregion
bematch_buffers[j]->source->region
. - In the current scope, we will treat all instance of
buffer
as aliases ofsource_buffer
, where every read or write is offset by indicesindices'
, which we define as[region[0]->min, region[1]->min, ..., region[n-1]->min]
, wheren
islen(region)
. That is, allBufferLoad
operations onbuffer
will be be treated asBufferLoad
operations onsource_buffer
(addingindices'
to theBufferLoad
's indices) andBufferStore
operations are handled analogously. This is the only form of aliasing permitted in the specification.- Note that if
buffer->elem_offset
orbuffer->stride
contain any hitherto unbound variables, they will be bound by pattern matching onsource_buffer
's run-time representation. Since the specification at this level leaves the element offset and stride as implementation details, it will not be specified how these values are determined (this rule is included to indicate that these variables are bound). buffer->shape
must have the same length asregion
or else there is an error. For alli
from 0 tolen(buffer->shape) - 1
, evaluateregion[i]->extent
, which we will callextent
.buffer->shape[i]
must be anIntImm
whose value matchesextent
(or else an error is raised), an already boundVar
whose bound value isextent
(or else an error is raised), or an unboundVar
(in which case, it is bound with the valueextent
).- Let
offset
belen(source_buffer->shape) - len(indices)
. We ignore the firstoffset
members ofindices
when doing the offsetting; that is,buffer
can have fewer dimensions thansource_buffer
. Letm
belen(indices)
. - For example, given a node
BufferLoad(buffer, indices)
, treat it asBufferLoad(source_buffer, [indices'[0], indices'[1], ..., indices'[offset-1], Add(indices[0], indices'[offset]), Add(indices[1], indices'[offset+1]), ..., Add(indices[m-1], indices'[n-1])])
. Ifindices[m-1]
is a vector rather than a scalar, the last element should instead beAdd(indices[m-1], BroadcastTo(indices'[n-1], ||indices[m-1]||))
. - Additionally, no reads or writes may be done on buffer past the indices
[Add(region[0]->min, region[0]->end), Add(region[1]->min, region[1]->end), ..., Add(region[n-1]->min, region[n-1]->end)]
(though, as with out-of-bounds accesses, this is not checked at runtime by default; no semantics are guaranteed for any access outside the bounds listed here).
- Note that if
- Let
- If any of the vars in
iter_vars
is a reductionIterVar
(has the typekCommReduce
), then we consider the block to be a "reduction block." If the block is a reduction block, the block is located in thebody
of aFor
orWhile
loop, andinit
is specified, theinit
statement is executed only during the first iteration of the loop. Otherwise, if specified, theinit
statement is executed each time the block is executed. - Execute
body
. The other fields are included only for the benefit of the compiler in code generation and do not affect the visible semantics. - Pop the scope, deallocating any buffers allocated in step ii.
BlockRealize(iter_values, predicate, block)
:- Open a new scope.
- For each
i
from 0 tolen(iter_values) - 1
:- Evaluate
iter_values[i]
and call the resultiter_value
. - Let
iter_var
beblock->iter_vars[i]
. - Bind
iter_value
toiter_var->var
.
- Evaluate
- Evaluate
block
.predicate
provides additional information for the compiler, but does not affect the semantics. - Pop the scope (removing all variables added in step ii).
Some expressions included in the TIR AST implementation are included primarily to support lowering from TE into TIR; they are, however, not part of TIR itself. That is, they will be lowered into other constructs in TIR. These constructs are the following:
CommReducer(lhs: [Var], rhs: [Var], result: [PrimExpr], identity_element: [PrimExpr])
Reduce(combiner: CommReducer, source: [PrimExpr], init: [PrimExpr], axis: [IterVar], value_index: int, condition: PrimExpr?)
Prefetch(buffer: Buffer, bounds: [Range])
Some expressions included in the TIR implementation have been deprecated; they should no longer be used and will not be supported. They include the following:
Load(dtype: DataType, buffer_var: Var, index: PrimExpr, predicate: PrimExpr)
: Replaced byBufferLoad
Store(buffer_var: Var, value: PrimExpr, index: PrimExpr, predicate: PrimExpr)
: Replaced byBufferStore
AllocateConst(buffer_var: Var, data_or_idx: NDArray | IntImm, dtype: DataType, extents: [PrimExpr], body: Stmt, annotations: {str: Object*})
: Never well-supported in the first place. Use an ordinaryAllocate
or an argument to aPrimFunc
instead.