Skip to content

Commit

Permalink
[torch][quant] Support quantize and dequantize for torch (#2731)
Browse files Browse the repository at this point in the history
Handle both `torch.dequantize` and `torch.quantize_per_tensor` including
the op based quantization parameter tracking. This includes adding
`qint32` to torch types as it was missing during the initial type
inclusion.

For testing we only have `torch.int8` and `torch.float` types on
function boundaries as the `qint8` types require passing the scale
and zero point quantization information which is not supported yet.
  • Loading branch information
rsuderman authored Jan 13, 2024
1 parent c7452af commit dc37616
Show file tree
Hide file tree
Showing 13 changed files with 496 additions and 8 deletions.
2 changes: 2 additions & 0 deletions include/torch-mlir/Conversion/TorchToLinalg/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ FailureOr<Type>
getBackendTypeForScalarType(MLIRContext *context,
torch_upstream::ScalarType dtypeInt);

bool isUnsignedTorchType(Type type);

} // namespace torch_to_linalg
} // namespace torch
} // namespace mlir
120 changes: 120 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14206,6 +14206,126 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
}];
}

def Torch_AtenQuantizePerTensorOp : Torch_Op<"aten.quantize_per_tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_FloatType:$scale,
Torch_IntType:$zero_point,
Torch_IntType:$dtype
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenQuantizePerTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenQuantizePerTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenDequantizeSelfOp : Torch_Op<"aten.dequantize.self", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::dequantize.self : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDequantizeSelfOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenDequantizeSelfOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenDequantizeTensorOp : Torch_Op<"aten.dequantize.tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::dequantize.tensor : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$qtensor
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDequantizeTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenDequantizeTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenIntReprOp : Torch_Op<"aten.int_repr", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::int_repr : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIntReprOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenIntReprOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_Aten_MakePerTensorQuantizedTensorOp : Torch_Op<"aten._make_per_tensor_quantized_tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_FloatType:$scale,
Torch_IntType:$zero_point
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_MakePerTensorQuantizedTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void Aten_MakePerTensorQuantizedTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
11 changes: 11 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/TorchTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,17 @@ def Torch_QUInt8Type : Torch_Type<"QUInt8", "quint8"> {
}];
}

def Torch_QInt32Type : Torch_Type<"QInt32", "qint32"> {
let summary = "Type modeling `ScalarType::QInt32`";
let description = [{
This is intended to be a 1:1 match for the Torch `ScalarType` types.

Looking at the variety / ad-hocness (e.g. `QUInt4x2`) of that set of
types, it is deemed preferable to import them as one-off ad-hoc types
instead of a single parameterized type.
}];
}

def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> {
let summary = "Torch packed linear params type";
let description = [{
Expand Down
153 changes: 147 additions & 6 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,106 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::XOrIOp>(loc, payloadArgs[0], allOnesVal);
}

if (isa<AtenDequantizeTensorOp, AtenDequantizeSelfOp>(op)) {
auto value = payloadArgs[0];
auto valueTy = value.getType();
auto qtensor = op->getOperand(0);
auto qtensorTy = qtensor.getType().cast<ValueTensorType>().getDtype();
auto makeQTensor =
qtensor.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
if (!makeQTensor) {
op->emitError(
"unimplemented: dequantizing tensor of unknown scale / zero-point");
return nullptr;
}

auto outFpTy = payloadArgs[1].getType();
auto outBw = outFpTy.getIntOrFloatBitWidth();
auto outIntTy = b.getIntegerType(outBw);

if (valueTy != outIntTy) {
if (torch_to_linalg::isUnsignedTorchType(qtensorTy)) {
value = b.create<arith::ExtUIOp>(loc, outIntTy, value);
} else {
value = b.create<arith::ExtSIOp>(loc, outIntTy, value);
}
}

Value zp = makeQTensor.getZeroPoint();
zp = converter->materializeTargetConversion(
b, loc, converter->convertType(zp.getType()),
makeQTensor.getZeroPoint());
auto zpTy = zp.getType();

if (zpTy != outIntTy) {
zp = b.create<arith::TruncIOp>(loc, outIntTy, zp);
}

value = b.create<arith::SubIOp>(loc, value, zp);

if (torch_to_linalg::isUnsignedTorchType(qtensorTy)) {
value = b.create<arith::UIToFPOp>(loc, outFpTy, value);
} else {
value = b.create<arith::SIToFPOp>(loc, outFpTy, value);
}

Value scale = makeQTensor.getScale();
scale = converter->materializeTargetConversion(
b, loc, converter->convertType(scale.getType()),
makeQTensor.getScale());
if (scale.getType() != value.getType()) {
scale = b.create<arith::TruncFOp>(loc, value.getType(), scale);
}
value = b.create<arith::MulFOp>(loc, value, scale);
return value;
}

if (auto quant = dyn_cast<AtenQuantizePerTensorOp>(op)) {
Value value = payloadArgs[0];
Value scale = quant.getScale();
Value zp = quant.getZeroPoint();
auto valueTy = value.getType();

zp = converter->materializeTargetConversion(
b, loc, converter->convertType(zp.getType()), zp);
zp = b.create<arith::SIToFPOp>(loc, valueTy, zp);

scale = converter->materializeTargetConversion(
b, loc, converter->convertType(scale.getType()), scale);
scale = b.create<arith::TruncFOp>(loc, valueTy, scale);

value = b.create<arith::DivFOp>(loc, value, scale);
value = b.create<math::RoundOp>(loc, value);
value = b.create<arith::AddFOp>(loc, value, zp);

auto destTy = payloadArgs[1].getType();
auto bitwidth = destTy.getIntOrFloatBitWidth();
bool isUnsigned = torch_to_linalg::isUnsignedTorchType(quant.getType());
APInt min = isUnsigned ? APInt::getMinValue(bitwidth)
: APInt::getSignedMinValue(bitwidth);
APInt max = isUnsigned ? APInt::getMaxValue(bitwidth)
: APInt::getSignedMaxValue(bitwidth);

Value minVal = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(valueTy, min.getSExtValue()));
Value maxVal = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(valueTy, max.getSExtValue()));
Value minCmp =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, value, minVal);
Value maxCmp =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, value, maxVal);
value = b.create<arith::SelectOp>(loc, minCmp, minVal, value);
value = b.create<arith::SelectOp>(loc, maxCmp, maxVal, value);

