Skip to content

Commit

Permalink
OpenXLA-specific changes
Browse files Browse the repository at this point in the history
  • Loading branch information
chsigg committed Dec 10, 2024
1 parent 88c704e commit b2132de
Show file tree
Hide file tree
Showing 39 changed files with 3,514 additions and 940 deletions.
931 changes: 931 additions & 0 deletions BUILD

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
lhsDivisibility = 1;
}
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
}

int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
Expand Down Expand Up @@ -1011,6 +1011,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
CastOpAxisInfoVisitor<arith::ExtUIOp>,
CastOpAxisInfoVisitor<arith::TruncIOp>,
CastOpAxisInfoVisitor<arith::IndexCastOp>,
CastOpAxisInfoVisitor<arith::IndexCastUIOp>,
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
CastOpAxisInfoVisitor<triton::BitcastOp>>();
Expand Down
12 changes: 12 additions & 0 deletions lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) -> Value {
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
// remaining arguments that have been converted to a new type.
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
// 'convert-triton-to-tritongpu'.
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
inputs);
llvm_unreachable("Argument rematerialization should not happen in Triton "
"-> TritonGPU conversion");
return {};
Expand All @@ -66,6 +72,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// convert origValue to newValue
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) -> Value {
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
// remaining uses of values that have been converted to a new type.
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
// 'convert-triton-to-tritongpu'.
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
inputs);
llvm_unreachable("Source rematerialization should not happen in Triton -> "
"TritonGPU Conversion");
return {};
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Triton/Transforms/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def CombineDotAddIPattern : Pat<
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
def CombineDotAddFPattern : Pat<
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath),
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath, $denorm),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
Expand All @@ -29,7 +29,7 @@ def CombineDotAddIRevPattern : Pat<
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
def CombineDotAddFRevPattern : Pat<
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath),
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath, $denorm),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
Expand Down
13 changes: 9 additions & 4 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3127,6 +3127,11 @@ struct CanonicalizeConvertFromAlloc
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
// to SharedEncoding, so we want to keep this layout conversion.
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
convert.getSrc().getType().getEncoding()))
return failure();
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
op, op->getResult(0).getType(), convert.getSrc());
return mlir::success();
Expand Down Expand Up @@ -3189,13 +3194,13 @@ struct CanonicalizeConvertFromConvert
// heuristic to accommodate fused attention.
auto srcType = op.getSrc().getType();
auto dstType = op.getType();
if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
if (mlir::isa_and_nonnull<DotOperandEncodingAttr>(dstType.getEncoding()) &&
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
return failure();

// for hopper MMAv3
if (mlir::isa<SharedEncodingAttr>(dstType.getEncoding()) &&
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
if (mlir::isa_and_nonnull<SharedEncodingAttr>(dstType.getEncoding()) &&
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
return dot->hasTrait<OpTrait::DotLike>();
})) {
Expand Down
50 changes: 43 additions & 7 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ namespace mlir {
namespace triton {
namespace gpu {

namespace {

// Get the highest version supported for the hardware and the dot.
static int getMMAVersionSafe(int computeCapability, DotOp op) {
// List supported mma version in order of preference.
Expand All @@ -44,8 +42,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
return 0;
}

SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
int numWarps) {
SmallVector<unsigned>
warpsPerTileV2(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps) {
auto rank = shape.size();
// Early exit for batched matmul
if (rank == 3)
Expand Down Expand Up @@ -109,10 +107,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
}

SmallVector<unsigned, 2>
warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
warpsPerTileV3(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps,
const SmallVector<unsigned, 3> &instrShape) {
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
mlir::getForwardSlice(dotOp->getResult(0), &slices);
// Contains a chained dot. We prefer to assign warps to one axis
// to facilitate use cases like flash attention, allowing reductions within
// the same warp.
Expand Down Expand Up @@ -167,11 +165,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
newLayout, SharedMemorySpace);
rewriter.setInsertionPointAfterValue(arg);

// LocalAllocOp lowering doesn't support going from DotOperandEncoding
// to SharedEncoding.
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
argType.getEncoding())) {
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
// then pass it to the LocalAllocOp.
auto newArgType = RankedTensorType::get(
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
auto dotOperandToBlockedCvt =
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
dotOperandToBlockedCvt);
}

return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
}

SmallVector<unsigned, 3>
getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
getWarpsPerTile(Operation* dotOp, const ArrayRef<int64_t> shape, int version,
int numWarps, const SmallVector<unsigned, 3> &instrShape) {
switch (version) {
case 2:
Expand All @@ -184,18 +197,32 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
}
}

// Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
// extension.
namespace {

class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
int computeCapability;
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;

static bool bwdFilter(Operation *op) {
// Dot operand layout assignment to Predicates are not currently supported
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
// condition limits visibility of the original bit-width so that predicate
// are not considered, hence, kwidth can never be = 32.
if (isa<arith::UIToFPOp>(op)) {
Type srcType = getElementTypeOrSelf(op->getOperand(0));
if (srcType.isInteger(1))
return false;
}
return op->getNumOperands() == 1 &&
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
isPureUnaryInlineAsm(op) ||
op->getDialect()->getTypeID() ==
mlir::TypeID::get<arith::ArithDialect>());
}

public:
// Finds the first different bitwidth in the chain of shape-preserving
// unary ops that x depends on.
// There are two primary scenarios:
Expand Down Expand Up @@ -806,6 +833,15 @@ class TritonGPUAccelerateMatmulPass
}
};

