-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FEATURE: use a safer eval code, with limited arithmetic capabilities.
- Loading branch information
1 parent
5f4146d
commit 71c909a
Showing
2 changed files
with
219 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
""" | ||
This file is part of Ardupilot methodic configurator. https://github.com/ArduPilot/MethodicConfigurator. | ||
SPDX-FileCopyrightText: 2024 Amilcar do Carmo Lucas <[email protected]> | ||
SPDX-License-Identifier: GPL-3.0-or-later | ||
""" | ||
|
||
import ast | ||
import logging | ||
import math | ||
import operator | ||
from typing import Callable, Union, cast | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# Type aliases | ||
Number = Union[int, float] | ||
MathFunc = Callable[..., Number] | ||
BinOperator = Callable[[Number, Number], Number] | ||
UnOperator = Callable[[Number], Number] | ||
|
||
def safe_eval(s: str) -> Number: | ||
def checkmath(x: str, *args: Number) -> Number: | ||
if x not in [x for x in dir(math) if "__" not in x]: | ||
msg = f"Unknown func {x}()" | ||
raise SyntaxError(msg) | ||
fun = cast(MathFunc, getattr(math, x)) | ||
try: | ||
return fun(*args) | ||
except TypeError as e: | ||
msg = f"Invalid arguments for {x}(): {e!s}" | ||
raise SyntaxError(msg) from e | ||
|
||
bin_ops: dict[type[ast.operator], BinOperator] = { | ||
ast.Add: operator.add, | ||
ast.Sub: operator.sub, | ||
ast.Mult: operator.mul, | ||
ast.Div: operator.truediv, | ||
ast.Mod: operator.mod, | ||
ast.Pow: operator.pow, | ||
ast.Call: checkmath, | ||
Check failure on line 42 in ardupilot_methodic_configurator/safe_eval.py
|
||
ast.BinOp: ast.BinOp, | ||
} | ||
|
||
un_ops: dict[type[ast.UnaryOp], UnOperator] = { | ||
ast.USub: operator.neg, | ||
Check failure on line 47 in ardupilot_methodic_configurator/safe_eval.py
|
||
ast.UAdd: operator.pos, | ||
Check failure on line 48 in ardupilot_methodic_configurator/safe_eval.py
|
||
ast.UnaryOp: ast.UnaryOp, | ||
} | ||
|
||
tree = ast.parse(s, mode="eval") | ||
|
||
def _eval(node: ast.AST) -> Number: | ||
if isinstance(node, ast.Expression): | ||
logger.debug("Expr") | ||
return _eval(node.body) | ||
if isinstance(node, ast.Constant): | ||
logger.info("Const") | ||
return cast(Number, node.value) | ||
if isinstance(node, ast.Name): | ||
# Handle math constants like pi, e, etc. | ||
logger.info("MathConst") | ||
if hasattr(math, node.id): | ||
return cast(Number, getattr(math, node.id)) | ||
msg = f"Unknown constant: {node.id}" | ||
raise SyntaxError(msg) | ||
if isinstance(node, ast.BinOp): | ||
logger.debug("BinOp") | ||
left = _eval(node.left) | ||
right = _eval(node.right) | ||
if type(node.op) not in bin_ops: | ||
msg = f"Unsupported operator: {type(node.op)}" | ||
raise SyntaxError(msg) | ||
return bin_ops[type(node.op)](left, right) | ||
if isinstance(node, ast.UnaryOp): | ||
logger.debug("UpOp") | ||
operand = _eval(node.operand) | ||
if type(node.op) not in un_ops: | ||
msg = f"Unsupported operator: {type(node.op)}" | ||
raise SyntaxError(msg) | ||
return un_ops[type(node.op)](operand) | ||
if isinstance(node, ast.Call): | ||
if not isinstance(node.func, ast.Name): | ||
msg = "Only direct math function calls allowed" | ||
raise SyntaxError(msg) | ||
args = [_eval(x) for x in node.args] | ||
return checkmath(node.func.id, *args) | ||
msg = f"Bad syntax, {type(node)}" | ||
raise SyntaxError(msg) | ||
|
||
return _eval(tree) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
#!/usr/bin/env python3 | ||
""" | ||
Tests for safe_eval.py. | ||
This file is part of Ardupilot methodic configurator. https://github.com/ArduPilot/MethodicConfigurator | ||
SPDX-FileCopyrightText: 2024 Amilcar do Carmo Lucas <[email protected]> | ||
SPDX-License-Identifier: GPL-3.0-or-later | ||
""" | ||
|
||
import math | ||
|
||
import pytest | ||
|
||
from ardupilot_methodic_configurator.safe_eval import safe_eval | ||
|
||
|
||
def test_basic_arithmetic() -> None: | ||
"""Test basic arithmetic operations.""" | ||
assert safe_eval("1+1") == 2 | ||
assert safe_eval("1+-5") == -4 | ||
assert safe_eval("-1") == -1 | ||
assert safe_eval("-+1") == -1 | ||
assert safe_eval("(100*10)+6") == 1006 | ||
assert safe_eval("100*(10+6)") == 1600 | ||
assert safe_eval("2**4") == 16 | ||
assert pytest.approx(safe_eval("1.2345 * 10")) == 12.345 | ||
|
||
|
||
def test_math_functions() -> None: | ||
"""Test mathematical functions.""" | ||
assert safe_eval("sqrt(16)+1") == 5 | ||
assert safe_eval("sin(0)") == 0 | ||
assert safe_eval("cos(0)") == 1 | ||
assert safe_eval("tan(0)") == 0 | ||
assert safe_eval("log(1)") == 0 | ||
assert safe_eval("exp(0)") == 1 | ||
assert safe_eval("pi") == math.pi | ||
|
||
|
||
def test_complex_expressions() -> None: | ||
"""Test more complex mathematical expressions.""" | ||
assert safe_eval("2 * (3 + 4)") == 14 | ||
assert safe_eval("2 ** 3 * 4") == 32 | ||
assert safe_eval("sqrt(16) + sqrt(9)") == 7 | ||
assert safe_eval("sin(pi/2)") == 1 | ||
|
||
|
||
def test_error_cases() -> None: | ||
"""Test error conditions.""" | ||
with pytest.raises(SyntaxError): | ||
safe_eval("1 + ") # Incomplete expression | ||
|
||
with pytest.raises(SyntaxError): | ||
safe_eval("unknown_func(10)") # Unknown function | ||
|
||
with pytest.raises(SyntaxError): | ||
safe_eval("1 = 1") # Invalid operator | ||
|
||
with pytest.raises(SyntaxError): | ||
safe_eval("import os") # Attempted import | ||
|
||
|
||
def test_nested_expressions() -> None: | ||
"""Test nested mathematical expressions.""" | ||
assert safe_eval("sqrt(pow(3,2) + pow(4,2))") == 5 # Pythagorean theorem | ||
assert safe_eval("log(exp(1))") == 1 | ||
assert safe_eval("sin(pi/6)**2 + cos(pi/6)**2") == pytest.approx(1) | ||
|
||
|
||
def test_division_by_zero() -> None: | ||
"""Test division by zero handling.""" | ||
with pytest.raises(ZeroDivisionError): | ||
safe_eval("1/0") | ||
with pytest.raises(ZeroDivisionError): | ||
safe_eval("10 % 0") | ||
|
||
|
||
def test_invalid_math_functions() -> None: | ||
"""Test invalid math function calls.""" | ||
with pytest.raises(SyntaxError, match=r".*takes exactly one argument.*"): | ||
safe_eval("sin()") # Missing argument | ||
with pytest.raises(SyntaxError, match=r".*takes exactly one argument.*"): | ||
safe_eval("sin(1,2)") # Too many arguments | ||
with pytest.raises(ValueError, match=r"math domain error"): | ||
safe_eval("sqrt(-1)") # Domain error | ||
with pytest.raises(ValueError, match=r"math domain error"): | ||
safe_eval("log(-1)") # Range error | ||
with pytest.raises(SyntaxError, match=r"Unknown func.*"): | ||
safe_eval("unknown(1)") # Unknown function | ||
|
||
|
||
def test_security() -> None: | ||
"""Test against code injection attempts.""" | ||
with pytest.raises(SyntaxError): | ||
safe_eval("__import__('os').system('ls')") | ||
with pytest.raises(SyntaxError): | ||
safe_eval("open('/etc/passwd')") | ||
with pytest.raises(SyntaxError): | ||
safe_eval("eval('1+1')") | ||
|
||
|
||
def test_operator_precedence() -> None: | ||
"""Test operator precedence rules.""" | ||
assert safe_eval("2 + 3 * 4") == 14 | ||
assert safe_eval("(2 + 3) * 4") == 20 | ||
assert safe_eval("-2 ** 2") == -4 # Exponentiation before negation | ||
assert safe_eval("-(2 ** 2)") == -4 | ||
|
||
|
||
def test_float_precision() -> None: | ||
"""Test floating point precision handling.""" | ||
assert pytest.approx(safe_eval("0.1 + 0.2")) == 0.3 | ||
assert pytest.approx(safe_eval("sin(pi/2)")) == 1.0 | ||
assert pytest.approx(safe_eval("cos(pi)")) == -1.0 | ||
assert pytest.approx(safe_eval("exp(log(2.718281828))")) == math.e | ||
|
||
|
||
def test_math_constants() -> None: | ||
"""Test mathematical constants.""" | ||
assert safe_eval("pi") == math.pi | ||
assert safe_eval("e") == math.e | ||
assert safe_eval("tau") == math.tau | ||
assert safe_eval("inf") == math.inf | ||
with pytest.raises(SyntaxError): | ||
safe_eval("not_a_constant") |