Skip to content

Commit

Permalink
Support NMS op lowering (#3871)
Browse files Browse the repository at this point in the history
TODO: support multiple batches and classes
  • Loading branch information
jinchen62 authored Nov 27, 2024
1 parent 7452460 commit c9ed993
Show file tree
Hide file tree
Showing 3 changed files with 317 additions and 44 deletions.
67 changes: 35 additions & 32 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3704,9 +3704,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
"attribute value to be 0");

// TODO: Add support for optional arguments to be absent.
if (operands.size() != 5)
if (operands.size() < 4)
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: expected all 5 args to be present");
binder.op, "unimplemented: expected at least 4 arguments");

// Squeeze the boxes and scores tensor.
// In Onnx, the shape of boxes is [BxNx4] while the
Expand Down Expand Up @@ -3734,31 +3734,38 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
boxes = squeezedBoxes.value();
scores = squeezedScores.value();

// TODO: Add support for handling score_threshold arg.
// If score_threshold > min(scores) then the op can't be lowered since
// the torchvision::nms op doesn't have support for handling the
// score_threshold arg.
Value scoreThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[4]);
Value minScores = rewriter.create<Torch::AtenMinOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(), {},
rewriter.getF32Type()),
scores);
minScores = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), minScores);

Value scoresCond = rewriter.create<Torch::AtenGeFloatOp>(
binder.getLoc(), minScores, scoreThreshold);
rewriter.create<Torch::RuntimeAssertOp>(
binder.getLoc(), scoresCond,
rewriter.getStringAttr(
"unimplemented: score_threshold should be <= min(scores)"));

// TODO: Support score_threshold input
// Filter out the boxes if the score < score_threshold
if (operands.size() == 5) {
Value scoreThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
operands[4]);
Value minScores = rewriter.create<Torch::AtenMinOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
SmallVector<int64_t>{},
rewriter.getF32Type()),
scores);
minScores = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), minScores);

Value scoresCond = rewriter.create<Torch::AtenGeFloatOp>(
binder.getLoc(), minScores, scoreThreshold);
rewriter.create<Torch::RuntimeAssertOp>(
binder.getLoc(), scoresCond,
rewriter.getStringAttr(
"unimplemented: score_threshold should be <= min(scores)"));
}

// TODO: Support default iou_threshold
Value iouThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]);
auto nmsTy = Torch::ValueTensorType::get(
binder.op->getContext(),
SmallVector<int64_t>{resultType.getSizes()[0]},
rewriter.getIntegerType(64, /*signed=*/true));
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
binder.getLoc(), resultType, boxes, scores, iouThreshold);
binder.getLoc(), nmsTy, boxes, scores, iouThreshold);

// The result generated by torchvision.nms op is of shape [n], while the
// onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
Expand Down Expand Up @@ -3805,14 +3812,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
binder.op->getLoc(), listType, SmallVector<Value>{result, zeros});

// TODO: Add support for handling max_output_boxes_per_class arg.
// If numOutputBoxes (N) > max_output_boxes_per_class then the op can't
// be lowered since the torchvision::nms op doesn't have support for
// handling the max_output_boxes_per_class arg. Also, we have already
// constrained the number of classes to be 1 above, so the number of
// output boxes inferred from the result is num_output_boxes_per_class.
binder.getLoc(), listType, SmallVector<Value>{zeros, result});

// TODO: Support max_output_boxes_per_class input
// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
Value maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), operands[2]);
Value boxesCond = rewriter.create<Torch::AtenLeIntOp>(
Expand Down
270 changes: 270 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10684,6 +10684,273 @@ class DecomposeAtenFloatPowerTensorTensorOp
};
} // namespace

namespace {
class DecomposeTorchvisionNmsOp : public OpRewritePattern<TorchvisionNmsOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TorchvisionNmsOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
Value boxes = op.getDets();
Value scores = op.getScores();
Value iouThreshold = op.getIouThreshold();

Value cst0 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value cst2 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(2));
Value cst4 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(4));
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstTrue =
rewriter.create<Torch::ConstantBoolOp>(loc, rewriter.getBoolAttr(true));
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(false));

