diff --git a/lib/Dialect/FIRRTL/FIRRTLFolds.cpp b/lib/Dialect/FIRRTL/FIRRTLFolds.cpp index 0bb8eff2f5a4..ca4d8acf0d11 100644 --- a/lib/Dialect/FIRRTL/FIRRTLFolds.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLFolds.cpp @@ -2870,7 +2870,7 @@ struct FoldRegMems : public mlir::RewritePattern { if (hasDontTouch(mem) || info.depth != 1) return failure(); - auto *block = mem->getBlock(); + auto memModule = mem->getParentOfType(); // Find the clock of the register-to-be, all write ports should share it. Value clock; @@ -2926,17 +2926,11 @@ struct FoldRegMems : public mlir::RewritePattern { } // Create a new register to store the data. - auto clockWire = rewriter.create(mem.getLoc(), clock.getType()); auto ty = mem.getDataType(); - auto reg = rewriter - .create(mem.getLoc(), ty, clockWire.getResult(), - mem.getName()) + rewriter.setInsertionPointAfterValue(clock); + auto reg = rewriter.create(mem.getLoc(), ty, clock, mem.getName()) .getResult(); - rewriter.setInsertionPointToEnd(block); - rewriter.create(mem.getLoc(), clockWire.getResult(), - clock); - // Helper to insert a given number of pipeline stages through registers. auto pipeline = [&](Value value, Value clock, const Twine &name, unsigned latency) { @@ -2970,7 +2964,7 @@ struct FoldRegMems : public mlir::RewritePattern { auto portPipeline = [&, port = port](StringRef field, unsigned stages) { Value value = getPortFieldValue(port, field); assert(value); - rewriter.setInsertionPointAfterValue(reg); + rewriter.setInsertionPointAfterValue(value); return pipeline(value, portClock, name + "_" + field, stages); }; @@ -3004,7 +2998,7 @@ struct FoldRegMems : public mlir::RewritePattern { Value en = getPortFieldValue(port, "en"); Value wmode = getPortFieldValue(port, "wmode"); - rewriter.setInsertionPointToEnd(block); + rewriter.setInsertionPointToEnd(memModule.getBodyBlock()); auto wen = rewriter.create(port.getLoc(), en, wmode); auto wenPipelined = @@ -3016,7 +3010,7 @@ struct FoldRegMems : public mlir::RewritePattern { } // Regardless of `writeUnderWrite`, always implement PortOrder. - rewriter.setInsertionPointToEnd(block); + rewriter.setInsertionPointToEnd(memModule.getBodyBlock()); Value next = reg; for (auto &[data, en, mask] : writes) { Value masked; diff --git a/test/Dialect/FIRRTL/simplify-mems.mlir b/test/Dialect/FIRRTL/simplify-mems.mlir index 8785f0fa4275..45aa83f88277 100644 --- a/test/Dialect/FIRRTL/simplify-mems.mlir +++ b/test/Dialect/FIRRTL/simplify-mems.mlir @@ -408,21 +408,10 @@ firrtl.circuit "OneAddressNoMask" { in %in_rwen: !firrtl.uint<1>, out %result_read: !firrtl.uint<32>, out %result_rw: !firrtl.uint<32>) { - %c1_ui1 = firrtl.constant 1 : !firrtl.uint<1> - %Memory_read, %Memory_rw, %Memory_write = firrtl.mem Undefined - { - depth = 1 : i64, - name = "Memory", - portNames = ["read", "rw", "write"], - readLatency = 2 : i32, - writeLatency = 4 : i32 - } : - !firrtl.bundle, en: uint<1>, clk: clock, data flip: uint<32>>, - !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>, - !firrtl.bundle, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>> - // CHECK: %Memory = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32> - + // Pipeline the inputs. + // TODO: It would be good to de-duplicate these either in the pass or in a canonicalizer. + // CHECK: %Memory_write_en_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1> // CHECK: firrtl.matchingconnect %Memory_write_en_0, %in_wen : !firrtl.uint<1> // CHECK: %Memory_write_en_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1> @@ -444,6 +433,22 @@ firrtl.circuit "OneAddressNoMask" { // CHECK: %Memory_rw_wdata_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32> // CHECK: firrtl.matchingconnect %Memory_rw_wdata_2, %Memory_rw_wdata_1 : !firrtl.uint<32> + %c1_ui1 = firrtl.constant 1 : !firrtl.uint<1> + + %Memory_read, %Memory_rw, %Memory_write = firrtl.mem Undefined + { + depth = 1 : i64, + name = "Memory", + portNames = ["read", "rw", "write"], + readLatency = 2 : i32, + writeLatency = 4 : i32 + } : + !firrtl.bundle, en: uint<1>, clk: clock, data flip: uint<32>>, + !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>, + !firrtl.bundle, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>> + + // CHECK: %Memory = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32> + // CHECK: firrtl.matchingconnect %result_read, %Memory : !firrtl.uint<32> %read_addr = firrtl.subfield %Memory_read[addr] : !firrtl.bundle, en: uint<1>, clk: clock, data flip: uint<32>> firrtl.connect %read_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1> @@ -492,105 +497,3 @@ firrtl.circuit "OneAddressNoMask" { firrtl.connect %write_mask, %c1_ui1 : !firrtl.uint<1>, !firrtl.uint<1> } } - -// ----- - -// This test ensures that the FoldRegMems canonicalization correctly -// folds memories under layerblocks. -firrtl.circuit "Rewrite1ElementMemoryToRegisterUnderLayerblock" { - firrtl.layer @A bind {} - - firrtl.module public @Rewrite1ElementMemoryToRegisterUnderLayerblock( - in %clock: !firrtl.clock, - in %addr: !firrtl.uint<1>, - in %in_data: !firrtl.uint<32>, - in %wmode_rw: !firrtl.uint<1>, - in %in_wen: !firrtl.uint<1>, - in %in_rwen: !firrtl.uint<1>) { - - %c1_ui1 = firrtl.constant 1 : !firrtl.uint<1> - - // CHECK firrtl.layerblock @A - firrtl.layerblock @A { - // CHECK: %result_read = firrtl.wire : !firrtl.uint<32> - // CHECK: %result_rw = firrtl.wire : !firrtl.uint<32> - %result_read = firrtl.wire : !firrtl.uint<32> - %result_rw = firrtl.wire : !firrtl.uint<32> - - %Memory_read, %Memory_rw, %Memory_write = firrtl.mem Undefined - { - depth = 1 : i64, - name = "Memory", - portNames = ["read", "rw", "write"], - readLatency = 2 : i32, - writeLatency = 2 : i32 - } : - !firrtl.bundle, en: uint<1>, clk: clock, data flip: uint<32>>, - !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>, - !firrtl.bundle, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>> - - // CHECK: %Memory = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32> - // CHECK: %Memory_write_mask_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1> - // CHECK: firrtl.matchingconnect %Memory_write_mask_0, %c1_ui1 : !firrtl.uint<1> - - // CHECK: %Memory_write_en_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1> - // CHECK: firrtl.matchingconnect %Memory_write_en_0, %in_wen : !firrtl.uint<1> - - // CHECK: %Memory_write_data_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32> - // CHECK: firrtl.matchingconnect %Memory_write_data_0, %in_data : !firrtl.uint<32> - - // CHECK: %Memory_rw_wmask_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1> - // CHECK: firrtl.matchingconnect %Memory_rw_wmask_0, %c1_ui1 : !firrtl.uint<1> - - // CHECK: %Memory_rw_wdata_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32> - // CHECK: firrtl.matchingconnect %Memory_rw_wdata_0, %in_data : !firrtl.uint<32> - - // CHECK: firrtl.matchingconnect %result_read, %Memory : !firrtl.uint<32> - // CHECK: firrtl.matchingconnect %result_rw, %Memory : !firrtl.uint<32> - - // CHECK: %0 = firrtl.and %in_rwen, %wmode_rw : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> - // CHECK: %Memory_rw_wen_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1> - // CHECK: firrtl.matchingconnect %Memory_rw_wen_0, %0 : !firrtl.uint<1> - // CHECK: %1 = firrtl.and %Memory_rw_wen_0, %Memory_rw_wmask_0 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> - // CHECK: %2 = firrtl.mux(%1, %Memory_rw_wdata_0, %Memory) : (!firrtl.uint<1>, !firrtl.uint<32>, !firrtl.uint<32>) -> !firrtl.uint<32> - // CHECK: %3 = firrtl.and %Memory_write_en_0, %Memory_write_mask_0 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> - // CHECK: %4 = firrtl.mux(%3, %Memory_write_data_0, %2) : (!firrtl.uint<1>, !firrtl.uint<32>, !firrtl.uint<32>) -> !firrtl.uint<32> - // CHECK: firrtl.matchingconnect %Memory, %4 : !firrtl.uint<32> - - %read_addr = firrtl.subfield %Memory_read[addr] : !firrtl.bundle, en: uint<1>, clk: clock, data flip: uint<32>> - firrtl.connect %read_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1> - %read_en = firrtl.subfield %Memory_read[en] : !firrtl.bundle, en: uint<1>, clk: clock, data flip: uint<32>> - firrtl.connect %read_en, %c1_ui1 : !firrtl.uint<1>, !firrtl.uint<1> - %read_clk = firrtl.subfield %Memory_read[clk] : !firrtl.bundle, en: uint<1>, clk: clock, data flip: uint<32>> - firrtl.connect %read_clk, %clock : !firrtl.clock, !firrtl.clock - %read_data = firrtl.subfield %Memory_read[data] : !firrtl.bundle, en: uint<1>, clk: clock, data flip: uint<32>> - firrtl.connect %result_read, %read_data : !firrtl.uint<32>, !firrtl.uint<32> - - %rw_addr = firrtl.subfield %Memory_rw[addr] : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>> - firrtl.connect %rw_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1> - %rw_en = firrtl.subfield %Memory_rw[en] : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>> - firrtl.connect %rw_en, %in_rwen : !firrtl.uint<1>, !firrtl.uint<1> - %rw_clk = firrtl.subfield %Memory_rw[clk] : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>> - firrtl.connect %rw_clk, %clock : !firrtl.clock, !firrtl.clock - %rw_rdata = firrtl.subfield %Memory_rw[rdata] : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>> - firrtl.connect %result_rw, %rw_rdata : !firrtl.uint<32>, !firrtl.uint<32> - %rw_wmode = firrtl.subfield %Memory_rw[wmode] : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>> - firrtl.connect %rw_wmode, %wmode_rw : !firrtl.uint<1>, !firrtl.uint<1> - %rw_wdata = firrtl.subfield %Memory_rw[wdata] : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>> - firrtl.connect %rw_wdata, %in_data : !firrtl.uint<32>, !firrtl.uint<32> - %rw_wmask = firrtl.subfield %Memory_rw[wmask] : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>> - firrtl.connect %rw_wmask, %c1_ui1 : !firrtl.uint<1>, !firrtl.uint<1> - - %write_addr = firrtl.subfield %Memory_write[addr] : !firrtl.bundle, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>> - firrtl.connect %write_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1> - %write_en = firrtl.subfield %Memory_write[en] : !firrtl.bundle, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>> - firrtl.connect %write_en, %in_wen : !firrtl.uint<1>, !firrtl.uint<1> - %write_clk = firrtl.subfield %Memory_write[clk] : !firrtl.bundle, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>> - firrtl.connect %write_clk, %clock : !firrtl.clock, !firrtl.clock - %write_data = firrtl.subfield %Memory_write[data] : !firrtl.bundle, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>> - firrtl.connect %write_data, %in_data : !firrtl.uint<32>, !firrtl.uint<32> - %write_mask = firrtl.subfield %Memory_write[mask] : !firrtl.bundle, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>> - firrtl.connect %write_mask, %c1_ui1 : !firrtl.uint<1>, !firrtl.uint<1> - } - } -}