Skip to content

Commit aaa7288

Browse files
committed
enh: performance optimalisation
enh: add more test fix: resolver test enh: simplify code
1 parent c8ef4f1 commit aaa7288

9 files changed

+147
-128
lines changed

crossmath.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,12 @@ def clear(self):
4040

4141

4242
class CrossMath:
43-
def __init__(self, exp_map: ExpressionMap, number_factory: NumberFactory, expression_resolver : ExpressionResolver):
43+
def __init__(
44+
self,
45+
exp_map: ExpressionMap,
46+
number_factory: NumberFactory,
47+
expression_resolver: ExpressionResolver,
48+
):
4449
self._map = exp_map
4550
self._number_factory = number_factory
4651
self._expression_resolver = expression_resolver
@@ -159,7 +164,7 @@ def generate(self):
159164
direction, _x, _y, values = desc
160165
try:
161166
expression = self._expression_resolver.resolve(
162-
Expression.from_values(values)
167+
Expression.from_list(values)
163168
)
164169
except ExpressionResolverException as e:
165170
print(e)

expression.py

+30-23
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from enum import Enum
22
from typing import TypeVar
33

4-
import pandas
5-
6-
from number_factory import NumberFactory
4+
from number_helper import number_is_zero, number_is_equal
75

86
Exp = TypeVar("Exp", bound="Expression")
97
Opr = TypeVar("Opr", bound="Operator")
@@ -47,20 +45,7 @@ def is_empty(self) -> bool:
4745
)
4846

