diff --git a/findmy/errors.py b/findmy/errors.py index 2eda554..fbf88c7 100644 --- a/findmy/errors.py +++ b/findmy/errors.py @@ -5,6 +5,10 @@ class InvalidCredentialsError(Exception): """Raised when credentials are incorrect.""" +class UnauthorizedError(Exception): + """Raised when an authorization error occurs.""" + + class UnhandledProtocolError(RuntimeError): """ Raised when an unexpected error occurs while communicating with Apple servers. diff --git a/findmy/reports/account.py b/findmy/reports/account.py index c6ff3ca..814ee77 100644 --- a/findmy/reports/account.py +++ b/findmy/reports/account.py @@ -27,10 +27,15 @@ import srp._pysrp as srp from typing_extensions import override -from findmy.errors import InvalidCredentialsError, InvalidStateError, UnhandledProtocolError +from findmy.errors import ( + InvalidCredentialsError, + InvalidStateError, + UnauthorizedError, + UnhandledProtocolError, +) from findmy.util import crypto from findmy.util.closable import Closable -from findmy.util.http import HttpSession, decode_plist +from findmy.util.http import HttpResponse, HttpSession, decode_plist from .reports import LocationReport, LocationReportsFetcher from .state import LoginState @@ -585,15 +590,36 @@ async def fetch_raw_reports(self, start: int, end: int, ids: list[str]) -> dict[ ) data = {"search": [{"startDate": start, "endDate": end, "ids": ids}]} - r = await self._http.post( - self._ENDPOINT_REPORTS_FETCH, - auth=auth, - headers=await self.get_anisette_headers(), - json=data, - ) - resp = r.json() - if not r.ok or resp["statusCode"] != "200": - msg = f"Failed to fetch reports: {resp['statusCode']}" + async def _do_request() -> HttpResponse: + return await self._http.post( + self._ENDPOINT_REPORTS_FETCH, + auth=auth, + headers=await self.get_anisette_headers(), + json=data, + ) + + r = await _do_request() + if r.status_code == 401: + logging.info("Got 401 while fetching reports, redoing login") + + new_state = await self._gsa_authenticate() + if new_state != LoginState.AUTHENTICATED: + msg = f"Unexpected login state after reauth: {new_state}. Please log in again." + raise UnauthorizedError(msg) + await self._login_mobileme() + + r = await _do_request() + + if r.status_code == 401: + msg = "Not authorized to fetch reports." + raise UnauthorizedError(msg) + + try: + resp = r.json() + except json.JSONDecodeError: + resp = {} + if not r.ok or resp.get("statusCode") != "200": + msg = f"Failed to fetch reports: {resp.get('statusCode')}" raise UnhandledProtocolError(msg) return resp @@ -679,7 +705,7 @@ async def fetch_last_reports( return await self.fetch_reports(keys, start, end) - @require_login_state(LoginState.LOGGED_OUT, LoginState.REQUIRE_2FA) + @require_login_state(LoginState.LOGGED_OUT, LoginState.REQUIRE_2FA, LoginState.LOGGED_IN) async def _gsa_authenticate( self, username: str | None = None, @@ -805,9 +831,9 @@ async def _login_mobileme(self) -> LoginState: data = resp.plist() mobileme_data = data.get("delegates", {}).get("com.apple.mobileme", {}) - status = mobileme_data.get("status") + status = mobileme_data.get("status") or data.get("status") if status != 0: - status_message = mobileme_data.get("status-message") + status_message = mobileme_data.get("status-message") or data.get("status-message") msg = f"com.apple.mobileme login failed with status {status}: {status_message}" raise UnhandledProtocolError(msg)