From c8466d74dfd8e1d2f2db66b0efd307052941d2d6 Mon Sep 17 00:00:00 2001 From: Edd Barrett Date: Wed, 13 Dec 2023 13:19:31 +0000 Subject: [PATCH 1/2] Make a start on the JIT IR and the trace builder. The JIT IR is designed to be small. There are two kinds of instruction: - short instructions, with inlined operands. - long ones (unimplemented as of yet) Ideally, function, block and instruction IDs (which are indices read from the on-disk AOT IR) would be either converted to references as we decode, but this would require changes to the decoder. We may have to ditch deku: https://github.com/sharksforarms/deku/issues/383 Co-authored-by: Lukas Diekmann --- ykrt/Cargo.toml | 1 + ykrt/src/compile/jitc_yk/aot_ir.rs | 148 +++++++- ykrt/src/compile/jitc_yk/jit_ir.rs | 440 ++++++++++++++++++++++ ykrt/src/compile/jitc_yk/mod.rs | 1 + ykrt/src/compile/jitc_yk/trace_builder.rs | 129 ++++++- 5 files changed, 692 insertions(+), 27 deletions(-) create mode 100644 ykrt/src/compile/jitc_yk/jit_ir.rs diff --git a/ykrt/Cargo.toml b/ykrt/Cargo.toml index 21fea4eda..9098324ea 100644 --- a/ykrt/Cargo.toml +++ b/ykrt/Cargo.toml @@ -19,6 +19,7 @@ ykaddr = { path = "../ykaddr" } yksmp = { path = "../yksmp" } strum = { version = "0.25", features = ["derive"] } yktracec = { path = "../yktracec" } +strum_macros = "0.25.3" [dependencies.llvm-sys] # note: using a git version to get llvm linkage features in llvm-sys (not in a diff --git a/ykrt/src/compile/jitc_yk/aot_ir.rs b/ykrt/src/compile/jitc_yk/aot_ir.rs index 4e63d6849..69f4d8d3b 100644 --- a/ykrt/src/compile/jitc_yk/aot_ir.rs +++ b/ykrt/src/compile/jitc_yk/aot_ir.rs @@ -5,13 +5,24 @@ use byteorder::{NativeEndian, ReadBytesExt}; use deku::prelude::*; -use std::{cell::RefCell, error::Error, ffi::CStr, fs, io::Cursor, path::PathBuf}; +use std::{ + cell::RefCell, + error::Error, + ffi::CStr, + fs, + io::Cursor, + ops::{Deref, DerefMut}, + path::PathBuf, +}; /// A magic number that all bytecode payloads begin with. const MAGIC: u32 = 0xedd5f00d; /// The version of the bytecode format. const FORMAT_VERSION: u32 = 0; +/// The symbol name of the control point function (after ykllvm has transformed it). +const CONTROL_POINT_NAME: &str = "__ykrt_control_point"; + fn deserialise_string(v: Vec) -> Result { let err = Err(DekuError::Parse("failed to parse string".to_owned())); match CStr::from_bytes_until_nul(v.as_slice()) { @@ -45,7 +56,7 @@ pub(crate) trait IRDisplay { /// An instruction opcode. #[deku_derive(DekuRead)] -#[derive(Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] #[deku(type = "u8")] pub(crate) enum Opcode { Nop = 0, @@ -82,14 +93,53 @@ impl IRDisplay for ConstantOperand { } #[deku_derive(DekuRead)] -#[derive(Debug)] -pub(crate) struct LocalVariableOperand { +#[derive(Debug, Hash, Eq, PartialEq)] +pub(crate) struct InstructionID { #[deku(skip)] // computed after deserialisation. func_idx: usize, bb_idx: usize, inst_idx: usize, } +impl InstructionID { + pub(crate) fn new(func_idx: usize, bb_idx: usize, inst_idx: usize) -> Self { + Self { + func_idx, + bb_idx, + inst_idx, + } + } +} + +#[derive(Debug)] +pub(crate) struct BlockID { + pub(crate) func_idx: usize, + pub(crate) bb_idx: usize, +} + +impl BlockID { + pub(crate) fn new(func_idx: usize, bb_idx: usize) -> Self { + Self { func_idx, bb_idx } + } +} + +#[deku_derive(DekuRead)] +#[derive(Debug, Hash, Eq, PartialEq)] +pub(crate) struct LocalVariableOperand(InstructionID); + +impl Deref for LocalVariableOperand { + type Target = InstructionID; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for LocalVariableOperand { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + impl IRDisplay for LocalVariableOperand { fn to_str(&self, _m: &Module) -> String { format!("${}_{}", self.bb_idx, self.inst_idx,) @@ -161,6 +211,31 @@ pub(crate) enum Operand { Unimplemented(#[deku(until = "|v: &u8| *v == 0", map = "deserialise_string")] String), } +impl Operand { + /// For a [Self::LocalVariable] operand return the instruction that defines the variable. + /// + /// Panics for other kinds of operand. + /// + /// OPT: This is expensive. + pub(crate) fn to_instr<'a>(&self, aotmod: &'a Module) -> &'a Instruction { + match self { + Self::LocalVariable(lvo) => { + &aotmod.funcs[lvo.func_idx].blocks[lvo.bb_idx].instrs[lvo.inst_idx] + } + _ => panic!(), + } + } + + /// Return the `InstructionID` of a local variable operand. Panics if called on other kinds of + /// operands. + pub(crate) fn to_instr_id(&self) -> InstructionID { + match self { + Self::LocalVariable(lvo) => InstructionID::new(lvo.func_idx, lvo.bb_idx, lvo.inst_idx), + _ => panic!(), + } + } +} + impl IRDisplay for Operand { fn to_str(&self, m: &Module) -> String { match self { @@ -190,6 +265,45 @@ pub(crate) struct Instruction { name: RefCell>, } +impl Instruction { + /// Returns the operand at the specified index. Panics if the index is out of bounds. + pub(crate) fn get_operand(&self, index: usize) -> &Operand { + &self.operands[index] + } + + pub(crate) fn opcode(&self) -> Opcode { + self.opcode + } + + pub(crate) fn is_store(&self) -> bool { + self.opcode == Opcode::Store + } + + pub(crate) fn is_gep(&self) -> bool { + self.opcode == Opcode::GetElementPtr + } + + pub(crate) fn is_control_point(&self, aot_mod: &Module) -> bool { + if self.opcode == Opcode::Call { + // Call instructions always have at least one operand (the callee), so this is safe. + let op = &self.operands[0]; + match op { + Operand::Function(fop) => { + return aot_mod.funcs[fop.func_idx].name == CONTROL_POINT_NAME; + } + _ => todo!(), + } + } + false + } + + /// Determine if two instructions in the (immutable) AOT IR are the same based on pointer + /// identity. + pub(crate) fn ptr_eq(&self, other: &Self) -> bool { + std::ptr::eq(self, other) + } +} + impl IRDisplay for Instruction { fn to_str(&self, m: &Module) -> String { if self.opcode == Opcode::Unimplemented { @@ -249,8 +363,9 @@ impl IRDisplay for Instruction { pub(crate) struct Block { #[deku(temp)] num_instrs: usize, + // FIXME: unpub #[deku(count = "num_instrs")] - instrs: Vec, + pub instrs: Vec, } impl IRDisplay for Block { @@ -546,7 +661,20 @@ impl Module { *self.var_names_computed.borrow_mut() = true; } - /// Fill in the function index of local variable operands of instructions.o + pub(crate) fn func_index(&self, find_func: &str) -> Option { + // OPT: create a cache in the Module. + self.funcs + .iter() + .enumerate() + .find(|(_, f)| f.name == find_func) + .map(|(f_idx, _)| f_idx) + } + + pub(crate) fn block(&self, bid: &BlockID) -> Option<&Block> { + self.funcs.get(bid.func_idx)?.block(bid.bb_idx) + } + + /// Fill in the function index of local variable operands of instructions. /// /// FIXME: It may be possible to do this as we deserialise, instead of after the fact: /// https://github.com/sharksforarms/deku/issues/363 @@ -571,16 +699,12 @@ impl Module { &self.types[instr.type_index] } + // FIXME: rename this to `is_def()`, which we've decided is a beter name. + // FIXME: also move this to the `Instruction` type. fn instr_generates_value(&self, i: &Instruction) -> bool { self.instr_type(i) != &Type::Void } - /// Retrieve the named function from the AOT module. - pub(crate) fn func_by_name(&self, name: &str) -> Option<&Function> { - // OPT: Cache function indices somewhere for faster lookup. - self.funcs.iter().find(|f| f.name == name) - } - pub(crate) fn to_str(&self) -> String { let mut ret = String::new(); ret.push_str(&format!("# IR format version: {}\n", self.version)); diff --git a/ykrt/src/compile/jitc_yk/jit_ir.rs b/ykrt/src/compile/jitc_yk/jit_ir.rs new file mode 100644 index 000000000..b1a5427de --- /dev/null +++ b/ykrt/src/compile/jitc_yk/jit_ir.rs @@ -0,0 +1,440 @@ +//! The Yk JIT IR +//! +//! This is the in-memory trace IR constructed by the trace builder and mutated by optimisations. +//! +//! Design notes: +//! +//! - This module uses `u64` extensively for bit-fields. This is not a consequence of any +//! particular hardware platform, we just chose a 64-bit field. +//! +//! - We avoid heap allocations at all costs. + +use std::fmt; +use strum_macros::FromRepr; + +/// Number of bits used to encode an opcode. +const OPCODE_SIZE: u64 = 8; + +/// Max number of operands in a short instruction. +const SHORT_INSTR_MAX_OPERANDS: u64 = 3; + +/// Bit fiddling. +/// +/// In the constants below: +/// - `*_SIZE`: the size of a field in bits. +/// - `*_MASK`: a mask with one bits occupying the field in question. +/// +/// Bit fiddling for a short operands: +const SHORT_OPERAND_SIZE: u64 = 18; +const SHORT_OPERAND_KIND_SIZE: u64 = 3; +const SHORT_OPERAND_KIND_MASK: u64 = 7; +const SHORT_OPERAND_VALUE_SIZE: u64 = 15; +const SHORT_OPERAND_MASK: u64 = 0x3ffff; +/// Bit fiddling for instructions. +const INSTR_ISSHORT_SIZE: u64 = 1; +const INSTR_ISSHORT_MASK: u64 = 1; +const INSTR_OPCODE_MASK: u64 = 0xe; + +/// An instruction is identified by its index in the instruction vector. +#[derive(Copy, Clone)] +pub(crate) struct InstructionID(usize); + +impl InstructionID { + pub(crate) fn new(v: usize) -> Self { + Self(v) + } + + pub(crate) fn get(&self) -> usize { + self.0 + } +} + +/// An operand kind. +#[repr(u64)] +#[derive(Debug, FromRepr, PartialEq)] +pub enum OpKind { + /// The operand is not present. + /// + /// This is used in short instructions where 3 operands are inlined. If the instruction + /// requires fewer then 3 operands, then it can use this variant to express that. + /// + /// By using the zero discriminant, this means that a freshly created short instruction has + /// with zero operands until they are explicitly filled in. + NotPresent = 0, + /// The operand references a previously defined local variable. + Local, +} + +impl From for OpKind { + fn from(v: u64) -> Self { + // unwrap safe assuming only valid discriminant numbers are used. + Self::from_repr(v).unwrap() + } +} + +#[derive(Debug, FromRepr, PartialEq)] +#[repr(u64)] +pub enum OpCode { + Load, + LoadArg, +} + +impl From for OpCode { + fn from(v: u64) -> Self { + // unwrap safe assuming only valid discriminant numbers are used. + Self::from_repr(v).unwrap() + } +} + +#[derive(Debug, PartialEq)] +pub enum Operand { + Long(LongOperand), + Short(ShortOperand), +} + +impl Operand { + pub(crate) fn new(kind: OpKind, val: u64) -> Self { + // check if the operand's value can fit in a short operand. + if val & (u64::MAX << SHORT_OPERAND_VALUE_SIZE) == 0 { + Self::Short(ShortOperand::new(kind, val)) + } else { + todo!() + } + } + + fn raw(&self) -> u64 { + match self { + Self::Long(_) => todo!(), + Self::Short(op) => op.0, + } + } + + fn kind(&self) -> OpKind { + match self { + Self::Long(_) => todo!(), + Self::Short(op) => op.kind(), + } + } + + fn val(&self) -> u64 { + match self { + Self::Long(_) => todo!(), + Self::Short(op) => op.val(), + } + } + + fn is_short(&self) -> bool { + matches!(self, Self::Short(_)) + } +} + +impl fmt::Display for Operand { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.kind() { + OpKind::Local => write!(f, " %{}", self.val())?, + OpKind::NotPresent => (), + } + Ok(()) + } +} + +#[derive(Debug, PartialEq)] +pub struct LongOperand(u64); + +#[derive(Debug, PartialEq)] +pub struct ShortOperand(u64); + +impl ShortOperand { + fn new(kind: OpKind, val: u64) -> ShortOperand { + ShortOperand((kind as u64) | (val << SHORT_OPERAND_KIND_SIZE)) + } + + fn kind(&self) -> OpKind { + OpKind::from(self.0 & SHORT_OPERAND_KIND_MASK) + } + + fn val(&self) -> u64 { + self.0 >> SHORT_OPERAND_KIND_SIZE + } +} + +/// An instruction. +/// +/// An instruction is either a short instruction or a long instruction. +/// +/// ## Short instruction +/// +/// - A 64-bit bit-field that encodes the entire instruction inline. +/// - Can encode up to three short operands. +/// - Is designed to encode the most commonly encountered instructions. +/// +/// Encoding (LSB first): +/// ```ignore +/// field bit-size +/// ------------------------ +/// is_short=1 1 +/// opcode 8 +/// short_operand0 18 +/// short_operand1 18 +/// short_operand2 18 +/// reserved 1 +/// ``` +/// +/// Where a short operand is encoded like this (LSB first): +/// ```ignore +/// field bit-size +/// -------------------- +/// kind 3 +/// payload 15 +/// ``` +/// +/// ## Long instruction +/// +/// - A pointer to an instruction description. +/// - Can encode an arbitrary number of long operands. +/// +/// The pointer is assumed to be at least 2-byte aligned, thus guaranteeing the LSB to be 0. +#[derive(Debug)] +pub(crate) struct Instruction(u64); + +impl Instruction { + fn new_short(opcode: OpCode) -> Self { + Self(((opcode as u64) << INSTR_ISSHORT_SIZE) | INSTR_ISSHORT_MASK) + } + + /// Returns true if the instruction is short. + fn is_short(&self) -> bool { + self.0 & INSTR_ISSHORT_MASK != 0 + } + + /// Returns the opcode. + fn opcode(&self) -> OpCode { + debug_assert!(self.is_short()); + OpCode::from((self.0 & INSTR_OPCODE_MASK) >> INSTR_ISSHORT_SIZE) + } + + /// Returns the specified operand. + fn operand(&self, index: u64) -> Operand { + if self.is_short() { + // Shift operand down the the LSB. + let op = self.0 >> (INSTR_ISSHORT_SIZE + OPCODE_SIZE + SHORT_OPERAND_SIZE * index); + // Then mask it out. + Operand::Short(ShortOperand(op & SHORT_OPERAND_MASK)) + } else { + todo!() + } + } + + /// Create a new `Load` instruction. + /// + /// ## Operands + /// + /// - ``: The pointer to load from. + /// + /// ## Semantics + /// + /// Return the value obtained by dereferencing the operand (which must be pointer-typed). + pub(crate) fn create_load(op: Operand) -> Self { + if op.is_short() { + let mut instr = Instruction::new_short(OpCode::Load); + instr.set_short_operand(op, 0); + instr + } else { + todo!(); + } + } + + /// Create a new `LoadArg` instruction. + /// + /// ## Operands + /// + /// FIXME + /// + /// ## Semantics + /// + /// FIXME + pub(crate) fn create_loadarg() -> Self { + Instruction::new_short(OpCode::LoadArg) + } + + /// Set the short operand at the specified index. + fn set_short_operand(&mut self, op: Operand, idx: u64) { + debug_assert!(self.is_short()); + debug_assert!(idx < SHORT_INSTR_MAX_OPERANDS); + self.0 |= op.raw() << (INSTR_ISSHORT_SIZE + OPCODE_SIZE + SHORT_OPERAND_SIZE * idx); + } + + /// Returns `true` if the instruction defines a local variable. + pub(crate) fn is_def(&self) -> bool { + match self.opcode() { + OpCode::Load => true, + OpCode::LoadArg => true, + } + } +} + +impl fmt::Display for Instruction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let opc = self.opcode(); + write!(f, "{:?}", opc)?; + if self.is_short() { + for i in 0..=2 { + let op = self.operand(i); + write!(f, "{}", op)?; + } + } + Ok(()) + } +} + +/// The `Module` is the top-level container for JIT IR. +#[derive(Debug)] +pub(crate) struct Module { + /// The name of the module and the eventual symbol name for the JITted code. + name: String, + /// The IR trace as a linear sequence of instructions. + instrs: Vec, +} + +impl Module { + /// Create a new [Module] with the specified name. + pub fn new(name: String) -> Self { + Self { + name, + instrs: Vec::new(), + } + } + + /// Push an instruction to the end of the [Module]. + pub(crate) fn push(&mut self, instr: Instruction) { + self.instrs.push(instr); + } + + /// Returns the number of [Instruction]s in the [Module]. + pub(crate) fn len(&self) -> usize { + self.instrs.len() + } + + /// Print the [Module] to `stderr`. + pub(crate) fn dump(&self) { + eprintln!("{}", self); + } +} + +impl fmt::Display for Module { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "; {}", self.name)?; + for (i, instr) in self.instrs.iter().enumerate() { + if instr.is_def() { + write!(f, "%{} = ", i)?; + } + writeln!(f, "{}", instr)?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn short_instruction() { + let op = Operand::new(OpKind::Local, 10); + let instr = Instruction::create_load(op); + assert_eq!(instr.opcode(), OpCode::Load); + assert_eq!(instr.operand(0).kind(), OpKind::Local); + assert_eq!(instr.operand(0).val(), 10); + assert!(instr.is_def()); + assert_eq!(instr.0, 0xa201); + assert!(instr.is_short()); + } + + #[test] + fn long_instruction() { + // FIXME: expand when long instructions are implemented. + let instr = Instruction(0); + assert!(!instr.is_short()); + } + + /// The IR encoding uses a LSB tag to determine if an instruction is short or not, and if it + /// isn't short then it's interpreted as a box pointer. So a box pointer had better be at least + /// 2-byte aligned! + /// + /// This test (somewhat) proves that we are safe by allocating a bunch of `Box` (which in + /// theory could be stored contiguously) and then checks their addresses don't have the LSB set + /// (as this would indicate 1-byte alignment!). + #[test] + fn tagging_valid() { + let mut boxes = Vec::new(); + for i in 0..8192 { + boxes.push(Box::new(i as u8)); + } + + for b in boxes { + assert_eq!((&*b as *const u8 as usize) & 1, 0); + } + } + + #[test] + fn short_operand_getters() { + let mut word = 1; // short instruction. + + let skip_lsbits = INSTR_ISSHORT_SIZE + OPCODE_SIZE; + + // operand0: + word |= 0x0aaa8 << skip_lsbits; + // operand1: + word |= 0x1bbb1 << skip_lsbits + SHORT_OPERAND_SIZE; + // operand2: + word |= 0x2ccc8 << skip_lsbits + SHORT_OPERAND_SIZE * 2; + + let inst = Instruction(word); + + assert_eq!(inst.operand(0), Operand::Short(ShortOperand(0x0aaa8))); + assert_eq!(inst.operand(0).kind() as u64, 0); + assert_eq!(inst.operand(0).val() as u64, 0x1555); + + assert_eq!(inst.operand(1), Operand::Short(ShortOperand(0x1bbb1))); + assert_eq!(inst.operand(1).kind() as u64, 1); + assert_eq!(inst.operand(1).val() as u64, 0x3776); + + assert_eq!(inst.operand(2), Operand::Short(ShortOperand(0x2ccc8))); + assert_eq!(inst.operand(2).kind() as u64, 0); + assert_eq!(inst.operand(2).val() as u64, 0x5999); + } + + #[test] + fn short_operand_setters() { + let mut inst = Instruction::new_short(OpCode::Load); + inst.set_short_operand(Operand::Short(ShortOperand(0x3ffff)), 0); + debug_assert_eq!(inst.operand(0), Operand::Short(ShortOperand(0x3ffff))); + debug_assert_eq!(inst.operand(1), Operand::Short(ShortOperand(0))); + debug_assert_eq!(inst.operand(2), Operand::Short(ShortOperand(0))); + + let mut inst = Instruction::new_short(OpCode::Load); + inst.set_short_operand(Operand::Short(ShortOperand(0x3ffff)), 1); + debug_assert_eq!(inst.operand(0), Operand::Short(ShortOperand(0))); + debug_assert_eq!(inst.operand(1), Operand::Short(ShortOperand(0x3ffff))); + debug_assert_eq!(inst.operand(2), Operand::Short(ShortOperand(0))); + + let mut inst = Instruction::new_short(OpCode::Load); + inst.set_short_operand(Operand::Short(ShortOperand(0x3ffff)), 2); + debug_assert_eq!(inst.operand(0), Operand::Short(ShortOperand(0))); + debug_assert_eq!(inst.operand(1), Operand::Short(ShortOperand(0))); + debug_assert_eq!(inst.operand(2), Operand::Short(ShortOperand(0x3ffff))); + } + + #[test] + fn does_fit_short_operand() { + for i in 0..SHORT_OPERAND_VALUE_SIZE { + Operand::new(OpKind::Local, 1 << i); + } + } + + #[test] + #[should_panic] + fn doesnt_fit_short_operand() { + Operand::new(OpKind::Local, 1 << SHORT_OPERAND_VALUE_SIZE); + } +} diff --git a/ykrt/src/compile/jitc_yk/mod.rs b/ykrt/src/compile/jitc_yk/mod.rs index 05c0417fc..1225c5270 100644 --- a/ykrt/src/compile/jitc_yk/mod.rs +++ b/ykrt/src/compile/jitc_yk/mod.rs @@ -49,6 +49,7 @@ static PHASES_TO_PRINT: LazyLock> = LazyLock::new(|| { }); pub mod aot_ir; +pub mod jit_ir; mod trace_builder; pub(crate) struct JITCYk; diff --git a/ykrt/src/compile/jitc_yk/trace_builder.rs b/ykrt/src/compile/jitc_yk/trace_builder.rs index 490793b4c..678ee1f24 100644 --- a/ykrt/src/compile/jitc_yk/trace_builder.rs +++ b/ykrt/src/compile/jitc_yk/trace_builder.rs @@ -1,44 +1,142 @@ //! The trace builder. -//! -//! Given a mapped trace and an AOT module, assembles an in-memory Yk IR trace by copying blocks -//! from the AOT IR. The output of this process will be the input to the code generator. -use super::aot_ir::{self, Module}; +use super::aot_ir::{self, IRDisplay, Module}; +use super::jit_ir; use crate::trace::TracedAOTBlock; +use std::collections::HashMap; use std::error::Error; +/// The argument index of the trace inputs struct in the control point call. +const CTRL_POINT_ARGIDX_INPUTS: usize = 3; + +/// Given a mapped trace and an AOT module, assembles an in-memory Yk IR trace by copying blocks +/// from the AOT IR. The output of this process will be the input to the code generator. struct TraceBuilder<'a> { + /// The AOR IR. aot_mod: &'a Module, - jit_mod: Module, + /// The JIT IR this struct builds. + jit_mod: jit_ir::Module, + /// The mapped trace. mtrace: &'a Vec, + // Maps an AOT instruction to a jit instruction via their index-based IDs. + local_map: HashMap, } impl<'a> TraceBuilder<'a> { - fn new(aot_mod: &'a Module, mtrace: &'a Vec) -> Self { + /// Create a trace builder. + /// + /// Arguments: + /// - `trace_name`: The eventual symbol name for the JITted code. + /// - `aot_mod`: The AOT IR module that the trace flows through. + /// - `mtrace`: The mapped trace. + fn new(trace_name: String, aot_mod: &'a Module, mtrace: &'a Vec) -> Self { Self { aot_mod, mtrace, - jit_mod: Module::default(), + jit_mod: jit_ir::Module::new(trace_name), + local_map: HashMap::new(), } } - fn lookup_aot_block(&self, tb: &TracedAOTBlock) -> Option<&aot_ir::Block> { + // Given a mapped block, find the AOT block ID, or return `None` if it is unmapped. + fn lookup_aot_block(&self, tb: &TracedAOTBlock) -> Option { match tb { TracedAOTBlock::Mapped { func_name, bb } => { let func_name = func_name.to_str().unwrap(); // safe: func names are valid UTF-8. - let func = self.aot_mod.func_by_name(func_name)?; - func.block(*bb) + let func = self.aot_mod.func_index(func_name)?; + Some(aot_ir::BlockID::new(func, *bb)) } TracedAOTBlock::Unmappable { .. } => None, } } - fn build(self) -> Result> { + /// Create the prolog of the trace. + fn create_trace_header(&mut self, blk: &aot_ir::Block) { + // Find trace input variables and emit `LoadArg` instructions for them. + let mut last_store = None; + let mut trace_input = None; + let mut input = Vec::new(); + for inst in blk.instrs.iter().rev() { + if inst.is_control_point(self.aot_mod) { + trace_input = Some(inst.get_operand(CTRL_POINT_ARGIDX_INPUTS)); + } + if inst.is_store() { + last_store = Some(inst); + } + if inst.is_gep() { + let op = inst.get_operand(0); + // unwrap safe: we know the AOT code was produced by ykllvm. + if trace_input + .unwrap() + .to_instr(self.aot_mod) + .ptr_eq(op.to_instr(self.aot_mod)) + { + // Found a trace input. + // unwrap safe: we know the AOT code was produced by ykllvm. + let inp = last_store.unwrap().get_operand(0); + input.insert(0, inp.to_instr(self.aot_mod)); + let load_arg = jit_ir::Instruction::create_loadarg(); + self.local_map + .insert(inp.to_instr_id(), self.next_instr_id()); + self.jit_mod.push(load_arg); + } + } + } + } + + /// Walk over a traced AOT block, translating the constituent instructions into the JIT module. + fn process_block(&mut self, bid: aot_ir::BlockID) { + // unwrap safe: can't trace a block not in the AOT module. + let blk = self.aot_mod.block(&bid).unwrap(); + + // Decide how to translate each AOT instruction based upon its opcode. + for (inst_idx, inst) in blk.instrs.iter().enumerate() { + let jit_inst = match inst.opcode() { + aot_ir::Opcode::Load => self.handle_load(inst), + _ => todo!("{:?}", inst), + }; + + // If the AOT instruction defines a new value, then add it to the local map. + if jit_inst.is_def() { + let aot_iid = aot_ir::InstructionID::new(bid.func_idx, bid.bb_idx, inst_idx); + *self.local_map.get_mut(&aot_iid).unwrap() = self.next_instr_id(); + } + + // Insert the newly-translated instruction into the JIT module. + self.jit_mod.push(jit_inst); + } + } + + fn next_instr_id(&self) -> jit_ir::InstructionID { + jit_ir::InstructionID::new(self.jit_mod.len()) + } + + // Translate a `Load` instruction. + fn handle_load(&self, inst: &aot_ir::Instruction) -> jit_ir::Instruction { + let aot_op = inst.get_operand(0); + let jit_op = match aot_op { + aot_ir::Operand::LocalVariable(aot_iid) => self.local_map[aot_iid], + _ => todo!("{}", aot_op.to_str(self.aot_mod)), + }; + jit_ir::Instruction::create_load(jit_ir::Operand::new( + jit_ir::OpKind::Local, + u64::try_from(jit_op.get()).unwrap(), + )) + } + + /// Entry point for building an IR trace. + /// + /// Consumes the trace builder, returning a JIT module. + fn build(mut self) -> Result> { + let firstblk = self.lookup_aot_block(&self.mtrace[0]); + debug_assert!(firstblk.is_some()); + self.create_trace_header(self.aot_mod.block(&firstblk.unwrap()).unwrap()); + for tblk in self.mtrace { match self.lookup_aot_block(tblk) { - Some(_blk) => { + Some(bid) => { // Mapped block - todo!(); + self.process_block(bid); } None => { // Unmappable block @@ -54,6 +152,7 @@ impl<'a> TraceBuilder<'a> { pub(super) fn build( aot_mod: &Module, mtrace: &Vec, -) -> Result> { - TraceBuilder::new(aot_mod, mtrace).build() +) -> Result> { + // FIXME: the XXX below should be a thread-safe monotonically incrementing integer. + TraceBuilder::new("__yk_compiled_trace_XXX".into(), aot_mod, mtrace).build() } From 80279aefec91bf8b4ef4dbfba49e22688642cd97 Mon Sep 17 00:00:00 2001 From: Edd Barrett Date: Wed, 20 Dec 2023 11:01:28 +0000 Subject: [PATCH 2/2] Rename: get_operand() -> operand(). --- ykrt/src/compile/jitc_yk/aot_ir.rs | 2 +- ykrt/src/compile/jitc_yk/trace_builder.rs | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ykrt/src/compile/jitc_yk/aot_ir.rs b/ykrt/src/compile/jitc_yk/aot_ir.rs index 69f4d8d3b..d636dcb8b 100644 --- a/ykrt/src/compile/jitc_yk/aot_ir.rs +++ b/ykrt/src/compile/jitc_yk/aot_ir.rs @@ -267,7 +267,7 @@ pub(crate) struct Instruction { impl Instruction { /// Returns the operand at the specified index. Panics if the index is out of bounds. - pub(crate) fn get_operand(&self, index: usize) -> &Operand { + pub(crate) fn operand(&self, index: usize) -> &Operand { &self.operands[index] } diff --git a/ykrt/src/compile/jitc_yk/trace_builder.rs b/ykrt/src/compile/jitc_yk/trace_builder.rs index 678ee1f24..7a914e6b4 100644 --- a/ykrt/src/compile/jitc_yk/trace_builder.rs +++ b/ykrt/src/compile/jitc_yk/trace_builder.rs @@ -58,13 +58,13 @@ impl<'a> TraceBuilder<'a> { let mut input = Vec::new(); for inst in blk.instrs.iter().rev() { if inst.is_control_point(self.aot_mod) { - trace_input = Some(inst.get_operand(CTRL_POINT_ARGIDX_INPUTS)); + trace_input = Some(inst.operand(CTRL_POINT_ARGIDX_INPUTS)); } if inst.is_store() { last_store = Some(inst); } if inst.is_gep() { - let op = inst.get_operand(0); + let op = inst.operand(0); // unwrap safe: we know the AOT code was produced by ykllvm. if trace_input .unwrap() @@ -73,7 +73,7 @@ impl<'a> TraceBuilder<'a> { { // Found a trace input. // unwrap safe: we know the AOT code was produced by ykllvm. - let inp = last_store.unwrap().get_operand(0); + let inp = last_store.unwrap().operand(0); input.insert(0, inp.to_instr(self.aot_mod)); let load_arg = jit_ir::Instruction::create_loadarg(); self.local_map @@ -113,7 +113,7 @@ impl<'a> TraceBuilder<'a> { // Translate a `Load` instruction. fn handle_load(&self, inst: &aot_ir::Instruction) -> jit_ir::Instruction { - let aot_op = inst.get_operand(0); + let aot_op = inst.operand(0); let jit_op = match aot_op { aot_ir::Operand::LocalVariable(aot_iid) => self.local_map[aot_iid], _ => todo!("{}", aot_op.to_str(self.aot_mod)),