Skip to content

Commit

Permalink
Simplified inlining algorithm. (#5359)
Browse files Browse the repository at this point in the history
  • Loading branch information
orizi authored Apr 4, 2024
1 parent dd066ac commit e72d55d
Showing 1 changed file with 39 additions and 76 deletions.
115 changes: 39 additions & 76 deletions crates/cairo-lang-lowering/src/inline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use cairo_lang_diagnostics::{Diagnostics, Maybe};
use cairo_lang_semantic::items::functions::InlineConfiguration;
use cairo_lang_utils::casts::IntoOrPanic;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use itertools::{izip, zip_eq, Itertools};
use itertools::{izip, zip_eq};
use statements_weights::InlineWeight;

use self::statements_weights::ApproxCasmInlineWeight;
Expand Down Expand Up @@ -113,73 +113,47 @@ pub struct FunctionInlinerRewriter<'db> {
/// The LoweringContext were we are building the new blocks.
variables: VariableAllocator<'db>,
/// A Queue of blocks on which we want to apply the FunctionInlinerRewriter.
block_queue: BlockQueue,
block_queue: BlockRewriteQueue,
/// rewritten statements.
statements: Vec<Statement>,

/// The end of the current block.
block_end: FlatBlockEnd,
/// The current block id.
current_block_id: BlockId,
/// stack for statements that require rewriting.
statement_rewrite_stack: StatementStack,
/// The processed statements of the current block.
unprocessed_statements: <Vec<Statement> as IntoIterator>::IntoIter,
/// Indicates that the inlining process was successful.
inlining_success: Maybe<()>,
/// A map between blocks and the parent block that created them.
block_to_parent: HashMap<BlockId, BlockId>,
/// A map between blocks and the function that originally contained them.
block_to_function: HashMap<BlockId, ConcreteFunctionWithBodyId>,
/// The id of the function calling the possibly inlined functions.
calling_function_id: ConcreteFunctionWithBodyId,
}

#[derive(Default)]
pub struct StatementStack {
stack: Vec<Statement>,
}

impl StatementStack {
/// Pushes multiple statement into the stack.
///
/// Note that to keep the order of the statements when they are popped from the stack
/// they need to be pushed in reverse order.
fn push_statements(&mut self, statements: impl DoubleEndedIterator<Item = Statement>) {
self.stack.extend(statements.rev());
}

// Consumes all the statements in the stack.
fn consume(&mut self) -> Vec<Statement> {
self.stack.drain(..).rev().collect_vec()
}

fn pop_statement(&mut self) -> Option<Statement> {
self.stack.pop()
}
}

