From c99b9ad30202b4b3950668177a26138356a4cb02 Mon Sep 17 00:00:00 2001 From: Luke Wagner Date: Wed, 16 Oct 2024 16:34:58 -0700 Subject: [PATCH] CABI: fix flatten_functype() to match canon_lift() and remove extraneous core arg from dtor --- design/mvp/CanonicalABI.md | 21 ++++++++++++++----- design/mvp/canonical-abi/definitions.py | 27 ++++++++++++++++++++++--- design/mvp/canonical-abi/run_tests.py | 20 +++++++++--------- 3 files changed, 51 insertions(+), 17 deletions(-) diff --git a/design/mvp/CanonicalABI.md b/design/mvp/CanonicalABI.md index 0608461b..b4f8947c 100644 --- a/design/mvp/CanonicalABI.md +++ b/design/mvp/CanonicalABI.md @@ -1629,8 +1629,10 @@ def flatten_functype(opts, ft, context): else: match context: case 'lift': - flat_params = [] - flat_results = [] + if opts.callback: + flat_results = ['i32'] + else: + flat_results = [] case 'lower': if len(flat_params) > 1: flat_params = ['i32'] @@ -2077,16 +2079,21 @@ Based on this, `canon_lift` is defined: async def canon_lift(opts, inst, ft, callee, caller, on_start, on_return, on_block = default_on_block): task = Task(opts, inst, ft, caller, on_return, on_block) flat_args = await task.enter(on_start) + flat_ft = flatten_functype(opts, ft, 'lift') + assert(types_match_values(flat_ft.params, flat_args)) if opts.sync: flat_results = await call_and_trap_on_throw(callee, task, flat_args) + assert(types_match_values(flat_ft.results, flat_results)) task.return_(flat_results) if opts.post_return is not None: [] = await call_and_trap_on_throw(opts.post_return, task, flat_results) else: if not opts.callback: [] = await call_and_trap_on_throw(callee, task, flat_args) + assert(types_match_values(flat_ft.results, [])) else: [packed_ctx] = await call_and_trap_on_throw(callee, task, flat_args) + assert(types_match_values(flat_ft.results, [packed_ctx])) while packed_ctx != 0: is_yield = bool(packed_ctx & 1) ctx = packed_ctx & ~1 @@ -2144,6 +2151,8 @@ Given this, `canon_lower` is defined: ```python async def canon_lower(opts, ft, callee, task, flat_args): trap_if(not task.inst.may_leave) + flat_ft = flatten_functype(opts, ft, 'lower') + assert(types_match_values(flat_ft.params, flat_args)) subtask = Subtask(opts, ft, task, flat_args) if opts.sync: await task.call_sync(callee, task, subtask.on_start, subtask.on_return) @@ -2162,6 +2171,7 @@ async def canon_lower(opts, ft, callee, task, flat_args): flat_results = [i | (int(subtask.state) << 30)] case Returned(): flat_results = [0] + assert(types_match_values(flat_ft.results, flat_results)) return flat_results ``` In the asynchronous case, if `do_call` blocks before `Subtask.finish` @@ -2252,7 +2262,7 @@ async def canon_resource_drop(rt, sync, task, i): callee_opts = CanonicalOptions(sync = rt.dtor_sync, callback = rt.dtor_callback) ft = FuncType([U32Type()],[]) callee = partial(canon_lift, callee_opts, rt.impl, ft, rt.dtor) - flat_results = await canon_lower(caller_opts, ft, callee, task, [h.rep, 0]) + flat_results = await canon_lower(caller_opts, ft, callee, task, [h.rep]) else: task.trap_if_on_the_stack(rt.impl) else: @@ -2384,12 +2394,13 @@ Calling `$f` does a non-blocking check for whether an event is already available, returning whether or not there was such an event as a boolean and, if there was an event, storing the `i32` event+payload pair as an outparam. ```python -async def canon_task_poll(task, ptr): +async def canon_task_poll(opts, task, ptr): trap_if(not task.inst.may_leave) ret = await task.poll() if ret is None: return [0] - store(task, ret, TupleType([U32Type(), U32Type()]), ptr) + cx = CallContext(opts, task.inst, task) + store(cx, ret, TupleType([U32Type(), U32Type()]), ptr) return [1] ``` Note that the `await` of `task.poll` indicates that `task.poll` can yield to diff --git a/design/mvp/canonical-abi/definitions.py b/design/mvp/canonical-abi/definitions.py index abd973bf..ef165dca 100644 --- a/design/mvp/canonical-abi/definitions.py +++ b/design/mvp/canonical-abi/definitions.py @@ -52,6 +52,17 @@ class CoreFuncType(CoreExternType): def __eq__(self, other): return self.params == other.params and self.results == other.results +def types_match_values(ts, vs): + if len(ts) != len(vs): + return False + return all(type_matches_value(t, v) for t,v in zip(ts, vs)) + +def type_matches_value(t, v): + match t: + case 'i32' | 'i64': return type(v) == int + case 'f32' | 'f64': return type(v) == float + assert(False) + @dataclass class CoreMemoryType(CoreExternType): initial: list[int] @@ -1138,8 +1149,10 @@ def flatten_functype(opts, ft, context): else: match context: case 'lift': - flat_params = [] - flat_results = [] + if opts.callback: + flat_results = ['i32'] + else: + flat_results = [] case 'lower': if len(flat_params) > 1: flat_params = ['i32'] @@ -1421,16 +1434,21 @@ def lower_heap_values(cx, vs, ts, out_param): async def canon_lift(opts, inst, ft, callee, caller, on_start, on_return, on_block = default_on_block): task = Task(opts, inst, ft, caller, on_return, on_block) flat_args = await task.enter(on_start) + flat_ft = flatten_functype(opts, ft, 'lift') + assert(types_match_values(flat_ft.params, flat_args)) if opts.sync: flat_results = await call_and_trap_on_throw(callee, task, flat_args) + assert(types_match_values(flat_ft.results, flat_results)) task.return_(flat_results) if opts.post_return is not None: [] = await call_and_trap_on_throw(opts.post_return, task, flat_results) else: if not opts.callback: [] = await call_and_trap_on_throw(callee, task, flat_args) + assert(types_match_values(flat_ft.results, [])) else: [packed_ctx] = await call_and_trap_on_throw(callee, task, flat_args) + assert(types_match_values(flat_ft.results, [packed_ctx])) while packed_ctx != 0: is_yield = bool(packed_ctx & 1) ctx = packed_ctx & ~1 @@ -1452,6 +1470,8 @@ async def call_and_trap_on_throw(callee, task, args): async def canon_lower(opts, ft, callee, task, flat_args): trap_if(not task.inst.may_leave) + flat_ft = flatten_functype(opts, ft, 'lower') + assert(types_match_values(flat_ft.params, flat_args)) subtask = Subtask(opts, ft, task, flat_args) if opts.sync: await task.call_sync(callee, task, subtask.on_start, subtask.on_return) @@ -1470,6 +1490,7 @@ async def do_call(on_block): flat_results = [i | (int(subtask.state) << 30)] case Returned(): flat_results = [0] + assert(types_match_values(flat_ft.results, flat_results)) return flat_results ### `canon resource.new` @@ -1499,7 +1520,7 @@ async def canon_resource_drop(rt, sync, task, i): callee_opts = CanonicalOptions(sync = rt.dtor_sync, callback = rt.dtor_callback) ft = FuncType([U32Type()],[]) callee = partial(canon_lift, callee_opts, rt.impl, ft, rt.dtor) - flat_results = await canon_lower(caller_opts, ft, callee, task, [h.rep, 0]) + flat_results = await canon_lower(caller_opts, ft, callee, task, [h.rep]) else: task.trap_if_on_the_stack(rt.impl) else: diff --git a/design/mvp/canonical-abi/run_tests.py b/design/mvp/canonical-abi/run_tests.py index f20ad8d8..653a76d4 100644 --- a/design/mvp/canonical-abi/run_tests.py +++ b/design/mvp/canonical-abi/run_tests.py @@ -419,6 +419,7 @@ async def dtor(task, args): nonlocal dtor_value assert(len(args) == 1) dtor_value = args[0] + return [] rt = ResourceType(ComponentInstance(), dtor) # usable in imports and exports inst = ComponentInstance() @@ -558,7 +559,7 @@ async def core_blocking_producer(task, args): async def consumer(task, args): [b] = args ptr = consumer_heap.realloc(0, 0, 1, 1) - [ret] = await canon_lower(consumer_opts, eager_ft, eager_callee, task, [0, ptr]) + [ret] = await canon_lower(consumer_opts, eager_ft, eager_callee, task, [ptr]) assert(ret == 0) u8 = consumer_heap.memory[ptr] assert(u8 == 43) @@ -596,6 +597,7 @@ async def dtor(task, args): assert(len(args) == 1) await task.on_block(dtor_fut) dtor_value = args[0] + return [] rt = ResourceType(producer_inst, dtor) [i] = await canon_resource_new(rt, task, 50) @@ -652,10 +654,10 @@ async def core_producer_pre(fut, task, args): async def consumer(task, args): assert(len(args) == 0) - [ret] = await canon_lower(opts, producer_ft, producer1, task, [0, 0]) + [ret] = await canon_lower(opts, producer_ft, producer1, task, []) assert(ret == (1 | (CallState.STARTED << 30))) - [ret] = await canon_lower(opts, producer_ft, producer2, task, [0, 0]) + [ret] = await canon_lower(opts, producer_ft, producer2, task, []) assert(ret == (2 | (CallState.STARTED << 30))) fut1.set_result(None) @@ -730,10 +732,10 @@ async def producer2_core(task, args): async def consumer(task, args): assert(len(args) == 0) - [ret] = await canon_lower(consumer_opts, producer_ft, producer1, task, [0, 0]) + [ret] = await canon_lower(consumer_opts, producer_ft, producer1, task, []) assert(ret == (1 | (CallState.STARTED << 30))) - [ret] = await canon_lower(consumer_opts, producer_ft, producer2, task, [0, 0]) + [ret] = await canon_lower(consumer_opts, producer_ft, producer2, task, []) assert(ret == (2 | (CallState.STARTING << 30))) assert(await task.poll() is None) @@ -808,10 +810,10 @@ async def producer2_core(task, args): async def consumer(task, args): assert(len(args) == 0) - [ret] = await canon_lower(consumer_opts, producer_ft, producer1, task, [0, 0]) + [ret] = await canon_lower(consumer_opts, producer_ft, producer1, task, []) assert(ret == (1 | (CallState.RETURNED << 30))) - [ret] = await canon_lower(consumer_opts, producer_ft, producer2, task, [0, 0]) + [ret] = await canon_lower(consumer_opts, producer_ft, producer2, task, []) assert(ret == (2 | (CallState.STARTING << 30))) assert(await task.poll() is None) @@ -872,9 +874,9 @@ async def core_hostcall_pre(fut, task, args): lower_opts.sync = False async def core_func(task, args): - [ret] = await canon_lower(lower_opts, ft, hostcall1, task, [0,0]) + [ret] = await canon_lower(lower_opts, ft, hostcall1, task, []) assert(ret == (1 | (CallState.STARTED << 30))) - [ret] = await canon_lower(lower_opts, ft, hostcall2, task, [0,0]) + [ret] = await canon_lower(lower_opts, ft, hostcall2, task, []) assert(ret == (2 | (CallState.STARTED << 30))) fut1.set_result(None)