Skip to content

Commit

Permalink
Merge pull request pygame-community#3053 from Matiiss/matiiss-allow-s…
Browse files Browse the repository at this point in the history
…prite-group-subscripts

Add runtime support for `pygame.sprite.AbstractGroup` subscripts
  • Loading branch information
MyreMylar authored Dec 31, 2024
2 parents 71d8b23 + e3816c5 commit 46a25b7
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 0 deletions.
2 changes: 2 additions & 0 deletions buildconfig/stubs/pygame/sprite.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import types
from collections.abc import Callable, Iterable, Iterator
from typing import (
Any,
Expand Down Expand Up @@ -144,6 +145,7 @@ _TDirtySprite = TypeVar("_TDirtySprite", bound=_DirtySpriteSupportsGroup)
class AbstractGroup(Generic[_TSprite]):
spritedict: dict[_TSprite, Optional[Union[FRect, Rect]]]
lostsprites: list[Union[FRect, Rect]]
def __class_getitem__(cls, generic: Any) -> types.GenericAlias: ...
def __init__(self) -> None: ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_TSprite]: ...
Expand Down
4 changes: 4 additions & 0 deletions src_py/sprite.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
# specific ones that aren't quite so general but fit into common
# specialized cases.

import types
from warnings import warn
from typing import Optional

Expand Down Expand Up @@ -371,6 +372,9 @@ class AbstractGroup:
"""

def __class_getitem__(cls, generic):
return types.GenericAlias(cls, generic)

# protected identifier value to identify sprite groups, and avoid infinite recursion
_spritegroup = True

Expand Down
12 changes: 12 additions & 0 deletions test/sprite_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#################################### IMPORTS ###################################


import types
import typing
import unittest

import pygame
Expand Down Expand Up @@ -660,6 +662,16 @@ def update(self, *args, **kwargs):
self.assertEqual(test_sprite.sink, [1, 2, 3])
self.assertEqual(test_sprite.sink_kwargs, {"foo": 4, "bar": 5})

def test_type_subscript(self):
try:
group_generic_alias = sprite.Group[sprite.Sprite]
except TypeError as e:
self.fail(e)

self.assertIsInstance(group_generic_alias, types.GenericAlias)
self.assertIs(typing.get_origin(group_generic_alias), sprite.Group)
self.assertEqual(typing.get_args(group_generic_alias), (sprite.Sprite,))


################################################################################

Expand Down

0 comments on commit 46a25b7

Please sign in to comment.