Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Calyx] Lower Arith CmpFOp to Calyx #7860

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions include/circt/Dialect/Calyx/CalyxLoweringUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,26 @@ class BuildCallInstance : public calyx::FuncOpPartialLoweringPattern {
ComponentOp getCallComponent(mlir::func::CallOp callOp) const;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, looking much better; I've added a few smaller comments. I am curious if you've verified at all the output of this code besides manually checking it, e.g., have you run the emitted code through the native Calyx compiler and checked the outputs via some backend like Icarus or SysVerilog?

};

/// Predicate information for the floating point comparisons
struct PredicateInfo {
struct InputPorts {
// Relevant ports to extract from the `std_compareFN`. For example, we
// extract the `lt` and the `unordered` ports when the predicate is `oge`.
enum Port { EQ, GT, LT, UNORDERED } port;
jiahanxie353 marked this conversation as resolved.
Show resolved Hide resolved
// Whether we should invert the port before passing as inputs to the `op`
// field. For example, we should invert both the `lt` and the `unordered`
// port just extracted for predicate `oge`.
bool invert;
};

// The combinational logic to apply to the input ports. For example, we should
// apply `AND` logic to the two input ports for predicate `oge`.
enum CombLogic { AND, OR, SPECIAL } logic;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is SPECIAL? It isn't immediatley clear from your documentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed it to noComb (hopefully the intention that it stands for "no combinational logic" is more obvious)

SmallVector<InputPorts> inputPorts;
};

PredicateInfo getPredicateInfo(mlir::arith::CmpFPredicate pred);

} // namespace calyx
} // namespace circt

Expand Down
52 changes: 46 additions & 6 deletions include/circt/Dialect/Calyx/CalyxPrimitives.td
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,15 @@ def AndLibOp : CombinationalArithBinaryLibraryOp<"and"> {}
def OrLibOp : CombinationalArithBinaryLibraryOp<"or"> {}
def XorLibOp : CombinationalArithBinaryLibraryOp<"xor"> {}

class ArithBinaryFloatingPointLibraryOp<string mnemonic> :
ArithBinaryLibraryOp<mnemonic, "", [
class ArithBinaryFloatingPointLibraryOp<string mnemonic, list<Trait> traits = []> :
ArithBinaryLibraryOp<mnemonic, "", !listconcat(traits, [
DeclareOpInterfaceMethods<FloatingPointOpInterface>,
SameTypeConstraint<"left", "out">
]> {}
SameTypeConstraint<"left", "right">
])> {}

def AddFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.add"> {
def AddFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.add", [
SameTypeConstraint<"left", "out">
]> {
let results = (outs I1:$clk, I1:$reset, I1:$go, I1:$control, I1:$subOp,
AnySignlessInteger:$left, AnySignlessInteger:$right, AnySignlessInteger:$roundingMode, AnySignlessInteger:$out,
AnySignlessInteger:$exceptionalFlags, I1:$done);
Expand Down Expand Up @@ -378,7 +380,9 @@ def AddFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.add"> {
}];
}

def MulFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.mul"> {
def MulFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.mul", [
SameTypeConstraint<"left", "out">
]> {
let results = (outs I1:$clk, I1:$reset, I1:$go, I1:$control,
AnySignlessInteger:$left, AnySignlessInteger:$right, AnySignlessInteger:$roundingMode, AnySignlessInteger:$out,
AnySignlessInteger:$exceptionalFlags, I1:$done);
Expand Down Expand Up @@ -413,6 +417,42 @@ def MulFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.mul"> {
}];
}

def CompareFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.compare", []> {
jiahanxie353 marked this conversation as resolved.
Show resolved Hide resolved
let results = (outs I1:$clk, I1:$reset, I1:$go,
AnySignlessInteger:$left, AnySignlessInteger:$right, I1:$signaling,
I1:$lt, I1: $eq, I1: $gt, I1: $unordered, AnySignlessInteger: $exceptionalFlags, I1: $done);
let assemblyFormat = "$sym_name attr-dict `:` qualified(type(results))";
let extraClassDefinition = [{
SmallVector<StringRef> $cppClass::portNames() {
return {clkPort, resetPort, goPort, "left", "right", "signaling",
"lt", "eq", "gt", "unordered", "exceptionalFlags", donePort
};
}
SmallVector<Direction> $cppClass::portDirections() {
return {Input, Input, Input, Input, Input, Input, Output, Output, Output, Output, Output, Output};
}
void $cppClass::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
getCellAsmResultNames(setNameFn, *this, this->portNames());
}
bool $cppClass::isCombinational() { return false; }
SmallVector<DictionaryAttr> $cppClass::portAttributes() {
IntegerAttr isSet = IntegerAttr::get(IntegerType::get(getContext(), 1), 1);
NamedAttrList go, clk, reset, done;
go.append(goPort, isSet);
clk.append(clkPort, isSet);
reset.append(resetPort, isSet);
done.append(donePort, isSet);
return {clk.getDictionary(getContext()), reset.getDictionary(getContext()),
go.getDictionary(getContext()), DictionaryAttr::get(getContext()),
DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext()),
DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext()),
DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext()),
DictionaryAttr::get(getContext()), done.getDictionary(getContext())
};
}
}];
}

def MuxLibOp : CalyxLibraryOp<"mux", "std_", [
Combinational, SameTypeConstraint<"tru", "fal">, SameTypeConstraint<"tru", "out">
]> {
Expand Down
164 changes: 162 additions & 2 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"

#include <variant>

Expand Down Expand Up @@ -289,7 +290,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp,
MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp,
/// floating point
AddFOp, MulFOp,
AddFOp, MulFOp, CmpFOp,
/// others
SelectOp, IndexCastOp, CallOp>(
[&](auto op) { return buildOp(rewriter, op).succeeded(); })
Expand Down Expand Up @@ -326,6 +327,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
LogicalResult buildOp(PatternRewriter &rewriter, RemSIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, AddFOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, MulFOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, CmpFOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, ShRUIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, ShRSIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, ShLIOp op) const;
Expand Down Expand Up @@ -729,6 +731,164 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
mulFOp.getOut());
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
CmpFOp cmpf) const {
Location loc = cmpf.getLoc();
IntegerType one = rewriter.getI1Type(), five = rewriter.getIntegerType(5),
width = rewriter.getIntegerType(
cmpf.getLhs().getType().getIntOrFloatBitWidth());
auto calyxCmpFOp = getState<ComponentLoweringState>()
.getNewLibraryOpInstance<calyx::CompareFOpIEEE754>(
rewriter, loc,
{one, one, one, width, width, one, one, one, one,
one, five, one});
hw::ConstantOp c1 = createConstant(loc, rewriter, getComponent(), 1, 1);
rewriter.setInsertionPointToStart(getComponent().getBodyBlock());

auto createSignalRegister =
[&](Value signal, bool invert, StringRef nameSuffix, Location loc,
PatternRewriter &rewriter, calyx::ComponentOp component,
calyx::CompareFOpIEEE754 calyxCmpFOp, calyx::GroupOp group,
OpBuilder &builder) -> calyx::RegisterOp {
auto reg = createRegister(
loc, rewriter, component, 1,
getState<ComponentLoweringState>().getUniqueName(nameSuffix));
rewriter.create<calyx::AssignOp>(loc, reg.getWriteEn(),
calyxCmpFOp.getDone());
if (invert)
rewriter.create<calyx::AssignOp>(
loc, reg.getIn(), c1, comb::createOrFoldNot(loc, signal, builder));
else
rewriter.create<calyx::AssignOp>(loc, reg.getIn(), signal);
return reg;
};

using calyx::PredicateInfo;
PredicateInfo info = calyx::getPredicateInfo(cmpf.getPredicate());
if (info.logic == PredicateInfo::SPECIAL) {
if (cmpf.getPredicate() == CmpFPredicate::AlwaysTrue) {
cmpf.getResult().replaceAllUsesWith(c1);
return success();
}

if (cmpf.getPredicate() == CmpFPredicate::AlwaysFalse) {
Value constantZero = createConstant(loc, rewriter, getComponent(), 1, 0);
cmpf.getResult().replaceAllUsesWith(constantZero);
return success();
}
}

// General case
StringRef opName = cmpf.getOperationName().split(".").second;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I should refactor with buildLibraryBinaryPipeOp, it's tricky though

auto reg =
createRegister(loc, rewriter, getComponent(), 1,
getState<ComponentLoweringState>().getUniqueName(opName));

// Operation pipelines are not combinational, so a GroupOp is required.
auto group = createGroupForOp<calyx::GroupOp>(rewriter, cmpf);
OpBuilder builder(group->getRegion(0));
getState<ComponentLoweringState>().addBlockScheduleable(cmpf->getBlock(),
group);

rewriter.setInsertionPointToEnd(group.getBodyBlock());
rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getLeft(), cmpf.getLhs());
rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getRight(), cmpf.getRhs());

