Skip to content

Commit

Permalink
Added const-folding 0 index access. (#6818)
Browse files Browse the repository at this point in the history
  • Loading branch information
orizi authored Dec 4, 2024
1 parent c503fd2 commit b826522
Show file tree
Hide file tree
Showing 8 changed files with 1,528 additions and 1,463 deletions.
14 changes: 11 additions & 3 deletions crates/cairo-lang-lowering/src/ids.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,18 @@ impl FunctionId {
pub fn semantic_full_path(&self, db: &dyn LoweringGroup) -> String {
self.lookup_intern(db).semantic_full_path(db)
}
pub fn get_extern(&self, db: &dyn LoweringGroup) -> Option<ExternFunctionId> {
/// Returns the function as an `ExternFunctionId` and its generic arguments, if it is an
/// `extern` functions.
pub fn get_extern(
&self,
db: &dyn LoweringGroup,
) -> Option<(ExternFunctionId, Vec<GenericArgumentId>)> {
let semantic = try_extract_matches!(self.lookup_intern(db), FunctionLongId::Semantic)?;
let generic = semantic.get_concrete(db.upcast()).generic_function;
try_extract_matches!(generic, GenericFunctionId::Extern)
let concrete = semantic.get_concrete(db.upcast());
Some((
try_extract_matches!(concrete.generic_function, GenericFunctionId::Extern)?,
concrete.generic_args,
))
}
}
pub trait SemanticFunctionIdEx {
Expand Down
36 changes: 28 additions & 8 deletions crates/cairo-lang-lowering/src/optimizations/const_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ impl ConstFoldingContext<'_> {
stmt: &mut StatementCall,
additional_consts: &mut Vec<StatementConst>,
) -> Option<StatementConst> {
let id = stmt.function.get_extern(self.db)?;
let (id, _generic_args) = stmt.function.get_extern(self.db)?;
if id == self.felt_sub {
// (a - 0) can be replaced by a.
let val = self.as_int(stmt.inputs[1].var_id)?;
Expand Down Expand Up @@ -331,7 +331,7 @@ impl ConstFoldingContext<'_> {
&mut self,
info: &mut MatchExternInfo,
) -> Option<(Option<StatementConst>, FlatBlockEnd)> {
let id = info.function.get_extern(self.db)?;
let (id, generic_args) = info.function.get_extern(self.db)?;
if self.nz_fns.contains(&id) {
let val = self.as_const(info.inputs[0].var_id)?;
let is_zero = match val {
Expand Down Expand Up @@ -461,9 +461,7 @@ impl ConstFoldingContext<'_> {
} else if id == self.bounded_int_constrain {
let input_var = info.inputs[0].var_id;
let (value, nz_ty) = self.as_int_ex(input_var)?;
let semantic_id =
extract_matches!(info.function.lookup_intern(self.db), FunctionLongId::Semantic);
let generic_arg = semantic_id.get_concrete(self.db.upcast()).generic_args[1];
let generic_arg = generic_args[1];
let constrain_value = extract_matches!(generic_arg, GenericArgumentId::Constant)
.lookup_intern(self.db)
.into_int()
Expand All @@ -474,6 +472,20 @@ impl ConstFoldingContext<'_> {
Some(self.propagate_const_and_get_statement(value.clone(), output, nz_ty)),
FlatBlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()),
))
} else if id == self.array_get {
if self.as_int(info.inputs[1].var_id)?.is_zero() {
if let [success, failure] = info.arms.as_mut_slice() {
let arr = info.inputs[0].var_id;
let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone());
let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone());
info.inputs.truncate(1);
info.function = ModuleHelper { db: self.db, id: self.array_module }
.function_id("array_snapshot_pop_front", generic_args);
success.var_ids.insert(0, unused_arr_output0);
failure.var_ids.insert(0, unused_arr_output1);
}
}
None
} else {
None
}
Expand Down Expand Up @@ -584,8 +596,6 @@ pub struct ConstFoldingLibfuncInfo {
upcast: ExternFunctionId,
/// The `downcast` libfunc.
downcast: ExternFunctionId,
/// The `storage_base_address_from_felt252` libfunc.
storage_base_address_from_felt252: ExternFunctionId,
/// The set of functions that check if a number is zero.
nz_fns: OrderedHashSet<ExternFunctionId>,
/// The set of functions that check if numbers are equal.
Expand All @@ -610,8 +620,14 @@ pub struct ConstFoldingLibfuncInfo {
bounded_int_sub: ExternFunctionId,
/// The `bounded_int_constrain` libfunc.
bounded_int_constrain: ExternFunctionId,
/// The array module.
array_module: ModuleId,
/// The `array_get` libfunc.
array_get: ExternFunctionId,
/// The storage access module.
storage_access_module: ModuleId,
/// The `storage_base_address_from_felt252` libfunc.
storage_base_address_from_felt252: ExternFunctionId,
/// Type ranges.
type_value_ranges: OrderedHashMap<TypeId, TypeInfo>,
}
Expand All @@ -625,6 +641,8 @@ impl ConstFoldingLibfuncInfo {
let bounded_int_module = core.submodule("internal").submodule("bounded_int");
let upcast = integer_module.extern_function_id("upcast");
let downcast = integer_module.extern_function_id("downcast");
let array_module = core.submodule("array");
let array_get = array_module.extern_function_id("array_get");
let starknet_module = core.submodule("starknet");
let storage_access_module = starknet_module.submodule("storage_access");
let storage_base_address_from_felt252 =
Expand Down Expand Up @@ -699,7 +717,6 @@ impl ConstFoldingLibfuncInfo {
into_box,
upcast,
downcast,
storage_base_address_from_felt252,
nz_fns,
eq_fns,
uadd_fns,
Expand All @@ -712,7 +729,10 @@ impl ConstFoldingLibfuncInfo {
bounded_int_add,
bounded_int_sub,
bounded_int_constrain,
array_module: array_module.id,
array_get,
storage_access_module: storage_access_module.id,
storage_base_address_from_felt252,
type_value_ranges,
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4449,3 +4449,82 @@ End:
Return(v10)

//! > lowering_diagnostics

//! > ==========================================================================

//! > Array get at known index 0.

//! > test_runner_name
test_match_optimizer

//! > function
fn foo(x: @Array<u8>) -> Option<Box<@u8>> {
x.get(0)
}

//! > function_name
foo

//! > module_code

//! > semantic_diagnostics

//! > before
Parameters: v0: @core::array::Array::<core::integer::u8>
blk0 (root):
Statements:
(v1: core::integer::u32) <- 0
End:
Match(match core::array::array_get::<core::integer::u8>(v0, v1) {
Option::Some(v2) => blk1,
Option::None => blk2,
})

blk1:
Statements:
(v3: core::option::Option::<core::box::Box::<@core::integer::u8>>) <- Option::Some(v2)
End:
Goto(blk3, {v3 -> v4})

blk2:
Statements:
(v5: ()) <- struct_construct()
(v6: core::option::Option::<core::box::Box::<@core::integer::u8>>) <- Option::None(v5)
End:
Goto(blk3, {v6 -> v4})

blk3:
Statements:
End:
Return(v4)

//! > after
Parameters: v0: @core::array::Array::<core::integer::u8>
blk0 (root):
Statements:
(v1: core::integer::u32) <- 0
End:
Match(match core::array::array_snapshot_pop_front::<core::integer::u8>(v0) {
Option::Some(v7, v2) => blk1,
Option::None(v8) => blk2,
})

blk1:
Statements:
(v3: core::option::Option::<core::box::Box::<@core::integer::u8>>) <- Option::Some(v2)
End:
Goto(blk3, {v3 -> v4})

blk2:
Statements:
(v5: ()) <- struct_construct()
(v6: core::option::Option::<core::box::Box::<@core::integer::u8>>) <- Option::None(v5)
End:
Goto(blk3, {v6 -> v4})

blk3:
Statements:
End:
Return(v4)

//! > lowering_diagnostics
4 changes: 2 additions & 2 deletions crates/cairo-lang-lowering/src/test_data/match
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ test_function_lowering(expect_diagnostics: false)

//! > function
fn foo(a: @Array::<felt252>) -> Option<Box<@felt252>> {
core::array::array_get(a, 0_u32)
core::array::array_get(a, 1_u32)
}

//! > function_name
Expand All @@ -180,7 +180,7 @@ foo
Parameters: v0: core::RangeCheck, v1: @core::array::Array::<core::felt252>
blk0 (root):
Statements:
(v2: core::integer::u32) <- 0
(v2: core::integer::u32) <- 1
End:
Match(match core::array::array_get::<core::felt252>(v0, v1, v2) {
Option::Some(v3, v4) => blk1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ fn test_contract_libfuncs_coverage(name: &str) {

/// Tests that compiled_class_hash() returns the correct hash, by comparing it to hard-coded
/// constant that was computed by other implementations.
#[test_case("account__account", "1663b22c467591b6288c2e063fbad4cda6285ebe4861df9aa3d5bab3f479eb6")]
#[test_case("account__account", "5191417f7d4b2560c387dd09a4a5aa2dfae8204c4f8a324d684a42fd32e46bc")]
fn test_compiled_class_hash(name: &str, expected_hash: &str) {
let compiled_json_path =
get_example_file_path(format!("{name}.compiled_contract_class.json").as_str());
Expand Down
Loading

0 comments on commit b826522

Please sign in to comment.