diff --git a/.idea/four-letter-blocks.iml b/.idea/four-letter-blocks.iml
index ac6d40f..e9c1ecd 100644
--- a/.idea/four-letter-blocks.iml
+++ b/.idea/four-letter-blocks.iml
@@ -6,7 +6,7 @@
-
+
diff --git a/.idea/misc.xml b/.idea/misc.xml
index c807d2f..fe25f72 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -3,5 +3,5 @@
-
+
\ No newline at end of file
diff --git a/four_letter_blocks/block_packer.py b/four_letter_blocks/block_packer.py
index 1499379..a430825 100644
--- a/four_letter_blocks/block_packer.py
+++ b/four_letter_blocks/block_packer.py
@@ -74,7 +74,7 @@ def rotated_positions(self):
rotated_positions[rotated_shape].append((x, y))
return rotated_positions
- def find_slots(self) -> dict[str, np.ndarray]:
+ def find_slots(self, is_rotation_allowed=False) -> dict[str, np.ndarray]:
if self.state is None:
raise RuntimeError('Cannot find slots with invalid state.')
@@ -100,7 +100,14 @@ def find_slots(self) -> dict[str, np.ndarray]:
has_even = np.logical_not(np.any(is_uneven, axis=(2, 3)))
usable_slots = np.logical_and(open_slots, has_even)
- slots[shape] = usable_slots
+ if not is_rotation_allowed or len(shape) == 1:
+ slots[shape] = usable_slots
+ else:
+ reported_shape = shape[0]
+ already_usable = slots.get(reported_shape)
+ if already_usable is not None:
+ usable_slots = np.logical_or(usable_slots, already_usable)
+ slots[reported_shape] = usable_slots
return slots
def display(self, state: np.ndarray | None = None) -> str:
@@ -355,6 +362,13 @@ def shape_coordinates() -> typing.Dict[str, typing.List[np.ndarray]]:
@cache
def build_masks(width: int, height: int) -> dict[str, np.ndarray]:
+ """ Build the masks for each shape in each position.
+
+ :return: {shape_name: mask_array}, where mask_array is a four-dimensional
+ array of occupied spaces with index (start_row, start_col, row, col). In
+ other words, if the shape starts at (start_row, start_col), is
+ (row, col) filled?
+ """
all_coordinates = shape_coordinates()
all_masks = {}
for shape, coordinate_list in all_coordinates.items():
diff --git a/tests/test_block_packer.py b/tests/test_block_packer.py
index 7f2e2d5..39ac93b 100644
--- a/tests/test_block_packer.py
+++ b/tests/test_block_packer.py
@@ -358,6 +358,7 @@ def test_fill_fail():
assert not is_filled
+# noinspection DuplicatedCode
def test_find_slots():
packer = BlockPacker(start_text=dedent("""\
#..#.
@@ -378,6 +379,45 @@ def test_find_slots():
assert np.array_equal(o_slots, expected_o_slots)
+# noinspection DuplicatedCode
+def test_find_slots_rotation_allowed():
+ packer = BlockPacker(start_text=dedent("""\
+ #..#.
+ .....
+ ..#..
+ .....
+ .#..#"""))
+ # Not at (1, 3) or (2, 0), because they cut off something.
+ expected_s0_slots = np.array(object=[[1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 0]],
+ dtype=bool)
+ expected_s1_slots = np.array(object=[[0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0]],
+ dtype=bool)
+ expected_s_slots = np.array(object=[[1, 0, 1, 0, 0],
+ [0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0],
+ [0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 0]],
+ dtype=bool)
+
+ is_rotation_allowed = False
+ s0_slots = packer.find_slots(is_rotation_allowed)['S0']
+ s1_slots = packer.find_slots()['S1']
+ is_rotation_allowed = True
+ s_slots = packer.find_slots(is_rotation_allowed)['S']
+
+ assert np.array_equal(s0_slots, expected_s0_slots)
+ assert np.array_equal(s1_slots, expected_s1_slots)
+ assert np.array_equal(s_slots, expected_s_slots)
+
+
def test_find_slots_after_fail():
packer = BlockPacker(start_text=dedent("""\
#..#.