Skip to content

Commit 8ded17e

Browse files
committed
enh: optimize expression resolver for div
enh: optimize crossmath generator
1 parent ce3d765 commit 8ded17e

File tree

5 files changed

+114
-26
lines changed

5 files changed

+114
-26
lines changed

crossmath.py

+20-15
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from expression import Expression, ExpressionValidator
66
from expression_map import ExpressionMap, ExpressionItem, Direction
7-
from expression_resolver import ExpressionResolver
7+
from expression_resolver import ExpressionResolver, ExpressionResolverException
88
from number_factory import NumberFactory
99

1010

@@ -13,7 +13,9 @@ def __init__(self, exp_map: ExpressionMap, number_factory: NumberFactory):
1313
self._map = exp_map
1414
self._number_factory = number_factory
1515
self._expression_resolver = ExpressionResolver(
16-
validator=ExpressionValidator(number_factory=number_factory),
16+
validator=ExpressionValidator(
17+
number_factory=number_factory, minimum=-100, maximum=50
18+
),
1719
number_factory=number_factory,
1820
)
1921

@@ -59,12 +61,16 @@ def _check_x_y_overflow(self, x: int, y: int, length: int) -> bool:
5961
return True
6062

6163
def _find_potential_values(
62-
self, potential_positions: list[Tuple[int, int]]
64+
self,
65+
potential_positions: list[Tuple[int, int]],
66+
dead_positions: list[Tuple[Direction, int, int]],
6367
) -> list[Tuple[Direction, int, int, list]]:
6468
max_expression_length = max(Expression.SUPPORTED_LENGTHS)
6569
for next_position in potential_positions:
6670
x, y = next_position
6771
for direction in Direction.all():
72+
if (direction, x, y) in dead_positions:
73+
continue
6874
for expression_offset in range(0, -max_expression_length - 1, -2):
6975
values_x_offset = (
7076
0 if direction.is_vertical() else expression_offset
@@ -105,33 +111,34 @@ def _init_generate(self):
105111
raise Exception("No expression found")
106112
direction = random.choice([Direction.HORIZONTAL, Direction.VERTICAL])
107113
item = ExpressionItem(
108-
3,
109-
3,
114+
2,
115+
2,
110116
direction,
111117
expression,
112118
)
113119
self._map.put(item)
114120

115121
def generate(self):
116-
start_time = time.time()
117122
self._init_generate()
118-
for i in range(40):
123+
dead_positions = []
124+
while True:
119125
potential_positions = self._find_potential_positions()
120-
potential_values = list(self._find_potential_values(potential_positions))
126+
potential_values = list(
127+
self._find_potential_values(potential_positions, dead_positions)
128+
)
121129
random.shuffle(potential_values)
122130
is_expression_appended = False
123131
for desc in potential_values:
132+
print(".", end="")
124133
direction, _x, _y, values = desc
125134
# print("desc:", desc, "x:", _x, "y:", _y)
126135
try:
127136
expression = self._expression_resolver.resolve(
128137
Expression.from_values(values)
129138
)
130-
except ValueError:
131-
# TODO store dead positions
132-
continue
133-
if expression is None:
134-
# TODO store dead positions
139+
except ExpressionResolverException as e:
140+
print(f"ExpressionResolverException: {e}")
141+
dead_positions.append((direction, _x, _y))
135142
continue
136143
expression_item = ExpressionItem(_x, _y, direction, expression)
137144
self._map.put(expression_item)
@@ -140,7 +147,5 @@ def generate(self):
140147
if not is_expression_appended:
141148
break
142149

143-
print("time: ", time.time() - start_time)
144-
145150
def print(self):
146151
self._map.print(number_factory=self._number_factory)

expression_resolver.py

+73-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,33 @@
55
from number_factory import NumberFactory
66

77

8+
class ExpressionResolverException(Exception):
9+
def __init__(
10+
self,
11+
message: str = "Expression resolver exception",
12+
expression: Expression | None = None,
13+
):
14+
super().__init__(f"{message} (expression={expression})")
15+
16+
17+
class ExpressionResolverNotResolvable(ExpressionResolverException):
18+
def __init__(
19+
self,
20+
message: str = "Expression is not resolvable",
21+
expression: Expression | None = None,
22+
):
23+
super().__init__(message, expression=expression)
24+
25+
26+
class ExpressionResolverMaybeNotResolvable(ExpressionResolverException):
27+
def __init__(
28+
self,
29+
message: str = "Expression is maybe not resolvable",
30+
expression: Expression | None = None,
31+
):
32+
super().__init__(message, expression=expression)
33+
34+
835
class ExpressionResolver:
936
def __init__(self, validator: ExpressionValidator, number_factory: NumberFactory):
1037
self._validator = validator
@@ -22,11 +49,28 @@ def _fix_result(self, expression: Expression):
2249
expression.operand2 = self._number_factory.fix(expression.operand2)
2350
expression.result = self._number_factory.fix(expression.result)
2451

25-
def _resolve_result_is_none(self, expression: Expression) -> Expression | None:
52+
def _next_operator(self, expression: Expression) -> list[Operator]:
53+
if expression.operand1 is None or expression.operand2 is None:
54+
yield random.choice(Operator.get_operators_without_eq())
55+
operators = Operator.get_operators_without_eq()
56+
random.shuffle(operators)
57+
for operator in operators:
58+
yield operator
59+
60+
def _resolve_result_is_none(self, expression: Expression) -> Expression:
61+
start = time.time()
62+
operators = self._next_operator(expression)
2663
while True:
2764
exp_result = expression.clone()
2865
if exp_result.operator is None:
29-
exp_result.operator = random.choice(Operator.get_operators_without_eq())
66+
try:
67+
exp_result.operator = next(operators)
68+
except StopIteration:
69+
raise ExpressionResolverNotResolvable(expression=expression)
70+
if exp_result.operator == Operator.DIV and NumberFactory.is_zero(
71+
exp_result.operand2
72+
):
73+
continue
3074
if exp_result.operand1 is None:
3175
exp_result.operand1 = self._number_factory.next()
3276
if exp_result.operand2 is None:
@@ -51,11 +95,14 @@ def _resolve_result_is_none(self, expression: Expression) -> Expression | None:
5195
self._check_result(expression, exp_result)
5296
if not self._validator.validate(exp_result):
5397
# TODO time or count limit
98+
if time.time() - start > 1.0:
99+
raise ExpressionResolverMaybeNotResolvable(expression=expression)
54100
continue
55101
self._fly_back(exp_result)
56102
return exp_result
57103

58-
def _resolve_result_is_available(self, expression: Expression) -> Expression | None:
104+
def _resolve_result_is_available(self, expression: Expression) -> Expression:
105+
start = time.time()
59106
while True:
60107
exp_calc = expression.clone()
61108
exp_result = expression.clone()
@@ -90,6 +137,12 @@ def _resolve_result_is_available(self, expression: Expression) -> Expression | N
90137
# a * ? = c -> c / a = b
91138
# ? * b = c -> c / b = a
92139
exp_calc.operator = Operator.DIV
140+
if not_has_operands:
141+
exp_result.operand1 = exp_calc.operand1 = self._number_factory.next(
142+
dividable_by=exp_result.result,
143+
zero_allowed=False,
144+
)
145+
print(exp_calc)
93146
elif operator == Operator.DIV:
94147
# a / b = c
95148
if not_has_operands:
@@ -109,20 +162,31 @@ def _resolve_result_is_available(self, expression: Expression) -> Expression | N
109162
else:
110163
raise RuntimeError(f"Invalid operator: {operator}")
111164
if exp_result.operand1 is None:
165+
if ExpressionResolver.is_zero_division(
166+
exp_calc.operator, exp_calc.operand2
167+
):
168+
continue
112169
exp_result.operand1 = self._number_factory.fix(
113170
eval(f"{exp_calc.result} {exp_calc.operator} {exp_calc.operand2}")
114171
)
115172
elif exp_result.operand2 is None:
173+
if ExpressionResolver.is_zero_division(
174+
exp_calc.operator, exp_calc.operand1
175+
):
176+
continue
116177
exp_result.operand2 = self._number_factory.fix(
117178
eval(f"{exp_calc.result} {exp_calc.operator} {exp_calc.operand1}")
118179
)
119180
else:
120-
raise ValueError(f"Operator resolver not supported ({expression})")
181+
raise ExpressionResolverException(
182+
"Operator resolver mode not supported", expression=expression
183+
)
121184
self._check_result(expression, exp_result)
122185

123-
# print("expression:", expression, "exp_result:", exp_result)
124186
if not self._validator.validate(exp_result):
125187
# TODO time or count limit
188+
if time.time() - start > 1.0:
189+
raise ExpressionResolverMaybeNotResolvable(expression=expression)
126190
continue
127191
self._fly_back(exp_result)
128192
return exp_result
@@ -152,3 +216,7 @@ def _check_result(self, expression: Expression, result: Expression):
152216
if expression.operator is not None:
153217
if expression.operator != result.operator:
154218
raise RuntimeError(f"Invalid operator: {expression} -> {result}")
219+
220+
@staticmethod
221+
def is_zero_division(operator: Operator, operand2: float) -> bool:
222+
return operator == Operator.DIV and NumberFactory.is_zero(operand2)

main.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from expression_map import ExpressionMap
55
from number_factory import NumberFactory
66

7-
WIDTH = int(os.environ.get("WIDTH", 20))
8-
HEIGHT = int(os.environ.get("HEIGHT", 20))
9-
NUMBER_FACTORY_MIN = float(os.environ.get("NUMBER_FACTORY_MIN", 0.1))
10-
NUMBER_FACTORY_MAX = float(os.environ.get("NUMBER_FACTORY_MAX", 8.0))
11-
NUMBER_FACTORY_STEP = float(os.environ.get("NUMBER_FACTORY_STEP", 0.5))
7+
WIDTH = int(os.environ.get("WIDTH", 30))
8+
HEIGHT = int(os.environ.get("HEIGHT", 30))
9+
NUMBER_FACTORY_MIN = float(os.environ.get("NUMBER_FACTORY_MIN", 0.01))
10+
NUMBER_FACTORY_MAX = float(os.environ.get("NUMBER_FACTORY_MAX", 10.0))
11+
NUMBER_FACTORY_STEP = float(os.environ.get("NUMBER_FACTORY_STEP", 0.1))
1212

1313
if __name__ == "__main__":
1414
exp_map = ExpressionMap(width=WIDTH, height=HEIGHT)
@@ -20,6 +20,7 @@
2020
cross_math = CrossMath(exp_map=exp_map, number_factory=number_factory)
2121
# try:
2222
cross_math.generate()
23+
print()
2324
cross_math.print()
2425
number_factory.print_statistic()
2526
# except Exception as e:

number_factory.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,16 @@ def next(
6060
value += self._step
6161
while value > (maximum or self._max):
6262
value -= self._step
63-
if not zero_allowed and self.is_equal(value, 0.0):
63+
if not zero_allowed and NumberFactory.is_zero(value):
6464
continue
6565
return value
6666

67+
@staticmethod
68+
def is_zero(value: float | None):
69+
if value is None:
70+
return False
71+
return abs(0.0 - value) < 1e-6
72+
6773
def is_equal(self, value1: float | None, value2: float | None) -> bool:
6874
return self.format(value1, decimals=8) == self.format(value2, decimals=8)
6975

test_number_factory.py

+8
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,11 @@ def test_decimals():
7171
assert number_factory.get_decimals() == 0
7272
number_factory = NumberFactory(step=100)
7373
assert number_factory.get_decimals() == 0
74+
75+
76+
def test_is_zero():
77+
assert NumberFactory.is_zero(0.0)
78+
assert not NumberFactory.is_zero(0.1)
79+
assert not NumberFactory.is_zero(-0.1)
80+
assert not NumberFactory.is_zero(1.0)
81+
assert not NumberFactory.is_zero(-1.0)

0 commit comments

Comments
 (0)