pub struct BlockQueue {
pub struct BlockRewriteQueue {
/// A Queue of blocks that require processing, and their id.
block_queue: VecDeque<FlatBlock>,
block_queue: VecDeque<(FlatBlock, bool)>,
/// The new blocks that were created during the inlining.
flat_blocks: FlatBlocksBuilder,
}
impl BlockQueue {
impl BlockRewriteQueue {
/// Enqueues the block for processing and returns the block_id that this
/// block is going to get in self.flat_blocks.
fn enqueue_block(&mut self, block: FlatBlock) -> BlockId {
self.block_queue.push_back(block);
fn enqueue_block(&mut self, block: FlatBlock, requires_rewrite: bool) -> BlockId {
self.block_queue.push_back((block, requires_rewrite));
BlockId(self.flat_blocks.len() + self.block_queue.len())
}
// Pops a block from the queue.
/// Pops a block requiring rewrites from the queue.
/// If the block doesn't require rewrites, it is finalized and added to the flat_blocks.
fn dequeue(&mut self) -> Option<FlatBlock> {
self.block_queue.pop_front()
while let Some((block, requires_rewrite)) = self.block_queue.pop_front() {
if requires_rewrite {
return Some(block);
}
self.finalize(block);
}
None
}
/// Finalizes a block.
fn finalize(&mut self, block: FlatBlock) -> BlockId {
self.flat_blocks.alloc(block)
}
}
impl Default for BlockQueue {
fn default() -> Self {
Self { block_queue: Default::default(), flat_blocks: FlatBlocksBuilder::new() }
fn finalize(&mut self, block: FlatBlock) {
self.flat_blocks.alloc(block);
}
}

Expand Down Expand Up @@ -240,35 +214,30 @@ impl<'db> FunctionInlinerRewriter<'db> {
) -> Maybe<FlatLowered> {
let mut rewriter = Self {
variables,
block_queue: BlockQueue {
block_queue: VecDeque::from(flat_lower.blocks.get().clone()),
block_queue: BlockRewriteQueue {
block_queue: flat_lower.blocks.iter().map(|(_, b)| (b.clone(), true)).collect(),
flat_blocks: FlatBlocksBuilder::new(),
},
statements: vec![],
block_end: FlatBlockEnd::NotSet,
current_block_id: BlockId::root(),
statement_rewrite_stack: StatementStack::default(),
unprocessed_statements: Default::default(),
inlining_success: flat_lower.blocks.has_root(),
block_to_parent: HashMap::new(),
block_to_function: (0..flat_lower.blocks.len())
.map(|i| (BlockId(i), calling_function_id))
.collect(),
calling_function_id,
};

rewriter.variables.variables = flat_lower.variables.clone();
while let Some(block) = rewriter.block_queue.dequeue() {
rewriter.block_end = block.end;
rewriter.statement_rewrite_stack.push_statements(block.statements.into_iter());
rewriter.unprocessed_statements = block.statements.into_iter();

while let Some(statement) = rewriter.statement_rewrite_stack.pop_statement() {
while let Some(statement) = rewriter.unprocessed_statements.next() {
rewriter.rewrite(statement)?;
}

rewriter.block_queue.finalize(FlatBlock {
statements: std::mem::take(&mut rewriter.statements),
end: rewriter.block_end,
});
rewriter.current_block_id = rewriter.current_block_id.next_block_id();
}

let blocks = rewriter
Expand All @@ -290,12 +259,9 @@ impl<'db> FunctionInlinerRewriter<'db> {
fn rewrite(&mut self, statement: Statement) -> Maybe<()> {
if let Statement::Call(ref stmt) = statement {
if let Some(called_func) = stmt.function.body(self.variables.db)? {
let orig_func = self.block_to_function[&BlockId::root()];

// TODO: Implement better logic to avoid inlining of destructors that call
// themselves.
if called_func != orig_func
&& orig_func == self.block_to_function[&self.current_block_id]
if called_func != self.calling_function_id
&& self.variables.db.priv_should_inline(called_func)?
{
return self.inline_function(called_func, &stmt.inputs, &stmt.outputs);
Expand All @@ -322,15 +288,13 @@ impl<'db> FunctionInlinerRewriter<'db> {
lowered.blocks.has_root()?;

// Create a new block with all the statements that follow the call statement.
let return_block_id = self.block_queue.enqueue_block(FlatBlock {
statements: self.statement_rewrite_stack.consume(),
end: self.block_end.clone(),
});
if let Some(parent_block_id) = self.block_to_parent.get(&self.current_block_id) {
self.block_to_parent.insert(return_block_id, *parent_block_id);
}
self.block_to_function
.insert(return_block_id, self.block_to_function[&self.current_block_id]);
let return_block_id = self.block_queue.enqueue_block(
FlatBlock {
statements: std::mem::take(&mut self.unprocessed_statements).collect(),
end: self.block_end.clone(),
},
true,
);

// As the block_ids and variable_ids are per function, we need to rename all
// the blocks and variables before we enqueue the blocks from the function that
Expand Down Expand Up @@ -361,11 +325,10 @@ impl<'db> FunctionInlinerRewriter<'db> {

for (block_id, block) in lowered.blocks.iter() {
let block = mapper.rebuild_block(block);

let new_block_id = self.block_queue.enqueue_block(block);
// Inlining is top down - so need to perform further inlining on the inlined function
// blocks.
let new_block_id = self.block_queue.enqueue_block(block, false);
assert_eq!(mapper.map_block_id(block_id), new_block_id, "Unexpected block_id.");
self.block_to_parent.insert(new_block_id, self.current_block_id);
self.block_to_function.insert(new_block_id, function_id);
}

Ok(())
Expand Down

0 comments on commit e72d55d

Please sign in to comment.