Towards High-Performance AI Compilers.
Problem Statement.
Making interpreted languages run fast for specialized hardware like ASICS is difficult.
Solution.
A killer compiler. Let's take a look at this paper in particular from Intel.
Multi-Level Intermediate Representation (MLIR).
MLIR is intended to be a hybrid IR that allows us to develop reusable and extensible compiler infrastructure. It came from the realization that modern ML frameworks like PyTorch & TensorFlow are composed of a lot of different compilers, graph technologies, and runtime systems which do not share a common infrastructure or design point. Not to mention most of these systems didn't follow good compiler design practices in their development. The downstream effect was poor error messages, edge case failures, unpredictable performance, and an inability to develop new hardware that generalized the software stack. MLIR allows us to do a couple of things:
- Represent and manage dataflow graphs, including dynamic shapes and user-extensible operations for different ML frameworks.
- Perform a wide range of optimizations and transformations, such as loop optimizations, memory layout adjustments, and quantization for deep learning models.
- Handle code generation and hardware-specific transformations, including DMA insertion, cache management, vectorization, and support for hardware synthesis tools.
We also need to have a good understanding of dialects. A dialect is a basic structure that enables the MLIR to implement a stack of reusable abstractions, composed operations, types, attributes etc. Each abstraction encodes and preserves transformation validity preconditions directly in its IR, reducing the complexity and cost of analysis passes. Each dialect models a specific domain. In our case, it makes sense to use the Linalg dialect which captures linear-algebra operations on either tensor or buffer operands + the Tensor dialect for tensor creation. Below is an example of a simple multi-layer perceptron (MLP) layer represented in Linalg + Tensor generic operations:
// Comments are bolded.
// Affine maps M,K * K,N -> M,N
# map-mk = affine_map<(d0, d1, d2) -> (d0, d2)>
# map-kn = affine_map<(d0, d1, d2) -> (d2, d1)>
# map-mn = affine_map<(d0, d1, d2) -> (d0, d1)>
// Nested affine fused multiply and accumulate operation (matmul).
%0 = linalg.generic {
indexing_maps = [#map-mk, #map-kn, #map-mn],
// Reduction iterator type is the third, ie. ``d2'', which is the ``K'' dimension.
iterator_types = ["parallel", "parallel", "reduction"]
}
// Inputs are A and B matrices, C is the initialized of the output (generally zero).
ins(%A, %B : tensor<128x256xf32>, tensor<256x512xf32>)
// Output is the C matrix, representing initialization (C+= A * B).
outs(%C : tensor<128x512xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%3 = arith.mulf %in, %in_1 : f32
%4 = arith.addf %out, %3 : f32
linalg.yield %4 : f32
} -> tensor<128x512xf32>
// Affine maps element-wise & broadcast.
# map-ew = affine_map<(d0, d1) -> (d0, d1)>
# map-bc = affine_map<(d0, d1) -> (d1)>
// A binary operation on the output of the matmul above (ex. Bias Add)
%1 = linalg.generic {
indexing_maps = [#map-ew, #map-bc],
iterator_types = ["parallel", "parallel"]
}
// Inputs are C and Bias matrices.
// Note: the bias is a 1D vector being broadcasted to add element-wise.
// Note: the C matrix is the initializer of the output, so it's in `outs`.
ins(%BIAS : tensor<512xf32>)
outs(%0 : tensor<128x512xf32>) {
^bb0(%in: f32, %out: f32):
%4 = arith.addf %in, %out : f32
linalg.yield %4 : f32
} -> tensor<128x512xf32>
// A unnary operation on the output of the binary above (ex. ReLU).
%ZERO = arith.constant 0.000000e+00 : f32
%2 = linalg.generic {
// Element-wise parallel operation only uses MN maps
indexing_maps = [#map-ew],
iterator_types = ["parallel", "parallel"]
}
// Input is just the result above.
// Note: the result is the initializer of the output, so it's in `outs`.
outs(%1 : tensor<128x512xf32>) {
^bb0(%out: f32):
%4 = arith.maximumf %out, %ZERO : f32
linalg.yield %4 : f32
} -> tensor<128x512xf32>
return %2
Tiling + Kernel Fusion.
Tiling is a technique that splits large computations into smaller blocks, or tiles, to make processing more efficient. CPUs can compute much faster than they can fetch data from memory, so tiling ensures that the data being worked on fits into the faster CPU cache. By doing this, it reduces the need to repeatedly access slower memory and reuses the data already in the cache, speeding up the overall computation. This technique enhances performance by optimizing data locality and parallelism across dimensions. When dealing with large datasets or matrices, we often perform multiple operations in sequence. This is where kernel fusion comes into play. Kernel fusion's main idea is to optimize operations involving tensors, such as contractions and element-wise operations, by combining multiple operations into a single computational kernel, reducing memory accesses via leveraging data locality. Let's look at how these concepts apply to matrix operations: Matrix multiplication \( C = A \times B \) involves a contraction over the shared dimension \( k \): \[ C_{ij} = \sum_{k} A_{ik} \cdot B_{kj} \] Element-wise operations are applied individually to each element of a tensor. For matrices, element-wise addition might be represented as: \[ E_{ij} = C_{ij} + D_{ij} \] The key to optimization is fusing contraction and element-wise operations. Instead of performing these operations separately, combining them into a single pass can reduce the number of memory accesses and improve performance: \[ E = (A \times B) + D \] Below is some code demonstrating how to fuse operations along the M and N parallel loops. Here, M and N represent the dimensions handling tile and fuse passes. These loops become two parallel loops, MB and NB, due to the absence of loop-carried dependencies between tile operations. Below is code showing this:// Comments are bolded
// Convert tile-wise operations
for (MB, NB) {
for (KB) {
C[MB][NB][mb][nb] += A[MB][KB][mb][kb] * B[NB][KB][kb][nb];
}
C[MB][NB][mb][nb] = add(C[MB][NB][mb][nb], bias[mb][nb]);
C[MB][NB][mb][nb] = max(C[MB][NB][mb][nb], 0);
}
// Into a parallel BRGEMM "tile" operation + element-wise tail operations
parallel(MB, NB) {
// Extract {A, B, C} x [MB][NB][KB] as appropriate
// Note: This is a batch-reduce GEMM into a "tile"
C[mb][nb] += A[KB][mb][kb] * B[KB][kb][nb];
C[mb][nb] = add(C[mb][nb], bias[mb][nb]);
C[mb][nb] = max(C[mb][nb], 0);
// Insert into C[MB][NB]
}
Data Blocking.
Sometimes called packing, data blocking is a pretty well known transformation in high-performance libraries that copies a non-contiguous block of data to a contiguous block in memory to reduce the # of TLB entries requires to access each page. When data is copied, data blocking rearranges block elements to decrease the stride between consecutive accesses which improves spatial locality and cache behavior. By bringing this into our compiler we can propagate the data layout through the IR instead of paying the price for every execution. This set of dialects with tiled and fused operations can then be bufferized by a one shot bufferization pass, cleaned, and lowered to further low level dialects where library/hardware specific passes can operate on already memory-friendly shape. And voila, we now have a common representation that we call Tensor-Linalg that simplifies the optimization process across tons of frameworks all the way down to the hardware.XSMM.
Before you get on your way, I'd like to introduce XSMM which is a dialect that maps the behavior of the libxsmm. The libsxmm library is a JIT-ing library for sparse and dense matrix operations for deep learning that is split into two stages:- Dispatch: The dispatch stage receives the shapes of the buffer, leading dimensions, broadcast, and fusion flags + compiles the microkernel in memory, returning a pointer to into its implementation. The second time a dispatch function is called, it just returns a cached pointer to the same implementation
- Invoke: The invoke stage calls that function pointer with the actual tensor data (usually a tile into a larger buffer with appropriate strides), which computes the operation, writing the result to the output buffer.
- unary: These are element-wise unary operations like ReLU, but also broadcasts, transposes, and reductions.
- binary: These are element-wise binary operations like add or multiply.
- gemm: This is General Matrix Multiplications mirroring the BLAS interface.
- brgemm: This is a more powerful abstraction that carries an extra reduction dimension on the input operands, allowing reducing tiles of A and B in the same C tile. For example you have matrix multiplication where matrices A and B are being multiplied to produce matrix C. Traditional matrix multiplication might compute each element of C individually. With brgemm, the computation can be optimized by breaking matrices A and B into smaller tiles and performing the multiplication and reduction operations within those tiles, resulting in more efficient use of computational resources.
- fused_brgemm: This allows the BRGEMM operation to be combined or "fused" with element-wise operations that occur before (prologue) and after (epilogue) the main BRGEMM computation. This fusion process uses register fusion to manage intermediate results more efficiently, leading to potential improvements in performance and reduced overhead.
// Comments are bolded
// Original sequence of XSMM calls
%3 = xsmm.unary.dispatch zero [...] flags = (none)
%4 = xsmm.brgemm.dispatch [none] flags = (none)
%5 = xsmm.binary.dispatch add [...] flags = (bcast_col_in0)
%6 = xsmm.unary.dispatch relu [...] flags = (none)
scf.parallel (MB, NB) {
%subview_A = memref.subview ... // Subview into matrix A
%subview_B = memref.subview ... // Subview into matrix B
%subview_C = memref.subview ... // Subview into matrix C
// Initialize C[MB][NB] with zeros
xsmm.zero(..., %3, %subview_C)
// Perform BRGEMM: C[MB][NB] = BRGEMM(A[MB][NB], B[MB][NB], C[MB][NB])
xsmm.brgemm(data_type = f32, %4, %subview_A, %subview_B, %subview_C, %c0)
// Add bias: C[MB][NB] = ADD(broadcast(Bias[NB]), C[MB][NB])
xsmm.binary add(..., %5, %BIAS, %subview_C, %subview_C)
// Apply ReLU activation: C[MB][NB] = ReLU(C[MB][NB])
xsmm.unary relu(..., %6, %subview_C, %subview_C)
}
// Optimized version using fused_brgemm
%3 = xsmm.fused_brgemm.dispatch [...][add,relu] flags = (beta_0)
binary_flags = (bcast_col_in0) unary_flags = (none)
scf.parallel(MB, NB) {
%subview_A = memref.subview ... // Subview into matrix A
%subview_B = memref.subview ... // Subview into matrix B
%subview_C = memref.subview ... // Subview into matrix C
// Fused operation combining initialization, BRGEMM, bias addition, and ReLU activation
// 1. C[MB][NB] = { 0.0 } (Initialization)
// 2. C[MB][NB] = BRGEMM(A[MB][NB], B[MB][NB], C[MB][NB]) (MatMul)
// 3. C[MB][NB] = ADD(broadcast(Bias[NB]), C[MB][NB]) (Bias addition)
// 4. C[MB][NB] = ReLU(C[MB][NB]) (ReLU activation)
xsmm.fused_brgemm(..., %3, %subview_A, %subview_B, %subview_C, %BIAS %c4)
}
Closing Thoughts.
There are a lot of people working on building compilers in the open for ML like IREE so check it out if you're interested, but I recommend going from scratch always. You'll learn existing best practices, e.g. writing and maintaining an IR spec, building an IR verifier etc. Below is a sample compilation strategy that you can follow using what you've hopefully learned from this post to be able to make your own AI compiler:- An ingress layer, that extracts MLIR from existing frameworks into our common representation compiler IR.
- high-level hardware agnostic Tensor-Linalg pipeline.
- low-level lowering dialect.
- lowering to hardware dialects (XSMM)
- An execution strategy for the generated code, including runtime libraries and wrappers.