4947
@staticmethod
50-
def from_str(exp: str) -> Exp:
51-
values = exp.split()
52-
result = []
53-
for i in range(0, len(values)):
54-
if values[i] == "?":
55-
result.append(None)
56-
elif i % 2 == 0:
57-
result.append(float(values[i]))
58-
else:
59-
result.append(Operator(values[i]))
60-
return Expression.from_values(result)
61-
62-
@staticmethod
63-
def from_values(values: list) -> Exp:
48+
def from_list(values: list) -> Exp:
6449
if len(values) not in Expression.SUPPORTED_LENGTHS:
6550
raise ValueError(
6651
f"Invalid values ({len(values)} not in {Expression.SUPPORTED_LENGTHS})"
@@ -99,15 +84,39 @@ def __str__(self):
9984
def get_length(self) -> int:
10085
return self._length
10186

87+
def is_match(self, exp: Exp, none_allowed: bool = True) -> bool:
88+
if not none_allowed:
89+
raise ValueError("Not supported")
90+
return (
91+
(
92+
self.operand1 is None
93+
or exp.operand1 is None
94+
or number_is_equal(self.operand1, exp.operand1)
95+
)
96+
and (
97+
self.operator is None
98+
or exp.operator is None
99+
or self.operator == exp.operator
100+
)
101+
and (
102+
self.operand2 is None
103+
or exp.operand2 is None
104+
or number_is_equal(self.operand2, exp.operand2)
105+
)
106+
and (
107+
self.result is None
108+
or exp.result is None
109+
or number_is_equal(self.result, exp.result)
110+
)
111+
)
112+
102113

103114
class ExpressionValidator:
104115
def __init__(
105116
self,
106-
number_factory: NumberFactory,
107117
minimum: float = 0.0,
108118
maximum: float = 100.0,
109119
):
110-
self._number_factory: NumberFactory = number_factory
111120
self._minimum: float = minimum
112121
self._maximum: float = maximum
113122

@@ -139,11 +148,9 @@ def validate(self, expression: Expression) -> bool:
139148
return False
140149
if not self._check_range(expression.result):
141150
return False
142-
if expression.operator == Operator.DIV and NumberFactory.is_zero(
143-
expression.operand2
144-
):
151+
if expression.operator == Operator.DIV and number_is_zero(expression.operand2):
145152
return False
146-
return NumberFactory.is_equal(
153+
return number_is_equal(
147154
eval(f"{expression.operand1} {expression.operator} {expression.operand2}"),
148155
expression.result,
149156
)

expression_resolver.py

+35-56
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import random
2-
import time
32

43
from expression import Expression, ExpressionValidator, Operator
54
from number_factory import NumberFactory
5+
from number_helper import number_is_zero, number_is_equal
66

77

88
class ExpressionResolverException(Exception):
@@ -40,7 +40,7 @@ class ExpressionResolver:
4040
def __init__(self, validator: ExpressionValidator, number_factory: NumberFactory):
4141
self._validator = validator
4242
self._number_factory = number_factory
43-
self._maximum_resolve_time_sec = 0.1
43+
self._resolve_maximum_loop_count = 8
4444

4545
def _fly_back(self, base: Expression, result: Expression):
4646
if base.operand1 is None:
@@ -56,24 +56,26 @@ def _fix_result(self, expression: Expression):
5656
expression.result = self._number_factory.fix(expression.result)
5757

5858
def _next_operator(self, expression: Expression) -> list[Operator]:
59+
operators_allowed = Operator.get_operators_without_eq()
60+
if number_is_zero(expression.operand2):
61+
# ? ? 0 = c
62+
operators_allowed.remove(Operator.DIV)
5963
if expression.operand1 is None or expression.operand2 is None:
60-
yield random.choice(Operator.get_operators_without_eq())
61-
operators = Operator.get_operators_without_eq()
62-
random.shuffle(operators)
63-
for operator in operators:
64+
yield random.choice(operators_allowed)
65+
random.shuffle(operators_allowed)
66+
for operator in operators_allowed:
6467
yield operator
6568

6669
def _resolve_result_is_none(self, expression: Expression) -> Expression:
67-
start = time.time()
6870
operators = self._next_operator(expression)
69-
while time.time() - start < self._maximum_resolve_time_sec:
71+
for _ in range(self._resolve_maximum_loop_count):
7072
exp_result = expression.clone()
7173
if exp_result.operator is None:
7274
try:
7375
exp_result.operator = next(operators)
7476
except StopIteration:
7577
raise ExpressionResolverNotResolvable(expression=expression)
76-
if exp_result.operator == Operator.DIV and NumberFactory.is_zero(
78+
if exp_result.operator == Operator.DIV and number_is_zero(
7779
exp_result.operand2
7880
):
7981
continue
@@ -104,22 +106,20 @@ def _resolve_result_is_none(self, expression: Expression) -> Expression:
104106
)
105107
else:
106108
exp_result.operand2 = self._number_factory.next()
107-
exp_result.result = self._number_factory.fix(
108-
eval(
109-
f"{exp_result.operand1} {exp_result.operator} {exp_result.operand2}"
110-
)
109+
exp_result.result = eval(
110+
f"{exp_result.operand1} {exp_result.operator} {exp_result.operand2}"
111111
)
112112
self._fix_result(exp_result)
113-
self._check_result(expression, exp_result)
113+
if not expression.is_match(exp_result):
114+
raise RuntimeError(f"Result is not match: {expression} vs {exp_result}")
114115
if not self._validator.validate(exp_result):
115-
# TODO time or count limit
116116
continue
117117
self._fly_back(expression, exp_result)
118118
return exp_result
119119

120-
time_diff = time.time() - start
121120
raise ExpressionResolverMaybeNotResolvable(
122-
message=f"Too slow: {time_diff:.1f}s", expression=expression
121+
message=f"Try count is reached ({self._resolve_maximum_loop_count})",
122+
expression=expression,
123123
)
124124

125125
def _resolve_only_operator_missing(self, expression: Expression) -> Expression:
@@ -135,29 +135,28 @@ def _resolve_only_operator_missing(self, expression: Expression) -> Expression:
135135
operators = Operator.get_operators_without_eq()
136136
random.shuffle(operators)
137137
for operator in operators:
138-
if operator == Operator.DIV and NumberFactory.is_zero(expression.operand2):
138+
if operator == Operator.DIV and number_is_zero(expression.operand2):
139139
# zero division
140140
continue
141141
exp_result = expression.clone()
142142
exp_result.operator = operator
143-
exp_result.result = self._number_factory.fix(
144-
eval(
145-
f"{exp_result.operand1} {exp_result.operator} {exp_result.operand2}"
146-
)
143+
exp_result.result = eval(
144+
f"{exp_result.operand1} {exp_result.operator} {exp_result.operand2}"
147145
)
148-
if not NumberFactory.is_equal(exp_result.result, expression.result):
146+
147+
if not number_is_equal(exp_result.result, expression.result):
149148
continue
150149
self._fix_result(exp_result)
151-
self._check_result(expression, exp_result)
150+
if not expression.is_match(exp_result):
151+
raise RuntimeError(f"Result is not match: {expression} vs {exp_result}")
152152
if not self._validator.validate(exp_result):
153153
continue
154154
self._fly_back(expression, exp_result)
155155
return exp_result
156156
raise ExpressionResolverNotResolvable(expression=expression)
157157

158158
def _resolve_result_is_available(self, expression: Expression) -> Expression:
159-
start = time.time()
160-
while time.time() - start < self._maximum_resolve_time_sec:
159+
for _ in range(self._resolve_maximum_loop_count):
161160
exp_calc = expression.clone()
162161
exp_result = expression.clone()
163162
exp_result.result = expression.result
@@ -241,47 +240,27 @@ def _resolve_result_is_available(self, expression: Expression) -> Expression:
241240
)
242241
else:
243242
raise RuntimeError("Invalid state: only one operand is missing")
244-
self._check_result(expression, exp_result)
245243

244+
if not expression.is_match(exp_result):
245+
raise RuntimeError(f"Result is not match: {expression} vs {exp_result}")
246246
if not self._validator.validate(exp_result):
247247
# TODO time or count limit
248248
continue
249249
self._fly_back(expression, exp_result)
250250
return exp_result
251251

252-
time_diff = time.time() - start
253252
raise ExpressionResolverMaybeNotResolvable(
254-
message=f"Too slow: {time_diff:.1f}s", expression=expression
253+
message=f"Try count is reached ({self._resolve_maximum_loop_count})",
254+
expression=expression,
255255
)
256256

257257
def resolve(self, expression: Expression) -> Expression | None:
258-
start_time = time.time()
259-
try:
260-
if expression.result is None:
261-
return self._resolve_result_is_none(expression)
262-
if expression.operand1 is not None and expression.operand2 is not None:
263-
return self._resolve_only_operator_missing(expression)
264-
return self._resolve_result_is_available(expression)
265-
finally:
266-
end_time = time.time()
267-
diff = end_time - start_time
268-
if diff > 0.01:
269-
print(f"ExpressionResolver.resolve: {diff:.2f} seconds")
270-
271-
def _check_result(self, expression: Expression, result: Expression):
272-
if expression.operand1 is not None:
273-
if not NumberFactory.is_equal(expression.operand1, result.operand1):
274-
raise RuntimeError(f"Invalid operand1: {expression} -> {result}")
275-
if expression.operand2 is not None:
276-
if not NumberFactory.is_equal(expression.operand2, result.operand2):
277-
raise RuntimeError(f"Invalid operand2: {expression} -> {result}")
278-
if expression.result is not None:
279-
if not NumberFactory.is_equal(expression.result, result.result):
280-
raise RuntimeError(f"Invalid result: {expression} -> {result}")
281-
if expression.operator is not None:
282-
if expression.operator != result.operator:
283-
raise RuntimeError(f"Invalid operator: {expression} -> {result}")
258+
if expression.result is None:
259+
return self._resolve_result_is_none(expression)
260+
if expression.operand1 is not None and expression.operand2 is not None:
261+
return self._resolve_only_operator_missing(expression)
262+
return self._resolve_result_is_available(expression)
284263

285264
@staticmethod
286265
def is_zero_division(operator: Operator, operand2: float) -> bool:
287-
return operator == Operator.DIV and NumberFactory.is_zero(operand2)
266+
return operator == Operator.DIV and number_is_zero(operand2)

main.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
HEIGHT = int(os.environ.get("HEIGHT", 30))
1010
NUMBER_FACTORY_MIN = float(os.environ.get("NUMBER_FACTORY_MIN", -20.0))
1111
NUMBER_FACTORY_MAX = float(os.environ.get("NUMBER_FACTORY_MAX", 20.0))
12-
NUMBER_FACTORY_STEP = float(os.environ.get("NUMBER_FACTORY_STEP", 0.5))
12+
NUMBER_FACTORY_STEP = float(os.environ.get("NUMBER_FACTORY_STEP", 0.1))
1313

1414
if __name__ == "__main__":
1515
exp_map = ExpressionMap(width=WIDTH, height=HEIGHT)
@@ -19,9 +19,7 @@
1919
step=NUMBER_FACTORY_STEP,
2020
)
2121
resolver = ExpressionResolver(
22-
validator=ExpressionValidator(
23-
number_factory=number_factory, minimum=-100, maximum=100
24-
),
22+
validator=ExpressionValidator(minimum=-100, maximum=100),
2523
number_factory=number_factory,
2624
)
2725
cross_math = CrossMath(

number_factory.py

+8-20
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import random
33
import time
44

5+
from number_helper import number_is_zero, number_is_equal, number_fix
6+
57

68
class RandomGenerator:
79
def __init__(self, seed: int = None):
@@ -61,7 +63,7 @@ def next(
6163
if minimum > maximum:
6264
raise ValueError(f"Minimum is greater than maximum: {minimum} > {maximum}")
6365
if dividable_by is not None:
64-
if NumberFactory.is_zero(dividable_by):
66+
if number_is_zero(dividable_by):
6567
dividable_by = self._step
6668
else:
6769
dividable_by = abs(dividable_by)
@@ -95,39 +97,25 @@ def next(
9597
random_in_range = self._random_generator.next_int(range_int)
9698
value = minimum_start + random_in_range * dividable_by
9799
value_fixed = self.fix(value, dividable_by)
98-
if not NumberFactory.is_equal(value, self.fix(value, dividable_by)):
100+
if not number_is_equal(value, self.fix(value, dividable_by)):
99101
raise RuntimeError(
100102
f"Value is not equal to fixed value: {value} vs {value_fixed}"
101103
)
102-
if value < minimum_start:
104+
if number_fix(value) < number_fix(minimum_start):
103105
raise RuntimeError(
104106
f"Value is less than minimum: {value} < {minimum_start} minimum={minimum}"
105107
)
106-
if value > maximum_end:
108+
if number_fix(value) > number_fix(maximum_end):
107109
raise RuntimeError(
108110
f"Value is greater than maximum: {value} > {maximum_end} maximum={maximum}"
109111
)
110-
if not zero_allowed and NumberFactory.is_zero(value):
112+
if not zero_allowed and number_is_zero(value):
111113
continue
112114
return self.fix(value)
113115
raise RuntimeError(
114116
f"Cannot find random value in {max_runtime_sec} second, paramters: minimum={minimum}, maximum={maximum}, dividable_by={dividable_by}, zero_allowed={zero_allowed}"
115117
)
116118

117-
@staticmethod
118-
def is_zero(value: float | None):
119-
if value is None:
120-
return False
121-
return abs(0.0 - value) < 1e-6
122-
123-
@staticmethod
124-
def is_equal(value1: float | None, value2: float | None) -> bool:
125-
if value1 is None and value2 is None:
126-
return True
127-
if value1 is None or value2 is None:
128-
return False
129-
return NumberFactory.is_zero(value1 - value2)
130-
131119
def format(self, value: float | None, decimals: int | None = None) -> str:
132120
if value is None:
133121
return ""
@@ -142,7 +130,7 @@ def fix(self, value: float, step: float | None = None) -> float:
142130
raise ValueError(
143131
f"Step must be greater than or equal to the factory step: {self._step} vs {step}"
144132
)
145-
return round(round(value / step) * step, self._decimals)
133+
return float(round(round(value / step) * step, self._decimals))
146134

147135
def fly_back(self, value: float):
148136
str_value = self.format(value)

number_helper.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
def number_is_zero(value: float | None):
2+
if value is None:
3+
return False
4+
return abs(0.0 - value) < 1e-6
5+
6+
7+
def number_is_equal(value1: float | None, value2: float | None) -> bool:
8+
if value1 is None and value2 is None:
9+
return True
10+
if value1 is None or value2 is None:
11+
return False
12+
return number_is_zero(value1 - value2)
13+
14+
15+
def number_fix(value: float | None) -> float | None:
16+
if value is None:
17+
return None
18+
return round(value, 6)

0 commit comments

Comments
 (0)