Skip to content

Commit

Permalink
FEATURE: use a safer eval code, with limited arithmetic capabilities.
Browse files Browse the repository at this point in the history
  • Loading branch information
amilcarlucas committed Feb 6, 2025
1 parent 5f4146d commit 71c909a
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 0 deletions.
92 changes: 92 additions & 0 deletions ardupilot_methodic_configurator/safe_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
This file is part of Ardupilot methodic configurator. https://github.com/ArduPilot/MethodicConfigurator.
SPDX-FileCopyrightText: 2024 Amilcar do Carmo Lucas <[email protected]>
SPDX-License-Identifier: GPL-3.0-or-later
"""

import ast
import logging
import math
import operator
from typing import Callable, Union, cast

logger = logging.getLogger(__name__)

# Type aliases
Number = Union[int, float]
MathFunc = Callable[..., Number]
BinOperator = Callable[[Number, Number], Number]
UnOperator = Callable[[Number], Number]

def safe_eval(s: str) -> Number:
def checkmath(x: str, *args: Number) -> Number:
if x not in [x for x in dir(math) if "__" not in x]:
msg = f"Unknown func {x}()"
raise SyntaxError(msg)
fun = cast(MathFunc, getattr(math, x))
try:
return fun(*args)
except TypeError as e:
msg = f"Invalid arguments for {x}(): {e!s}"
raise SyntaxError(msg) from e

bin_ops: dict[type[ast.operator], BinOperator] = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Mod: operator.mod,
ast.Pow: operator.pow,
ast.Call: checkmath,

Check failure on line 42 in ardupilot_methodic_configurator/safe_eval.py

View workflow job for this annotation

GitHub Actions / mypy

ardupilot_methodic_configurator/safe_eval.py#L42

[error: Dict entry 6 has incompatible type "type[Call]": "Callable[[str, VarArg(int | float)], int | float]"; expected "type[operator]"]
ast.BinOp: ast.BinOp,

Check failure on line 43 in ardupilot_methodic_configurator/safe_eval.py

View workflow job for this annotation

GitHub Actions / mypy

ardupilot_methodic_configurator/safe_eval.py#L43

[error: Dict entry 7 has incompatible type "type[BinOp]": "type[BinOp]"; expected "type[operator]"]
}

un_ops: dict[type[ast.UnaryOp], UnOperator] = {
ast.USub: operator.neg,

Check failure on line 47 in ardupilot_methodic_configurator/safe_eval.py

View workflow job for this annotation

GitHub Actions / mypy

ardupilot_methodic_configurator/safe_eval.py#L47

[error: Dict entry 0 has incompatible type "type[USub]": "Callable[[_SupportsNeg[_T_co]], _T_co]"; expected "type[UnaryOp]"]
ast.UAdd: operator.pos,

Check failure on line 48 in ardupilot_methodic_configurator/safe_eval.py

View workflow job for this annotation

GitHub Actions / mypy

ardupilot_methodic_configurator/safe_eval.py#L48

[error: Dict entry 1 has incompatible type "type[UAdd]": "Callable[[_SupportsPos[_T_co]], _T_co]"; expected "type[UnaryOp]"]
ast.UnaryOp: ast.UnaryOp,

Check failure on line 49 in ardupilot_methodic_configurator/safe_eval.py

View workflow job for this annotation

GitHub Actions / mypy

ardupilot_methodic_configurator/safe_eval.py#L49

[error: Dict entry 2 has incompatible type "type[UnaryOp]": "type[UnaryOp]"; expected "type[UnaryOp]"]
}

tree = ast.parse(s, mode="eval")

def _eval(node: ast.AST) -> Number:
if isinstance(node, ast.Expression):
logger.debug("Expr")
return _eval(node.body)
if isinstance(node, ast.Constant):
logger.info("Const")
return cast(Number, node.value)
if isinstance(node, ast.Name):
# Handle math constants like pi, e, etc.
logger.info("MathConst")
if hasattr(math, node.id):
return cast(Number, getattr(math, node.id))
msg = f"Unknown constant: {node.id}"
raise SyntaxError(msg)
if isinstance(node, ast.BinOp):
logger.debug("BinOp")
left = _eval(node.left)
right = _eval(node.right)
if type(node.op) not in bin_ops:
msg = f"Unsupported operator: {type(node.op)}"
raise SyntaxError(msg)
return bin_ops[type(node.op)](left, right)
if isinstance(node, ast.UnaryOp):
logger.debug("UpOp")
operand = _eval(node.operand)
if type(node.op) not in un_ops:

Check failure on line 79 in ardupilot_methodic_configurator/safe_eval.py

View workflow job for this annotation

GitHub Actions / mypy

ardupilot_methodic_configurator/safe_eval.py#L79

[error: Non-overlapping container check (element type: "type[unaryop]", container item type]
msg = f"Unsupported operator: {type(node.op)}"
raise SyntaxError(msg)
return un_ops[type(node.op)](operand)

Check failure on line 82 in ardupilot_methodic_configurator/safe_eval.py

View workflow job for this annotation

GitHub Actions / mypy

ardupilot_methodic_configurator/safe_eval.py#L82

[error] Invalid index type
if isinstance(node, ast.Call):
if not isinstance(node.func, ast.Name):
msg = "Only direct math function calls allowed"
raise SyntaxError(msg)
args = [_eval(x) for x in node.args]
return checkmath(node.func.id, *args)
msg = f"Bad syntax, {type(node)}"
raise SyntaxError(msg)

return _eval(tree)
127 changes: 127 additions & 0 deletions tests/test_safe_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#!/usr/bin/env python3
"""
Tests for safe_eval.py.
This file is part of Ardupilot methodic configurator. https://github.com/ArduPilot/MethodicConfigurator
SPDX-FileCopyrightText: 2024 Amilcar do Carmo Lucas <[email protected]>
SPDX-License-Identifier: GPL-3.0-or-later
"""

import math

import pytest

from ardupilot_methodic_configurator.safe_eval import safe_eval


def test_basic_arithmetic() -> None:
"""Test basic arithmetic operations."""
assert safe_eval("1+1") == 2
assert safe_eval("1+-5") == -4
assert safe_eval("-1") == -1
assert safe_eval("-+1") == -1
assert safe_eval("(100*10)+6") == 1006
assert safe_eval("100*(10+6)") == 1600
assert safe_eval("2**4") == 16
assert pytest.approx(safe_eval("1.2345 * 10")) == 12.345


def test_math_functions() -> None:
"""Test mathematical functions."""
assert safe_eval("sqrt(16)+1") == 5
assert safe_eval("sin(0)") == 0
assert safe_eval("cos(0)") == 1
assert safe_eval("tan(0)") == 0
assert safe_eval("log(1)") == 0
assert safe_eval("exp(0)") == 1
assert safe_eval("pi") == math.pi


def test_complex_expressions() -> None:
"""Test more complex mathematical expressions."""
assert safe_eval("2 * (3 + 4)") == 14
assert safe_eval("2 ** 3 * 4") == 32
assert safe_eval("sqrt(16) + sqrt(9)") == 7
assert safe_eval("sin(pi/2)") == 1


def test_error_cases() -> None:
"""Test error conditions."""
with pytest.raises(SyntaxError):
safe_eval("1 + ") # Incomplete expression

with pytest.raises(SyntaxError):
safe_eval("unknown_func(10)") # Unknown function

with pytest.raises(SyntaxError):
safe_eval("1 = 1") # Invalid operator

with pytest.raises(SyntaxError):
safe_eval("import os") # Attempted import


def test_nested_expressions() -> None:
"""Test nested mathematical expressions."""
assert safe_eval("sqrt(pow(3,2) + pow(4,2))") == 5 # Pythagorean theorem
assert safe_eval("log(exp(1))") == 1
assert safe_eval("sin(pi/6)**2 + cos(pi/6)**2") == pytest.approx(1)


def test_division_by_zero() -> None:
"""Test division by zero handling."""
with pytest.raises(ZeroDivisionError):
safe_eval("1/0")
with pytest.raises(ZeroDivisionError):
safe_eval("10 % 0")


def test_invalid_math_functions() -> None:
"""Test invalid math function calls."""
with pytest.raises(SyntaxError, match=r".*takes exactly one argument.*"):
safe_eval("sin()") # Missing argument
with pytest.raises(SyntaxError, match=r".*takes exactly one argument.*"):
safe_eval("sin(1,2)") # Too many arguments
with pytest.raises(ValueError, match=r"math domain error"):
safe_eval("sqrt(-1)") # Domain error
with pytest.raises(ValueError, match=r"math domain error"):
safe_eval("log(-1)") # Range error
with pytest.raises(SyntaxError, match=r"Unknown func.*"):
safe_eval("unknown(1)") # Unknown function


def test_security() -> None:
"""Test against code injection attempts."""
with pytest.raises(SyntaxError):
safe_eval("__import__('os').system('ls')")
with pytest.raises(SyntaxError):
safe_eval("open('/etc/passwd')")
with pytest.raises(SyntaxError):
safe_eval("eval('1+1')")


def test_operator_precedence() -> None:
"""Test operator precedence rules."""
assert safe_eval("2 + 3 * 4") == 14
assert safe_eval("(2 + 3) * 4") == 20
assert safe_eval("-2 ** 2") == -4 # Exponentiation before negation
assert safe_eval("-(2 ** 2)") == -4


def test_float_precision() -> None:
"""Test floating point precision handling."""
assert pytest.approx(safe_eval("0.1 + 0.2")) == 0.3
assert pytest.approx(safe_eval("sin(pi/2)")) == 1.0
assert pytest.approx(safe_eval("cos(pi)")) == -1.0
assert pytest.approx(safe_eval("exp(log(2.718281828))")) == math.e


def test_math_constants() -> None:
"""Test mathematical constants."""
assert safe_eval("pi") == math.pi
assert safe_eval("e") == math.e
assert safe_eval("tau") == math.tau
assert safe_eval("inf") == math.inf
with pytest.raises(SyntaxError):
safe_eval("not_a_constant")

0 comments on commit 71c909a

Please sign in to comment.