// Prepare signals and create registers
SmallVector<calyx::RegisterOp> inputRegs;
for (const auto &input : info.inputPorts) {
Value signal;
switch (input.port) {
case PredicateInfo::InputPorts::Port::EQ: {
signal = calyxCmpFOp.getEq();
break;
}
case PredicateInfo::InputPorts::Port::GT: {
signal = calyxCmpFOp.getGt();
break;
}
case PredicateInfo::InputPorts::Port::LT: {
signal = calyxCmpFOp.getLt();
break;
}
case PredicateInfo::InputPorts::Port::UNORDERED: {
signal = calyxCmpFOp.getUnordered();
break;
}
}
std::string nameSuffix =
(input.port == PredicateInfo::InputPorts::Port::UNORDERED)
? "unordered_port"
: "compare_port";
auto signalReg =
createSignalRegister(signal, input.invert, nameSuffix, loc, rewriter,
getComponent(), calyxCmpFOp, group, builder);
inputRegs.push_back(signalReg);
}

// Create the output logical operation
Value outputValue, doneValue;
if (info.logic == PredicateInfo::SPECIAL) {
jiahanxie353 marked this conversation as resolved.
Show resolved Hide resolved
// it's guaranteed to be either ORD or UNO
outputValue = inputRegs[0].getOut();
doneValue = inputRegs[0].getOut();
} else {
if (info.logic == PredicateInfo::AND) {
auto outputLibOp = getState<ComponentLoweringState>()
.getNewLibraryOpInstance<calyx::AndLibOp>(
rewriter, loc, {one, one, one});
rewriter.create<calyx::AssignOp>(loc, outputLibOp.getLeft(),
inputRegs[0].getOut());
rewriter.create<calyx::AssignOp>(loc, outputLibOp.getRight(),
inputRegs[1].getOut());

outputValue = outputLibOp.getOut();
} else /*info.op == PredicateInfo::OR*/ {
auto outputLibOp = getState<ComponentLoweringState>()
.getNewLibraryOpInstance<calyx::OrLibOp>(
rewriter, loc, {one, one, one});
rewriter.create<calyx::AssignOp>(loc, outputLibOp.getLeft(),
inputRegs[0].getOut());
rewriter.create<calyx::AssignOp>(loc, outputLibOp.getRight(),
inputRegs[1].getOut());

outputValue = outputLibOp.getOut();
}

auto doneLibOp = getState<ComponentLoweringState>()
.getNewLibraryOpInstance<calyx::AndLibOp>(
rewriter, loc, {one, one, one});
rewriter.create<calyx::AssignOp>(loc, doneLibOp.getLeft(),
inputRegs[0].getDone());
rewriter.create<calyx::AssignOp>(loc, doneLibOp.getRight(),
inputRegs[1].getDone());
doneValue = doneLibOp.getOut();
}

// Write to the output register
rewriter.create<calyx::AssignOp>(loc, reg.getIn(), outputValue);
rewriter.create<calyx::AssignOp>(loc, reg.getWriteEn(), doneValue);

// Set the go and done signal
rewriter.create<calyx::AssignOp>(
loc, calyxCmpFOp.getGo(), c1,
comb::createOrFoldNot(loc, calyxCmpFOp.getDone(), builder));
rewriter.create<calyx::GroupDoneOp>(loc, reg.getDone());

cmpf.getResult().replaceAllUsesWith(reg.getOut());

