-
Notifications
You must be signed in to change notification settings - Fork 302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Calyx] Lower Arith CmpFOp to Calyx #7860
base: main
Are you sure you want to change the base?
Changes from 6 commits
38bdc55
c28790d
1e43cf9
506ac35
5ef5d11
0db4e8e
397823d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -779,6 +779,26 @@ class BuildCallInstance : public calyx::FuncOpPartialLoweringPattern { | |
ComponentOp getCallComponent(mlir::func::CallOp callOp) const; | ||
}; | ||
|
||
/// Predicate information for the floating point comparisons | ||
struct PredicateInfo { | ||
struct InputPorts { | ||
// Relevant ports to extract from the `std_compareFN`. For example, we | ||
// extract the `lt` and the `unordered` ports when the predicate is `oge`. | ||
enum Port { EQ, GT, LT, UNORDERED } port; | ||
jiahanxie353 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// Whether we should invert the port before passing as inputs to the `op` | ||
// field. For example, we should invert both the `lt` and the `unordered` | ||
// port just extracted for predicate `oge`. | ||
bool invert; | ||
}; | ||
|
||
// The combinational logic to apply to the input ports. For example, we should | ||
// apply `AND` logic to the two input ports for predicate `oge`. | ||
enum CombLogic { AND, OR, SPECIAL } logic; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed it to |
||
SmallVector<InputPorts> inputPorts; | ||
}; | ||
|
||
PredicateInfo getPredicateInfo(mlir::arith::CmpFPredicate pred); | ||
|
||
} // namespace calyx | ||
} // namespace circt | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "llvm/ADT/TypeSwitch.h" | ||
#include "llvm/Support/LogicalResult.h" | ||
#include "llvm/Support/raw_ostream.h" | ||
|
||
#include <variant> | ||
|
||
|
@@ -289,7 +290,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { | |
AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp, | ||
MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp, | ||
/// floating point | ||
AddFOp, MulFOp, | ||
AddFOp, MulFOp, CmpFOp, | ||
/// others | ||
SelectOp, IndexCastOp, CallOp>( | ||
[&](auto op) { return buildOp(rewriter, op).succeeded(); }) | ||
|
@@ -326,6 +327,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { | |
LogicalResult buildOp(PatternRewriter &rewriter, RemSIOp op) const; | ||
LogicalResult buildOp(PatternRewriter &rewriter, AddFOp op) const; | ||
LogicalResult buildOp(PatternRewriter &rewriter, MulFOp op) const; | ||
LogicalResult buildOp(PatternRewriter &rewriter, CmpFOp op) const; | ||
LogicalResult buildOp(PatternRewriter &rewriter, ShRUIOp op) const; | ||
LogicalResult buildOp(PatternRewriter &rewriter, ShRSIOp op) const; | ||
LogicalResult buildOp(PatternRewriter &rewriter, ShLIOp op) const; | ||
|
@@ -729,6 +731,164 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, | |
mulFOp.getOut()); | ||
} | ||
|
||
LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, | ||
CmpFOp cmpf) const { | ||
Location loc = cmpf.getLoc(); | ||
IntegerType one = rewriter.getI1Type(), five = rewriter.getIntegerType(5), | ||
width = rewriter.getIntegerType( | ||
cmpf.getLhs().getType().getIntOrFloatBitWidth()); | ||
auto calyxCmpFOp = getState<ComponentLoweringState>() | ||
.getNewLibraryOpInstance<calyx::CompareFOpIEEE754>( | ||
rewriter, loc, | ||
{one, one, one, width, width, one, one, one, one, | ||
one, five, one}); | ||
hw::ConstantOp c1 = createConstant(loc, rewriter, getComponent(), 1, 1); | ||
rewriter.setInsertionPointToStart(getComponent().getBodyBlock()); | ||
|
||
auto createSignalRegister = | ||
[&](Value signal, bool invert, StringRef nameSuffix, Location loc, | ||
PatternRewriter &rewriter, calyx::ComponentOp component, | ||
calyx::CompareFOpIEEE754 calyxCmpFOp, calyx::GroupOp group, | ||
OpBuilder &builder) -> calyx::RegisterOp { | ||
auto reg = createRegister( | ||
loc, rewriter, component, 1, | ||
getState<ComponentLoweringState>().getUniqueName(nameSuffix)); | ||
rewriter.create<calyx::AssignOp>(loc, reg.getWriteEn(), | ||
calyxCmpFOp.getDone()); | ||
if (invert) | ||
rewriter.create<calyx::AssignOp>( | ||
loc, reg.getIn(), c1, comb::createOrFoldNot(loc, signal, builder)); | ||
else | ||
rewriter.create<calyx::AssignOp>(loc, reg.getIn(), signal); | ||
return reg; | ||
}; | ||
|
||
using calyx::PredicateInfo; | ||
PredicateInfo info = calyx::getPredicateInfo(cmpf.getPredicate()); | ||
if (info.logic == PredicateInfo::SPECIAL) { | ||
if (cmpf.getPredicate() == CmpFPredicate::AlwaysTrue) { | ||
cmpf.getResult().replaceAllUsesWith(c1); | ||
return success(); | ||
} | ||
|
||
if (cmpf.getPredicate() == CmpFPredicate::AlwaysFalse) { | ||
Value constantZero = createConstant(loc, rewriter, getComponent(), 1, 0); | ||
cmpf.getResult().replaceAllUsesWith(constantZero); | ||
return success(); | ||
} | ||
} | ||
|
||
// General case | ||
StringRef opName = cmpf.getOperationName().split(".").second; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I should refactor with |
||
auto reg = | ||
createRegister(loc, rewriter, getComponent(), 1, | ||
getState<ComponentLoweringState>().getUniqueName(opName)); | ||
|
||
// Operation pipelines are not combinational, so a GroupOp is required. | ||
auto group = createGroupForOp<calyx::GroupOp>(rewriter, cmpf); | ||
OpBuilder builder(group->getRegion(0)); | ||
getState<ComponentLoweringState>().addBlockScheduleable(cmpf->getBlock(), | ||
group); | ||
|
||
rewriter.setInsertionPointToEnd(group.getBodyBlock()); | ||
rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getLeft(), cmpf.getLhs()); | ||
rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getRight(), cmpf.getRhs()); | ||
|
||
// Prepare signals and create registers | ||
SmallVector<calyx::RegisterOp> inputRegs; | ||
for (const auto &input : info.inputPorts) { | ||
Value signal; | ||
switch (input.port) { | ||
case PredicateInfo::InputPorts::Port::EQ: { | ||
signal = calyxCmpFOp.getEq(); | ||
break; | ||
} | ||
case PredicateInfo::InputPorts::Port::GT: { | ||
signal = calyxCmpFOp.getGt(); | ||
break; | ||
} | ||
case PredicateInfo::InputPorts::Port::LT: { | ||
signal = calyxCmpFOp.getLt(); | ||
break; | ||
} | ||
case PredicateInfo::InputPorts::Port::UNORDERED: { | ||
signal = calyxCmpFOp.getUnordered(); | ||
break; | ||
} | ||
} | ||
std::string nameSuffix = | ||
(input.port == PredicateInfo::InputPorts::Port::UNORDERED) | ||
? "unordered_port" | ||
: "compare_port"; | ||
auto signalReg = | ||
createSignalRegister(signal, input.invert, nameSuffix, loc, rewriter, | ||
getComponent(), calyxCmpFOp, group, builder); | ||
inputRegs.push_back(signalReg); | ||
} | ||
|
||
// Create the output logical operation | ||
Value outputValue, doneValue; | ||
if (info.logic == PredicateInfo::SPECIAL) { | ||
jiahanxie353 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// it's guaranteed to be either ORD or UNO | ||
outputValue = inputRegs[0].getOut(); | ||
doneValue = inputRegs[0].getOut(); | ||
} else { | ||
if (info.logic == PredicateInfo::AND) { | ||
auto outputLibOp = getState<ComponentLoweringState>() | ||
.getNewLibraryOpInstance<calyx::AndLibOp>( | ||
rewriter, loc, {one, one, one}); | ||
rewriter.create<calyx::AssignOp>(loc, outputLibOp.getLeft(), | ||
inputRegs[0].getOut()); | ||
rewriter.create<calyx::AssignOp>(loc, outputLibOp.getRight(), | ||
inputRegs[1].getOut()); | ||
|
||
outputValue = outputLibOp.getOut(); | ||
} else /*info.op == PredicateInfo::OR*/ { | ||
auto outputLibOp = getState<ComponentLoweringState>() | ||
.getNewLibraryOpInstance<calyx::OrLibOp>( | ||
rewriter, loc, {one, one, one}); | ||
rewriter.create<calyx::AssignOp>(loc, outputLibOp.getLeft(), | ||
inputRegs[0].getOut()); | ||
rewriter.create<calyx::AssignOp>(loc, outputLibOp.getRight(), | ||
inputRegs[1].getOut()); | ||
|
||
outputValue = outputLibOp.getOut(); | ||
} | ||
|
||
auto doneLibOp = getState<ComponentLoweringState>() | ||
.getNewLibraryOpInstance<calyx::AndLibOp>( | ||
rewriter, loc, {one, one, one}); | ||
rewriter.create<calyx::AssignOp>(loc, doneLibOp.getLeft(), | ||
inputRegs[0].getDone()); | ||
rewriter.create<calyx::AssignOp>(loc, doneLibOp.getRight(), | ||
inputRegs[1].getDone()); | ||
doneValue = doneLibOp.getOut(); | ||
} | ||
|
||
// Write to the output register | ||
rewriter.create<calyx::AssignOp>(loc, reg.getIn(), outputValue); | ||
rewriter.create<calyx::AssignOp>(loc, reg.getWriteEn(), doneValue); | ||
|
||
// Set the go and done signal | ||
rewriter.create<calyx::AssignOp>( | ||
loc, calyxCmpFOp.getGo(), c1, | ||
comb::createOrFoldNot(loc, calyxCmpFOp.getDone(), builder)); | ||
rewriter.create<calyx::GroupDoneOp>(loc, reg.getDone()); | ||
|
||
cmpf.getResult().replaceAllUsesWith(reg.getOut()); | ||
|
||
// Register evaluating groups | ||
getState<ComponentLoweringState>().registerEvaluatingGroup(outputValue, | ||
group); | ||
getState<ComponentLoweringState>().registerEvaluatingGroup(doneValue, group); | ||
getState<ComponentLoweringState>().registerEvaluatingGroup( | ||
calyxCmpFOp.getLeft(), group); | ||
getState<ComponentLoweringState>().registerEvaluatingGroup( | ||
calyxCmpFOp.getRight(), group); | ||
|
||
return success(); | ||
} | ||
|
||
template <typename TAllocOp> | ||
static LogicalResult buildAllocOp(ComponentLoweringState &componentState, | ||
PatternRewriter &rewriter, TAllocOp allocOp) { | ||
|
@@ -2113,7 +2273,7 @@ class SCFToCalyxPass : public circt::impl::SCFToCalyxBase<SCFToCalyxPass> { | |
ShRSIOp, AndIOp, XOrIOp, OrIOp, ExtUIOp, TruncIOp, | ||
CondBranchOp, BranchOp, MulIOp, DivUIOp, DivSIOp, RemUIOp, | ||
RemSIOp, ReturnOp, arith::ConstantOp, IndexCastOp, FuncOp, | ||
ExtSIOp, CallOp, AddFOp, MulFOp>(); | ||
ExtSIOp, CallOp, AddFOp, MulFOp, CmpFOp>(); | ||
|
||
RewritePatternSet legalizePatterns(&getContext()); | ||
legalizePatterns.add<DummyPattern>(&getContext()); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, looking much better; I've added a few smaller comments. I am curious if you've verified at all the output of this code besides manually checking it, e.g., have you run the emitted code through the native Calyx compiler and checked the outputs via some backend like Icarus or SysVerilog?