Skip to content

Commit

Permalink
implement f32 and f64
Browse files Browse the repository at this point in the history
  • Loading branch information
tertsdiepraam committed Feb 26, 2025
1 parent 0d8a9ac commit d6b5cb4
Show file tree
Hide file tree
Showing 19 changed files with 564 additions and 99 deletions.
1 change: 1 addition & 0 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ pub enum Literal {
Asn(Asn),
IpAddress(std::net::IpAddr),
Integer(i64),
Float(f64),
Bool(bool),
}

Expand Down
8 changes: 8 additions & 0 deletions src/codegen/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ fn check_roto_type(
let I16: TypeId = TypeId::of::<i16>();
let I32: TypeId = TypeId::of::<i32>();
let I64: TypeId = TypeId::of::<i64>();
let F32: TypeId = TypeId::of::<f32>();
let F64: TypeId = TypeId::of::<f64>();
let UNIT: TypeId = TypeId::of::<()>();
let ASN: TypeId = TypeId::of::<Asn>();
let IPADDR: TypeId = TypeId::of::<IpAddr>();
Expand All @@ -108,6 +110,10 @@ fn check_roto_type(
roto_ty = Type::Primitive(Primitive::I32);
}

if let Type::FloatVar(_) = roto_ty {
roto_ty = Type::Primitive(Primitive::F64);
}

match rust_ty.description {
TypeDescription::Leaf => {
let expected_roto = match rust_ty.type_id {
Expand All @@ -120,6 +126,8 @@ fn check_roto_type(
x if x == I16 => Type::Primitive(Primitive::I16),
x if x == I32 => Type::Primitive(Primitive::I32),
x if x == I64 => Type::Primitive(Primitive::I64),
x if x == F32 => Type::Primitive(Primitive::F32),
x if x == F64 => Type::Primitive(Primitive::F64),
x if x == UNIT => Type::Primitive(Primitive::Unit),
x if x == ASN => Type::Primitive(Primitive::Asn),
x if x == IPADDR => Type::Primitive(Primitive::IpAddr),
Expand Down
115 changes: 91 additions & 24 deletions src/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::{
use crate::{
ast::Identifier,
lower::{
ir::{self, IntCmp, Operand, Var, VarKind},
ir::{self, FloatCmp, IntCmp, Operand, Var, VarKind},
label::{LabelRef, LabelStore},
value::IrType,
IrFunction,
Expand Down Expand Up @@ -42,7 +42,7 @@ use cranelift::{
},
jit::{JITBuilder, JITModule},
module::{DataDescription, FuncId, FuncOrDataId, Linkage, Module as _},
prelude::Signature,
prelude::{FloatCC, Signature},
};
use cranelift_codegen::ir::SigRef;
use log::info;
Expand Down Expand Up @@ -514,6 +514,8 @@ impl ModuleBuilder {
IrType::U16 | IrType::I16 => I16,
IrType::U32 | IrType::I32 | IrType::Asn => I32,
IrType::U64 | IrType::I64 => I64,
IrType::F32 => F32,
IrType::F64 => F64,
IrType::Pointer | IrType::ExtPointer => self.isa.pointer_type(),
IrType::ExtValue => todo!(),
}
Expand Down Expand Up @@ -704,7 +706,7 @@ impl<'c> FuncGen<'c> {
ir::Instruction::Return(None) => {
self.ins().return_(&[]);
}
ir::Instruction::Cmp {
ir::Instruction::IntCmp {
to,
cmp,
left,
Expand All @@ -713,7 +715,19 @@ impl<'c> FuncGen<'c> {
let (l, _) = self.operand(left);
let (r, _) = self.operand(right);
let var = self.variable(to, I8);
let val = self.binop(l, r, cmp);
let val = self.int_cmp(l, r, cmp);
self.def(var, val);
}
ir::Instruction::FloatCmp {
to,
cmp,
left,
right,
} => {
let (l, _) = self.operand(left);
let (r, _) = self.operand(right);
let var = self.variable(to, I8);
let val = self.float_cmp(l, r, cmp);
self.def(var, val);
}
ir::Instruction::Not { to, val } => {
Expand Down Expand Up @@ -743,7 +757,11 @@ impl<'c> FuncGen<'c> {
let var = self.variable(to, left_ty);
// Possibly interesting note for later: this is wrapping
// addition
let val = self.ins().iadd(l, r);
let val = if let F32 | F64 = left_ty {
self.ins().fadd(l, r)
} else {
self.ins().iadd(l, r)
};
self.def(var, val)
}
ir::Instruction::Sub { to, left, right } => {
Expand All @@ -753,7 +771,11 @@ impl<'c> FuncGen<'c> {
let var = self.variable(to, left_ty);
// Possibly interesting note for later: this is wrapping
// subtraction
let val = self.ins().isub(l, r);
let val = if let F32 | F64 = left_ty {
self.ins().fsub(l, r)
} else {
self.ins().isub(l, r)
};
self.def(var, val)
}
ir::Instruction::Mul { to, left, right } => {
Expand All @@ -763,7 +785,11 @@ impl<'c> FuncGen<'c> {
let var = self.variable(to, left_ty);
// Possibly interesting note for later: this is wrapping
// multiplication
let val = self.ins().imul(l, r);
let val = if let F32 | F64 = left_ty {
self.ins().fmul(l, r)
} else {
self.ins().imul(l, r)
};
self.def(var, val)
}
ir::Instruction::Div {
Expand All @@ -784,6 +810,7 @@ impl<'c> FuncGen<'c> {
IrType::U8 | IrType::U16 | IrType::U32 | IrType::U64 => {
self.ins().udiv(l, r)
}
IrType::F32 | IrType::F64 => self.ins().fdiv(l, r),
_ => panic!(),
};
self.def(var, val)
Expand Down Expand Up @@ -1019,7 +1046,6 @@ impl<'c> FuncGen<'c> {
}

fn operand(&mut self, val: &Operand) -> (Value, Type) {
let pointer_ty = self.module.isa.pointer_type();
match val {
ir::Operand::Place(p) => {
let (var, ty) = self.module.variable_map.get(p).map_or_else(
Expand All @@ -1034,25 +1060,49 @@ impl<'c> FuncGen<'c> {
(self.builder.use_var(*var), *ty)
}
ir::Operand::Value(v) => {
let (ty, val) = match v {
IrValue::Bool(x) => (I8, *x as i64),
IrValue::U8(x) => (I8, *x as i64),
IrValue::U16(x) => (I16, *x as i64),
IrValue::U32(x) => (I32, *x as i64),
IrValue::U64(x) => (I64, *x as i64),
IrValue::I8(x) => (I8, *x as i64),
IrValue::I16(x) => (I16, *x as i64),
IrValue::I32(x) => (I32, *x as i64),
IrValue::I64(x) => (I64, *x),
IrValue::Asn(x) => (I32, x.into_u32() as i64),
IrValue::Pointer(x) => (pointer_ty, *x as i64),
_ => todo!(),
};
(self.ins().iconst(ty, val), ty)
if let Some((ty, val)) = self.integer_operand(v) {
(self.ins().iconst(ty, val), ty)
} else if let Some((ty, val)) = self.float_operand(v) {
if ty == F32 {
(self.ins().f32const(val as f32), ty)
} else if ty == F64 {
(self.ins().f64const(val), ty)
} else {
panic!()
}
} else {
todo!()
}
}
}
}

fn integer_operand(&self, val: &IrValue) -> Option<(Type, i64)> {
let pointer_ty = self.module.isa.pointer_type();
Some(match val {
IrValue::Bool(x) => (I8, *x as i64),
IrValue::U8(x) => (I8, *x as i64),
IrValue::U16(x) => (I16, *x as i64),
IrValue::U32(x) => (I32, *x as i64),
IrValue::U64(x) => (I64, *x as i64),
IrValue::I8(x) => (I8, *x as i64),
IrValue::I16(x) => (I16, *x as i64),
IrValue::I32(x) => (I32, *x as i64),
IrValue::I64(x) => (I64, *x),
IrValue::Asn(x) => (I32, x.into_u32() as i64),
IrValue::Pointer(x) => (pointer_ty, *x as i64),
_ => return None,
})
}

fn float_operand(&self, val: &IrValue) -> Option<(Type, f64)> {
Some(match val {
IrValue::F32(x) => (F32, *x as f64),
IrValue::F64(x) => (F64, *x),
_ => return None,
})
}

fn variable(&mut self, var: &Var, ty: Type) -> Variable {
let len = self.module.variable_map.len();
let (var, _ty) =
Expand All @@ -1066,7 +1116,7 @@ impl<'c> FuncGen<'c> {
var
}

fn binop(&mut self, left: Value, right: Value, op: &IntCmp) -> Value {
fn int_cmp(&mut self, left: Value, right: Value, op: &IntCmp) -> Value {
let cc = match op {
IntCmp::Eq => IntCC::Equal,
IntCmp::Ne => IntCC::NotEqual,
Expand All @@ -1081,6 +1131,23 @@ impl<'c> FuncGen<'c> {
};
self.ins().icmp(cc, left, right)
}

fn float_cmp(
&mut self,
left: Value,
right: Value,
op: &FloatCmp,
) -> Value {
let cc = match op {
FloatCmp::Eq => FloatCC::Equal,
FloatCmp::Ne => FloatCC::NotEqual,
FloatCmp::Lt => FloatCC::LessThan,
FloatCmp::Le => FloatCC::LessThanOrEqual,
FloatCmp::Gt => FloatCC::GreaterThan,
FloatCmp::Ge => FloatCC::GreaterThanOrEqual,
};
self.ins().fcmp(cc, left, right)
}
}

impl Module {
Expand Down
95 changes: 95 additions & 0 deletions src/codegen/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,101 @@ fn multiply() {
assert_eq!(res, Verdict::Accept(40));
}

#[test]
fn float_mul() {
let s = src!(
"
filtermap main(x: f32) {
accept 2.0 * x
}
"
);

let mut p = compile(s);
let f = p
.get_function::<(), (f32,), Verdict<f32, ()>>("main")
.expect("No function found (or mismatched types)");

let res = f.call(&mut (), 20.0);
assert_eq!(res, Verdict::Accept(40.0));
}

#[test]
fn float_add() {
let s = src!(
"
filtermap main(x: f32) {
accept 2.0 + x
}
"
);

let mut p = compile(s);
let f = p
.get_function::<(), (f32,), Verdict<f32, ()>>("main")
.expect("No function found (or mismatched types)");

let res = f.call(&mut (), 20.0);
assert_eq!(res, Verdict::Accept(22.0));
}

#[test]
fn float_sub() {
let s = src!(
"
filtermap main(x: f32) {
accept 20.0 - x
}
"
);

let mut p = compile(s);
let f = p
.get_function::<(), (f32,), Verdict<f32, ()>>("main")
.expect("No function found (or mismatched types)");

let res = f.call(&mut (), 2.0);
assert_eq!(res, Verdict::Accept(18.0));
}

#[test]
fn float_cmp() {
let s = src!(
"
filtermap main(x: f32) {
accept x == 20.0
}
"
);

let mut p = compile(s);
let f = p
.get_function::<(), (f32,), Verdict<bool, ()>>("main")
.expect("No function found (or mismatched types)");

let res = f.call(&mut (), 20.0);
assert_eq!(res, Verdict::Accept(true));
}

#[test]
fn float_add_f64() {
let s = src!(
"
filtermap main(x: f64) {
accept x + 20.0
}
"
);

let mut p = compile(s);
let f = p
.get_function::<(), (f64,), Verdict<f64, ()>>("main")
.expect("No function found (or mismatched types)");

let res = f.call(&mut (), 20.0);
assert_eq!(res, Verdict::Accept(40.0));
}

#[test]
fn ip_output() {
let s = src!(
Expand Down
22 changes: 20 additions & 2 deletions src/lower/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::ir::{Function, Operand, Var};
use crate::{
ast::Identifier,
lower::{
ir::{Instruction, IntCmp, VarKind},
ir::{FloatCmp, Instruction, IntCmp, VarKind},
value::IrValue,
},
runtime::{RuntimeConstant, RuntimeFunctionRef},
Expand Down Expand Up @@ -429,7 +429,7 @@ pub fn eval(
return val;
}
}
Instruction::Cmp {
Instruction::IntCmp {
to,
cmp,
left,
Expand All @@ -451,6 +451,24 @@ pub fn eval(
};
vars.insert(to.clone(), IrValue::Bool(res));
}
Instruction::FloatCmp {
to,
cmp,
left,
right,
} => {
let left = eval_operand(&vars, left);
let right = eval_operand(&vars, right);
let res = match cmp {
FloatCmp::Eq => left == right,
FloatCmp::Ne => left != right,
FloatCmp::Lt => left.as_f64() < right.as_f64(),
FloatCmp::Le => left.as_f64() <= right.as_f64(),
FloatCmp::Gt => left.as_f64() > right.as_f64(),
FloatCmp::Ge => left.as_f64() >= right.as_f64(),
};
vars.insert(to.clone(), IrValue::Bool(res));
}
Instruction::Eq { to, left, right } => {
let left = eval_operand(&vars, left);
let right = eval_operand(&vars, right);
Expand Down
Loading

0 comments on commit d6b5cb4

Please sign in to comment.