Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: sum aggregation #61

Merged
merged 2 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions core/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ fn translate_select(schema: &Schema, select: Select) -> Result<Program> {
AggregationFunc::Max => todo!(),
AggregationFunc::Min => todo!(),
AggregationFunc::StringAgg => todo!(),
AggregationFunc::Sum => todo!(),
AggregationFunc::Sum => AggFunc::Sum,
AggregationFunc::Total => todo!(),
};
program.emit_insn(Insn::AggFinal {
Expand Down Expand Up @@ -453,7 +453,20 @@ fn translate_aggregation(
AggregationFunc::Max => todo!(),
AggregationFunc::Min => todo!(),
AggregationFunc::StringAgg => todo!(),
AggregationFunc::Sum => todo!(),
AggregationFunc::Sum => {
if args.len() != 1 {
anyhow::bail!("Parse error: sum bad number of arguments");
}
let expr = &args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, cursor_id, table, &expr, expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
func: crate::vdbe::AggFunc::Sum,
});
target_register
}
AggregationFunc::Total => todo!(),
};
Ok(dest)
Expand Down
108 changes: 101 additions & 7 deletions core/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,112 @@ pub enum Value<'a> {
Blob(&'a Vec<u8>),
}

#[derive(Debug, Clone, PartialEq)]
pub enum AggContext {
Avg(f64, usize), // acc and count
}

#[derive(Debug, Clone, PartialEq)]
pub enum OwnedValue {
Null,
Integer(i64),
Float(f64),
Text(Rc<String>),
Blob(Rc<Vec<u8>>),
Agg(Box<AggContext>),
Agg(Box<AggContext>), // TODO(pere): make this without Box. Currently this might cause cache miss but let's leave it for future analysis
}

#[derive(Debug, Clone, PartialEq)]
pub enum AggContext {
Avg(OwnedValue, OwnedValue), // acc and count
Sum(OwnedValue),
}

impl std::ops::Add<OwnedValue> for OwnedValue {
type Output = OwnedValue;

fn add(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(OwnedValue::Integer(int_left), OwnedValue::Integer(int_right)) => {
OwnedValue::Integer(int_left + int_right)
}
(OwnedValue::Integer(int_left), OwnedValue::Float(float_right)) => {
OwnedValue::Float(int_left as f64 + float_right)
}
(OwnedValue::Float(float_left), OwnedValue::Integer(int_right)) => {
OwnedValue::Float(float_left + int_right as f64)
}
(OwnedValue::Float(float_left), OwnedValue::Float(float_right)) => {
OwnedValue::Float(float_left + float_right)
}
_ => unreachable!(),
}
}
}

impl std::ops::Add<f64> for OwnedValue {
type Output = OwnedValue;

fn add(self, rhs: f64) -> Self::Output {
match self {
OwnedValue::Integer(int_left) => OwnedValue::Float(int_left as f64 + rhs),
OwnedValue::Float(float_left) => OwnedValue::Float(float_left + rhs),
_ => unreachable!(),
}
}
}

impl std::ops::Add<i64> for OwnedValue {
type Output = OwnedValue;

fn add(self, rhs: i64) -> Self::Output {
match self {
OwnedValue::Integer(int_left) => OwnedValue::Integer(int_left + rhs),
OwnedValue::Float(float_left) => OwnedValue::Float(float_left + rhs as f64),
_ => unreachable!(),
}
}
}

impl std::ops::AddAssign for OwnedValue {
fn add_assign(&mut self, rhs: Self) {
*self = self.clone() + rhs;
}
}

impl std::ops::AddAssign<i64> for OwnedValue {
fn add_assign(&mut self, rhs: i64) {
*self = self.clone() + rhs;
}
}

impl std::ops::AddAssign<f64> for OwnedValue {
fn add_assign(&mut self, rhs: f64) {
*self = self.clone() + rhs;
}
}