// Expose helper functions from BlockedToMMA to be reused for sparse matmul.
int computeOrigBitWidth(Value x) {
return BlockedToMMA::computeOrigBitWidth(x);
}
Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter,
int opIdx, bool allowTranspose) {
return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose);
}

} // namespace gpu
} // namespace triton
} // namespace mlir
26 changes: 22 additions & 4 deletions lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ namespace {
// dot(a, b, inputPrecision="tf32x3") ->
// let aBig = f32ToTF32(a), aSmall = a - aBig;
// let bBig = f32ToTF32(b), bSmall = b - bBig;
// dot(aSmall, bBig, inputPrecision="tf32") +
// dot(aBig, bSmall, inputPrecision="tf32") +
// dot(aBig, bBig, inputPrecision="tf32")
// let small = dot(aSmall, bBig, inputPrecision="tf32") +
// dot(aBig, bSmall, inputPrecision="tf32")
// let masked_nans = replaceNansWithZeros(small)
// let big = dot(aBig, bBig, inputPrecision="tf32")
// return big + masked_nans;
class TF32x3 : public OpRewritePattern<DotOp> {
public:
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -62,6 +64,13 @@ class TF32x3 : public OpRewritePattern<DotOp> {
InputPrecision::TF32,
dotOp.getMaxNumImpreciseAcc());
};
auto replaceNansWithZeros = [&](Value value) -> Value {
auto nans = rewriter.create<arith::CmpFOp>(
dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value);
auto zero = zeroLike(value);
return rewriter.create<arith::SelectOp>(dotOp->getLoc(), nans, zero,
value);
};

auto aBig = f32ToTF32(dotOp.getA());
auto aSmall = sub(dotOp.getA(), aBig);
Expand All @@ -73,7 +82,16 @@ class TF32x3 : public OpRewritePattern<DotOp> {

auto dot1 = dot(aSmall, bBig, zero);
auto dot2 = dot(aBig, bSmall, dot1);
auto dot3 = dot(aBig, bBig, dot2);

// If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0.
// If rhs is +infinity, we will have:
// +infinity * 1.0 = +infinity
// +infinity * 0.0 = NaN
// We would get the wrong result if we sum these partial products. Instead,
// we must override any accumulated result if the last partial product is
// non-finite.
auto dot2withZeroedNans = replaceNansWithZeros(dot2);
auto dot3 = dot(aBig, bBig, dot2withZeroedNans);

auto sum = add(dot3, dotOp.getC());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,

Value zero = builder.createWithStage<arith::ConstantIntOp>(
forOp.getLoc(), stage, clusterId, 0, 32);

// Replace the load with insert/extract slice.
builder.setInsertionPoint(loadOp);
Location loc = loadOp.getLoc();
Expand Down Expand Up @@ -468,7 +469,8 @@ assignMemoryLayouts(scf::ForOp &forOp,
}
});

