Skip to content

Commit

Permalink
OpenXLA-specific changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Moerafaat committed Oct 29, 2024
1 parent a33e3ac commit da8895b
Show file tree
Hide file tree
Showing 40 changed files with 2,418 additions and 70 deletions.
908 changes: 908 additions & 0 deletions BUILD

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
lhsDivisibility = 1;
}
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
}

int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
Expand Down Expand Up @@ -1084,8 +1084,8 @@ LogicalResult AxisInfoAnalysis::visitOperation(

void AxisInfoAnalysis::visitForOpInductionVar(
scf::ForOp op, ArrayRef<dataflow::Lattice<AxisInfo> *> argLattices) {
auto lb = getLatticeElementFor(op, op.getLowerBound())->getValue();
auto step = getLatticeElementFor(op, op.getStep())->getValue();
auto lb = getLatticeElementFor(getProgramPointAfter(op), op.getLowerBound())->getValue();
auto step = getLatticeElementFor(getProgramPointAfter(op), op.getStep())->getValue();

AxisInfo::DimVectorT knownContiguity(1, 1);
AxisInfo::DimVectorT knownDivisibility(1, 1);
Expand Down
6 changes: 3 additions & 3 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,15 +896,15 @@ class ConstantAnalysis : public DataFlowAnalysis {

LogicalResult initialize(Operation *top) override {
WalkResult result = top->walk([&](Operation *op) {
if (failed(visit(op)))
if (failed(visit(getProgramPointAfter(op))))
return WalkResult::interrupt();
return WalkResult::advance();
});
return success(!result.wasInterrupted());
}

LogicalResult visit(ProgramPoint point) override {
Operation *op = point.get<Operation *>();
LogicalResult visit(ProgramPoint* point) override {
Operation *op = point->getPrevOp();
Attribute value;
if (matchPattern(op, m_Constant(&value))) {
auto *constant = getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(
Expand Down
6 changes: 4 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
auto ouEltTy = ouTensorTy.getElementType();
if (inBitWidth == ouBitWidth)
return values;
if (inBitWidth == 16 && ouBitWidth == 32) {
if ((inBitWidth == 16 && ouBitWidth == 32) ||
(inBitWidth == 32 && ouBitWidth == 16)) {
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 8) {
ret.push_back(values[i]);
Expand All @@ -54,7 +55,8 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
}
return ret;
}
if (inBitWidth == 8 && ouBitWidth == 16) {
if ((inBitWidth == 8 && ouBitWidth == 16) ||
(inBitWidth == 16 && ouBitWidth == 8)) {
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 16) {
ret.push_back(values[i + 0]);
Expand Down
10 changes: 5 additions & 5 deletions lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,20 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// This will create newArg, and map(origArg, newArg)
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
llvm_unreachable("Argument rematerialization should not happen in Triton "
"-> TritonGPU conversion");
return std::nullopt;
return Value();
});

// If the origValue still has live user(s), use this to
// convert origValue to newValue
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
llvm_unreachable("Source rematerialization should not happen in Triton -> "
"TritonGPU Conversion");
return std::nullopt;
return Value();
});

// This will be called when (desiredType != newOperandType)
Expand All @@ -79,7 +79,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
ValueRange inputs, Location loc) {
auto cast =
builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType, inputs);
return std::optional<Value>(cast.getResult());
return Value(cast.getResult());
});
}

Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2770,6 +2770,11 @@ struct CanonicalizeConvertFromAlloc
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
// to SharedEncoding, so we want to keep this layout conversion.
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
convert.getSrc().getType().getEncoding()))
return failure();
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
op, op->getResult(0).getType(), convert.getSrc());
return mlir::success();
Expand Down
24 changes: 24 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,21 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
newLayout, SharedMemorySpace);
rewriter.setInsertionPointAfterValue(arg);

// LocalAllocOp lowering doesn't support going from DotOperandEncoding
// to SharedEncoding.
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
argType.getEncoding())) {
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
// then pass it to the LocalAllocOp.
auto newArgType = RankedTensorType::get(
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
auto dotOperandToBlockedCvt =
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
dotOperandToBlockedCvt);
}

return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
}

Expand All @@ -162,6 +177,15 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;

static bool bwdFilter(Operation *op) {
// Dot operand layout assignment to Predicates are not currently supported
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
// condition limits visibility of the original bit-width so that predicate
// are not considered, hence, kwidth can never be = 32.
if (isa<arith::UIToFPOp>(op)) {
Type srcType = getElementTypeOrSelf(op->getOperand(0));
if (srcType.isInteger(1))
return false;
}
return op->getNumOperands() == 1 &&
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
isPureUnaryInlineAsm(op) ||
Expand Down
17 changes: 16 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
PatternRewriter &rewriter) const override {
// Only consider conversions to dot operand.
auto cvtTy = cast<RankedTensorType>(cvt.getType());
if (!isa<DotOperandEncodingAttr>(cvtTy.getEncoding()))
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
if (!dotOpEnc)
return failure();

auto src = cvt.getSrc().getDefiningOp();
Expand All @@ -126,6 +127,12 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
[](Type ty) { return isa<RankedTensorType>(ty); }))
return failure();

// Quick handling to fix loading issues when computing the original
// bitwidth is unable to realize that there is a mixed-precision dot
// (hence kWidth = 1) but wants to hoist through the type conversion.
if (isa<arith::ExtFOp>(src) && dotOpEnc.getKWidth() == 1)
return failure();