// Get number of boxes for the loop count
auto boxesTensorType = dyn_cast<Torch::ValueTensorType>(boxes.getType());
auto dType = boxesTensorType.getDtype();
int64_t boxesSize = boxesTensorType.getSizes()[0];
Value len = rewriter.create<AtenSizeIntOp>(loc, boxes, /*dim=*/cst0);

// Calculate the area of each box: (x2 - x1) * (y2 - y1)
auto sliceTy = rewriter.getType<ValueTensorType>(
SmallVector<int64_t>{boxesSize, 2}, dType);
Value lowSlice = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, boxes,
/*dim=*/cst1, /*start=*/cst0, /*end=*/cst2, /*step=*/cst1);
Value highSlice = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, boxes,
/*dim=*/cst1, /*start=*/cst2, /*end=*/cst4, /*step=*/cst1);
Value distance = rewriter.create<Torch::AtenSubTensorOp>(
loc, sliceTy, highSlice, lowSlice, cst1);
auto areaTy = rewriter.getType<ValueTensorType>(
SmallVector<int64_t>{boxesSize}, dType);
Value area = rewriter.create<Torch::AtenProdDimIntOp>(
loc, areaTy, distance, /*dim=*/cst1, /*keepdim=*/cstFalse,
/*dtype=*/cstNone);

// Sort scores in descending order
// Use the sorted indices to iterate boxes
auto scoresType = dyn_cast<BaseTensorType>(scores.getType());
auto intTensorType = scoresType.getWithSizesAndDtype(
scoresType.getOptionalSizes(),
IntegerType::get(context, 64, IntegerType::Signed));
auto sortResult = rewriter.create<Torch::AtenSortOp>(
loc, TypeRange({scores.getType(), intTensorType}), scores,
/*dim=*/cst0, /*descending=*/cstTrue);

// Create a mask to mark if we keep the boxes
Value lenShapeList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
SmallVector<Value>{len});
Value mask = rewriter.create<Torch::AtenOnesOp>(
loc, intTensorType, lenShapeList, cstNone, cstNone, cstNone, cstNone);
Value zeroShapeList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
SmallVector<Value>{cst1});
auto zeroTy = rewriter.getType<ValueTensorType>(
SmallVector<int64_t>{1}, rewriter.getIntegerType(64, /*signed=*/true));
Value falseMask = rewriter.create<Torch::AtenZerosOp>(
loc, zeroTy, zeroShapeList, cstNone, cstNone, cstNone, cstNone);

// Create an empty tensor for result
Value result = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
loc, intTensorType, lenShapeList, /*dtype=*/cst4, /*layout=*/cstNone,
/*device=*/cstNone, /*pinMemory=*/cstNone, /*memoryFormat=*/cstNone);

auto intTy = rewriter.getType<Torch::IntType>();
auto rowSliceTy =
rewriter.getType<ValueTensorType>(SmallVector<int64_t>{1, 4}, dType);
auto pointTy =
rewriter.getType<ValueTensorType>(SmallVector<int64_t>{1, 2}, dType);
auto extractTy = rewriter.getType<ValueTensorType>(
SmallVector<int64_t>{1}, rewriter.getIntegerType(64, true));
Value float0 = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getFloatAttr(dType, 0.0));
auto scalarFloatType = rewriter.getType<Torch::ValueTensorType>(
SmallVector<int64_t>{1}, dType);
Value float0Tensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
loc, scalarFloatType, float0);

