Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Calyx] Lower Arith CmpFOp to Calyx #7860

Merged
merged 14 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions include/circt/Dialect/Calyx/CalyxLoweringUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,28 @@ class BuildCallInstance : public calyx::FuncOpPartialLoweringPattern {
ComponentOp getCallComponent(mlir::func::CallOp callOp) const;
jiahanxie353 marked this conversation as resolved.
Show resolved Hide resolved
};

/// Predicate information for the floating point comparisons
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the description in the initial comment to be more declarative rather than "This patch tries to...", e.g.,

Adds floating points comparison op with the Berkeley HardFloat interface.

Also, if the issue below it is no longer relevant, I'd remove that too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I'll change the final squash commit message to be more informative as you suggested

struct PredicateInfo {
struct InputPorts {
// Relevant ports to extract from the `std_compareFN`. For example, we
// extract the `lt` and the `unordered` ports when the predicate is `oge`.
enum class Port { Eq, Gt, Lt, Unordered };
Port port;
// Whether we should invert the port before passing as inputs to the `op`
// field. For example, we should invert both the `lt` and the `unordered`
// port just extracted for predicate `oge`.
bool invert;
};

// The combinational logic to apply to the input ports. For example, we should
// apply `And` to the two input ports for predicate `oge`.
enum class CombLogic { None, And, Or };
CombLogic logic;
SmallVector<InputPorts> inputPorts;
};

PredicateInfo getPredicateInfo(mlir::arith::CmpFPredicate pred);

} // namespace calyx
} // namespace circt

Expand Down
53 changes: 47 additions & 6 deletions include/circt/Dialect/Calyx/CalyxPrimitives.td
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,15 @@ def AndLibOp : CombinationalArithBinaryLibraryOp<"and"> {}
def OrLibOp : CombinationalArithBinaryLibraryOp<"or"> {}
def XorLibOp : CombinationalArithBinaryLibraryOp<"xor"> {}

class ArithBinaryFloatingPointLibraryOp<string mnemonic> :
ArithBinaryLibraryOp<mnemonic, "", [
class ArithBinaryFloatingPointLibraryOp<string mnemonic, list<Trait> traits = []> :
ArithBinaryLibraryOp<mnemonic, "", !listconcat(traits, [
DeclareOpInterfaceMethods<FloatingPointOpInterface>,
SameTypeConstraint<"left", "out">
]> {}
SameTypeConstraint<"left", "right">
])> {}

def AddFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.add"> {
def AddFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.add", [
SameTypeConstraint<"left", "out">
]> {
let results = (outs I1:$clk, I1:$reset, I1:$go, I1:$control, I1:$subOp,
AnySignlessInteger:$left, AnySignlessInteger:$right, AnySignlessInteger:$roundingMode, AnySignlessInteger:$out,
AnySignlessInteger:$exceptionalFlags, I1:$done);
Expand Down Expand Up @@ -378,7 +380,9 @@ def AddFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.add"> {
}];
}

def MulFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.mul"> {
def MulFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.mul", [
SameTypeConstraint<"left", "out">
]> {
let results = (outs I1:$clk, I1:$reset, I1:$go, I1:$control,
AnySignlessInteger:$left, AnySignlessInteger:$right, AnySignlessInteger:$roundingMode, AnySignlessInteger:$out,
AnySignlessInteger:$exceptionalFlags, I1:$done);
Expand Down Expand Up @@ -413,6 +417,43 @@ def MulFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.mul"> {
}];
}

// This models the compare operation interface in Berkeley HardFloat implementation.
def CompareFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.compare", []> {
jiahanxie353 marked this conversation as resolved.
Show resolved Hide resolved
let results = (outs I1:$clk, I1:$reset, I1:$go,
AnySignlessInteger:$left, AnySignlessInteger:$right, I1:$signaling,
I1:$lt, I1: $eq, I1: $gt, I1: $unordered, AnySignlessInteger: $exceptionalFlags, I1: $done);
let assemblyFormat = "$sym_name attr-dict `:` qualified(type(results))";
let extraClassDefinition = [{
SmallVector<StringRef> $cppClass::portNames() {
return {clkPort, resetPort, goPort, "left", "right", "signaling",
"lt", "eq", "gt", "unordered", "exceptionalFlags", donePort
};
}
SmallVector<Direction> $cppClass::portDirections() {
return {Input, Input, Input, Input, Input, Input, Output, Output, Output, Output, Output, Output};
}
void $cppClass::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
getCellAsmResultNames(setNameFn, *this, this->portNames());
}
bool $cppClass::isCombinational() { return false; }
SmallVector<DictionaryAttr> $cppClass::portAttributes() {
IntegerAttr isSet = IntegerAttr::get(IntegerType::get(getContext(), 1), 1);
NamedAttrList go, clk, reset, done;
go.append(goPort, isSet);
clk.append(clkPort, isSet);
reset.append(resetPort, isSet);
done.append(donePort, isSet);
return {clk.getDictionary(getContext()), reset.getDictionary(getContext()),
go.getDictionary(getContext()), DictionaryAttr::get(getContext()),
DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext()),
DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext()),
DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext()),
DictionaryAttr::get(getContext()), done.getDictionary(getContext())
};
}
}];
}

