diff --git a/logfire/_internal/auth.py b/logfire/_internal/auth.py index 96a01454..3eeac440 100644 --- a/logfire/_internal/auth.py +++ b/logfire/_internal/auth.py @@ -1,17 +1,21 @@ from __future__ import annotations import platform +import re import warnings +from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path -from typing import TypedDict +from typing import TypedDict, cast from urllib.parse import urljoin import requests +from rich.prompt import IntPrompt +from typing_extensions import Self from logfire.exceptions import LogfireConfigError -from .utils import UnexpectedResponse +from .utils import UnexpectedResponse, read_toml_file HOME_LOGFIRE = Path.home() / '.logfire' """Folder used to store global configuration, and user tokens.""" @@ -19,6 +23,29 @@ """File used to store user tokens.""" +PYDANTIC_LOGFIRE_TOKEN_PATTERN = re.compile( + r'^(?Ppylf_v(?P[0-9]+)_(?P[a-z]+)_)(?P[a-zA-Z0-9]+)$' +) + + +class _RegionData(TypedDict): + base_url: str + gcp_region: str + + +REGIONS: dict[str, _RegionData] = { + 'us': { + 'base_url': 'https://logfire-us.pydantic.dev', + 'gcp_region': 'us-east4', + }, + 'eu': { + 'base_url': 'https://logfire-eu.pydantic.dev', + 'gcp_region': 'europe-west4', + }, +} +"""The existing Logfire regions.""" + + class UserTokenData(TypedDict): """User token data.""" @@ -26,12 +53,144 @@ class UserTokenData(TypedDict): expiration: str -class DefaultFile(TypedDict): - """Content of the default.toml file.""" +class UserTokensFileData(TypedDict, total=False): + """Content of the file containing the user tokens.""" tokens: dict[str, UserTokenData] +@dataclass +class UserToken: + """A user token.""" + + token: str + base_url: str + expiration: str + + @classmethod + def from_user_token_data(cls, base_url: str, token: UserTokenData) -> Self: + return cls( + token=token['token'], + base_url=base_url, + expiration=token['expiration'], + ) + + @property + def is_expired(self) -> bool: + """Whether the token is expired.""" + return datetime.now(tz=timezone.utc) >= datetime.fromisoformat(self.expiration.rstrip('Z')).replace( + tzinfo=timezone.utc + ) + + def __str__(self) -> str: + region = 'us' + if match := PYDANTIC_LOGFIRE_TOKEN_PATTERN.match(self.token): + region = match.group('region') + if region not in REGIONS: + region = 'us' + + token_repr = f'{region.upper()} ({self.base_url}) - ' + if match: + token_repr += match.group('safe_part') + match.group('token')[:5] + else: + token_repr += self.token[:5] + token_repr += '****' + return token_repr + + +@dataclass +class UserTokenCollection: + """A collection of user tokens, read from a user tokens file. + + Args: + path: The path where the user tokens will be stored. If the path doesn't exist, + an empty collection is created. Defaults to `~/.logfire/default.toml`. + """ + + user_tokens: dict[str, UserToken] + """A mapping between base URLs and user tokens.""" + + path: Path + """The path where the user tokens are stored.""" + + def __init__(self, path: Path = DEFAULT_FILE) -> None: + self.path = path + try: + data = cast(UserTokensFileData, read_toml_file(path)) + except FileNotFoundError: + data: UserTokensFileData = {} + self.user_tokens = {url: UserToken(base_url=url, **data) for url, data in data.get('tokens', {}).items()} + + def get_token(self, base_url: str | None = None) -> UserToken: + """Get a user token from the collection. + + Args: + base_url: Only look for user tokens valid for this base URL. If not provided, + all the tokens of the collection will be considered: if only one token is + available, it will be used, otherwise the user will be prompted to choose + a token. + + Raises: + LogfireConfigError: If no user token is found (no token matched the base URL, + the collection is empty, or the selected token is expired). + """ + tokens_list = list(self.user_tokens.values()) + + if base_url is not None: + token = self.user_tokens.get(base_url) + if token is None: + raise LogfireConfigError( + f'No user token was found matching the {base_url} Logfire URL. ' + 'Please run `logfire auth` to authenticate.' + ) + elif len(tokens_list) == 1: + token = tokens_list[0] + elif len(tokens_list) >= 2: + choices_str = '\n'.join( + f'{i}. {token} ({"expired" if token.is_expired else "valid"})' + for i, token in enumerate(tokens_list, start=1) + ) + int_choice = IntPrompt.ask( + f'Multiple user tokens found. Please select one:\n{choices_str}\n', + choices=[str(i) for i in range(1, len(tokens_list) + 1)], + ) + token = tokens_list[int_choice - 1] + else: # tokens_list == [] + raise LogfireConfigError('No user tokens are available. Please run `logfire auth` to authenticate.') + + if token.is_expired: + raise LogfireConfigError(f'User token {token} is expired. Please run `logfire auth` to authenticate.') + return token + + def is_logged_in(self, base_url: str | None = None) -> bool: + """Check whether the user token collection contains at least one valid user token. + + Args: + base_url: Only check for user tokens valid for this base URL. If not provided, + all the tokens of the collection will be considered. + """ + if base_url is not None: + tokens = (t for t in self.user_tokens.values() if t.base_url == base_url) + else: + tokens = self.user_tokens.values() + return any(not t.is_expired for t in tokens) + + def add_token(self, base_url: str, token: UserTokenData) -> UserToken: + """Add a user token to the collection.""" + self.user_tokens[base_url] = user_token = UserToken.from_user_token_data(base_url, token) + self._dump() + return user_token + + def _dump(self) -> None: + """Dump the user token collection as TOML to the provided path.""" + # There's no standard library package to write TOML files, so we'll write it manually. + with self.path.open('w') as f: + for base_url, user_token in self.user_tokens.items(): + f.write(f'[tokens."{base_url}"]\n') + f.write(f'token = "{user_token.token}"\n') + f.write(f'expiration = "{user_token.expiration}"\n') + + class NewDeviceFlow(TypedDict): """Matches model of the same name in the backend.""" @@ -91,17 +250,3 @@ def poll_for_token(session: requests.Session, device_code: str, base_api_url: st opt_user_token: UserTokenData | None = res.json() if opt_user_token: return opt_user_token - - -def is_logged_in(data: DefaultFile, logfire_url: str) -> bool: - """Check if the user is logged in. - - Returns: - True if the user is logged in, False otherwise. - """ - for url, info in data['tokens'].items(): # pragma: no branch - # token expirations are in UTC - expiry_date = datetime.fromisoformat(info['expiration'].rstrip('Z')).replace(tzinfo=timezone.utc) - if url == logfire_url and datetime.now(tz=timezone.utc) < expiry_date: # pragma: no branch - return True - return False # pragma: no cover diff --git a/logfire/_internal/cli.py b/logfire/_internal/cli.py index 5c1eafca..bf7cf567 100644 --- a/logfire/_internal/cli.py +++ b/logfire/_internal/cli.py @@ -14,7 +14,7 @@ from collections.abc import Sequence from operator import itemgetter from pathlib import Path -from typing import Any, cast +from typing import Any from urllib.parse import urlparse import requests @@ -24,11 +24,17 @@ from logfire.propagate import ContextCarrier, get_context from ..version import VERSION -from .auth import DEFAULT_FILE, HOME_LOGFIRE, DefaultFile, is_logged_in, poll_for_token, request_device_code +from .auth import ( + DEFAULT_FILE, + HOME_LOGFIRE, + UserTokenCollection, + poll_for_token, + request_device_code, +) +from .client import LogfireClient from .config import REGIONS, LogfireCredentials, get_base_url_from_token from .config_params import ParamManager from .tracer import SDKTracerProvider -from .utils import read_toml_file BASE_OTEL_INTEGRATION_URL = 'https://opentelemetry-python-contrib.readthedocs.io/en/latest/instrumentation/' BASE_DOCS_URL = 'https://logfire.pydantic.dev/docs' @@ -51,7 +57,7 @@ def parse_whoami(args: argparse.Namespace) -> None: """Show user authenticated username and the URL to your Logfire project.""" data_dir = Path(args.data_dir) param_manager = ParamManager.create(data_dir) - base_url = param_manager.load_param('base_url', args.logfire_url) + base_url: str | None = param_manager.load_param('base_url', args.logfire_url) token = param_manager.load_param('token') if token: @@ -61,12 +67,15 @@ def parse_whoami(args: argparse.Namespace) -> None: credentials.print_token_summary() return - current_user = LogfireCredentials.get_current_user(session=args._session, logfire_api_url=base_url) - if current_user is None: + try: + client = LogfireClient.from_url(base_url) + except LogfireConfigError: sys.stderr.write('Not logged in. Run `logfire auth` to log in.\n') else: + current_user = client.get_user_information() username = current_user['name'] sys.stderr.write(f'Logged in as: {username}\n') + credentials = LogfireCredentials.load_creds_file(data_dir) if credentials is None: sys.stderr.write(f'No Logfire credentials found in {data_dir.resolve()}\n') @@ -198,16 +207,10 @@ def parse_auth(args: argparse.Namespace) -> None: This will authenticate your machine with Logfire and store the credentials. """ - logfire_url = args.logfire_url - if DEFAULT_FILE.is_file(): - data = cast(DefaultFile, read_toml_file(DEFAULT_FILE)) - else: - data: DefaultFile = {'tokens': {}} + logfire_url: str | None = args.logfire_url - if logfire_url: - logged_in = is_logged_in(data, logfire_url) - else: - logged_in = any(is_logged_in(data, url) for url in data['tokens']) + tokens_collection = UserTokenCollection() + logged_in = tokens_collection.is_logged_in(logfire_url) if logged_in: sys.stderr.writelines( @@ -256,22 +259,16 @@ def parse_auth(args: argparse.Namespace) -> None: ) ) - data['tokens'][logfire_url] = poll_for_token(args._session, device_code, logfire_url) + tokens_collection.add_token(logfire_url, poll_for_token(args._session, device_code, logfire_url)) sys.stderr.write('Successfully authenticated!\n') - - # There's no standard library package to write TOML files, so we'll write it manually. - with DEFAULT_FILE.open('w') as f: - for url, info in data['tokens'].items(): - f.write(f'[tokens."{url}"]\n') - f.write(f'token = "{info["token"]}"\n') - f.write(f'expiration = "{info["expiration"]}"\n') - sys.stderr.write(f'\nYour Logfire credentials are stored in {DEFAULT_FILE}\n') def parse_list_projects(args: argparse.Namespace) -> None: """List user projects.""" - projects = LogfireCredentials.get_user_projects(session=args._session, logfire_api_url=args.logfire_url) + client = LogfireClient.from_url(args.logfire_url) + + projects = client.get_user_projects() if projects: sys.stderr.write( _pretty_table( @@ -300,42 +297,37 @@ def _write_credentials(project_info: dict[str, Any], data_dir: Path, logfire_api def parse_create_new_project(args: argparse.Namespace) -> None: """Create a new project.""" data_dir = Path(args.data_dir) - logfire_url = args.logfire_url - if logfire_url is None: # pragma: no cover - _, logfire_url = LogfireCredentials._get_user_token_data() # type: ignore + client = LogfireClient.from_url(args.logfire_url) + project_name = args.project_name organization = args.org default_organization = args.default_org project_info = LogfireCredentials.create_new_project( - session=args._session, - logfire_api_url=logfire_url, + client=client, organization=organization, default_organization=default_organization, project_name=project_name, ) - credentials = _write_credentials(project_info, data_dir, logfire_url) + credentials = _write_credentials(project_info, data_dir, client.base_url) sys.stderr.write(f'Project created successfully. You will be able to view it at: {credentials.project_url}\n') def parse_use_project(args: argparse.Namespace) -> None: """Use an existing project.""" data_dir = Path(args.data_dir) - logfire_url = args.logfire_url - if logfire_url is None: # pragma: no cover - _, logfire_url = LogfireCredentials._get_user_token_data() # type: ignore + client = LogfireClient.from_url(args.logfire_url) + project_name = args.project_name organization = args.org - - projects = LogfireCredentials.get_user_projects(session=args._session, logfire_api_url=logfire_url) + projects = client.get_user_projects() project_info = LogfireCredentials.use_existing_project( - session=args._session, - logfire_api_url=logfire_url, + client=client, projects=projects, organization=organization, project_name=project_name, ) if project_info: - credentials = _write_credentials(project_info, data_dir, logfire_url) + credentials = _write_credentials(project_info, data_dir, client.base_url) sys.stderr.write( f'Project configured successfully. You will be able to view it at: {credentials.project_url}\n' ) diff --git a/logfire/_internal/client.py b/logfire/_internal/client.py new file mode 100644 index 00000000..528d420c --- /dev/null +++ b/logfire/_internal/client.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from typing import Any +from urllib.parse import urljoin + +from requests import Response, Session +from typing_extensions import Self + +from logfire.exceptions import LogfireConfigError +from logfire.version import VERSION + +from .auth import UserToken, UserTokenCollection +from .utils import UnexpectedResponse + +UA_HEADER = f'logfire/{VERSION}' + + +class ProjectAlreadyExists(Exception): + pass + + +class InvalidProjectName(Exception): + def __init__(self, reason: str, /) -> None: + self.reason = reason + + +class LogfireClient: + """A Logfire HTTP client to interact with the API. + + Args: + user_token: The user token to use when authenticating against the API. + """ + + def __init__(self, user_token: UserToken) -> None: + if user_token.is_expired: + raise RuntimeError('The provided user token is expired') + self.base_url = user_token.base_url + self._token = user_token.token + self._session = Session() + self._session.headers.update({'Authorization': self._token, 'User-Agent': UA_HEADER}) + + @classmethod + def from_url(cls, base_url: str | None) -> Self: + """Create a client from the provided base URL. + + Args: + base_url: The base URL to use when looking for a user token. If `None`, will prompt + the user into selecting a token from the token collection (or, if only one available, + use it directly). The token collection will be created from the `~/.logfire/default.toml` + file (or an empty one if no such file exists). + """ + return cls(user_token=UserTokenCollection().get_token(base_url)) + + def _get_raw(self, endpoint: str) -> Response: + response = self._session.get(urljoin(self.base_url, endpoint)) + UnexpectedResponse.raise_for_status(response) + return response + + def _get(self, endpoint: str, *, error_message: str) -> Any: + try: + return self._get_raw(endpoint).json() + except UnexpectedResponse as e: + raise LogfireConfigError(error_message) from e + + def _post_raw(self, endpoint: str, body: Any | None = None) -> Response: + response = self._session.post(urljoin(self.base_url, endpoint), json=body) + UnexpectedResponse.raise_for_status(response) + return response + + def _post(self, endpoint: str, *, body: Any | None = None, error_message: str) -> Any: + try: + return self._post_raw(endpoint, body).json() + except UnexpectedResponse as e: + raise LogfireConfigError(error_message) from e + + def get_user_organizations(self) -> list[dict[str, Any]]: + """Get the organizations of the logged-in user.""" + return self._get('/v1/organizations/', error_message='Error retrieving list of organizations') + + def get_user_information(self) -> dict[str, Any]: + """Get information about the logged-in user.""" + return self._get('/v1/account/me', error_message='Error retrieving user information') + + def get_user_projects(self) -> list[dict[str, Any]]: + """Get the projects of the logged-in user.""" + return self._get('/v1/projects/', error_message='Error retrieving list of projects') + + def create_new_project(self, organization: str, project_name: str): + """Create a new project. + + Args: + organization: The organization that should hold the new project. + project_name: The name of the project to be created. + + Returns: + The newly created project. + """ + try: + response = self._post_raw(f'/v1/projects/{organization}', body={'project_name': project_name}) + except UnexpectedResponse as e: + r = e.response + if r.status_code == 409: + raise ProjectAlreadyExists + if r.status_code == 422: + error = r.json()['detail'][0] + if error['loc'] == ['body', 'project_name']: # pragma: no branch + raise InvalidProjectName(error['msg']) + + raise LogfireConfigError('Error creating new project') + return response.json() + + def create_write_token(self, organization: str, project_name: str) -> dict[str, Any]: + """Create a write token for the given project in the given organization.""" + return self._post( + f'/v1/organizations/{organization}/projects/{project_name}/write-tokens/', + error_message='Error creating project write token', + ) diff --git a/logfire/_internal/config.py b/logfire/_internal/config.py index 559108a9..6bdb9d44 100644 --- a/logfire/_internal/config.py +++ b/logfire/_internal/config.py @@ -14,7 +14,7 @@ from dataclasses import dataclass, field from pathlib import Path from threading import RLock, Thread -from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, cast +from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict from urllib.parse import urljoin from uuid import uuid4 @@ -55,9 +55,10 @@ from opentelemetry.sdk.trace.id_generator import IdGenerator from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio, Sampler from rich.console import Console -from rich.prompt import Confirm, IntPrompt, Prompt +from rich.prompt import Confirm, Prompt from typing_extensions import Self, Unpack +from logfire._internal.auth import PYDANTIC_LOGFIRE_TOKEN_PATTERN, REGIONS from logfire._internal.baggage import DirectBaggageAttributesSpanProcessor from logfire.exceptions import LogfireConfigError from logfire.sampling import SamplingOptions @@ -65,7 +66,7 @@ from logfire.version import VERSION from ..propagate import NoExtractTraceContextPropagator, WarnOnExtractTraceContextPropagator -from .auth import DEFAULT_FILE, DefaultFile, is_logged_in +from .client import InvalidProjectName, LogfireClient, ProjectAlreadyExists from .config_params import ParamManager, PydanticPluginRecordValues from .constants import ( RESOURCE_ATTRIBUTES_CODE_ROOT_PATH, @@ -103,11 +104,9 @@ from .tracer import OPEN_SPANS, PendingSpanProcessor, ProxyTracerProvider from .utils import ( SeededRandomIdGenerator, - UnexpectedResponse, ensure_data_dir_exists, handle_internal_errors, platform_is_emscripten, - read_toml_file, suppress_instrumentation, ) @@ -122,9 +121,6 @@ COMMON_REQUEST_HEADERS = {'User-Agent': f'logfire/{VERSION}'} """Common request headers for requests to the Logfire API.""" PROJECT_NAME_PATTERN = r'^[a-z0-9]+(?:-[a-z0-9]+)*$' -PYDANTIC_LOGFIRE_TOKEN_PATTERN = re.compile( - r'^(?Ppylf_v(?P[0-9]+)_(?P[a-z]+)_)(?P[a-zA-Z0-9]+)$' -) METRICS_PREFERRED_TEMPORALITY = { Counter: AggregationTemporality.DELTA, @@ -140,24 +136,6 @@ """ -class _RegionData(TypedDict): - base_url: str - gcp_region: str - - -REGIONS: dict[str, _RegionData] = { - 'us': { - 'base_url': 'https://logfire-us.pydantic.dev', - 'gcp_region': 'us-east4', - }, - 'eu': { - 'base_url': 'https://logfire-eu.pydantic.dev', - 'gcp_region': 'europe-west4', - }, -} -"""The existing Logfire regions.""" - - @dataclass class ConsoleOptions: """Options for controlling console output.""" @@ -903,10 +881,8 @@ def add_span_processor(span_processor: SpanProcessor) -> None: # if we still don't have a token, try initializing a new project and writing a new creds file # note, we only do this if `send_to_logfire` is explicitly `True`, not 'if-token-present' if self.send_to_logfire is True and credentials is None: - credentials = LogfireCredentials.initialize_project( - logfire_api_url=self.advanced.base_url, - session=requests.Session(), - ) + client = LogfireClient.from_url(self.advanced.base_url) + credentials = LogfireCredentials.initialize_project(client=client) credentials.write_creds_file(self.data_dir) if credentials is not None: @@ -1352,96 +1328,12 @@ def from_token(cls, token: str, session: requests.Session, base_url: str) -> Sel logfire_api_url=base_url, ) - @classmethod - def _get_user_token_data(cls, logfire_api_url: str | None = None) -> tuple[str, str]: - """Get a token and its associated base API URL. - - Args: - logfire_api_url: An explicitly configured base API URL. If set, the token attached - to this URL will be used. If not provided, will prompt for a token to use if multiple - ones are available, or use the only one available otherwise. - - Returns: - A two-tuple, the first element being the token and the second element being the base API URL. - """ - if DEFAULT_FILE.is_file(): - data = cast(DefaultFile, read_toml_file(DEFAULT_FILE)) - if logfire_api_url is None: - tokens_list = list(data['tokens'].items()) - if len(tokens_list) == 1: - return cls._get_user_token_data(tokens_list[0][0]) - elif len(tokens_list) >= 2: # pragma: no branch - choices_str = '\n'.join( - f'{i}. {_get_token_repr(url, d["token"])}' for i, (url, d) in enumerate(tokens_list, start=1) - ) - int_choice = IntPrompt.ask( - f'Multiple user tokens found. Please select one:\n{choices_str}\n', - choices=[str(i) for i in range(1, len(data['tokens']) + 1)], - ) - url, token_data = tokens_list[int_choice - 1] - if is_logged_in(data, url): # pragma: no branch - return token_data['token'], url - elif is_logged_in(data, logfire_api_url): - return data['tokens'][logfire_api_url]['token'], logfire_api_url - - raise LogfireConfigError( - """You are not authenticated. Please run `logfire auth` to authenticate. - -If you are running in production, you can set the `LOGFIRE_TOKEN` environment variable. -To create a write token, refer to https://logfire.pydantic.dev/docs/guides/advanced/creating_write_tokens/ -""" - ) - - @classmethod - def get_current_user(cls, session: requests.Session, logfire_api_url: str | None = None) -> dict[str, Any] | None: - try: - user_token, logfire_api_url = cls._get_user_token_data(logfire_api_url=logfire_api_url) - except LogfireConfigError: - return None - return cls._get_user_for_token(user_token, session, logfire_api_url) - - @classmethod - def _get_user_for_token(cls, user_token: str, session: requests.Session, logfire_api_url: str) -> dict[str, Any]: - headers = {**COMMON_REQUEST_HEADERS, 'Authorization': user_token} - account_info_url = urljoin(logfire_api_url, '/v1/account/me') - try: - response = session.get(account_info_url, headers=headers) - UnexpectedResponse.raise_for_status(response) - except requests.RequestException as e: - raise LogfireConfigError('Error retrieving user information.') from e - return response.json() - - @classmethod - def get_user_projects(cls, session: requests.Session, logfire_api_url: str | None = None) -> list[dict[str, Any]]: - """Get list of projects that user has access to them. - - Args: - session: HTTP client session used to communicate with the Logfire API. - logfire_api_url: The Logfire API base URL. - - Returns: - List of user projects. - - Raises: - LogfireConfigError: If there was an error retrieving user projects. - """ - user_token, logfire_api_url = cls._get_user_token_data(logfire_api_url=logfire_api_url) - headers = {**COMMON_REQUEST_HEADERS, 'Authorization': user_token} - projects_url = urljoin(logfire_api_url, '/v1/projects/') - try: - response = session.get(projects_url, headers=headers) - UnexpectedResponse.raise_for_status(response) - except requests.RequestException as e: # pragma: no cover - raise LogfireConfigError('Error retrieving list of projects.') from e - return response.json() - @classmethod def use_existing_project( cls, *, - session: requests.Session, + client: LogfireClient, projects: list[dict[str, Any]], - logfire_api_url: str | None = None, organization: str | None = None, project_name: str | None = None, ) -> dict[str, Any] | None: @@ -1451,8 +1343,7 @@ def use_existing_project( the user has access to it. Otherwise, it asks the user to select a project interactively. Args: - session: HTTP client session used to communicate with the Logfire API. - logfire_api_url: The Logfire API base URL. + client: The Logfire client to use when making requests. projects: List of user projects. organization: Project organization. project_name: Name of project that has to be used. @@ -1463,9 +1354,6 @@ def use_existing_project( Raises: LogfireConfigError: If there was an error configuring the project. """ - user_token, logfire_api_url = cls._get_user_token_data(logfire_api_url=logfire_api_url) - headers = {**COMMON_REQUEST_HEADERS, 'Authorization': user_token} - org_message = '' org_flag = '' project_message = 'projects' @@ -1529,28 +1417,17 @@ def use_existing_project( choices=list(project_choices.keys()), default='1', ) - project_info_tuple = project_choices[selected_project_key] + project_info_tuple: tuple[str, str] = project_choices[selected_project_key] organization = project_info_tuple[0] project_name = project_info_tuple[1] - project_write_token_url = urljoin( - logfire_api_url, - f'/v1/organizations/{organization}/projects/{project_name}/write-tokens/', - ) - try: - response = session.post(project_write_token_url, headers=headers) - UnexpectedResponse.raise_for_status(response) - except requests.RequestException as e: - raise LogfireConfigError('Error creating project write token.') from e - - return response.json() + return client.create_write_token(organization, project_name) @classmethod def create_new_project( cls, *, - session: requests.Session, - logfire_api_url: str | None = None, + client: LogfireClient, organization: str | None = None, default_organization: bool = False, project_name: str | None = None, @@ -1561,8 +1438,7 @@ def create_new_project( Otherwise, it asks the user to select organization and enter a valid project name interactively. Args: - session: HTTP client session used to communicate with the Logfire API. - logfire_api_url: The Logfire API base URL. + client: The Logfire client to use when making requests. organization: The organization name of the new project. default_organization: Whether to create the project under the user default organization. project_name: The default name of the project. @@ -1573,24 +1449,15 @@ def create_new_project( Raises: LogfireConfigError: If there was an error creating projects. """ - user_token, logfire_api_url = cls._get_user_token_data(logfire_api_url=logfire_api_url) - headers = {**COMMON_REQUEST_HEADERS, 'Authorization': user_token} - - # Get user organizations - organizations_url = urljoin(logfire_api_url, '/v1/organizations/') - try: - response = session.get(organizations_url, headers=headers) - UnexpectedResponse.raise_for_status(response) - except requests.RequestException as e: - raise LogfireConfigError('Error retrieving list of organizations.') from e - organizations = [item['organization_name'] for item in response.json()] + organizations: list[str] = [item['organization_name'] for item in client.get_user_organizations()] if organization not in organizations: if len(organizations) > 1: # Get user default organization - user_details = cls._get_user_for_token(user_token, session, logfire_api_url) - assert user_details is not None - user_default_organization_name = user_details.get('default_organization', {}).get('organization_name') + user_details = client.get_user_information() + user_default_organization_name: str | None = user_details.get('default_organization', {}).get( + 'organization_name' + ) if default_organization and user_default_organization_name: organization = user_default_organization_name @@ -1599,7 +1466,7 @@ def create_new_project( '\nTo create and use a new project, please provide the following information:\n' 'Select the organization to create the project in', choices=organizations, - default=user_default_organization_name if user_default_organization_name else organizations[0], + default=user_default_organization_name or organizations[0], ) else: organization = organizations[0] @@ -1610,7 +1477,7 @@ def create_new_project( if not confirm: sys.exit(1) - project_name_default: str | None = default_project_name() + project_name_default: str = default_project_name() project_name_prompt = 'Enter the project name' while True: project_name = project_name or Prompt.ask(project_name_prompt, default=project_name_default) @@ -1624,46 +1491,35 @@ def create_new_project( default=project_name_default, ) - url = urljoin(logfire_api_url, f'/v1/projects/{organization}') try: - response = session.post(url, headers=headers, json={'project_name': project_name}) - if response.status_code == 409: - project_name_default = ... # type: ignore # this means the value is required - project_name_prompt = ( - f"\nA project with the name '{project_name}' already exists." - f' Please enter a different project name' - ) - project_name = None - continue - if response.status_code == 422: - error = response.json()['detail'][0] - if error['loc'] == ['body', 'project_name']: # pragma: no branch - project_name_default = ... # type: ignore # this means the value is required - project_name_prompt = ( - f'\nThe project name you entered is invalid:\n' - f'{error["msg"]}\n' - f'Please enter a different project name' - ) - project_name = None - continue - UnexpectedResponse.raise_for_status(response) - except requests.RequestException as e: - raise LogfireConfigError('Error creating new project.') from e + project = client.create_new_project(organization, project_name) + except ProjectAlreadyExists: + project_name_default = ... # type: ignore # this means the value is required + project_name_prompt = ( + f"\nA project with the name '{project_name}' already exists. Please enter a different project name" + ) + project_name = None + continue + except InvalidProjectName as exc: + project_name_default = ... # type: ignore # this means the value is required + project_name_prompt = ( + f'\nThe project name you entered is invalid:\n{exc.reason}\nPlease enter a different project name' + ) + project_name = None + continue else: - return response.json() + return project @classmethod def initialize_project( cls, *, - session: requests.Session, - logfire_api_url: str | None = None, + client: LogfireClient, ) -> Self: """Create a new project or use an existing project on logfire.dev requesting the given project name. Args: - session: HTTP client session used to communicate with the Logfire API. - logfire_api_url: The Logfire API base URL. + client: The Logfire client to use when making requests. Returns: The new credentials. @@ -1678,24 +1534,17 @@ def initialize_project( 'All data sent to Logfire must be associated with a project.\n' ) - _, logfire_api_url = cls._get_user_token_data(logfire_api_url=logfire_api_url) - - projects = cls.get_user_projects(session=session, logfire_api_url=logfire_api_url) + projects = client.get_user_projects() if projects: use_existing_projects = Confirm.ask('Do you want to use one of your existing projects? ', default=True) if use_existing_projects: # pragma: no branch - credentials = cls.use_existing_project( - session=session, logfire_api_url=logfire_api_url, projects=projects - ) + credentials = cls.use_existing_project(client=client, projects=projects) if not credentials: - credentials = cls.create_new_project( - session=session, - logfire_api_url=logfire_api_url, - ) + credentials = cls.create_new_project(client=client) try: - result = cls(**credentials, logfire_api_url=logfire_api_url) + result = cls(**credentials, logfire_api_url=client.base_url) Prompt.ask( f'Project initialized successfully. You will be able to view it at: {result.project_url}\n' 'Press Enter to continue' @@ -1738,23 +1587,6 @@ def _get_creds_file(creds_dir: Path) -> Path: return creds_dir / CREDENTIALS_FILENAME -def _get_token_repr(url: str, token: str) -> str: - region = 'us' - if match := PYDANTIC_LOGFIRE_TOKEN_PATTERN.match(token): - region = match.group('region') - if region not in REGIONS: - region = 'us' - - token_repr = f'{region.upper()} ({url}) - ' - if match: - # new_token, include prefix and 5 chars - token_repr += match.group('safe_part') + match.group('token')[:5] - else: - token_repr += token[:5] - token_repr += '****' - return token_repr - - def get_base_url_from_token(token: str) -> str: """Get the base API URL from the token's region.""" # default to US for tokens that were created before regions were added: diff --git a/logfire/_internal/utils.py b/logfire/_internal/utils.py index 4b71084f..d783174d 100644 --- a/logfire/_internal/utils.py +++ b/logfire/_internal/utils.py @@ -167,6 +167,8 @@ def span_to_dict(span: ReadableSpan) -> ReadableSpanDict: class UnexpectedResponse(RequestException): """An unexpected response was received from the server.""" + response: Response # type: ignore (guaranteed to exist) + def __init__(self, response: Response) -> None: super().__init__(f'Unexpected response: {response.status_code}', response=response) diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 00000000..b7d08382 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from pathlib import Path +from unittest.mock import patch + +import inline_snapshot.extra +import pytest +from inline_snapshot import snapshot + +from logfire._internal.auth import UserToken, UserTokenCollection +from logfire.exceptions import LogfireConfigError + + +@pytest.mark.parametrize( + ['base_url', 'token', 'expected'], + [ + ( + 'https://logfire-us.pydantic.dev', + 'pylf_v1_us_0kYhc414Ys2FNDRdt5vFB05xFx5NjVcbcBMy4Kp6PH0W', + 'US (https://logfire-us.pydantic.dev) - pylf_v1_us_0kYhc****', + ), + ( + 'https://logfire-eu.pydantic.dev', + 'pylf_v1_eu_0kYhc414Ys2FNDRdt5vFB05xFx5NjVcbcBMy4Kp6PH0W', + 'EU (https://logfire-eu.pydantic.dev) - pylf_v1_eu_0kYhc****', + ), + ( + 'https://logfire-us.pydantic.dev', + '0kYhc414Ys2FNDRdt5vFB05xFx5NjVcbcBMy4Kp6PH0W', + 'US (https://logfire-us.pydantic.dev) - 0kYhc****', + ), + ( + 'https://logfire-us.pydantic.dev', + 'pylf_v1_unknownregion_0kYhc414Ys2FNDRdt5vFB05xFx5NjVcbcBMy4Kp6PH0W', + 'US (https://logfire-us.pydantic.dev) - pylf_v1_unknownregion_0kYhc****', + ), + ], +) +def test_user_token_str(base_url: str, token: str, expected: str) -> None: + user_token = UserToken( + token=token, + base_url=base_url, + expiration='1970-01-01', + ) + assert str(user_token) == expected + + +def test_get_user_token_explicit_url(default_credentials: Path) -> None: + token_collection = UserTokenCollection(default_credentials) + + # https://logfire-us.pydantic.dev is the URL present in the default credentials fixture: + token = token_collection.get_token(base_url='https://logfire-us.pydantic.dev') + assert token.base_url == 'https://logfire-us.pydantic.dev' + + with inline_snapshot.extra.raises( + snapshot( + 'LogfireConfigError: No user token was found matching the https://logfire-eu.pydantic.dev Logfire URL. Please run `logfire auth` to authenticate.' + ) + ): + token_collection.get_token(base_url='https://logfire-eu.pydantic.dev') + + +def test_get_user_token_no_explicit_url(default_credentials: Path) -> None: + token_collection = UserTokenCollection(default_credentials) + + token = token_collection.get_token(base_url=None) + + # https://logfire-us.pydantic.dev is the URL present in the default credentials fixture: + assert token.base_url == 'https://logfire-us.pydantic.dev' + + +def test_get_user_token_input_choice(multiple_credentials: Path) -> None: + token_collection = UserTokenCollection(multiple_credentials) + + with patch('rich.prompt.IntPrompt.ask', side_effect=[1]): + token = token_collection.get_token(base_url=None) + # https://logfire-us.pydantic.dev is the first URL present in the multiple credentials fixture: + assert token.base_url == 'https://logfire-us.pydantic.dev' + + +def test_get_user_token_empty_credentials(tmp_path: Path) -> None: + empty_auth_file = tmp_path / 'default.toml' + empty_auth_file.touch() + + token_collection = UserTokenCollection(empty_auth_file) + with inline_snapshot.extra.raises( + snapshot('LogfireConfigError: No user tokens are available. Please run `logfire auth` to authenticate.') + ): + token_collection.get_token() + + +def test_get_user_token_expired_credentials(expired_credentials: Path) -> None: + token_collection = UserTokenCollection(expired_credentials) + + with inline_snapshot.extra.raises( + snapshot( + 'LogfireConfigError: User token US (https://logfire-us.pydantic.dev) - pylf_v1_us_0kYhc**** is expired. Please run `logfire auth` to authenticate.' + ) + ): + # https://logfire-us.pydantic.dev is the URL present in the expired credentials fixture: + token_collection.get_token(base_url='https://logfire-us.pydantic.dev') + + +def test_get_user_token_not_authenticated(default_credentials: Path) -> None: + token_collection = UserTokenCollection(default_credentials) + + with pytest.raises( + LogfireConfigError, + match=( + 'No user token was found matching the http://localhost:8234 Logfire URL. ' + 'Please run `logfire auth` to authenticate.' + ), + ): + # Use a port that we don't use for local development to reduce conflicts with local configuration + token_collection.get_token(base_url='http://localhost:8234') diff --git a/tests/test_cli.py b/tests/test_cli.py index 5b63dfa1..8f63e310 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -20,6 +20,7 @@ import logfire._internal.cli from logfire import VERSION +from logfire._internal.auth import UserToken from logfire._internal.cli import STANDARD_LIBRARY_PACKAGES, main from logfire._internal.config import LogfireCredentials, sanitize_project_name from logfire.exceptions import LogfireConfigError @@ -128,8 +129,8 @@ def test_whoami_logged_in( with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('123', 'http://localhost'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken(token='123', base_url='http://localhost', expiration='2099-12-31T23:59:59'), ) ) @@ -166,7 +167,7 @@ def test_whoami_default_dir( def test_whoami_no_token_no_url(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: auth_file = tmp_path / 'default.toml' - with patch('logfire._internal.cli.DEFAULT_FILE', auth_file), pytest.raises(SystemExit): + with patch('logfire._internal.auth.DEFAULT_FILE', auth_file), pytest.raises(SystemExit): main(['whoami']) assert 'Not logged in. Run `logfire auth` to log in.' in capsys.readouterr().err @@ -307,6 +308,8 @@ def new_find_spec(name: str) -> ModuleSpec | None: def test_auth(tmp_path: Path, webbrowser_error: bool, capsys: pytest.CaptureFixture[str]) -> None: auth_file = tmp_path / 'default.toml' with ExitStack() as stack: + stack.enter_context(patch('logfire._internal.auth.DEFAULT_FILE', auth_file)) + # Necessary to assert that credentials are written to the `auth_file` (which happens from the `cli` module) stack.enter_context(patch('logfire._internal.cli.DEFAULT_FILE', auth_file)) stack.enter_context(patch('logfire._internal.cli.input')) webbrowser_open = stack.enter_context( @@ -357,7 +360,7 @@ def test_auth(tmp_path: Path, webbrowser_error: bool, capsys: pytest.CaptureFixt def test_auth_temp_failure(tmp_path: Path) -> None: auth_file = tmp_path / 'default.toml' with ExitStack() as stack: - stack.enter_context(patch('logfire._internal.cli.DEFAULT_FILE', auth_file)) + stack.enter_context(patch('logfire._internal.auth.DEFAULT_FILE', auth_file)) stack.enter_context(patch('logfire._internal.cli.input')) stack.enter_context(patch('logfire._internal.cli.webbrowser.open')) @@ -382,7 +385,7 @@ def test_auth_temp_failure(tmp_path: Path) -> None: def test_auth_permanent_failure(tmp_path: Path) -> None: auth_file = tmp_path / 'default.toml' with ExitStack() as stack: - stack.enter_context(patch('logfire._internal.cli.DEFAULT_FILE', auth_file)) + stack.enter_context(patch('logfire._internal.auth.DEFAULT_FILE', auth_file)) stack.enter_context(patch('logfire._internal.cli.input')) stack.enter_context(patch('logfire._internal.cli.webbrowser.open')) @@ -400,7 +403,7 @@ def test_auth_permanent_failure(tmp_path: Path) -> None: def test_auth_on_authenticated_user(default_credentials: Path, capsys: pytest.CaptureFixture[str]) -> None: - with patch('logfire._internal.cli.DEFAULT_FILE', default_credentials): + with patch('logfire._internal.auth.DEFAULT_FILE', default_credentials): # US is the default region in the default credentials fixture: main(['--region', 'us', 'auth']) @@ -411,6 +414,8 @@ def test_auth_on_authenticated_user(default_credentials: Path, capsys: pytest.Ca def test_auth_no_region_specified(tmp_path: Path) -> None: auth_file = tmp_path / 'default.toml' with ExitStack() as stack: + stack.enter_context(patch('logfire._internal.auth.DEFAULT_FILE', auth_file)) + # Necessary to assert that credentials are written to the `auth_file` (which happens from the `cli` module) stack.enter_context(patch('logfire._internal.cli.DEFAULT_FILE', auth_file)) # 'not_an_int' is used as the first input to test that invalid inputs are supported, # '2' will result in the EU region being used: @@ -452,8 +457,10 @@ def test_projects_list(default_credentials: Path, capsys: pytest.CaptureFixture[ with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) @@ -480,8 +487,10 @@ def test_projects_list_no_project(default_credentials: Path, capsys: pytest.Capt with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) @@ -504,8 +513,10 @@ def test_projects_new_with_project_name_and_org( with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) @@ -544,8 +555,10 @@ def test_projects_new_with_project_name_without_org( with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) confirm_mock = stack.enter_context(patch('rich.prompt.Confirm.ask', side_effect=[True])) @@ -587,8 +600,10 @@ def test_projects_new_with_project_name_and_wrong_org( with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) confirm_mock = stack.enter_context(patch('rich.prompt.Confirm.ask', side_effect=[True])) @@ -629,8 +644,10 @@ def test_projects_new_with_project_name_and_default_org( with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) @@ -667,8 +684,10 @@ def test_projects_new_with_project_name_multiple_organizations( with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) prompt_mock = stack.enter_context(patch('rich.prompt.Prompt.ask', side_effect=['fake_org'])) @@ -722,8 +741,10 @@ def test_projects_new_with_project_name_and_default_org_multiple_organizations( with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) @@ -768,8 +789,10 @@ def test_projects_new_without_project_name( with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) prompt_mock = stack.enter_context(patch('rich.prompt.Prompt.ask', side_effect=['myproject', ''])) @@ -811,8 +834,10 @@ def test_projects_new_invalid_project_name( with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) prompt_mock = stack.enter_context(patch('rich.prompt.Prompt.ask', side_effect=['myproject', ''])) @@ -859,8 +884,10 @@ def test_projects_new_error(tmp_dir_cwd: Path, default_credentials: Path) -> Non with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) stack.enter_context(patch('logfire._internal.cli.LogfireCredentials.write_creds_file', side_effect=TypeError)) @@ -891,8 +918,10 @@ def test_projects_without_project_name_without_org( with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) confirm_mock = stack.enter_context(patch('rich.prompt.Confirm.ask', side_effect=[True])) @@ -936,8 +965,10 @@ def test_projects_new_get_organizations_error(tmp_dir_cwd: Path, default_credent with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) @@ -945,7 +976,7 @@ def test_projects_new_get_organizations_error(tmp_dir_cwd: Path, default_credent stack.enter_context(m) m.get('https://logfire-us.pydantic.dev/v1/organizations/', text='Error', status_code=500) - with pytest.raises(LogfireConfigError, match='Error retrieving list of organizations.'): + with pytest.raises(LogfireConfigError, match='Error retrieving list of organizations'): main(['projects', 'new']) @@ -953,8 +984,10 @@ def test_projects_new_get_user_info_error(tmp_dir_cwd: Path, default_credentials with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) @@ -967,7 +1000,7 @@ def test_projects_new_get_user_info_error(tmp_dir_cwd: Path, default_credentials ) m.get('https://logfire-us.pydantic.dev/v1/account/me', text='Error', status_code=500) - with pytest.raises(LogfireConfigError, match='Error retrieving user information.'): + with pytest.raises(LogfireConfigError, match='Error retrieving user information'): main(['projects', 'new']) @@ -975,8 +1008,10 @@ def test_projects_new_create_project_error(tmp_dir_cwd: Path, default_credential with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) stack.enter_context(patch('logfire._internal.cli.LogfireCredentials.write_creds_file', side_effect=TypeError)) @@ -987,7 +1022,7 @@ def test_projects_new_create_project_error(tmp_dir_cwd: Path, default_credential m.get('https://logfire-us.pydantic.dev/v1/organizations/', json=[{'organization_name': 'fake_org'}]) m.post('https://logfire-us.pydantic.dev/v1/projects/fake_org', text='Error', status_code=500) - with pytest.raises(LogfireConfigError, match='Error creating new project.'): + with pytest.raises(LogfireConfigError, match='Error creating new project'): main(['projects', 'new', 'myproject', '--org', 'fake_org']) @@ -995,8 +1030,10 @@ def test_projects_use(tmp_dir_cwd: Path, default_credentials: Path, capsys: pyte with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) @@ -1038,8 +1075,10 @@ def test_projects_use_without_project_name( with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) prompt_mock = stack.enter_context(patch('rich.prompt.Prompt.ask', side_effect=['1'])) @@ -1094,8 +1133,10 @@ def test_projects_use_multiple( with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) config_console = stack.enter_context(patch('logfire._internal.config.Console')) @@ -1157,8 +1198,10 @@ def test_projects_use_multiple_with_org( with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) @@ -1200,8 +1243,10 @@ def test_projects_use_wrong_project( with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) prompt_mock = stack.enter_context(patch('rich.prompt.Prompt.ask', side_effect=['y', '1'])) @@ -1254,8 +1299,10 @@ def test_projects_use_wrong_project_give_up( with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) config_console = stack.enter_context(patch('logfire._internal.config.Console')) @@ -1288,8 +1335,10 @@ def test_projects_use_without_projects(tmp_dir_cwd: Path, capsys: pytest.Capture with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) @@ -1312,8 +1361,10 @@ def test_projects_use_error(tmp_dir_cwd: Path, default_credentials: Path) -> Non with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) stack.enter_context(patch('logfire._internal.cli.LogfireCredentials.write_creds_file', side_effect=TypeError)) @@ -1344,8 +1395,10 @@ def test_projects_use_write_token_error(tmp_dir_cwd: Path, default_credentials: with ExitStack() as stack: stack.enter_context( patch( - 'logfire._internal.config.LogfireCredentials._get_user_token_data', - return_value=('', 'https://logfire-us.pydantic.dev'), + 'logfire._internal.auth.UserTokenCollection.get_token', + return_value=UserToken( + token='', base_url='https://logfire-us.pydantic.dev', expiration='2099-12-31T23:59:59' + ), ) ) stack.enter_context(patch('logfire._internal.cli.LogfireCredentials.write_creds_file', side_effect=TypeError)) @@ -1362,7 +1415,7 @@ def test_projects_use_write_token_error(tmp_dir_cwd: Path, default_credentials: status_code=500, ) - with pytest.raises(LogfireConfigError, match='Error creating project write token.'): + with pytest.raises(LogfireConfigError, match='Error creating project write token'): main(['projects', 'use', 'myproject', '--org', 'fake_org']) diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 00000000..7f57c9a0 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import pytest + +from logfire._internal.auth import UserToken +from logfire._internal.client import LogfireClient + + +def test_client_expired_token() -> None: + with pytest.raises(RuntimeError): + LogfireClient(user_token=UserToken(token='abc', base_url='http://localhost', expiration='1970-01-01T00:00:00')) diff --git a/tests/test_configure.py b/tests/test_configure.py index 018231cc..1562bcf1 100644 --- a/tests/test_configure.py +++ b/tests/test_configure.py @@ -51,7 +51,6 @@ ConsoleOptions, LogfireConfig, LogfireCredentials, - _get_token_repr, # type: ignore get_base_url_from_token, sanitize_project_name, ) @@ -904,7 +903,7 @@ def test_initialize_project_use_existing_project_no_projects(tmp_dir_cwd: Path, '[tokens."https://logfire-api.pydantic.dev"]\ntoken = "fake_user_token"\nexpiration = "2099-12-31T23:59:59"' ) with ExitStack() as stack: - stack.enter_context(mock.patch('logfire._internal.config.DEFAULT_FILE', auth_file)) + stack.enter_context(mock.patch('logfire._internal.auth.DEFAULT_FILE', auth_file)) confirm_mock = stack.enter_context(mock.patch('rich.prompt.Confirm.ask', side_effect=[True, True])) stack.enter_context(mock.patch('rich.prompt.Prompt.ask', side_effect=['', 'myproject', ''])) @@ -941,7 +940,7 @@ def test_initialize_project_use_existing_project(tmp_dir_cwd: Path, tmp_path: Pa '[tokens."https://logfire-api.pydantic.dev"]\ntoken = "fake_user_token"\nexpiration = "2099-12-31T23:59:59"' ) with ExitStack() as stack: - stack.enter_context(mock.patch('logfire._internal.config.DEFAULT_FILE', auth_file)) + stack.enter_context(mock.patch('logfire._internal.auth.DEFAULT_FILE', auth_file)) confirm_mock = stack.enter_context(mock.patch('rich.prompt.Confirm.ask', side_effect=[True, True])) prompt_mock = stack.enter_context(mock.patch('rich.prompt.Prompt.ask', side_effect=['1', ''])) @@ -999,7 +998,7 @@ def test_initialize_project_not_using_existing_project( '[tokens."https://logfire-api.pydantic.dev"]\ntoken = "fake_user_token"\nexpiration = "2099-12-31T23:59:59"' ) with ExitStack() as stack: - stack.enter_context(mock.patch('logfire._internal.config.DEFAULT_FILE', auth_file)) + stack.enter_context(mock.patch('logfire._internal.auth.DEFAULT_FILE', auth_file)) confirm_mock = stack.enter_context(mock.patch('rich.prompt.Confirm.ask', side_effect=[False, True])) prompt_mock = stack.enter_context(mock.patch('rich.prompt.Prompt.ask', side_effect=['my-project', ''])) @@ -1056,7 +1055,7 @@ def test_initialize_project_not_confirming_organization(tmp_path: Path) -> None: '[tokens."https://logfire-api.pydantic.dev"]\ntoken = "fake_user_token"\nexpiration = "2099-12-31T23:59:59"' ) with ExitStack() as stack: - stack.enter_context(mock.patch('logfire._internal.config.DEFAULT_FILE', auth_file)) + stack.enter_context(mock.patch('logfire._internal.auth.DEFAULT_FILE', auth_file)) confirm_mock = stack.enter_context(mock.patch('rich.prompt.Confirm.ask', side_effect=[False, False])) request_mocker = requests_mock.Mocker() @@ -1086,7 +1085,7 @@ def test_initialize_project_create_project(tmp_dir_cwd: Path, tmp_path: Path, ca '[tokens."https://logfire-api.pydantic.dev"]\ntoken = "fake_user_token"\nexpiration = "2099-12-31T23:59:59"' ) with ExitStack() as stack: - stack.enter_context(mock.patch('logfire._internal.config.DEFAULT_FILE', auth_file)) + stack.enter_context(mock.patch('logfire._internal.auth.DEFAULT_FILE', auth_file)) confirm_mock = stack.enter_context(mock.patch('rich.prompt.Confirm.ask', side_effect=[True, True])) prompt_mock = stack.enter_context( mock.patch( @@ -1211,7 +1210,7 @@ def test_initialize_project_create_project_default_organization(tmp_dir_cwd: Pat '[tokens."https://logfire-api.pydantic.dev"]\ntoken = "fake_user_token"\nexpiration = "2099-12-31T23:59:59"' ) with ExitStack() as stack: - stack.enter_context(mock.patch('logfire._internal.config.DEFAULT_FILE', auth_file)) + stack.enter_context(mock.patch('logfire._internal.auth.DEFAULT_FILE', auth_file)) prompt_mock = stack.enter_context( mock.patch('rich.prompt.Prompt.ask', side_effect=['fake_org', 'mytestproject1', '']) ) @@ -1271,11 +1270,9 @@ def test_send_to_logfire_true(tmp_path: Path) -> None: '[tokens."https://logfire-api.pydantic.dev"]\ntoken = "fake_user_token"\nexpiration = "2099-12-31T23:59:59"' ) with ExitStack() as stack: - stack.enter_context(mock.patch('logfire._internal.config.DEFAULT_FILE', auth_file)) + stack.enter_context(mock.patch('logfire._internal.auth.DEFAULT_FILE', auth_file)) stack.enter_context( - mock.patch( - 'logfire._internal.config.LogfireCredentials.get_user_projects', side_effect=RuntimeError('expected') - ) + mock.patch('logfire._internal.client.LogfireClient.get_user_projects', side_effect=RuntimeError('expected')) ) with pytest.raises(RuntimeError, match='^expected$'): configure(send_to_logfire=True, console=False, data_dir=data_dir) @@ -1445,92 +1442,6 @@ def test_load_creds_file_invalid_key(tmp_path: Path): LogfireCredentials.load_creds_file(creds_dir=tmp_path) -def test_get_user_token_data_explicit_url(default_credentials: Path): - with patch('logfire._internal.config.DEFAULT_FILE', default_credentials): - # https://logfire-us.pydantic.dev is the URL present in the default credentials fixture: - _, url = LogfireCredentials._get_user_token_data(logfire_api_url='https://logfire-us.pydantic.dev') # type: ignore - assert url == 'https://logfire-us.pydantic.dev' - - with pytest.raises(LogfireConfigError): - LogfireCredentials._get_user_token_data(logfire_api_url='https://logfire-eu.pydantic.dev') # type: ignore - - -def test_get_user_token_data_no_explicit_url(default_credentials: Path): - with patch('logfire._internal.config.DEFAULT_FILE', default_credentials): - _, url = LogfireCredentials._get_user_token_data(logfire_api_url=None) # type: ignore - # https://logfire-us.pydantic.dev is the URL present in the default credentials fixture: - assert url == 'https://logfire-us.pydantic.dev' - - -def test_get_user_token_data_input_choice(multiple_credentials: Path): - with ( - patch('logfire._internal.config.DEFAULT_FILE', multiple_credentials), - patch('rich.prompt.IntPrompt.ask', side_effect=[1]), - ): - _, url = LogfireCredentials._get_user_token_data(logfire_api_url=None) # type: ignore - # https://logfire-us.pydantic.dev is the first URL present in the multiple credentials fixture: - assert url == 'https://logfire-us.pydantic.dev' - - -@pytest.mark.parametrize( - ['url', 'token', 'expected'], - [ - ( - 'https://logfire-us.pydantic.dev', - 'pylf_v1_us_0kYhc414Ys2FNDRdt5vFB05xFx5NjVcbcBMy4Kp6PH0W', - 'US (https://logfire-us.pydantic.dev) - pylf_v1_us_0kYhc****', - ), - ( - 'https://logfire-eu.pydantic.dev', - 'pylf_v1_eu_0kYhc414Ys2FNDRdt5vFB05xFx5NjVcbcBMy4Kp6PH0W', - 'EU (https://logfire-eu.pydantic.dev) - pylf_v1_eu_0kYhc****', - ), - ( - 'https://logfire-us.pydantic.dev', - '0kYhc414Ys2FNDRdt5vFB05xFx5NjVcbcBMy4Kp6PH0W', - 'US (https://logfire-us.pydantic.dev) - 0kYhc****', - ), - ( - 'https://logfire-us.pydantic.dev', - 'pylf_v1_unknownregion_0kYhc414Ys2FNDRdt5vFB05xFx5NjVcbcBMy4Kp6PH0W', - 'US (https://logfire-us.pydantic.dev) - pylf_v1_unknownregion_0kYhc****', - ), - ], -) -def test_get_token_repr(url: str, token: str, expected: str): - assert _get_token_repr(url, token) == expected - - -def test_get_user_token_data_no_credentials(tmp_path: Path): - with patch('logfire._internal.config.DEFAULT_FILE', tmp_path): - with pytest.raises(LogfireConfigError): - LogfireCredentials._get_user_token_data() # type: ignore - - -def test_get_user_token_data_empty_credentials(tmp_path: Path): - empty_auth_file = tmp_path / 'default.toml' - empty_auth_file.touch() - with patch('logfire._internal.config.DEFAULT_FILE', tmp_path): - with pytest.raises(LogfireConfigError): - LogfireCredentials._get_user_token_data() # type: ignore - - -def test_get_user_token_data_expired_credentials(expired_credentials: Path): - with patch('logfire._internal.config.DEFAULT_FILE', expired_credentials): - with pytest.raises(LogfireConfigError): - # https://logfire-us.pydantic.dev is the URL present in the expired credentials fixture: - LogfireCredentials._get_user_token_data(logfire_api_url='https://logfire-us.pydantic.dev') # type: ignore - - -def test_get_user_token_data_not_authenticated(default_credentials: Path): - with patch('logfire._internal.config.DEFAULT_FILE', default_credentials): - with pytest.raises( - LogfireConfigError, match='You are not authenticated. Please run `logfire auth` to authenticate.' - ): - # Use a port that we don't use for local development to reduce conflicts with local configuration - LogfireCredentials._get_user_token_data(logfire_api_url='http://localhost:8234') # type: ignore - - def test_initialize_credentials_from_token_unreachable(): with pytest.warns( UserWarning,