// 1. Loop through the boxes based on sorted indices
// 2. Add the current box to result if it's not suppressed
// 3. Calculate the IoUs with all boxes
// 4. Loop through the rest boxes in sorted indices
// 5. Suppress the box if the corresponding IoU is larger than threshold
auto loop1 = rewriter.create<Torch::PrimLoopOp>(
loc, TypeRange({intTensorType, intTensorType, intTy}), len, cstTrue,
ValueRange({mask, result, cst0}));
{
PatternRewriter::InsertionGuard guard(rewriter);
Block *loopBody1 = rewriter.createBlock(
&loop1.getRegion(), loop1.getRegion().begin(),
TypeRange({intTy, intTensorType, intTensorType, intTy}),
{loc, loc, loc, loc});
Value i = loopBody1->getArgument(0);
Value mask1 = loopBody1->getArgument(1);
Value curResult = loopBody1->getArgument(2);
Value curCnt = loopBody1->getArgument(3);

// Extract the mask to check if the base box is suppressed
Value extract = rewriter.create<AtenSelectIntOp>(
loc, extractTy, mask1, /*dim=*/cst0, /*index=*/i);
Value scalar = rewriter.create<Torch::AtenItemOp>(loc, intTy, extract);
Value iskept = rewriter.create<Torch::AtenBoolIntOp>(
loc, rewriter.getType<Torch::BoolType>(), scalar);
auto ifFilterOthers = rewriter.create<Torch::PrimIfOp>(
loc, TypeRange({intTensorType, intTensorType, intTy}), iskept);
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifFilterOthers.getThenRegion(),
ifFilterOthers.getThenRegion().begin());

// Scatter the selected indices into result
Value extractIdx1 = rewriter.create<AtenSelectIntOp>(
loc, extractTy, sortResult.getResults()[1], /*dim=*/cst0,
/*index=*/i);
Value next = rewriter.create<Torch::AtenAddIntOp>(loc, curCnt, cst1);
Value updatedResult = rewriter.create<Torch::AtenSliceScatterOp>(
loc, intTensorType, curResult, extractIdx1, /*dim=*/cst0,
/*start=*/curCnt, /*end=*/next, /*step=*/cst1);

// Get the coordinates of base box
Value idx1 =
rewriter.create<Torch::AtenItemOp>(loc, intTy, extractIdx1);
Value idx1End = rewriter.create<Torch::AtenAddIntOp>(loc, idx1, cst1);
Value curBox = rewriter.create<AtenSliceTensorOp>(
loc, rowSliceTy, boxes,
/*dim=*/cst0, /*start=*/idx1, /*end=*/idx1End, /*step=*/cst1);

// Calculate IoUs: intersectionArea / unionArea
// Intersection area = intersectionWidth * intersectionHeight
Value point1 = rewriter.create<AtenSliceTensorOp>(
loc, pointTy, curBox,
/*dim=*/cst1, /*start=*/cst0, /*end=*/cst2, /*step=*/cst1);
Value point2 = rewriter.create<AtenSliceTensorOp>(
loc, pointTy, curBox,
/*dim=*/cst1, /*start=*/cst2, /*end=*/cst4, /*step=*/cst1);
Value innerLow = rewriter.create<Torch::AtenMaximumOp>(
loc, sliceTy, lowSlice, point1);
Value innerHigh = rewriter.create<Torch::AtenMinimumOp>(
loc, sliceTy, highSlice, point2);
Value innerDistance = rewriter.create<Torch::AtenSubTensorOp>(
loc, sliceTy, innerHigh, innerLow, cst1);
innerDistance = rewriter.create<Torch::AtenMaximumOp>(
loc, sliceTy, innerDistance, float0Tensor);
Value intersectionArea = rewriter.create<Torch::AtenProdDimIntOp>(
loc, areaTy, innerDistance, /*dim=*/cst1, /*keepdim=*/cstFalse,
/*dtype=*/cstNone);
Value iEnd = rewriter.create<Torch::AtenAddIntOp>(loc, i, cst1);
Value curArea = rewriter.create<AtenSliceTensorOp>(
loc, scalarFloatType, area,
/*dim=*/cst0, /*start=*/i, /*end=*/iEnd, /*step=*/cst1);
// Union area = area1 + area2 - intersectionArea
Value unionArea = rewriter.create<Torch::AtenAddTensorOp>(
loc, areaTy, area, curArea, cst1);
unionArea = rewriter.create<Torch::AtenSubTensorOp>(
loc, areaTy, unionArea, intersectionArea, cst1);
Value iou = rewriter.create<Torch::AtenDivTensorOp>(
loc, areaTy, intersectionArea, unionArea);