def MuxLibOp : CalyxLibraryOp<"mux", "std_", [
Combinational, SameTypeConstraint<"tru", "fal">, SameTypeConstraint<"tru", "out">
]> {
Expand Down
209 changes: 207 additions & 2 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp,
MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp,
/// floating point
AddFOp, MulFOp,
AddFOp, MulFOp, CmpFOp,
/// others
SelectOp, IndexCastOp, CallOp>(
[&](auto op) { return buildOp(rewriter, op).succeeded(); })
Expand Down Expand Up @@ -326,6 +326,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
LogicalResult buildOp(PatternRewriter &rewriter, RemSIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, AddFOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, MulFOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, CmpFOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, ShRUIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, ShRSIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, ShLIOp op) const;
Expand Down Expand Up @@ -502,6 +503,33 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
address.value());
}
}

calyx::RegisterOp createSignalRegister(PatternRewriter &rewriter,
Value signal, bool invert,
StringRef nameSuffix,
calyx::CompareFOpIEEE754 calyxCmpFOp,
calyx::GroupOp group) const {
Location loc = calyxCmpFOp.getLoc();
IntegerType one = rewriter.getI1Type();
auto component = getComponent();
OpBuilder builder(group->getRegion(0));
auto reg = createRegister(
loc, rewriter, component, 1,
getState<ComponentLoweringState>().getUniqueName(nameSuffix));
rewriter.create<calyx::AssignOp>(loc, reg.getWriteEn(),
calyxCmpFOp.getDone());
if (invert) {
auto notLibOp = getState<ComponentLoweringState>()
.getNewLibraryOpInstance<calyx::NotLibOp>(
rewriter, loc, {one, one});
rewriter.create<calyx::AssignOp>(loc, notLibOp.getIn(), signal);
rewriter.create<calyx::AssignOp>(loc, reg.getIn(), notLibOp.getOut());
getState<ComponentLoweringState>().registerEvaluatingGroup(
notLibOp.getOut(), group);
} else
rewriter.create<calyx::AssignOp>(loc, reg.getIn(), signal);
return reg;
};
};

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
Expand Down Expand Up @@ -729,6 +757,183 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
mulFOp.getOut());
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
CmpFOp cmpf) const {
Location loc = cmpf.getLoc();
IntegerType one = rewriter.getI1Type(), five = rewriter.getIntegerType(5),
width = rewriter.getIntegerType(
cmpf.getLhs().getType().getIntOrFloatBitWidth());
auto calyxCmpFOp = getState<ComponentLoweringState>()
.getNewLibraryOpInstance<calyx::CompareFOpIEEE754>(
rewriter, loc,
{one, one, one, width, width, one, one, one, one,
one, five, one});
hw::ConstantOp c0 = createConstant(loc, rewriter, getComponent(), 1, 0);
hw::ConstantOp c1 = createConstant(loc, rewriter, getComponent(), 1, 1);
rewriter.setInsertionPointToStart(getComponent().getBodyBlock());

using calyx::PredicateInfo;
using CombLogic = PredicateInfo::CombLogic;
using Port = PredicateInfo::InputPorts::Port;
PredicateInfo info = calyx::getPredicateInfo(cmpf.getPredicate());
if (info.logic == CombLogic::None) {
if (cmpf.getPredicate() == CmpFPredicate::AlwaysTrue) {
cmpf.getResult().replaceAllUsesWith(c1);
return success();
}

if (cmpf.getPredicate() == CmpFPredicate::AlwaysFalse) {
cmpf.getResult().replaceAllUsesWith(c0);
return success();
}
}

// General case
StringRef opName = cmpf.getOperationName().split(".").second;
jiahanxie353 marked this conversation as resolved.
Show resolved Hide resolved
auto reg =
createRegister(loc, rewriter, getComponent(), 1,
getState<ComponentLoweringState>().getUniqueName(opName));

// Operation pipelines are not combinational, so a GroupOp is required.
auto group = createGroupForOp<calyx::GroupOp>(rewriter, cmpf);
OpBuilder builder(group->getRegion(0));
getState<ComponentLoweringState>().addBlockScheduleable(cmpf->getBlock(),
group);

rewriter.setInsertionPointToEnd(group.getBodyBlock());
rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getLeft(), cmpf.getLhs());
rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getRight(), cmpf.getRhs());

