-
Notifications
You must be signed in to change notification settings - Fork 29
Tensor Operations
TBLIS computes operations on tensors (and as specific cases, matrices and vectors). A tensor is essentially a multidimensional array. Each tensor has a name and a list of indices, each of which has a specific length. For examle, consider a multidimensional array in C:
double A[10][4][5][12];
In tensor notation, this is denoted:
where nx is the length of index x. The indices pqrs are arbitrary, but are useful when defining tensor operations. The simplest tensor operations take one tensor as input and/or output. For example, the operation:
takes each element of A, scales it by , and writes it to a different element of B (the element written to is different because the indices are transposed, i.e. permuted). These kinds of operations are denoted unary or "level-1" operations in analogy with BLAS level-1 functions.
The remaining tensor operations take two tensor as inputs and produce one output tensor. These are called binary or "level-3" operations (even though technically the matrix equivalent may be a level-2 BLAS operation in some cases). An example is:
This is an example of tensor contraction, which is the tensor analogue of matrix multiplication. In this case, elements of A and B are multiplied together, and the products summed over the range of the t and u indices, then written to elements of C.
For conciseness, the summation symbol is often dropped, and generalized Einstein summation is used instead. In this notation, all indices appearing in the output tensor are iterated over, and all other indices (those appearing only in one or more of the inputs) are summed over. This results in three distinct classes of indices for unary operations (A, B, and AB, where A is the input and B is the output) and seven for binary operations (A, B, C, AB, AC, BC, and ABC, where A and B are the inputs and C is the output). Based on which types of indices appear, different operations can be defined (note that if a tensor has no indices then it is replaced by a scalar, for example in a unary operation with no B or AB indices):
Index Types | Unary Operations | Index Types | Binary Operations |
---|---|---|---|
A | reduction | AB | dot product |
AB | transpose/copy | ABC | hadamard product |
B | set to scalar | AC, BC | outer product |
A, AB | trace | AC, BC, ABC | weighting |
AB, B | replication | AB, AC, BC | contraction |
all | addition | all | multiplication |
Currently, operations that are special cases of another operation (except where a tensor is demoted to a scalar), are not separately accessible in the API, although they may be in the future. Taking all of these base cases (and one more, where the input and output in a copy are the same--giving a scale operation) gives:
Operation | Classification | Definition |
---|---|---|
add | level-1t | |
dot | level-1t | |
reduce | level-1t | |
scale | level-1t | |
set | level-1t | |
shift | level-1t | |
mult | level-3t |
where ... denotes an arbitrary set of indices. Note that dot
has been moved to level-1t to continue the analogy to BLAS, and that the definition of reduce
has been expanded to include other reductions such as min/max, lp norms, etc. These are the basic operations exported by TBLIS, in addition to vector (level-1v, i.e. BLAS1) and matrix (level-1m and level-3m = GEMM) equivalents.