// Loop through the rest of boxes in sorted indices
auto loop2 = rewriter.create<Torch::PrimLoopOp>(loc, intTensorType, len,
cstTrue, mask1);
{
PatternRewriter::InsertionGuard guard(rewriter);
Block *loopBody2 = rewriter.createBlock(
&loop2.getRegion(), loop2.getRegion().begin(),
TypeRange({intTy, intTensorType}), {loc, loc});
Value j = loopBody2->getArgument(0);
Value mask2 = loopBody2->getArgument(1);

// Check if current index is out of range
j = rewriter.create<Torch::AtenAddIntOp>(loc, j, i);
j = rewriter.create<Torch::AtenAddIntOp>(loc, j, cst1);
Value isInRange = rewriter.create<Torch::AtenLtIntOp>(loc, j, len);
auto ifCalculateIou = rewriter.create<Torch::PrimIfOp>(
loc, TypeRange({intTensorType}), isInRange);
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifCalculateIou.getThenRegion(),
ifCalculateIou.getThenRegion().begin());

// Retrieve IoU and check if suppress the box
Value extractIdx2 = rewriter.create<AtenSelectIntOp>(
loc, extractTy, sortResult.getResults()[1], /*dim=*/cst0,
/*index=*/j);
Value idx2 =
rewriter.create<Torch::AtenItemOp>(loc, intTy, extractIdx2);
Value idx2End =
rewriter.create<Torch::AtenAddIntOp>(loc, idx2, cst1);
Value curIoU = rewriter.create<AtenSliceTensorOp>(
loc, scalarFloatType, iou,
/*dim=*/cst0, /*start=*/idx2, /*end=*/idx2End, /*step=*/cst1);
curIoU = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::FloatType>(), curIoU);
Value isSuppressed = rewriter.create<Torch::AtenGtFloatOp>(
loc, curIoU, iouThreshold);

auto ifUnmask = rewriter.create<Torch::PrimIfOp>(
loc, TypeRange({intTensorType}), isSuppressed);
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifUnmask.getThenRegion(),
ifUnmask.getThenRegion().begin());

// Update the mask if suppress
Value jEnd = rewriter.create<Torch::AtenAddIntOp>(loc, j, cst1);
Value updatedMask = rewriter.create<Torch::AtenSliceScatterOp>(
loc, intTensorType, mask2, falseMask, /*dim=*/cst0,
/*start=*/j, /*end=*/jEnd, /*step=*/cst1);
rewriter.create<Torch::PrimIfYieldOp>(loc, updatedMask);
}
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifUnmask.getElseRegion(),
ifUnmask.getElseRegion().begin());
rewriter.create<Torch::PrimIfYieldOp>(loc, mask2);
}

rewriter.create<Torch::PrimIfYieldOp>(loc, ifUnmask.getResult(0));
}
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifCalculateIou.getElseRegion(),
ifCalculateIou.getElseRegion().begin());
rewriter.create<Torch::PrimIfYieldOp>(loc, mask2);
}

rewriter.create<Torch::PrimLoopConditionOp>(
loc, cstTrue, ifCalculateIou.getResult(0));
}

rewriter.create<Torch::PrimIfYieldOp>(
loc, ValueRange({loop2.getResult(0), updatedResult, next}));
}
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifFilterOthers.getElseRegion(),
ifFilterOthers.getElseRegion().begin());
rewriter.create<Torch::PrimIfYieldOp>(
loc, ValueRange({mask1, curResult, curCnt}));
}

rewriter.create<Torch::PrimLoopConditionOp>(loc, cstTrue,
ifFilterOthers.getResults());
}

rewriter.replaceOpWithNewOp<AtenSliceTensorOp>(
op, op.getType(), loop1.getResult(1), /*dim=*/cst0, /*start=*/cst0,
/*end=*/loop1.getResult(2), /*step=*/cst1);
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -10968,6 +11235,9 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<
DecomposeAtenFMaxMinOp<AtenFminOp, AtenMinimumOp>>(patterns);

// Torchvision ops
addPatternIfTargetOpIsIllegal<DecomposeTorchvisionNmsOp>(patterns);

GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.maxIterations = GreedyRewriteConfig::kNoLimit;
Expand Down
Loading

0 comments on commit c9ed993

Please sign in to comment.