bool signalingFlag = false;
switch (cmpf.getPredicate()) {
case CmpFPredicate::UGT:
case CmpFPredicate::UGE:
case CmpFPredicate::ULT:
case CmpFPredicate::ULE:
case CmpFPredicate::OGT:
case CmpFPredicate::OGE:
case CmpFPredicate::OLT:
case CmpFPredicate::OLE:
signalingFlag = true;
break;
case CmpFPredicate::UEQ:
case CmpFPredicate::UNE:
case CmpFPredicate::OEQ:
case CmpFPredicate::ONE:
case CmpFPredicate::UNO:
case CmpFPredicate::ORD:
case CmpFPredicate::AlwaysTrue:
case CmpFPredicate::AlwaysFalse:
signalingFlag = false;
break;
}

// The IEEE Standard mandates that equality comparisons ordinarily are quiet,
// while inequality comparisons ordinarily are signaling.
rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getSignaling(),
signalingFlag ? c1 : c0);

// Prepare signals and create registers
SmallVector<calyx::RegisterOp> inputRegs;
for (const auto &input : info.inputPorts) {
Value signal;
switch (input.port) {
case Port::Eq: {
signal = calyxCmpFOp.getEq();
break;
}
case Port::Gt: {
signal = calyxCmpFOp.getGt();
break;
}
case Port::Lt: {
signal = calyxCmpFOp.getLt();
break;
}
case Port::Unordered: {
signal = calyxCmpFOp.getUnordered();
break;
}
}
std::string nameSuffix =
(input.port == PredicateInfo::InputPorts::Port::Unordered)
? "unordered_port"
: "compare_port";
auto signalReg = createSignalRegister(rewriter, signal, input.invert,
nameSuffix, calyxCmpFOp, group);
inputRegs.push_back(signalReg);
}

// Create the output logical operation
Value outputValue, doneValue;
switch (info.logic) {
case CombLogic::None: {
// it's guaranteed to be either ORD or UNO
outputValue = inputRegs[0].getOut();
doneValue = inputRegs[0].getOut();
break;
}
case CombLogic::And: {
auto outputLibOp = getState<ComponentLoweringState>()
.getNewLibraryOpInstance<calyx::AndLibOp>(
rewriter, loc, {one, one, one});
rewriter.create<calyx::AssignOp>(loc, outputLibOp.getLeft(),
inputRegs[0].getOut());
rewriter.create<calyx::AssignOp>(loc, outputLibOp.getRight(),
inputRegs[1].getOut());

outputValue = outputLibOp.getOut();
break;
}
case CombLogic::Or: {
auto outputLibOp = getState<ComponentLoweringState>()
.getNewLibraryOpInstance<calyx::OrLibOp>(
rewriter, loc, {one, one, one});
rewriter.create<calyx::AssignOp>(loc, outputLibOp.getLeft(),
inputRegs[0].getOut());
rewriter.create<calyx::AssignOp>(loc, outputLibOp.getRight(),
inputRegs[1].getOut());

outputValue = outputLibOp.getOut();
jiahanxie353 marked this conversation as resolved.
Show resolved Hide resolved
break;
}
}

if (info.logic != CombLogic::None) {
auto doneLibOp = getState<ComponentLoweringState>()
.getNewLibraryOpInstance<calyx::AndLibOp>(
rewriter, loc, {one, one, one});
rewriter.create<calyx::AssignOp>(loc, doneLibOp.getLeft(),
inputRegs[0].getDone());
rewriter.create<calyx::AssignOp>(loc, doneLibOp.getRight(),
inputRegs[1].getDone());
doneValue = doneLibOp.getOut();
}

// Write to the output register
rewriter.create<calyx::AssignOp>(loc, reg.getIn(), outputValue);
rewriter.create<calyx::AssignOp>(loc, reg.getWriteEn(), doneValue);