loadsToPipeline.insert(&op);
// TODO: b/381421713 - Uncomment this once pipelining is fixed.
// loadsToPipeline.insert(&op);
LoadInfo loadInfo;
for (auto use : users) {
if (use->hasTrait<OpTrait::DotLike>()) {
Expand Down Expand Up @@ -508,6 +510,11 @@ assignMemoryLayouts(scf::ForOp &forOp,
getBlockedEncoding(loadOp, axisInfoAnalysis);
}
}

// TODO: b/381421713 - Remove this once pipelining is fixed.
if (!loadInfo.sharedEncoding) continue;
loadsToPipeline.insert(&op);

loadToInfo[&op] = loadInfo;
}
// Make sure all loads in loadsToPipeline are in loadToInfo.
Expand Down
26 changes: 24 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
// opIdx: 0 => a, 1 => b
auto type = cast<triton::gpu::MemDescType>(v.getType());
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
SmallVector<int64_t> offset{0, 0};
SmallVector<int64_t> offset(shape.size(), 0);
Type elementType = type.getElementType();

// k => (prefetchWidth, k - prefetchWidth)
Expand All @@ -141,8 +141,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
type.getMutableMemory(), type.getAllocShape()),
v, offsetsVal);

// We need to assign kwidth to zero in the case where the parent layout is
// Blocked, otherwise the verifier emits a failure. The parent layout is
// Blocked only when Tensor Cores are disabled.
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
? 0
: prefetchWidth / 8;
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
builder.getContext(), opIdx, dotEncoding, kwidth);
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
Expand Down Expand Up @@ -191,6 +197,22 @@ LogicalResult Prefetcher::initialize() {
break;
if (!op->getResult(0).hasOneUse())
break;
// Similar to issues faced in HoistLayoutConversion pattern in
// OptimizeDotOperands.cpp, we can't propagate through type casts from
// predicates as they aren't supported in Triton when encoded with dot_op
// layout.
if (isa<arith::UIToFPOp>(op)) {
Type srcType = getElementTypeOrSelf(op->getOperand(0));
if (srcType.isInteger(1))
break;
}
// Propagation through ExpandDims is currently not supported. This blindly
// replaces the encoding with dot encoding & but ExpandDims requires a
// SliceEncoding. This could be rewritten to support it somehow, but I
// don't think it's trivial & it's currently crashing.
if (isa<ExpandDimsOp>(op)) {
break;
}
rets.push_back(op->getOperand(0));
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
foundConvertFromShared = true;
Expand Down
30 changes: 19 additions & 11 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -953,18 +953,26 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
} else {
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
return std::nullopt;
auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(
cast<triton::gpu::TensorOrMemDesc>(user->getResult(0).getType())
.getEncoding());
if (!dotOpEnc)
auto enc =
cast<triton::gpu::TensorOrMemDesc>(user->getResult(0).getType()).getEncoding();
if (isa<ttg::DotOperandEncodingAttr>(enc)) {
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
auto order = ttg::getOrder(srcTy.getEncoding());
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
tempAttr = ttg::SharedEncodingAttr::get(
val.getContext(), cast<ttg::DotOperandEncodingAttr>(enc),
srcTy.getShape(), order, CTALayout, bitWidth, /*needTrans=*/false);
} else if (enc.getAbstractAttribute().getName().str() ==
"triton.gpu.sparse_dot_meta_encoding") {
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
tempAttr = ttg::SharedEncodingAttr::get(
val.getContext(), /*vec=*/1, /*perPhase=*/1, /*maxPhase=*/1,
ttg::getOrder(srcTy.getEncoding()),
ttg::getCTALayout(srcTy.getEncoding()));
} else {
return std::nullopt;
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
auto order = ttg::getOrder(srcTy.getEncoding());
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
tempAttr = ttg::SharedEncodingAttr::get(
val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout,
bitWidth, /*needTrans=*/false);
}
}
// Check that the shared encodings needed by the users are compatible.
if (attr != nullptr && attr != tempAttr) {
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct FenceInsertionPass
return;
ModuleOp mod = getOperation();
mod.walk([&](Operation *op) {
if (!isa<ttng::WarpGroupDotOp>(op))
if (!op->hasTrait<OpTrait::DotLike>())
return WalkResult::advance();
OpBuilder builder(op);
auto a = op->getOperand(0);
Expand Down
Loading

0 comments on commit b2132de

Please sign in to comment.