diff --git a/src/api/models/Achievement.py b/src/api/models/Achievement.py index 7181e8ecd..3de125674 100644 --- a/src/api/models/Achievement.py +++ b/src/api/models/Achievement.py @@ -1,17 +1,18 @@ from __future__ import annotations +from enum import Enum + from pydantic import Field from src.api.models.AbstractEntity import AbstractEntity -from src.util import StringValuedEnum -class State(StringValuedEnum): +class State(Enum): REVEALED = "REVEALED" UNLOCKED = "UNLOCKED" -class ProgressType(StringValuedEnum): +class ProgressType(Enum): STANDARD = "STANDARD" INCREMENTAL = "INCREMENTAL" @@ -34,8 +35,8 @@ class Achievement(AbstractEntity): @property def init_state(self) -> State: - return State.from_string(self.initial_state) + return State(self.initial_state) @property def progress_type(self) -> ProgressType: - return ProgressType.from_string(self.typ) + return ProgressType(self.typ) diff --git a/src/api/models/MapType.py b/src/api/models/MapType.py index 2262ea7f4..467dfa08c 100644 --- a/src/api/models/MapType.py +++ b/src/api/models/MapType.py @@ -10,8 +10,6 @@ class MapType(Enum): @staticmethod def from_string(map_type: str) -> MapType: - for mtype in list(MapType): - if mtype.value == map_type: - return mtype - else: - return MapType.OTHER + if map_type in MapType: + return MapType(map_type) + return MapType.OTHER diff --git a/src/api/models/ModType.py b/src/api/models/ModType.py index 823e3b478..a8e2893df 100644 --- a/src/api/models/ModType.py +++ b/src/api/models/ModType.py @@ -10,7 +10,6 @@ class ModType(Enum): @staticmethod def from_string(string: str) -> ModType: - for modtype in list(ModType): - if modtype.value == string: - return modtype + if string in ModType: + return ModType(string) return ModType.OTHER diff --git a/src/api/models/PlayerAchievement.py b/src/api/models/PlayerAchievement.py index 578b2fa38..6e2861ff7 100644 --- a/src/api/models/PlayerAchievement.py +++ b/src/api/models/PlayerAchievement.py @@ -15,4 +15,4 @@ class PlayerAchievement(AbstractEntity): @property def current_state(self) -> State: - return State.from_string(self.state) + return State(self.state) diff --git a/src/contextmenu/playercontextmenu.py b/src/contextmenu/playercontextmenu.py index 77824cc61..1441f3895 100644 --- a/src/contextmenu/playercontextmenu.py +++ b/src/contextmenu/playercontextmenu.py @@ -172,11 +172,8 @@ def party_actions( if online_player is None: return - if online_player.id in self._client_window.games.party.memberIds: - if ( - self._me.player.id - == self._client_window.games.party.owner_id - ): + if online_player.id in self._client_window.games.party.member_ids: + if self._me.player.id == self._client_window.games.party.owner_id: yield PlayerMenuItem.KICK_FROM_PARTY elif online_player.currentGame is not None: return diff --git a/src/games/_gameswidget.py b/src/games/_gameswidget.py index d2d66432e..900e5061c 100644 --- a/src/games/_gameswidget.py +++ b/src/games/_gameswidget.py @@ -2,6 +2,7 @@ import logging from typing import TYPE_CHECKING +from typing import Self from PyQt6 import QtWidgets from PyQt6.QtCore import Qt @@ -34,43 +35,29 @@ class Party: - def __init__(self, owner_id=-1, owner=None): + def __init__(self, owner_id: int = -1, owner: PartyMember | None = None) -> None: self.owner_id = owner_id self.members = [owner] if owner else [] @property - def memberCount(self): - return len(self.memberList) + def member_count(self) -> int: + return len(self.members) - @property - def memberList(self): - return self.members - - def addMember(self, member): - self.memberList.append(member) + def add_member(self, member: PartyMember) -> None: + self.members.append(member) @property - def memberIds(self): - uids = [] - if len(self.members) > 0: - for member in self.members: - uids.append(member.id_) - return uids - - def __eq__(self, other): - if ( - sorted(self.memberIds) == sorted(other.memberIds) - and self.owner_id == other.owner_id - ): - return True - else: - return False + def member_ids(self) -> list[int]: + return [member.id_ for member in self.members] + + def __eq__(self, other: Self) -> bool: + return set(self.member_ids) == set(other.member_ids) and self.owner_id == other.owner_id class PartyMember: - def __init__(self, id_=-1, factions=None): + def __init__(self, id_: int = -1, factions: list[str] | None = None) -> None: self.id_ = id_ - self.factions = ["uef", "cybran", "aeon", "seraphim"] + self.factions = factions class GamesWidget(FormClass, BaseClass): @@ -199,7 +186,7 @@ def gameDoubleClicked(self, game): if ( self.party is not None - and self.party.memberCount > 1 + and self.party.member_count > 1 and not self.leave_party() ): return @@ -235,7 +222,7 @@ def hostGameClicked(self, item): if ( self.party is not None - and self.party.memberCount > 1 + and self.party.member_count > 1 and not self.leave_party() ): return @@ -263,9 +250,7 @@ def teamListItemClicked(self, item): menu.popup(QCursor.pos()) def updateParty(self, message): - players_ids = [] - for member in message["members"]: - players_ids.append(member["player"]) + players_ids = [member["player"] for member in message["members"]] old_owner = self.client.players[self.party.owner_id] new_owner = self.client.players[message["owner"]] @@ -283,17 +268,15 @@ def updateParty(self, message): new_party.owner_id = new_owner.id for member in message["members"]: players_id = member["player"] - new_party.addMember( - PartyMember(id_=players_id, factions=member["factions"]), - ) + new_party.add_member(PartyMember(id_=players_id, factions=member["factions"])) else: new_party.owner_id = self._me.id - new_party.addMember(PartyMember(id_=self._me.id)) + new_party.add_member(PartyMember(id_=self._me.id)) if self.party != new_party: self.stopSearch() self.party = new_party - if self.party.memberCount > 1: + if self.party.member_count > 1: self.client._chatMVC.connection.join( "#{}{}".format(new_owner.login, PARTY_CHANNEL_SUFFIX), ) @@ -308,15 +291,15 @@ def showPartyInfo(self): def hidePartyInfo(self): self.partyInfo.hide() - def updatePartyInfoFrame(self): - if self.party.memberCount > 1: + def updatePartyInfoFrame(self) -> None: + if self.party.member_count > 1: self.showPartyInfo() else: self.hidePartyInfo() - def updateTeamList(self): + def updateTeamList(self) -> None: self.teamList.clear() - for member_id in self.party.memberIds: + for member_id in self.party.member_ids: if member_id != self._me.id: item = QtWidgets.QListWidgetItem( self.client.players[member_id].login, diff --git a/src/games/automatchframe.py b/src/games/automatchframe.py index 7fe1a56c8..84a552ec0 100644 --- a/src/games/automatchframe.py +++ b/src/games/automatchframe.py @@ -148,7 +148,7 @@ def updateLabelMatchingIn(self): def startSearchRanked(self): if ( - self.games.party.memberCount > self.teamSize + self.games.party.member_count > self.teamSize or self.games.party.owner_id != self.client.me.id ): return @@ -196,7 +196,7 @@ def stopSearchRanked(self): def handlePartyUpdate(self): if ( - self.games.party.memberCount > self.teamSize + self.games.party.member_count > self.teamSize or self.games.party.owner_id != self.client.me.id ): self.rankedPlay.setEnabled(False) diff --git a/src/replays/_replayswidget.py b/src/replays/_replayswidget.py index a40f03e6f..46bee22c1 100644 --- a/src/replays/_replayswidget.py +++ b/src/replays/_replayswidget.py @@ -257,7 +257,7 @@ def liveTreeDoubleClicked(self, item): if ( self.client.games.party - and self.client.games.party.memberCount > 1 + and self.client.games.party.member_count > 1 ): if not self.client.games.leave_party(): return @@ -918,7 +918,7 @@ def online_tree_clicked(self, item: ReplayItem | QTreeWidgetItem) -> None: def onlineTreeDoubleClicked(self, item): if ( self.client.games.party - and self.client.games.party.memberCount > 1 + and self.client.games.party.member_count > 1 ): if not self.client.games.leave_party(): return diff --git a/src/util/__init__.py b/src/util/__init__.py index 4f58ff856..eef9a315d 100644 --- a/src/util/__init__.py +++ b/src/util/__init__.py @@ -7,8 +7,6 @@ import shutil import subprocess import sys -from enum import Enum -from typing import Self from PyQt6 import QtWidgets from PyQt6.QtCore import QDateTime @@ -529,13 +527,3 @@ def capitalize(string: str) -> str: Capitalize the first letter only, leave the rest as it is """ return f"{string[0].upper()}{string[1:]}" - - -class StringValuedEnum(Enum): - - @classmethod - def from_string(cls, string: str) -> Self: - for member in iter(cls): - if member.value == string: - return member - raise ValueError("Unsupported value")