Skip to content

Commit

Permalink
call multi-arg closures correctly!
Browse files Browse the repository at this point in the history
  • Loading branch information
jmgrosen committed Jun 6, 2024
1 parent 3812415 commit 807710a
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 17 deletions.
29 changes: 27 additions & 2 deletions src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,36 @@ pub fn compile(code: String) -> Result<Vec<u8>, String> {
translator2.globals[def_idx] = ir2::GlobalDef::ClosedExpr { body: expr_ir2 };
}

for (i, func) in translator2.globals.iter().enumerate() {
let mut global_defs = translator2.globals;
for (i, func) in global_defs.iter().enumerate() {
println!("global {i}: {func:?}");
}

let wasm_bytes = wasm::translate(&translator2.globals, main.unwrap());
// generate partial application functions
let partial_app_def_offset = global_defs.len() as u32;
let max_arity = global_defs.iter().map(|def| def.arity().unwrap_or(0)).max().unwrap();
// let partial_app_defs = Vec::with_capacity(max_arity * (max_arity + 1) / 2 - max_arity);
for arity in 2..=max_arity {
for n_args in 1..arity {
let n_remaining_args = arity - n_args;
let args_to_call = (n_remaining_args..arity).map(|i| {
expr2_arena.alloc(ir2::Expr::Var(ir1::DebruijnIndex(i+1)))
}).chain((0..n_remaining_args).map(|i| {
expr2_arena.alloc(ir2::Expr::Var(ir1::DebruijnIndex(i)))
}));
global_defs.push(ir2::GlobalDef::Func {
rec: false,
arity: n_remaining_args,
env_size: n_args + 1,
body: expr2_arena.alloc(ir2::Expr::CallIndirect(
expr2_arena.alloc(ir2::Expr::Var(ir1::DebruijnIndex(n_remaining_args))),
expr2_arena.alloc_slice_r(args_to_call)
)),
});
}
}

let wasm_bytes = wasm::translate(&global_defs, partial_app_def_offset, main.unwrap());

Ok(wasm_bytes)
}
11 changes: 11 additions & 0 deletions src/ir2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ pub enum GlobalDef<'a> {
},
}

impl<'a> GlobalDef<'a> {
pub fn arity(&self) -> Option<u32> {
match *self {
GlobalDef::Func { arity, .. } =>
Some(arity),
GlobalDef::ClosedExpr { .. } =>
None,
}
}
}

pub struct Translator<'a> {
pub arena: &'a ArenaPlus<'a, Expr<'a>>,
pub globals: Vec<GlobalDef<'a>>,
Expand Down
29 changes: 27 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,11 +364,36 @@ fn cmd_compile<'a>(toplevel: &mut TopLevel<'a>, file: Option<PathBuf>, out: Opti
println!("{}: {:?}", toplevel.interner.resolve(name).unwrap(), expr_ir2);
}