// Register evaluating groups
getState<ComponentLoweringState>().registerEvaluatingGroup(outputValue,
group);
getState<ComponentLoweringState>().registerEvaluatingGroup(doneValue, group);
getState<ComponentLoweringState>().registerEvaluatingGroup(
calyxCmpFOp.getLeft(), group);
getState<ComponentLoweringState>().registerEvaluatingGroup(
calyxCmpFOp.getRight(), group);

return success();
}

template <typename TAllocOp>
static LogicalResult buildAllocOp(ComponentLoweringState &componentState,
PatternRewriter &rewriter, TAllocOp allocOp) {
Expand Down Expand Up @@ -2113,7 +2273,7 @@ class SCFToCalyxPass : public circt::impl::SCFToCalyxBase<SCFToCalyxPass> {
ShRSIOp, AndIOp, XOrIOp, OrIOp, ExtUIOp, TruncIOp,
CondBranchOp, BranchOp, MulIOp, DivUIOp, DivSIOp, RemUIOp,
RemSIOp, ReturnOp, arith::ConstantOp, IndexCastOp, FuncOp,
ExtSIOp, CallOp, AddFOp, MulFOp>();
ExtSIOp, CallOp, AddFOp, MulFOp, CmpFOp>();

RewritePatternSet legalizePatterns(&getContext());
legalizePatterns.add<DummyPattern>(&getContext());
Expand Down
6 changes: 6 additions & 0 deletions lib/Dialect/Calyx/CalyxOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1209,10 +1209,16 @@ FloatingPointStandard MulFOpIEEE754::getFloatingPointStandard() {
return FloatingPointStandard::IEEE754;
}

FloatingPointStandard CompareFOpIEEE754::getFloatingPointStandard() {
return FloatingPointStandard::IEEE754;
}

std::string AddFOpIEEE754::getCalyxLibraryName() { return "std_addFN"; }

std::string MulFOpIEEE754::getCalyxLibraryName() { return "std_mulFN"; }

std::string CompareFOpIEEE754::getCalyxLibraryName() { return "std_compareFN"; }

//===----------------------------------------------------------------------===//
// GroupInterface
//===----------------------------------------------------------------------===//
Expand Down
11 changes: 9 additions & 2 deletions lib/Dialect/Calyx/Export/CalyxEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <bitset>
#include <string>

Expand Down Expand Up @@ -157,6 +158,10 @@ struct ImportTracker {
static constexpr std::string_view sFloatingPoint = "float/mulFN";
return {sFloatingPoint};
})
.Case<CompareFOpIEEE754>([&](auto op) -> FailureOr<StringRef> {
static constexpr std::string_view sFloatingPoint = "float/compareFN";
return {sFloatingPoint};
})
.Default([&](auto op) {
auto diag = op->emitOpError() << "not supported for emission";
return diag;
Expand Down Expand Up @@ -679,7 +684,7 @@ void Emitter::emitComponent(ComponentInterface op) {
emitLibraryPrimTypedByFirstOutputPort(
op, /*calyxLibName=*/{"std_sdiv_pipe"});
})
.Case<AddFOpIEEE754, MulFOpIEEE754>(
.Case<AddFOpIEEE754, MulFOpIEEE754, CompareFOpIEEE754>(
[&](auto op) { emitLibraryFloatingPoint(op); })
.Default([&](auto op) {
emitOpError(op, "not supported for emission inside component");
Expand Down Expand Up @@ -996,8 +1001,10 @@ void Emitter::emitLibraryPrimTypedByFirstOutputPort(

void Emitter::emitLibraryFloatingPoint(Operation *op) {
auto cell = cast<CellInterface>(op);
// magic number for the index of `left/right` input port
size_t inputPortIndex = cell.getInputPorts().size() - 3;
unsigned bitWidth =
cell.getOutputPorts()[0].getType().getIntOrFloatBitWidth();
cell.getInputPorts()[inputPortIndex].getType().getIntOrFloatBitWidth();
// Since Calyx interacts with HardFloat, we'll also only be using expWidth and
// sigWidth. See
// http://www.jhauser.us/arithmetic/HardFloat-1/doc/HardFloat-Verilog.html
Expand Down
Loading
Loading