diff --git a/src/lib.rs b/src/lib.rs index c143ff5..ad7320c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ use cel_interpreter::extractors::This; use cel_interpreter::objects::{Key, Map, TryIntoValue}; use cel_interpreter::{Context, ExecutionError, Expression, FunctionContext, Program, Value}; use cel_parser::parse; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::error::Error; use std::fmt; use std::fmt::Debug; @@ -74,9 +74,16 @@ pub fn evaluate_ast_with_context(definition: String, host: Arc) } }; let host = host.clone(); + + // Convert to Expression first to extract skip set + let expr: Expression = data.expression.into(); + + // Extract variables that are compared against string literals + let skip_normalization = extract_string_compared_variables(&expr); + // Transform the expression for null-safe property access let transformed_expr = transform_expression_for_null_safety( - data.expression.into(), + expr, SUPPORTED_FUNCTIONS, &data.device.clone().unwrap_or_default(), &data.computed.clone().unwrap_or_default(), @@ -87,6 +94,7 @@ pub fn evaluate_ast_with_context(definition: String, host: Arc) data.computed, data.device, host, + skip_normalization, ) .map(|val| val.to_passable()) .map_err(|err| err.to_string()); @@ -142,6 +150,10 @@ pub fn evaluate_with_context(definition: String, host: Arc) -> let parsed_expr = parse(data.expression.as_str()); let result = match parsed_expr { Ok(expr) => { + // Extract variables that are compared against string literals + // These should not have string-to-number normalization applied + let skip_normalization = extract_string_compared_variables(&expr); + let transformed_expr = transform_expression_for_null_safety( expr, SUPPORTED_FUNCTIONS, @@ -154,6 +166,7 @@ pub fn evaluate_with_context(definition: String, host: Arc) -> data.computed, data.device, host, + skip_normalization, ) .map(|val| val.to_passable()) .map_err(|err| err.to_string()) @@ -195,6 +208,7 @@ fn execute_with( computed: Option>>, device: Option>>, host: Arc, + skip_normalization: HashSet, ) -> Result { let supported_fn = SUPPORTED_FUNCTIONS; let host = host.clone(); @@ -210,11 +224,12 @@ fn execute_with( .clone(); // Add predefined variables locally to the context + // Skip string-to-number normalization for variables compared against string literals let standardized_variables = variables .map .iter() .map(|it| { - let next = normalize_variables(it.1.clone()); + let next = normalize_variables_with_skip(it.1.clone(), &skip_normalization, it.0); (it.0.clone(), next) }) .collect(); @@ -426,7 +441,9 @@ fn execute_with( ) }) .chain(total_device_properties.iter().map(|(k, v)| { - let mapped_val = normalize_variables(v.clone()); + // Use the skip set when normalizing device properties + let path = format!("device.{}", k); + let mapped_val = normalize_variables_with_skip(v.clone(), &skip_normalization, &path); ( Key::String(Arc::new(k.clone())), mapped_val.to_cel().clone(), @@ -570,26 +587,55 @@ fn execute_with( * - `u64` => `PassableValue::UInt` * - `f64` => `PassableValue::Float` * - All other variants are returned unchanged + * + * Variables in `skip_normalization` set will NOT have string-to-number conversion applied. + * This is used for variables that are compared against string literals in the expression. */ pub fn normalize_variables(passable_value: PassableValue) -> PassableValue { + normalize_variables_with_skip(passable_value, &HashSet::new(), "") +} + +fn normalize_variables_with_skip( + passable_value: PassableValue, + skip_normalization: &HashSet, + current_path: &str, +) -> PassableValue { match passable_value.clone() { PassableValue::String(data) => { let res = match data.as_str() { "true" => PassableValue::Bool(true), "false" => PassableValue::Bool(false), - _ => is_number(passable_value), + _ => { + // Skip number conversion if this variable is compared against a string + if skip_normalization.contains(current_path) { + passable_value + } else { + is_number(passable_value) + } + } }; res } PassableValue::PMap(map) => { let mut new_map = HashMap::new(); for (key, value) in map { - new_map.insert(key, normalize_variables(value)); + let child_path = if current_path.is_empty() { + key.clone() + } else { + format!("{}.{}", current_path, key) + }; + new_map.insert( + key, + normalize_variables_with_skip(value, skip_normalization, &child_path), + ); } PassableValue::PMap(new_map) } PassableValue::List(list) => { - let new_list = list.into_iter().map(normalize_variables).collect(); + let new_list = list + .into_iter() + .map(|v| normalize_variables_with_skip(v, skip_normalization, current_path)) + .collect(); PassableValue::List(new_list) } _ => passable_value, @@ -619,7 +665,7 @@ pub fn normalize_ast_variables(atom: cel_parser::Atom) -> cel_parser::Atom { } /** -* Tries parsing a string atom using numbers, and if it is a number, treats it as such. +* Tries parsing a string atom as a number, and if it is a number, converts it. */ fn is_atom_number(atom: cel_parser::Atom) -> cel_parser::Atom { match atom.clone() { @@ -633,18 +679,18 @@ fn is_atom_number(atom: cel_parser::Atom) -> cel_parser::Atom { _ => {} } match data.parse::() { - Ok(i) => { - if i.fract() == 0.0 { - let as_i64 = i as i64; - if as_i64 as f64 == i { + Ok(f) => { + if f.fract() == 0.0 { + let as_i64 = f as i64; + if as_i64 as f64 == f { return cel_parser::Atom::Int(as_i64); } - let as_u64 = i as u64; - if as_u64 as f64 == i { + let as_u64 = f as u64; + if as_u64 as f64 == f { return cel_parser::Atom::UInt(as_u64); } } - return cel_parser::Atom::Float(i); + return cel_parser::Atom::Float(f); } _ => {} } @@ -655,7 +701,7 @@ fn is_atom_number(atom: cel_parser::Atom) -> cel_parser::Atom { } /** -* Tries parsing a string value using numbers, and if it is a number, treats it as such. +* Tries parsing a string value as a number, and if it is a number, converts it. */ fn is_number(passable: PassableValue) -> PassableValue { match passable.clone() { @@ -669,18 +715,18 @@ fn is_number(passable: PassableValue) -> PassableValue { _ => {} } match data.parse::() { - Ok(i) => { - if i.fract() == 0.0 { - let as_i64 = i as i64; - if as_i64 as f64 == i { + Ok(f) => { + if f.fract() == 0.0 { + let as_i64 = f as i64; + if as_i64 as f64 == f { return PassableValue::Int(as_i64); } - let as_u64 = i as u64; - if as_u64 as f64 == i { + let as_u64 = f as u64; + if as_u64 as f64 == f { return PassableValue::UInt(as_u64); } } - return PassableValue::Float(i); + return PassableValue::Float(f); } _ => {} } @@ -690,6 +736,108 @@ fn is_number(passable: PassableValue) -> PassableValue { } } +/** + * Extracts variable paths that are compared against string literals in the expression. + * These variables should NOT have their string values converted to numbers during normalization, + * as they are meant to be compared as strings (e.g., version comparisons like "009.000" > "007.003.001"). + */ +fn extract_string_compared_variables(expr: &Expression) -> HashSet { + let mut result = HashSet::new(); + extract_string_compared_variables_internal(expr, &mut result); + result +} + +fn extract_string_compared_variables_internal(expr: &Expression, result: &mut HashSet) { + match expr { + Expression::Relation(lhs, _op, rhs) => { + // Check if RHS is a string literal + let rhs_is_string = matches!(rhs.as_ref(), Expression::Atom(cel_parser::Atom::String(_))); + // Check if LHS is a string literal + let lhs_is_string = matches!(lhs.as_ref(), Expression::Atom(cel_parser::Atom::String(_))); + + if rhs_is_string { + // Extract variable path from LHS + if let Some(path) = extract_variable_path(lhs) { + result.insert(path); + } + } + if lhs_is_string { + // Extract variable path from RHS + if let Some(path) = extract_variable_path(rhs) { + result.insert(path); + } + } + + // Continue recursing into both sides + extract_string_compared_variables_internal(lhs, result); + extract_string_compared_variables_internal(rhs, result); + } + Expression::And(lhs, rhs) | Expression::Or(lhs, rhs) => { + extract_string_compared_variables_internal(lhs, result); + extract_string_compared_variables_internal(rhs, result); + } + Expression::Ternary(cond, if_true, if_false) => { + extract_string_compared_variables_internal(cond, result); + extract_string_compared_variables_internal(if_true, result); + extract_string_compared_variables_internal(if_false, result); + } + Expression::Unary(_, operand) => { + extract_string_compared_variables_internal(operand, result); + } + Expression::List(elements) => { + for elem in elements { + extract_string_compared_variables_internal(elem, result); + } + } + Expression::Map(entries) => { + for (k, v) in entries { + extract_string_compared_variables_internal(k, result); + extract_string_compared_variables_internal(v, result); + } + } + Expression::FunctionCall(func, args, _) => { + extract_string_compared_variables_internal(func, result); + for arg in args { + extract_string_compared_variables_internal(arg, result); + } + } + Expression::Member(operand, _) => { + extract_string_compared_variables_internal(operand, result); + } + Expression::Arithmetic(lhs, _, rhs) => { + extract_string_compared_variables_internal(lhs, result); + extract_string_compared_variables_internal(rhs, result); + } + _ => {} + } +} + +/** + * Extracts the variable path from a member access expression (e.g., "device.appVersionPadded"). + * Returns None if the expression is not a simple variable path. + */ +fn extract_variable_path(expr: &Expression) -> Option { + match expr { + Expression::Ident(name) => Some(name.to_string()), + Expression::Member(operand, member) => { + let parent_path = extract_variable_path(operand)?; + match member.as_ref() { + cel_parser::Member::Attribute(attr) => Some(format!("{}.{}", parent_path, attr)), + cel_parser::Member::Index(idx_expr) => { + // For index access like device["key"], try to extract the key + if let Expression::Atom(cel_parser::Atom::String(s)) = idx_expr.as_ref() { + Some(format!("{}.{}", parent_path, s)) + } else { + None + } + } + _ => None, + } + } + _ => None, + } +} + /** * Check if an expression is an atomic value (string, int, float, bool, etc.) */ @@ -2743,6 +2891,109 @@ mod tests { assert!(!result2.is_empty()); } + #[test] + fn test_extract_string_compared_variables() { + // Test that we correctly extract variable paths from expressions + use cel_parser::parse; + + let expr = parse(r#"device.appVersionPadded > "007.003.001""#).unwrap(); + let skip_set = super::extract_string_compared_variables(&expr); + + println!("Skip set: {:?}", skip_set); + assert!(skip_set.contains("device.appVersionPadded"), "Should contain device.appVersionPadded, got: {:?}", skip_set); + } + + #[test] + fn test_version_string_comparison_not_normalized_to_number() { + // Test that version strings like "009.000" are NOT converted to numbers + // when compared against string literals like "007.003.001" + // This was a bug where "009.000" was normalized to Int(9), causing type mismatch + let ctx = Arc::new(TestContext { + map: HashMap::new(), + }); + + // Test case: device.appVersionPadded > "007.003.001" where appVersionPadded = "009.000" + let res = evaluate_with_context( + r#"{ + "variables": { + "map": { + "device": { + "type": "map", + "value": { + "appVersionPadded": {"type": "string", "value": "009.000"} + } + } + } + }, + "expression": "device.appVersionPadded > \"007.003.001\"" + }"#.to_string(), + ctx.clone(), + ); + // Should return true because "009.000" > "007.003.001" lexicographically + assert!(res.contains("true"), "Expected true but got: {}", res); + + // Test the reverse - should be false + let res2 = evaluate_with_context( + r#"{ + "variables": { + "map": { + "device": { + "type": "map", + "value": { + "appVersionPadded": {"type": "string", "value": "005.000"} + } + } + } + }, + "expression": "device.appVersionPadded > \"007.003.001\"" + }"#.to_string(), + ctx.clone(), + ); + // Should return false because "005.000" < "007.003.001" lexicographically + assert!(res2.contains("false"), "Expected false but got: {}", res2); + + // Test equality + let res3 = evaluate_with_context( + r#"{ + "variables": { + "map": { + "device": { + "type": "map", + "value": { + "appVersionPadded": {"type": "string", "value": "007.003.001"} + } + } + } + }, + "expression": "device.appVersionPadded == \"007.003.001\"" + }"#.to_string(), + ctx.clone(), + ); + assert!(res3.contains("true"), "Expected true for equality but got: {}", res3); + + // Test with 2-component version (the original bug case) + let res4 = evaluate_with_context( + r#"{ + "variables": { + "map": { + "device": { + "type": "map", + "value": { + "appVersionPadded": {"type": "string", "value": "009.000"} + } + } + } + }, + "expression": "device.appVersionPadded > \"007.003.001\"" + }"#.to_string(), + ctx.clone(), + ); + // This was the bug - "009.000" should NOT become Int(9) + // It should stay as String("009.000") and compare correctly + assert!(!res4.contains("Err"), "Should not error, got: {}", res4); + assert!(res4.contains("true"), "Expected true but got: {}", res4); + } + #[test] fn test_error_handling_in_property_resolution() { // Test error path in property resolution (lines 500-507)