if (isUnsigned) {
value = b.create<arith::FPToUIOp>(loc, destTy, value);
} else {
value = b.create<arith::FPToSIOp>(loc, destTy, value);
}

return value;
}

op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForElementwiseOp");
return nullptr;
Expand Down Expand Up @@ -1368,9 +1468,10 @@ class ConvertElementwiseOp : public ConversionPattern {
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp,
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp>(op))
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp,
AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
Expand Down Expand Up @@ -2080,6 +2181,42 @@ class ConvertLogitOp : public OpConversionPattern<AtenLogitOp> {
}
};
} // namespace

namespace {
class ConvertAtenIntReprOp : public OpConversionPattern<AtenIntReprOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenIntReprOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
adaptor.getSelf());
return success();
}
};
} // namespace

namespace {
class ConvertMakePerTensorQuantizedTensorOp
: public OpConversionPattern<Aten_MakePerTensorQuantizedTensorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(Aten_MakePerTensorQuantizedTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
adaptor.getSelf());
return success();
}
};
} // namespace

void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
Expand All @@ -2102,9 +2239,9 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
AtenTrilOp,
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenRealOp, AtenImagOp>();
AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp,
AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, AtenImagOp,
AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
Expand All @@ -2122,4 +2259,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
patterns.add<ConvertTensorStaticInfoCastOp>(typeConverter, context);
target.addIllegalOp<TensorStaticInfoCastOp>();
patterns.add<ConvertAtenIntReprOp>(typeConverter, context);
target.addIllegalOp<AtenIntReprOp>();
patterns.add<ConvertMakePerTensorQuantizedTensorOp>(typeConverter, context);
target.addIllegalOp<Aten_MakePerTensorQuantizedTensorOp>();
}
17 changes: 17 additions & 0 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,3 +559,20 @@ FailureOr<Type> torch_to_linalg::getBackendTypeForScalarType(
}
return type;
}

bool torch_to_linalg::isUnsignedTorchType(Type type) {
if (auto tty = dyn_cast<ValueTensorType>(type))
return isUnsignedTorchType(tty.getDtype());
if (isa<mlir::FloatType>(type))
return false;
if (isa<QInt8Type>(type))
return false;
if (isa<QUInt8Type>(type))
return true;
if (isa<QInt32Type>(type))
return false;
if (auto intTy = dyn_cast<IntegerType>(type))
return intTy.isUnsigned();
llvm_unreachable("Unknown type checked for signedness");
return false;
}
12 changes: 11 additions & 1 deletion lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ static bool isValidTorchDtype(Type dtype) {
dtype = dtype.cast<ComplexType>().getElementType();
}
// Torch quantized types.
if (dtype.isa<Torch::QInt8Type, Torch::QUInt8Type>())
if (dtype.isa<Torch::QInt8Type, Torch::QUInt8Type, Torch::QInt32Type>())
return true;
// Builtin floating point types.
if (dtype.isa<Float16Type, BFloat16Type, Float32Type, Float64Type>())
Expand Down Expand Up @@ -410,6 +410,16 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
} else if (dtype.isa<mlir::ComplexType>()){
return dtype;
}

if (isa<QUInt8Type>(dtype))
return IntegerType::get(context, 8, IntegerType::Signless);

if (isa<QInt8Type>(dtype))
return IntegerType::get(context, 8, IntegerType::Signless);

if (isa<QInt32Type>(dtype))
return IntegerType::get(context, 32, IntegerType::Signless);

emitError(UnknownLoc::get(context))
<< "unimplemented: conversion of dtype " << dtype
<< " to builtin tensor element type";
Expand Down
Loading

0 comments on commit dc37616

Please sign in to comment.