impl std::ops::Div<OwnedValue> for OwnedValue {
type Output = OwnedValue;

fn div(self, rhs: OwnedValue) -> Self::Output {
match (self, rhs) {
(OwnedValue::Integer(int_left), OwnedValue::Integer(int_right)) => {
OwnedValue::Integer(int_left / int_right)
}
(OwnedValue::Integer(int_left), OwnedValue::Float(float_right)) => {
OwnedValue::Float(int_left as f64 / float_right)
}
(OwnedValue::Float(float_left), OwnedValue::Integer(int_right)) => {
OwnedValue::Float(float_left / int_right as f64)
}
(OwnedValue::Float(float_left), OwnedValue::Float(float_right)) => {
OwnedValue::Float(float_left / float_right)
}
_ => unreachable!(),
}
}
}

impl std::ops::DivAssign<OwnedValue> for OwnedValue {
fn div_assign(&mut self, rhs: OwnedValue) {
*self = self.clone() / rhs;
}
}

pub fn to_value(value: &OwnedValue) -> Value<'_> {
Expand All @@ -34,7 +127,8 @@ pub fn to_value(value: &OwnedValue) -> Value<'_> {
OwnedValue::Text(s) => Value::Text(s),
OwnedValue::Blob(b) => Value::Blob(b),
OwnedValue::Agg(a) => match a.as_ref() {
AggContext::Avg(acc, _count) => Value::Float(*acc), // we assume aggfinal was called
AggContext::Avg(acc, _count) => to_value(acc), // we assume aggfinal was called
AggContext::Sum(acc) => to_value(acc),
_ => todo!(),
},
}
Expand Down
41 changes: 31 additions & 10 deletions core/vdbe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,14 @@ pub enum Insn {

pub enum AggFunc {
Avg,
Sum,
}

impl AggFunc {
fn to_string(&self) -> &str {
match self {
AggFunc::Avg => "avg",
AggFunc::Sum => "sum",
_ => "unknown",
}
}
Expand Down Expand Up @@ -367,8 +369,15 @@ impl Program {
},
Insn::AggStep { acc_reg, col, func } => {
if let OwnedValue::Null = &state.registers[*acc_reg] {
state.registers[*acc_reg] =
OwnedValue::Agg(Box::new(AggContext::Avg(0.0, 0)));
state.registers[*acc_reg] = match func {
AggFunc::Avg => OwnedValue::Agg(Box::new(AggContext::Avg(
OwnedValue::Float(0.0),
OwnedValue::Integer(0),
))),
AggFunc::Sum => {
OwnedValue::Agg(Box::new(AggContext::Sum(OwnedValue::Float(0.0))))
}
};
}
match func {
AggFunc::Avg => {
Expand All @@ -377,14 +386,23 @@ impl Program {
else {
unreachable!();
};
let AggContext::Avg(acc, count) = agg.borrow_mut();
match col {
OwnedValue::Integer(i) => *acc += i as f64,
OwnedValue::Float(f) => *acc += f,
_ => unreachable!(),
}
let AggContext::Avg(acc, count) = agg.borrow_mut() else {
unreachable!();
};
*acc += col;
*count += 1;
}
AggFunc::Sum => {
let col = state.registers[*col].clone();
let OwnedValue::Agg(agg) = state.registers[*acc_reg].borrow_mut()
else {
unreachable!();
};
let AggContext::Sum(acc) = agg.borrow_mut() else {
unreachable!();
};
*acc += col;
}
};
state.pc += 1;
}
Expand All @@ -395,9 +413,12 @@ impl Program {
else {
unreachable!();
};
let AggContext::Avg(acc, count) = agg.borrow_mut();
*acc /= *count as f64
let AggContext::Avg(acc, count) = agg.borrow_mut() else {
unreachable!();
};
*acc /= count.clone();
}
AggFunc::Sum => {}
};
state.pc += 1;
}
Expand Down
Loading