From e72d55dab12cec80c308a3250bf8098c7a26b2d3 Mon Sep 17 00:00:00 2001 From: orizi <104711814+orizi@users.noreply.github.com> Date: Thu, 4 Apr 2024 16:51:04 +0300 Subject: [PATCH] Simplified inlining algorithm. (#5359) --- crates/cairo-lang-lowering/src/inline/mod.rs | 115 +++++++------------ 1 file changed, 39 insertions(+), 76 deletions(-) diff --git a/crates/cairo-lang-lowering/src/inline/mod.rs b/crates/cairo-lang-lowering/src/inline/mod.rs index 66c60ded2f5..7537939fb1e 100644 --- a/crates/cairo-lang-lowering/src/inline/mod.rs +++ b/crates/cairo-lang-lowering/src/inline/mod.rs @@ -10,7 +10,7 @@ use cairo_lang_diagnostics::{Diagnostics, Maybe}; use cairo_lang_semantic::items::functions::InlineConfiguration; use cairo_lang_utils::casts::IntoOrPanic; use cairo_lang_utils::ordered_hash_map::OrderedHashMap; -use itertools::{izip, zip_eq, Itertools}; +use itertools::{izip, zip_eq}; use statements_weights::InlineWeight; use self::statements_weights::ApproxCasmInlineWeight; @@ -113,73 +113,47 @@ pub struct FunctionInlinerRewriter<'db> { /// The LoweringContext were we are building the new blocks. variables: VariableAllocator<'db>, /// A Queue of blocks on which we want to apply the FunctionInlinerRewriter. - block_queue: BlockQueue, + block_queue: BlockRewriteQueue, /// rewritten statements. statements: Vec, /// The end of the current block. block_end: FlatBlockEnd, - /// The current block id. - current_block_id: BlockId, - /// stack for statements that require rewriting. - statement_rewrite_stack: StatementStack, + /// The processed statements of the current block. + unprocessed_statements: as IntoIterator>::IntoIter, /// Indicates that the inlining process was successful. inlining_success: Maybe<()>, - /// A map between blocks and the parent block that created them. - block_to_parent: HashMap, - /// A map between blocks and the function that originally contained them. - block_to_function: HashMap, + /// The id of the function calling the possibly inlined functions. + calling_function_id: ConcreteFunctionWithBodyId, } -#[derive(Default)] -pub struct StatementStack { - stack: Vec, -} - -impl StatementStack { - /// Pushes multiple statement into the stack. - /// - /// Note that to keep the order of the statements when they are popped from the stack - /// they need to be pushed in reverse order. - fn push_statements(&mut self, statements: impl DoubleEndedIterator) { - self.stack.extend(statements.rev()); - } - - // Consumes all the statements in the stack. - fn consume(&mut self) -> Vec { - self.stack.drain(..).rev().collect_vec() - } - - fn pop_statement(&mut self) -> Option { - self.stack.pop() - } -} - -pub struct BlockQueue { +pub struct BlockRewriteQueue { /// A Queue of blocks that require processing, and their id. - block_queue: VecDeque, + block_queue: VecDeque<(FlatBlock, bool)>, /// The new blocks that were created during the inlining. flat_blocks: FlatBlocksBuilder, } -impl BlockQueue { +impl BlockRewriteQueue { /// Enqueues the block for processing and returns the block_id that this /// block is going to get in self.flat_blocks. - fn enqueue_block(&mut self, block: FlatBlock) -> BlockId { - self.block_queue.push_back(block); + fn enqueue_block(&mut self, block: FlatBlock, requires_rewrite: bool) -> BlockId { + self.block_queue.push_back((block, requires_rewrite)); BlockId(self.flat_blocks.len() + self.block_queue.len()) } - // Pops a block from the queue. + /// Pops a block requiring rewrites from the queue. + /// If the block doesn't require rewrites, it is finalized and added to the flat_blocks. fn dequeue(&mut self) -> Option { - self.block_queue.pop_front() + while let Some((block, requires_rewrite)) = self.block_queue.pop_front() { + if requires_rewrite { + return Some(block); + } + self.finalize(block); + } + None } /// Finalizes a block. - fn finalize(&mut self, block: FlatBlock) -> BlockId { - self.flat_blocks.alloc(block) - } -} -impl Default for BlockQueue { - fn default() -> Self { - Self { block_queue: Default::default(), flat_blocks: FlatBlocksBuilder::new() } + fn finalize(&mut self, block: FlatBlock) { + self.flat_blocks.alloc(block); } } @@ -240,27 +214,23 @@ impl<'db> FunctionInlinerRewriter<'db> { ) -> Maybe { let mut rewriter = Self { variables, - block_queue: BlockQueue { - block_queue: VecDeque::from(flat_lower.blocks.get().clone()), + block_queue: BlockRewriteQueue { + block_queue: flat_lower.blocks.iter().map(|(_, b)| (b.clone(), true)).collect(), flat_blocks: FlatBlocksBuilder::new(), }, statements: vec![], block_end: FlatBlockEnd::NotSet, - current_block_id: BlockId::root(), - statement_rewrite_stack: StatementStack::default(), + unprocessed_statements: Default::default(), inlining_success: flat_lower.blocks.has_root(), - block_to_parent: HashMap::new(), - block_to_function: (0..flat_lower.blocks.len()) - .map(|i| (BlockId(i), calling_function_id)) - .collect(), + calling_function_id, }; rewriter.variables.variables = flat_lower.variables.clone(); while let Some(block) = rewriter.block_queue.dequeue() { rewriter.block_end = block.end; - rewriter.statement_rewrite_stack.push_statements(block.statements.into_iter()); + rewriter.unprocessed_statements = block.statements.into_iter(); - while let Some(statement) = rewriter.statement_rewrite_stack.pop_statement() { + while let Some(statement) = rewriter.unprocessed_statements.next() { rewriter.rewrite(statement)?; } @@ -268,7 +238,6 @@ impl<'db> FunctionInlinerRewriter<'db> { statements: std::mem::take(&mut rewriter.statements), end: rewriter.block_end, }); - rewriter.current_block_id = rewriter.current_block_id.next_block_id(); } let blocks = rewriter @@ -290,12 +259,9 @@ impl<'db> FunctionInlinerRewriter<'db> { fn rewrite(&mut self, statement: Statement) -> Maybe<()> { if let Statement::Call(ref stmt) = statement { if let Some(called_func) = stmt.function.body(self.variables.db)? { - let orig_func = self.block_to_function[&BlockId::root()]; - // TODO: Implement better logic to avoid inlining of destructors that call // themselves. - if called_func != orig_func - && orig_func == self.block_to_function[&self.current_block_id] + if called_func != self.calling_function_id && self.variables.db.priv_should_inline(called_func)? { return self.inline_function(called_func, &stmt.inputs, &stmt.outputs); @@ -322,15 +288,13 @@ impl<'db> FunctionInlinerRewriter<'db> { lowered.blocks.has_root()?; // Create a new block with all the statements that follow the call statement. - let return_block_id = self.block_queue.enqueue_block(FlatBlock { - statements: self.statement_rewrite_stack.consume(), - end: self.block_end.clone(), - }); - if let Some(parent_block_id) = self.block_to_parent.get(&self.current_block_id) { - self.block_to_parent.insert(return_block_id, *parent_block_id); - } - self.block_to_function - .insert(return_block_id, self.block_to_function[&self.current_block_id]); + let return_block_id = self.block_queue.enqueue_block( + FlatBlock { + statements: std::mem::take(&mut self.unprocessed_statements).collect(), + end: self.block_end.clone(), + }, + true, + ); // As the block_ids and variable_ids are per function, we need to rename all // the blocks and variables before we enqueue the blocks from the function that @@ -361,11 +325,10 @@ impl<'db> FunctionInlinerRewriter<'db> { for (block_id, block) in lowered.blocks.iter() { let block = mapper.rebuild_block(block); - - let new_block_id = self.block_queue.enqueue_block(block); + // Inlining is top down - so need to perform further inlining on the inlined function + // blocks. + let new_block_id = self.block_queue.enqueue_block(block, false); assert_eq!(mapper.map_block_id(block_id), new_block_id, "Unexpected block_id."); - self.block_to_parent.insert(new_block_id, self.current_block_id); - self.block_to_function.insert(new_block_id, function_id); } Ok(())