// Set the go and done signal
rewriter.create<calyx::AssignOp>(
loc, calyxCmpFOp.getGo(), c1,
comb::createOrFoldNot(loc, calyxCmpFOp.getDone(), builder));
rewriter.create<calyx::GroupDoneOp>(loc, reg.getDone());

cmpf.getResult().replaceAllUsesWith(reg.getOut());

// Register evaluating groups
getState<ComponentLoweringState>().registerEvaluatingGroup(outputValue,
group);
getState<ComponentLoweringState>().registerEvaluatingGroup(doneValue, group);
getState<ComponentLoweringState>().registerEvaluatingGroup(
calyxCmpFOp.getLeft(), group);
getState<ComponentLoweringState>().registerEvaluatingGroup(
calyxCmpFOp.getRight(), group);

return success();
}

template <typename TAllocOp>
static LogicalResult buildAllocOp(ComponentLoweringState &componentState,
PatternRewriter &rewriter, TAllocOp allocOp) {
Expand Down Expand Up @@ -2113,7 +2318,7 @@ class SCFToCalyxPass : public circt::impl::SCFToCalyxBase<SCFToCalyxPass> {
ShRSIOp, AndIOp, XOrIOp, OrIOp, ExtUIOp, TruncIOp,
CondBranchOp, BranchOp, MulIOp, DivUIOp, DivSIOp, RemUIOp,
RemSIOp, ReturnOp, arith::ConstantOp, IndexCastOp, FuncOp,
ExtSIOp, CallOp, AddFOp, MulFOp>();
ExtSIOp, CallOp, AddFOp, MulFOp, CmpFOp>();

RewritePatternSet legalizePatterns(&getContext());
legalizePatterns.add<DummyPattern>(&getContext());
Expand Down
6 changes: 6 additions & 0 deletions lib/Dialect/Calyx/CalyxOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1209,10 +1209,16 @@ FloatingPointStandard MulFOpIEEE754::getFloatingPointStandard() {
return FloatingPointStandard::IEEE754;
}

FloatingPointStandard CompareFOpIEEE754::getFloatingPointStandard() {
return FloatingPointStandard::IEEE754;
}

std::string AddFOpIEEE754::getCalyxLibraryName() { return "std_addFN"; }

std::string MulFOpIEEE754::getCalyxLibraryName() { return "std_mulFN"; }

std::string CompareFOpIEEE754::getCalyxLibraryName() { return "std_compareFN"; }

//===----------------------------------------------------------------------===//
// GroupInterface
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 8 additions & 2 deletions lib/Dialect/Calyx/Export/CalyxEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ struct ImportTracker {
static constexpr std::string_view sFloatingPoint = "float/mulFN";
return {sFloatingPoint};
})
.Case<CompareFOpIEEE754>([&](auto op) -> FailureOr<StringRef> {
static constexpr std::string_view sFloatingPoint = "float/compareFN";
return {sFloatingPoint};
})
.Default([&](auto op) {
auto diag = op->emitOpError() << "not supported for emission";
return diag;
Expand Down Expand Up @@ -679,7 +683,7 @@ void Emitter::emitComponent(ComponentInterface op) {
emitLibraryPrimTypedByFirstOutputPort(
op, /*calyxLibName=*/{"std_sdiv_pipe"});
})
.Case<AddFOpIEEE754, MulFOpIEEE754>(
.Case<AddFOpIEEE754, MulFOpIEEE754, CompareFOpIEEE754>(
[&](auto op) { emitLibraryFloatingPoint(op); })
.Default([&](auto op) {
emitOpError(op, "not supported for emission inside component");
Expand Down Expand Up @@ -996,8 +1000,10 @@ void Emitter::emitLibraryPrimTypedByFirstOutputPort(

void Emitter::emitLibraryFloatingPoint(Operation *op) {
auto cell = cast<CellInterface>(op);
// magic number for the index of `left/right` input port
size_t inputPortIndex = cell.getInputPorts().size() - 3;
unsigned bitWidth =
cell.getOutputPorts()[0].getType().getIntOrFloatBitWidth();
cell.getInputPorts()[inputPortIndex].getType().getIntOrFloatBitWidth();
// Since Calyx interacts with HardFloat, we'll also only be using expWidth and
// sigWidth. See
// http://www.jhauser.us/arithmetic/HardFloat-1/doc/HardFloat-Verilog.html
Expand Down
Loading