for (i, func) in translator2.globals.iter().enumerate() {
let mut global_defs = translator2.globals;
for (i, func) in global_defs.iter().enumerate() {
println!("global {i}: {func:?}");
}

let wasm_bytes = wasm::translate(&translator2.globals, main.unwrap());
// generate partial application functions
let partial_app_def_offset = global_defs.len() as u32;
let max_arity = global_defs.iter().map(|def| def.arity().unwrap_or(0)).max().unwrap();
// let partial_app_defs = Vec::with_capacity(max_arity * (max_arity + 1) / 2 - max_arity);
for arity in 2..=max_arity {
for n_args in 1..arity {
let n_remaining_args = arity - n_args;
let args_to_call = (n_remaining_args..arity).map(|i| {
expr2_arena.alloc(ir2::Expr::Var(ir1::DebruijnIndex(i+1)))
}).chain((0..n_remaining_args).map(|i| {
expr2_arena.alloc(ir2::Expr::Var(ir1::DebruijnIndex(i)))
}));
global_defs.push(ir2::GlobalDef::Func {
rec: false,
arity: n_remaining_args,
env_size: n_args + 1,
body: expr2_arena.alloc(ir2::Expr::CallIndirect(
expr2_arena.alloc(ir2::Expr::Var(ir1::DebruijnIndex(n_remaining_args))),
expr2_arena.alloc_slice_r(args_to_call)
)),
});
}
}

let wasm_bytes = wasm::translate(&global_defs, partial_app_def_offset, main.unwrap());
let orig_wasm_bytes = wasm_bytes.clone();

write_file(out.as_deref(), &wasm_bytes)?;
Expand Down
73 changes: 60 additions & 13 deletions src/wasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::runtime::Runtime;

const RUNTIME_BYTES: &'static [u8] = include_bytes!(env!("CARGO_CDYLIB_FILE_CLOCKY_RUNTIME"));

pub fn translate<'a>(global_defs: &[GlobalDef<'a>], main: usize) -> Vec<u8> {
pub fn translate<'a>(global_defs: &[GlobalDef<'a>], partial_app_def_offset: u32, main: usize) -> Vec<u8> {
// TODO: can we parse more of this at compile time?
// probably... would have to be a build script though, I imagine
let runtime = Runtime::from_bytes(RUNTIME_BYTES);
Expand Down Expand Up @@ -44,11 +44,13 @@ pub fn translate<'a>(global_defs: &[GlobalDef<'a>], main: usize) -> Vec<u8> {
runtime.emit_code(&mut codes);
runtime.emit_data(&mut data);


let func_offset = runtime.functions.len() as u32;
let mut translator = Translator {
globals: global_defs,
globals_offset,
func_table_offset: func_offset, // TODO: is this right?
partial_app_table_offset: partial_app_def_offset + func_offset,
function_types: &mut function_types,
bad_global,
runtime_exports: &runtime.exports,
Expand Down Expand Up @@ -183,6 +185,7 @@ struct Translator<'a> {
globals: &'a [GlobalDef<'a>],
globals_offset: u32,
func_table_offset: u32,
partial_app_table_offset: u32,
function_types: &'a mut FunctionTypes,
bad_global: u32,
runtime_exports: &'a HashMap<String, (wasmparser::ExternalKind, u32)>,
Expand Down Expand Up @@ -294,30 +297,74 @@ impl<'a, 'b> FuncTranslator<'a, 'b> {
self.translate(ctx.clone(), arg);
}
self.translate(ctx, target);
let t0 = self.temp(wasm::ValType::I32, 0);
self.insns.push(wasm::Instruction::LocalTee(t0));
let call_target_closure = self.temp(wasm::ValType::I32, 0);
self.insns.push(wasm::Instruction::LocalTee(call_target_closure));
self.insns.push(wasm::Instruction::I32Load(wasm::MemArg { offset: 4, align: 2, memory_index: 0 }));
let t1 = self.temp(wasm::ValType::I32, 1);
self.insns.push(wasm::Instruction::LocalSet(t1));
let arity = self.temp(wasm::ValType::I32, 1);
self.insns.push(wasm::Instruction::LocalTee(arity));

self.insns.push(wasm::Instruction::LocalGet(t1));
// TODO: should this logic be outsourced to a single
// function per number of args? probably
self.insns.push(wasm::Instruction::I32Const(args.len() as i32));
self.insns.push(wasm::Instruction::I32Eq);
let if_idx = self.translator.function_types.for_args(args.len() as u32);
self.insns.push(wasm::Instruction::If(wasm::BlockType::FunctionType(if_idx)));
self.insns.push(wasm::Instruction::LocalGet(t0));
self.insns.push(wasm::Instruction::LocalGet(t0));

self.insns.push(wasm::Instruction::LocalGet(call_target_closure));
self.insns.push(wasm::Instruction::LocalGet(call_target_closure));
self.insns.push(wasm::Instruction::I32Load(wasm::MemArg { offset: 0, align: 2, memory_index: 0 }));
let funty_idx = self.translator.function_types.for_args(args.len() as u32 + 1);
self.insns.push(wasm::Instruction::CallIndirect { ty: funty_idx, table: 0 });

self.insns.push(wasm::Instruction::Else);
// TODO: create closure!

// we'll have to create a partial application closure.
let closure_size = (args.len() as i32 + 3) * 4;
self.insns.push(wasm::Instruction::I32Const(closure_size));
self.alloc();
let partial_app_closure = self.temp(wasm::ValType::I32, 2);
self.insns.push(wasm::Instruction::LocalTee(partial_app_closure));

// TODO: move this calculation logic somewhere else to couple this less badly.
//
// if n is the arity of the closure we're trying to
// call and m is the number of arguments we want to
// apply (where m < n), the index of the function
// should be (((n-1) * (n-2))/2 + (m - 1))
self.insns.push(wasm::Instruction::LocalGet(arity));
self.insns.push(wasm::Instruction::I32Const(1));
self.insns.push(wasm::Instruction::I32Sub);
self.insns.push(wasm::Instruction::LocalGet(arity));
self.insns.push(wasm::Instruction::I32Const(2));
self.insns.push(wasm::Instruction::I32Sub);
self.insns.push(wasm::Instruction::I32Mul);
self.insns.push(wasm::Instruction::I32Const(1));
self.insns.push(wasm::Instruction::GlobalSet(self.translator.bad_global));
for _ in 0..args.len() {
self.insns.push(wasm::Instruction::Drop);
self.insns.push(wasm::Instruction::I32ShrU);
let offset = self.translator.partial_app_table_offset as i32 + args.len() as i32 - 1;
self.insns.push(wasm::Instruction::I32Const(offset));
self.insns.push(wasm::Instruction::I32Add);
self.insns.push(wasm::Instruction::I32Store(wasm::MemArg { offset: 0, align: 2, memory_index: 0 }));

self.insns.push(wasm::Instruction::LocalGet(partial_app_closure));
self.insns.push(wasm::Instruction::LocalGet(arity));
self.insns.push(wasm::Instruction::I32Const(args.len() as i32));
self.insns.push(wasm::Instruction::I32Sub);
self.insns.push(wasm::Instruction::I32Store(wasm::MemArg { offset: 4, align: 2, memory_index: 0 }));

self.insns.push(wasm::Instruction::LocalGet(partial_app_closure));
self.insns.push(wasm::Instruction::LocalGet(call_target_closure));
self.insns.push(wasm::Instruction::I32Store(wasm::MemArg { offset: 8, align: 2, memory_index: 0 }));

let arg_temp = self.temp(wasm::ValType::I32, 3);
for i in 0..args.len() {
self.insns.push(wasm::Instruction::LocalSet(arg_temp));
self.insns.push(wasm::Instruction::LocalGet(partial_app_closure));
self.insns.push(wasm::Instruction::LocalGet(arg_temp));
self.insns.push(wasm::Instruction::I32Store(wasm::MemArg { offset: 12 + (i as u64) * 4, align: 2, memory_index: 0 }));
}
self.insns.push(wasm::Instruction::I32Const(0));

self.insns.push(wasm::Instruction::LocalGet(partial_app_closure));

self.insns.push(wasm::Instruction::End);
},
}
Expand Down

0 comments on commit 807710a

Please sign in to comment.