diff --git a/backend/src/evaluator/mod.rs b/backend/src/evaluator/mod.rs index 02f2551..9d2e4a8 100644 --- a/backend/src/evaluator/mod.rs +++ b/backend/src/evaluator/mod.rs @@ -3,13 +3,14 @@ use serde::{Deserialize, Serialize}; use crate::{ cell::CellRef, common::{LeadErr, LeadErrCode, Literal}, - evaluator::utils::*, + evaluator::{numerics::*, utils::*}, grid::Grid, parser::*, }; use std::{collections::HashSet, f64, fmt}; +mod numerics; mod utils; #[derive(Debug, PartialEq, Clone, Deserialize, Serialize)] @@ -124,81 +125,29 @@ fn evaluate_expr( precs, grid, |nums| { - let mut res = 0.0; - let mut count = 0; - - for eval in nums { - match eval { - Eval::Literal(Literal::Number(num)) => { - res += num; - count += 1; - } - Eval::Unset => {} - _ => unreachable!(), - } - } - - if count == 0 { + if nums.is_empty() { Err(LeadErr { title: "Evaluation error.".into(), desc: "Attempted to divide by zero.".into(), code: LeadErrCode::DivZero, }) } else { - Ok(res / count as f64) + Ok(nums.iter().sum::() / nums.len() as f64) } }, - "AVG".into(), - )?, - "SUM" => eval_numeric_func( - args, - precs, - grid, - |nums| { - Ok(nums - .iter() - .filter_map(|e| { - if let Eval::Literal(Literal::Number(n)) = e { - Some(n) - } else { - None - } - }) - .sum()) - }, - "SUM".into(), - )?, - "PROD" => eval_numeric_func( - args, - precs, - grid, - |nums| { - Ok(nums - .iter() - .filter_map(|e| { - if let Eval::Literal(Literal::Number(n)) = e { - Some(n) - } else { - None - } - }) - .product()) - }, - "PROD".into(), + "AVG", )?, + "SUM" => eval_numeric_func(args, precs, grid, |nums| Ok(nums.iter().sum()), "SUM")?, + "PROD" => { + eval_numeric_func(args, precs, grid, |nums| Ok(nums.iter().product()), "PROD")? + } "MAX" => eval_numeric_func( args, precs, grid, |nums| { nums.iter() - .filter_map(|e| { - if let Eval::Literal(Literal::Number(n)) = e { - Some(*n) // deref to f64 - } else { - None - } - }) + .cloned() .max_by(|a, b| a.partial_cmp(b).unwrap()) .ok_or(LeadErr { title: "Evaluation error.".into(), @@ -206,7 +155,7 @@ fn evaluate_expr( code: LeadErrCode::Unsupported, }) }, - "MAX".into(), + "MAX", )?, "MIN" => eval_numeric_func( args, @@ -214,13 +163,7 @@ fn evaluate_expr( grid, |nums| { nums.iter() - .filter_map(|e| { - if let Eval::Literal(Literal::Number(n)) = e { - Some(*n) // deref to f64 - } else { - None - } - }) + .cloned() .min_by(|a, b| a.partial_cmp(b).unwrap()) .ok_or(LeadErr { title: "Evaluation error.".into(), @@ -228,17 +171,21 @@ fn evaluate_expr( code: LeadErrCode::Unsupported, }) }, - "MIN".into(), + "MIN", )?, - "EXP" => eval_single_arg_numeric(args, precs, grid, |x| x.exp(), "EXP".into())?, - "SIN" => eval_single_arg_numeric(args, precs, grid, |x| x.sin(), "SIN".into())?, - "COS" => eval_single_arg_numeric(args, precs, grid, |x| x.cos(), "COS".into())?, - "TAN" => eval_single_arg_numeric(args, precs, grid, |x| x.tan(), "TAN".into())?, - "ASIN" => eval_single_arg_numeric(args, precs, grid, |x| x.asin(), "ASIN".into())?, - "ACOS" => eval_single_arg_numeric(args, precs, grid, |x| x.acos(), "ACOS".into())?, - "ATAN" => eval_single_arg_numeric(args, precs, grid, |x| x.atan(), "ATAN".into())?, - "PI" => eval_const(args, Eval::Literal(Literal::Number(f64::consts::PI)))?, - "TAU" => eval_const(args, Eval::Literal(Literal::Number(f64::consts::TAU)))?, + "ABS" => eval_abs(args, precs, grid)?, + "LOG" => eval_log(args, precs, grid)?, + "SQRT" => eval_sqrt(args, precs, grid)?, + "EXP" => eval_exp(args, precs, grid)?, + "SIN" => eval_sin(args, precs, grid)?, + "COS" => eval_cos(args, precs, grid)?, + "TAN" => eval_tan(args, precs, grid)?, + "ASIN" => eval_asin(args, precs, grid)?, + "ACOS" => eval_acos(args, precs, grid)?, + "ATAN" => eval_atan(args, precs, grid)?, + "PI" => eval_pi(args)?, + "TAU" => eval_tau(args)?, + "SQRT2" => eval_sqrt2(args)?, it => { return Err(LeadErr { title: "Evaluation error.".into(), @@ -319,77 +266,6 @@ fn eval_range( } } -fn eval_avg( - args: &Vec, - precs: &mut HashSet, - grid: Option<&Grid>, -) -> Result { - let mut res = 0.0; - let mut count = 0; - - for arg in args { - match evaluate_expr(arg, precs, grid)? { - Eval::Literal(Literal::Number(num)) => { - res += num; - count += 1; - } - Eval::Range(range) => { - for cell in range { - let Eval::CellRef { eval, reference: _ } = cell else { - return Err(LeadErr { - title: "Evaluation error.".into(), - desc: "Found non-cellref in RANGE during AVG evaluation.".into(), - code: LeadErrCode::Server, - }); - }; - - if let Eval::Literal(Literal::Number(num)) = *eval { - res += num; - count += 1; - } else if matches!(*eval, Eval::Unset) { - continue; - } else { - return Err(LeadErr { - title: "Evaluation error.".into(), - desc: "Expected numeric types for AVG function.".into(), - code: LeadErrCode::Unsupported, - }); - } - } - } - _ => { - return Err(LeadErr { - title: "Evaluation error.".into(), - desc: "Expected numeric types for AVG function.".into(), - code: LeadErrCode::Unsupported, - }); - } - } - } - - if count == 0 { - Err(LeadErr { - title: "Evaluation error.".into(), - desc: "Attempted to divide by zero.".into(), - code: LeadErrCode::DivZero, - }) - } else { - Ok(Eval::Literal(Literal::Number(res / count as f64))) - } -} - -fn eval_const(args: &Vec, value: Eval) -> Result { - if args.len() != 0 { - return Err(LeadErr { - title: "Evaluation error.".into(), - desc: format!("PI function requires no arguments."), - code: LeadErrCode::Invalid, - }); - } - - Ok(value) -} - fn eval_add(lval: &Eval, rval: &Eval) -> Result { match (lval, rval) { (Eval::Literal(a), Eval::Literal(b)) => { diff --git a/backend/src/evaluator/numerics.rs b/backend/src/evaluator/numerics.rs new file mode 100644 index 0000000..7f5ded2 --- /dev/null +++ b/backend/src/evaluator/numerics.rs @@ -0,0 +1,107 @@ +use std::collections::HashSet; + +use crate::{ + cell::CellRef, + common::{LeadErr, LeadErrCode, Literal}, + evaluator::{Eval, evaluate_expr}, + grid::Grid, + parser::Expr, +}; + +// -------------------------------------------------- // + +fn eval_unary_numeric( + args: &Vec, + precs: &mut HashSet, + grid: Option<&Grid>, + func: fn(f64) -> f64, + func_name: &str, +) -> Result { + if args.len() != 1 { + return Err(LeadErr { + title: "Evaluation error.".into(), + desc: format!("{func_name} function requires a single argument."), + code: LeadErrCode::Invalid, + }); + } + let err = LeadErr { + title: "Evaluation error.".into(), + desc: format!("{func_name} function requires a numeric argument."), + code: LeadErrCode::TypeErr, + }; + match evaluate_expr(&args[0], precs, grid)? { + Eval::Literal(Literal::Number(num)) => Ok(Eval::Literal(Literal::Number(func(num)))), + Eval::CellRef { eval, .. } => match *eval { + Eval::Literal(Literal::Number(n)) => Ok(Eval::Literal(Literal::Number(func(n)))), + _ => Err(err), + }, + _ => Err(err), + } +} + +macro_rules! unary_numeric_func { + ($fn_name:ident, $func:expr, $label:expr) => { + pub fn $fn_name( + args: &Vec, + precs: &mut HashSet, + grid: Option<&Grid>, + ) -> Result { + eval_unary_numeric(args, precs, grid, $func, $label) + } + }; +} + +unary_numeric_func!(eval_exp, |x| x.exp(), "EXP"); +unary_numeric_func!(eval_log, |x| x.ln(), "LOG"); +unary_numeric_func!(eval_sqrt, |x| x.sqrt(), "SQRT"); +unary_numeric_func!(eval_abs, |x| x.abs(), "ABS"); + +unary_numeric_func!(eval_sin, |x| x.sin(), "SIN"); +unary_numeric_func!(eval_cos, |x| x.cos(), "COS"); +unary_numeric_func!(eval_tan, |x| x.tan(), "TAN"); + +unary_numeric_func!(eval_asin, |x| x.asin(), "ASIN"); +unary_numeric_func!(eval_acos, |x| x.acos(), "ACOS"); +unary_numeric_func!(eval_atan, |x| x.atan(), "ATAN"); + +// -------------------------------------------------- // + +fn eval_const(args: &Vec, value: Eval, label: &str) -> Result { + if args.len() != 0 { + return Err(LeadErr { + title: "Evaluation error.".into(), + desc: format!("{label} function requires no arguments."), + code: LeadErrCode::Invalid, + }); + } + + Ok(value) +} + +macro_rules! const_numeric_func { + ($fn_name:ident, $value:expr, $label:expr) => { + pub fn $fn_name(args: &Vec) -> Result { + eval_const(args, $value, $label) + } + }; +} + +const_numeric_func!( + eval_pi, + Eval::Literal(Literal::Number(std::f64::consts::PI)), + "PI" +); + +const_numeric_func!( + eval_tau, + Eval::Literal(Literal::Number(std::f64::consts::TAU)), + "TAU" +); + +const_numeric_func!( + eval_sqrt2, + Eval::Literal(Literal::Number(std::f64::consts::SQRT_2)), + "SQRT2" +); + +// -------------------------------------------------- // diff --git a/backend/src/evaluator/utils.rs b/backend/src/evaluator/utils.rs index 21b0983..dec8dde 100644 --- a/backend/src/evaluator/utils.rs +++ b/backend/src/evaluator/utils.rs @@ -1,4 +1,4 @@ -use std::{collections::HashSet, default}; +use std::collections::HashSet; use crate::{ cell::CellRef, @@ -8,34 +8,6 @@ use crate::{ parser::Expr, }; -pub fn eval_single_arg_numeric( - args: &Vec, - precs: &mut HashSet, - grid: Option<&Grid>, - func: fn(f64) -> f64, - func_name: String, -) -> Result { - if args.len() != 1 { - return Err(LeadErr { - title: "Evaluation error.".into(), - desc: format!("{func_name} function requires a single argument."), - code: LeadErrCode::Invalid, - }); - } - let err = LeadErr { - title: "Evaluation error.".into(), - desc: format!("{func_name} function requires a numeric argument."), - code: LeadErrCode::TypeErr, - }; - match evaluate_expr(&args[0], precs, grid)? { - Eval::Literal(Literal::Number(num)) => Ok(Eval::Literal(Literal::Number(func(num)))), - Eval::CellRef { eval, .. } => match *eval { - Eval::Literal(Literal::Number(n)) => Ok(Eval::Literal(Literal::Number(func(n)))), - _ => Err(err), - }, - _ => Err(err), - } -} pub fn eval_n_arg_numeric( n: usize, @@ -81,53 +53,55 @@ pub fn eval_numeric_func( args: &Vec, precs: &mut HashSet, grid: Option<&Grid>, - func: fn(Vec) -> Result, - func_name: String, + func: impl Fn(&[f64]) -> Result, + func_name: &str, ) -> Result { - let mut numeric_args = Vec::new(); + let mut numbers = Vec::new(); for arg in args { let eval = evaluate_expr(arg, precs, grid)?; - if matches!(eval, Eval::Literal(Literal::Number(_)) | Eval::Unset) { - numeric_args.push(eval); - } else if matches!(eval, Eval::Range(_)) { - if let Eval::Range(range) = eval { + match eval { + Eval::Literal(Literal::Number(n)) => numbers.push(n), + Eval::Unset => {} // skip + Eval::Range(range) => { for cell in range { - let Eval::CellRef { - eval: eval2, - reference: _, - } = cell - else { - return Err(LeadErr { - title: "Evaluation error.".into(), - desc: format!( - "Found non-cellref in RANGE during {func_name} evaluation." - ), - code: LeadErrCode::Server, - }); - }; - - if matches!(*eval2, Eval::Literal(Literal::Number(_)) | Eval::Unset) { - numeric_args.push(*eval2); - } else { - return Err(LeadErr { - title: "Evaluation error.".into(), - desc: format!("Expected numeric types for {func_name} function."), - code: LeadErrCode::Unsupported, - }); + match cell { + Eval::CellRef { eval: boxed, .. } => match *boxed { + Eval::Literal(Literal::Number(n)) => numbers.push(n), + Eval::Unset => {} + _ => { + return Err(LeadErr { + title: "Evaluation error.".into(), + desc: format!( + "Expected numeric types for {func_name} function." + ), + code: LeadErrCode::Unsupported, + }); + } + }, + _ => { + return Err(LeadErr { + title: "Evaluation error.".into(), + desc: format!( + "Found non-cellref in RANGE during {func_name} evaluation." + ), + code: LeadErrCode::Server, + }); + } } } } - } else { - return Err(LeadErr { - title: "Evaluation error.".into(), - desc: format!("Expected numeric types for {func_name} function."), - code: LeadErrCode::Unsupported, - }); + _ => { + return Err(LeadErr { + title: "Evaluation error.".into(), + desc: format!("Expected numeric types for {func_name} function."), + code: LeadErrCode::Unsupported, + }); + } } } - let res = func(numeric_args)?; + let res = func(&numbers)?; Ok(Eval::Literal(Literal::Number(res))) }