Skip to content

Commit

Permalink
CABI: fix flatten_functype() to match canon_lift() and remove extrane…
Browse files Browse the repository at this point in the history
…ous core arg from dtor
  • Loading branch information
lukewagner committed Oct 16, 2024
1 parent 42e78f3 commit c99b9ad
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 17 deletions.
21 changes: 16 additions & 5 deletions design/mvp/CanonicalABI.md
Original file line number Diff line number Diff line change
Expand Up @@ -1629,8 +1629,10 @@ def flatten_functype(opts, ft, context):
else:
match context:
case 'lift':
flat_params = []
flat_results = []
if opts.callback:
flat_results = ['i32']
else:
flat_results = []
case 'lower':
if len(flat_params) > 1:
flat_params = ['i32']
Expand Down Expand Up @@ -2077,16 +2079,21 @@ Based on this, `canon_lift` is defined:
async def canon_lift(opts, inst, ft, callee, caller, on_start, on_return, on_block = default_on_block):
task = Task(opts, inst, ft, caller, on_return, on_block)
flat_args = await task.enter(on_start)
flat_ft = flatten_functype(opts, ft, 'lift')
assert(types_match_values(flat_ft.params, flat_args))
if opts.sync:
flat_results = await call_and_trap_on_throw(callee, task, flat_args)
assert(types_match_values(flat_ft.results, flat_results))
task.return_(flat_results)
if opts.post_return is not None:
[] = await call_and_trap_on_throw(opts.post_return, task, flat_results)
else:
if not opts.callback:
[] = await call_and_trap_on_throw(callee, task, flat_args)
assert(types_match_values(flat_ft.results, []))
else:
[packed_ctx] = await call_and_trap_on_throw(callee, task, flat_args)
assert(types_match_values(flat_ft.results, [packed_ctx]))
while packed_ctx != 0:
is_yield = bool(packed_ctx & 1)
ctx = packed_ctx & ~1
Expand Down Expand Up @@ -2144,6 +2151,8 @@ Given this, `canon_lower` is defined:
```python
async def canon_lower(opts, ft, callee, task, flat_args):
trap_if(not task.inst.may_leave)
flat_ft = flatten_functype(opts, ft, 'lower')
assert(types_match_values(flat_ft.params, flat_args))
subtask = Subtask(opts, ft, task, flat_args)
if opts.sync:
await task.call_sync(callee, task, subtask.on_start, subtask.on_return)
Expand All @@ -2162,6 +2171,7 @@ async def canon_lower(opts, ft, callee, task, flat_args):
flat_results = [i | (int(subtask.state) << 30)]
case Returned():
flat_results = [0]
assert(types_match_values(flat_ft.results, flat_results))
return flat_results
```
In the asynchronous case, if `do_call` blocks before `Subtask.finish`
Expand Down Expand Up @@ -2252,7 +2262,7 @@ async def canon_resource_drop(rt, sync, task, i):
callee_opts = CanonicalOptions(sync = rt.dtor_sync, callback = rt.dtor_callback)
ft = FuncType([U32Type()],[])
callee = partial(canon_lift, callee_opts, rt.impl, ft, rt.dtor)
flat_results = await canon_lower(caller_opts, ft, callee, task, [h.rep, 0])
flat_results = await canon_lower(caller_opts, ft, callee, task, [h.rep])
else:
task.trap_if_on_the_stack(rt.impl)
else:
Expand Down Expand Up @@ -2384,12 +2394,13 @@ Calling `$f` does a non-blocking check for whether an event is already
available, returning whether or not there was such an event as a boolean and,
if there was an event, storing the `i32` event+payload pair as an outparam.
```python
async def canon_task_poll(task, ptr):
async def canon_task_poll(opts, task, ptr):
trap_if(not task.inst.may_leave)
ret = await task.poll()
if ret is None:
return [0]
store(task, ret, TupleType([U32Type(), U32Type()]), ptr)
cx = CallContext(opts, task.inst, task)
store(cx, ret, TupleType([U32Type(), U32Type()]), ptr)
return [1]
```
Note that the `await` of `task.poll` indicates that `task.poll` can yield to
Expand Down
27 changes: 24 additions & 3 deletions design/mvp/canonical-abi/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ class CoreFuncType(CoreExternType):
def __eq__(self, other):
return self.params == other.params and self.results == other.results

def types_match_values(ts, vs):
if len(ts) != len(vs):
return False
return all(type_matches_value(t, v) for t,v in zip(ts, vs))

def type_matches_value(t, v):
match t:
case 'i32' | 'i64': return type(v) == int
case 'f32' | 'f64': return type(v) == float
assert(False)

@dataclass
class CoreMemoryType(CoreExternType):
initial: list[int]
Expand Down Expand Up @@ -1138,8 +1149,10 @@ def flatten_functype(opts, ft, context):
else:
match context:
case 'lift':
flat_params = []
flat_results = []
if opts.callback:
flat_results = ['i32']
else:
flat_results = []
case 'lower':
if len(flat_params) > 1:
flat_params = ['i32']
Expand Down Expand Up @@ -1421,16 +1434,21 @@ def lower_heap_values(cx, vs, ts, out_param):
async def canon_lift(opts, inst, ft, callee, caller, on_start, on_return, on_block = default_on_block):
task = Task(opts, inst, ft, caller, on_return, on_block)
flat_args = await task.enter(on_start)
flat_ft = flatten_functype(opts, ft, 'lift')
assert(types_match_values(flat_ft.params, flat_args))
if opts.sync:
flat_results = await call_and_trap_on_throw(callee, task, flat_args)
assert(types_match_values(flat_ft.results, flat_results))
task.return_(flat_results)
if opts.post_return is not None:
[] = await call_and_trap_on_throw(opts.post_return, task, flat_results)
else:
if not opts.callback:
[] = await call_and_trap_on_throw(callee, task, flat_args)
assert(types_match_values(flat_ft.results, []))
else:
[packed_ctx] = await call_and_trap_on_throw(callee, task, flat_args)
assert(types_match_values(flat_ft.results, [packed_ctx]))
while packed_ctx != 0:
is_yield = bool(packed_ctx & 1)
ctx = packed_ctx & ~1
Expand All @@ -1452,6 +1470,8 @@ async def call_and_trap_on_throw(callee, task, args):