// Only consider custom conversions or arith ops.
// TODO(jlebar): Is this too restrictive?
if (!isa<FpToFpOp, BitcastOp>(src) && !isPureUnaryInlineAsm(src) &&
Expand All @@ -138,6 +145,14 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src))
return failure();

// Don't hoist through u1 -> fp casts as they aren't supported in
// ElementwiseOpToLLVM::reorderValues().
if (isa<arith::UIToFPOp>(src)) {
Type srcType = getElementTypeOrSelf(src->getOperand(0));
if (srcType.isInteger(1))
return failure();
}

// Check that the conversion is transitively dependent on a load, and all
// operations between the load and the conversion are layout preserving.
//
Expand Down
26 changes: 24 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
// opIdx: 0 => a, 1 => b
auto type = cast<triton::MemDescType>(v.getType());
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
SmallVector<int64_t> offset{0, 0};
SmallVector<int64_t> offset(shape.size(), 0);
Type elementType = type.getElementType();

// k => (prefetchWidth, k - prefetchWidth)
Expand All @@ -140,8 +140,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
type.getMemorySpace()),
v, offsetsVal);

// We need to assign kwidth to zero in the case where the parent layout is
// Blocked, otherwise the verifier emits a failure. The parent layout is
// Blocked only when Tensor Cores are disabled.
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
? 0
: prefetchWidth / 8;
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
builder.getContext(), opIdx, dotEncoding, kwidth);
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
Expand Down Expand Up @@ -190,6 +196,22 @@ LogicalResult Prefetcher::initialize() {
break;
if (!op->getResult(0).hasOneUse())
break;
// Similar to issues faced in HoistLayoutConversion pattern in
// OptimizeDotOperands.cpp, we can't propagate through type casts from
// predicates as they aren't supported in Triton when encoded with dot_op
// layout.
if (isa<arith::UIToFPOp>(op)) {
Type srcType = getElementTypeOrSelf(op->getOperand(0));
if (srcType.isInteger(1))
break;
}
// Propagation through ExpandDims is currently not supported. This blindly
// replaces the encoding with dot encoding & but ExpandDims requires a
// SliceEncoding. This could be rewritten to support it somehow, but I
// don't think it's trivial & it's currently crashing.
if (isa<ExpandDimsOp>(op)) {
break;
}
rets.push_back(op->getOperand(0));
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
foundConvertFromShared = true;
Expand Down
3 changes: 2 additions & 1 deletion lib/Target/LLVMIR/LLVMDIScope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ struct LLVMDIScopePass : public LLVMDIScopeBase<LLVMDIScopePass> {
auto subprogramAttr = LLVM::DISubprogramAttr::get(
context, distinctId, compileUnitAttr, fileAttr, funcNameAttr,
funcNameAttr, fileAttr, /*line=*/line, /*scopeline=*/line,
subprogramFlags, subroutineTypeAttr, /*retainNodes=*/{});
subprogramFlags, subroutineTypeAttr, /*retainNodes=*/{},
/*annotations=*/{});
funcOp->setLoc(FusedLoc::get(context, {loc}, subprogramAttr));
}

Expand Down
77 changes: 77 additions & 0 deletions python/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# NOTE: Do not depend on any targets from this directory,
# but use //third_party/py/triton instead.

load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")

package(
default_applicable_licenses = ["//:license"],
default_visibility = [
"//third_party/py/triton:__pkg__",
"@triton//python:__subpackages__",
],
)

cc_library(
name = "passes",
hdrs = ["src/passes.h"],
includes = ["src"],
visibility = ["@triton//third_party:__subpackages__"],
)

pybind_extension(
name = "libtriton",
srcs = [
"src/interpreter.cc",
"src/ir.cc",
"src/llvm.cc",
"src/main.cc",
"src/passes.cc",
],
copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"],
deps = [
":passes",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:IPO",
"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:InstCombine",
"@llvm-project//llvm:Linker",
"@llvm-project//llvm:MC",
"@llvm-project//llvm:Passes",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
"@llvm-project//mlir:BytecodeWriter",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:ConversionPasses",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:IndexDialect",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMIRTransforms",
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:ToLLVMIRTranslation",
"@llvm-project//mlir:Transforms",
"//:TritonAnalysis",
"//:TritonDialects",
"//:TritonGPUToLLVM",
"//:TritonGPUTransforms",
"//:TritonHSACO",
"//:TritonLLVMIR",
"//:TritonNvidiaGPUTransforms",
"//:TritonPTX",
"//:TritonToTritonGPU",
"//:TritonTools",
"//:TritonTransforms",
"@triton//third_party/nvidia:triton_nvidia",
],
)

filegroup(
name = "files",
srcs = glob(
include = ["triton/**/*.py"],
),
)
26 changes: 26 additions & 0 deletions python/test/regression/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests")

package(
default_applicable_licenses = ["//:license"],
)

pytest_multi_tests(
name = "tests",
size = "large",
srcs = ["conftest.py"],
shard_count = 10,
tags = [
"config-cuda-only",
"requires-gpu-sm80",
],
tests = glob(
include = ["test_*.py"],
exclude = [
"test_performance.py", #TODO(b/321005767): fix failing test
],
),
deps = [
"//third_party/py/torch:pytorch",
"//third_party/py/triton",
],
)
12 changes: 12 additions & 0 deletions python/test/regression/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# content of conftest.py

import pytest


def pytest_addoption(parser):
parser.addoption("--device", action="store", default='cuda')


@pytest.fixture
def device(request):
return request.config.getoption("--device")
Loading

0 comments on commit da8895b

Please sign in to comment.