Skip to content
Open
26 changes: 26 additions & 0 deletions symengine/lib/symengine_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,32 @@ class Relational(Boolean):
def is_Relational(self):
return True

def __bool__(self):
# We will narrow down the boolean value of our relational with some simple checks
# Get the Left- and Right-hand-sides of the relation, since two expressions are equal if their difference
# is equal to 0.
# If the expand method will not cancel out free symbols in the given expression, then this
# will throw a TypeError.
lhs, rhs = self.args
difference = (lhs - rhs).expand()

if len(difference.free_symbols):
# If there are any free symbols, then boolean evaluation is ambiguous in most cases. Throw a Type Error
raise TypeError(f'Relational with free symbols cannot be cast as bool: {self}')
else:
# Instantiating relationals that are obviously True or False (according to symengine) will automatically
# simplify to BooleanTrue or BooleanFalse
relational_type = type(self)
simplified = relational_type(difference, S.Zero)
if isinstance(simplified, BooleanAtom):
return bool(simplified)
# If we still cannot determine whether or not the relational is true, then we can either outsource the
# evaluation to sympy (if available) or raise a ValueError expressing that the evaluation is unclear.
try:
return bool(self.simplify())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will return True for 2*(x + 1) - 2 which is different from what we get for 2*x which is a TypeError.

except ImportError:
raise ValueError(f'Boolean evaluation is unclear for relational: {self}')

Rel = Relational


Expand Down
1 change: 1 addition & 0 deletions symengine/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ install(FILES __init__.py
test_matrices.py
test_ntheory.py
test_printing.py
test_relationals.py
test_sage.py
test_series_expansion.py
test_sets.py
Expand Down
2 changes: 1 addition & 1 deletion symengine/tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_eval_double2():
x = Symbol("x")
e = sin(x)**2 + sqrt(2)
raises(RuntimeError, lambda: e.n(real=True))
assert abs(e.n() - x**2 - 1.414) < 1e-3
assert abs(e.n() - sin(x)**2.0 - 1.414) < 1e-3

def test_n():
x = Symbol("x")
Expand Down
138 changes: 138 additions & 0 deletions symengine/tests/test_relationals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from symengine.utilities import raises
from symengine import (Symbol, sympify, Eq, Ne, Lt, Le, Ge, Gt, sqrt, pi)

from unittest.case import SkipTest

try:
import sympy
HAVE_SYMPY = True
except ImportError:
HAVE_SYMPY = False


def assert_equal(x, y):
"""Asserts that x and y are equal. This will test Equality, Unequality, LE, and GE classes."""
assert bool(Eq(x, y))
assert not bool(Ne(x, y))
assert bool(Ge(x, y))
assert bool(Le(x, y))


def assert_not_equal(x, y):
"""Asserts that x and y are not equal. This will test Equality and Unequality"""
assert not bool(Eq(x, y))
assert bool(Ne(x, y))


def assert_less_than(x, y):
"""Asserts that x is less than y. This will test Le, Lt, Ge, Gt classes."""
assert bool(Le(x, y))
assert bool(Lt(x, y))
assert not bool(Ge(x, y))
assert not bool(Gt(x, y))


def assert_greater_than(x, y):
"""Asserts that x is greater than y. This will test Le, Lt, Ge, Gt classes."""
assert not bool(Le(x, y))
assert not bool(Lt(x, y))
assert bool(Ge(x, y))
assert bool(Gt(x, y))


def test_equals_constants_easy():
assert_equal(3, 3)
assert_equal(4, 2 ** 2)


def test_equals_constants_hard():
# Short and long are symbolically equivalent, but sufficiently different in form that expand() does not
# catch it. Ideally, our equality should still catch these, but until symengine supports as robust simplification as
# sympy, we can forgive failing, as long as it raises a ValueError
short = sympify('(3/2)*sqrt(11 + sqrt(21))')
long = sympify('sqrt((33/8 + (1/24)*sqrt(27)*sqrt(63))**2 + ((3/8)*sqrt(27) + (-1/8)*sqrt(63))**2)')
assert_equal(short, short)
assert_equal(long, long)
if HAVE_SYMPY:
assert_equal(short, long)
else:
raises(ValueError, lambda: bool(Eq(short, long)))


def test_not_equals_constants():
assert_not_equal(3, 4)
assert_not_equal(4, 4 - .000000001)


def test_equals_symbols():
x = Symbol("x")
y = Symbol("y")
assert_equal(x, x)
assert_equal(x ** 2, x * x)
assert_equal(x * y, y * x)


def test_not_equals_symbols():
x = Symbol("x")
y = Symbol("y")
assert_not_equal(x, x + 1)
assert_not_equal(x ** 2, x ** 2 + 1)
assert_not_equal(x * y, y * x + 1)


def test_not_equals_symbols_raise_typeerror():
x = Symbol("x")
y = Symbol("y")
raises(TypeError, lambda: bool(Eq(x, 1)))
raises(TypeError, lambda: bool(Eq(x, y)))
raises(TypeError, lambda: bool(Eq(x ** 2, x)))


def test_less_than_constants_easy():
assert_less_than(1, 2)
assert_less_than(-1, 1)


def test_less_than_constants_hard():
# Each of the below pairs are distinct numbers, with the one on the left less than the one on the right.
# Ideally, Less-than will catch this when evaluated, but until symengine has a more robust simplification,
# we can forgive a failure to evaluate as long as it raises a ValueError.
if HAVE_SYMPY:
assert_less_than(sqrt(2), 2)
assert_less_than(3.14, pi)
else:
raises(ValueError, lambda: bool(Lt(sqrt(2), 2)))
raises(ValueError, lambda: bool(Lt(3.14, pi)))


def test_greater_than_constants():
assert_greater_than(2, 1)
assert_greater_than(1, -1)


def test_greater_than_constants_hard():
# Each of the below pairs are distinct numbers, with the one on the left less than the one on the right.
# Ideally, Greater-than will catch this when evaluated, but until symengine has a more robust simplification,
# we can forgive a failure to evaluate as long as it raises a ValueError.
if HAVE_SYMPY:
assert_greater_than(2, sqrt(2))
assert_greater_than(pi, 3.14)
else:
raises(ValueError, lambda: bool(Gt(2, sqrt(2))))
raises(ValueError, lambda: bool(Gt(pi, 3.14)))


def test_less_than_raises_typeerror():
x = Symbol("x")
y = Symbol("y")
raises(TypeError, lambda: bool(Lt(x, 1)))
raises(TypeError, lambda: bool(Lt(x, y)))
raises(TypeError, lambda: bool(Lt(x ** 2, x)))


def test_greater_than_raises_typeerror():
x = Symbol("x")
y = Symbol("y")
raises(TypeError, lambda: bool(Gt(x, 1)))
raises(TypeError, lambda: bool(Gt(x, y)))
raises(TypeError, lambda: bool(Gt(x ** 2, x)))