async def canon_lower(opts, ft, callee, task, flat_args):
trap_if(not task.inst.may_leave)
flat_ft = flatten_functype(opts, ft, 'lower')
assert(types_match_values(flat_ft.params, flat_args))
subtask = Subtask(opts, ft, task, flat_args)
if opts.sync:
await task.call_sync(callee, task, subtask.on_start, subtask.on_return)
Expand All @@ -1470,6 +1490,7 @@ async def do_call(on_block):
flat_results = [i | (int(subtask.state) << 30)]
case Returned():
flat_results = [0]
assert(types_match_values(flat_ft.results, flat_results))
return flat_results

### `canon resource.new`
Expand Down Expand Up @@ -1499,7 +1520,7 @@ async def canon_resource_drop(rt, sync, task, i):
callee_opts = CanonicalOptions(sync = rt.dtor_sync, callback = rt.dtor_callback)
ft = FuncType([U32Type()],[])
callee = partial(canon_lift, callee_opts, rt.impl, ft, rt.dtor)
flat_results = await canon_lower(caller_opts, ft, callee, task, [h.rep, 0])
flat_results = await canon_lower(caller_opts, ft, callee, task, [h.rep])
else:
task.trap_if_on_the_stack(rt.impl)
else:
Expand Down
20 changes: 11 additions & 9 deletions design/mvp/canonical-abi/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ async def dtor(task, args):
nonlocal dtor_value
assert(len(args) == 1)
dtor_value = args[0]
return []

rt = ResourceType(ComponentInstance(), dtor) # usable in imports and exports
inst = ComponentInstance()
Expand Down Expand Up @@ -558,7 +559,7 @@ async def core_blocking_producer(task, args):
async def consumer(task, args):
[b] = args
ptr = consumer_heap.realloc(0, 0, 1, 1)
[ret] = await canon_lower(consumer_opts, eager_ft, eager_callee, task, [0, ptr])
[ret] = await canon_lower(consumer_opts, eager_ft, eager_callee, task, [ptr])
assert(ret == 0)
u8 = consumer_heap.memory[ptr]
assert(u8 == 43)
Expand Down Expand Up @@ -596,6 +597,7 @@ async def dtor(task, args):
assert(len(args) == 1)
await task.on_block(dtor_fut)
dtor_value = args[0]
return []
rt = ResourceType(producer_inst, dtor)

[i] = await canon_resource_new(rt, task, 50)
Expand Down Expand Up @@ -652,10 +654,10 @@ async def core_producer_pre(fut, task, args):
async def consumer(task, args):
assert(len(args) == 0)

[ret] = await canon_lower(opts, producer_ft, producer1, task, [0, 0])
[ret] = await canon_lower(opts, producer_ft, producer1, task, [])
assert(ret == (1 | (CallState.STARTED << 30)))

[ret] = await canon_lower(opts, producer_ft, producer2, task, [0, 0])
[ret] = await canon_lower(opts, producer_ft, producer2, task, [])
assert(ret == (2 | (CallState.STARTED << 30)))

fut1.set_result(None)
Expand Down Expand Up @@ -730,10 +732,10 @@ async def producer2_core(task, args):
async def consumer(task, args):
assert(len(args) == 0)

[ret] = await canon_lower(consumer_opts, producer_ft, producer1, task, [0, 0])
[ret] = await canon_lower(consumer_opts, producer_ft, producer1, task, [])
assert(ret == (1 | (CallState.STARTED << 30)))

[ret] = await canon_lower(consumer_opts, producer_ft, producer2, task, [0, 0])
[ret] = await canon_lower(consumer_opts, producer_ft, producer2, task, [])
assert(ret == (2 | (CallState.STARTING << 30)))

assert(await task.poll() is None)
Expand Down Expand Up @@ -808,10 +810,10 @@ async def producer2_core(task, args):
async def consumer(task, args):
assert(len(args) == 0)

[ret] = await canon_lower(consumer_opts, producer_ft, producer1, task, [0, 0])
[ret] = await canon_lower(consumer_opts, producer_ft, producer1, task, [])
assert(ret == (1 | (CallState.RETURNED << 30)))

[ret] = await canon_lower(consumer_opts, producer_ft, producer2, task, [0, 0])
[ret] = await canon_lower(consumer_opts, producer_ft, producer2, task, [])
assert(ret == (2 | (CallState.STARTING << 30)))

assert(await task.poll() is None)
Expand Down Expand Up @@ -872,9 +874,9 @@ async def core_hostcall_pre(fut, task, args):
lower_opts.sync = False

async def core_func(task, args):
[ret] = await canon_lower(lower_opts, ft, hostcall1, task, [0,0])
[ret] = await canon_lower(lower_opts, ft, hostcall1, task, [])
assert(ret == (1 | (CallState.STARTED << 30)))
[ret] = await canon_lower(lower_opts, ft, hostcall2, task, [0,0])
[ret] = await canon_lower(lower_opts, ft, hostcall2, task, [])
assert(ret == (2 | (CallState.STARTED << 30)))

fut1.set_result(None)
Expand Down

0 comments on commit c99b9ad

Please sign in to comment.