Skip to content

Refactor user tokens, introduce Logfire client #981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 156 additions & 18 deletions logfire/_internal/auth.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,189 @@
from __future__ import annotations

import platform
import re
import sys
import warnings
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import TypedDict
from typing import TypedDict, cast
from urllib.parse import urljoin

if sys.version_info >= (3, 9):
from functools import cache
else:
from functools import lru_cache

cache = lru_cache(maxsize=None)

import requests
from rich.prompt import IntPrompt
from typing_extensions import Self

from logfire.exceptions import LogfireConfigError

from .utils import UnexpectedResponse
from .utils import UnexpectedResponse, read_toml_file

HOME_LOGFIRE = Path.home() / '.logfire'
"""Folder used to store global configuration, and user tokens."""
DEFAULT_FILE = HOME_LOGFIRE / 'default.toml'
"""File used to store user tokens."""


PYDANTIC_LOGFIRE_TOKEN_PATTERN = re.compile(
r'^(?P<safe_part>pylf_v(?P<version>[0-9]+)_(?P<region>[a-z]+)_)(?P<token>[a-zA-Z0-9]+)$'
)


class _RegionData(TypedDict):
base_url: str
gcp_region: str


REGIONS: dict[str, _RegionData] = {
'us': {
'base_url': 'https://logfire-us.pydantic.dev',
'gcp_region': 'us-east4',
},
'eu': {
'base_url': 'https://logfire-eu.pydantic.dev',
'gcp_region': 'europe-west4',
},
}
"""The existing Logfire regions."""


class UserTokenData(TypedDict):
"""User token data."""

token: str
expiration: str


class DefaultFile(TypedDict):
"""Content of the default.toml file."""
class UserTokensFileData(TypedDict):
"""Content of the file containing the user tokens."""

tokens: dict[str, UserTokenData]


@dataclass
class UserToken:
"""A user token."""

token: str
base_url: str
expiration: str

@classmethod
def from_user_token_data(cls, base_url: str, token: UserTokenData) -> Self:
return cls(
token=token['token'],
base_url=base_url,
expiration=token['expiration'],
)

@property
def is_expired(self) -> bool:
return datetime.now(tz=timezone.utc) >= datetime.fromisoformat(self.expiration.rstrip('Z')).replace(
tzinfo=timezone.utc
)

def __str__(self) -> str:
# TODO define in this module?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, yes

from .config import PYDANTIC_LOGFIRE_TOKEN_PATTERN, REGIONS

region = 'us'
if match := PYDANTIC_LOGFIRE_TOKEN_PATTERN.match(self.token):
region = match.group('region')
if region not in REGIONS:
region = 'us'
Comment on lines +94 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this should be a separate method/function


token_repr = f'{region.upper()} ({self.base_url}) - '
if match:
token_repr += match.group('safe_part') + match.group('token')[:5]
else:
token_repr += self.token[:5]
token_repr += '****'
return token_repr


@dataclass
class UserTokenCollection:
"""A collection of user tokens."""

user_tokens: dict[str, UserToken]
"""A mapping between base URLs and user tokens."""

@classmethod
def empty(cls) -> Self:
"""Create an empty user token collection."""
return cls(user_tokens={})

@classmethod
def from_file_data(cls, file_data: UserTokensFileData) -> Self:
return cls(user_tokens={url: UserToken(base_url=url, **data) for url, data in file_data['tokens'].items()})

@classmethod
def from_tokens_file(cls, file: Path) -> Self:
return cls.from_file_data(cast(UserTokensFileData, read_toml_file(file)))

def get_token(self, base_url: str | None = None) -> UserToken:
tokens_list = list(self.user_tokens.values())
if base_url is not None:
token = next((t for t in tokens_list if t.base_url == base_url), None)
if token is None:
raise LogfireConfigError(
f'No user token was found matching the {base_url} Logfire URL. '
'Please run `logfire auth` to authenticate.'
)
else:
if len(tokens_list) == 1:
token = tokens_list[0]
elif len(tokens_list) >= 2:
choices_str = '\n'.join(
f'{i}. {token} ({"expired" if token.is_expired else "valid"})'
for i, token in enumerate(tokens_list, start=1)
)
int_choice = IntPrompt.ask(
f'Multiple user tokens found. Please select one:\n{choices_str}\n',
choices=[str(i) for i in range(1, len(tokens_list) + 1)],
)
token = tokens_list[int_choice - 1]
else: # tokens_list == []
raise LogfireConfigError('No user tokens are available. Please run `logfire auth` to authenticate.')

if token.is_expired:
raise LogfireConfigError(f'User token {token} is expired. Pleas run `logfire auth` to authenticate.')
return token

def is_logged_in(self, base_url: str | None = None) -> bool:
if base_url is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base_url is None case seems weird to me

tokens = (t for t in self.user_tokens.values() if t.base_url == base_url)
else:
tokens = self.user_tokens.values()
return any(not t.is_expired for t in tokens)

def add_token(self, base_url: str, token: UserTokenData) -> UserToken:
user_token = UserToken.from_user_token_data(base_url, token)
self.user_tokens[base_url] = UserToken.from_user_token_data(base_url, token)
return user_token

def dump(self, path: Path) -> None:
# There's no standard library package to write TOML files, so we'll write it manually.
with path.open('w') as f:
for base_url, user_token in self.user_tokens.items():
f.write(f'[tokens."{base_url}"]\n')
f.write(f'token = "{user_token.token}"\n')
f.write(f'expiration = "{user_token.expiration}"\n')


@cache
def default_token_collection() -> UserTokenCollection:
"""The default token collection, created from the `~/.logfire/default.toml` file."""
return UserTokenCollection.from_tokens_file(DEFAULT_FILE)


class NewDeviceFlow(TypedDict):
"""Matches model of the same name in the backend."""

Expand Down Expand Up @@ -91,17 +243,3 @@ def poll_for_token(session: requests.Session, device_code: str, base_api_url: st
opt_user_token: UserTokenData | None = res.json()
if opt_user_token:
return opt_user_token


def is_logged_in(data: DefaultFile, logfire_url: str) -> bool:
"""Check if the user is logged in.

Returns:
True if the user is logged in, False otherwise.
"""
for url, info in data['tokens'].items(): # pragma: no branch
# token expirations are in UTC
expiry_date = datetime.fromisoformat(info['expiration'].rstrip('Z')).replace(tzinfo=timezone.utc)
if url == logfire_url and datetime.now(tz=timezone.utc) < expiry_date: # pragma: no branch
return True
return False # pragma: no cover
65 changes: 32 additions & 33 deletions logfire/_internal/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import webbrowser
from operator import itemgetter
from pathlib import Path
from typing import Any, Sequence, cast
from typing import Any, Sequence
from urllib.parse import urlparse

import requests
Expand All @@ -23,11 +23,18 @@
from logfire.propagate import ContextCarrier, get_context

from ..version import VERSION
from .auth import DEFAULT_FILE, HOME_LOGFIRE, DefaultFile, is_logged_in, poll_for_token, request_device_code
from .auth import (
DEFAULT_FILE,
HOME_LOGFIRE,
UserTokenCollection,
default_token_collection,
poll_for_token,
request_device_code,
)
from .client import LogfireClient
from .config import REGIONS, LogfireCredentials, get_base_url_from_token
from .config_params import ParamManager
from .tracer import SDKTracerProvider
from .utils import read_toml_file

BASE_OTEL_INTEGRATION_URL = 'https://opentelemetry-python-contrib.readthedocs.io/en/latest/instrumentation/'
BASE_DOCS_URL = 'https://logfire.pydantic.dev/docs'
Expand Down Expand Up @@ -197,16 +204,14 @@ def parse_auth(args: argparse.Namespace) -> None:

This will authenticate your machine with Logfire and store the credentials.
"""
logfire_url = args.logfire_url
logfire_url: str | None = args.logfire_url

if DEFAULT_FILE.is_file():
data = cast(DefaultFile, read_toml_file(DEFAULT_FILE))
tokens_collection = default_token_collection()
else:
data: DefaultFile = {'tokens': {}}
tokens_collection = UserTokenCollection.empty()

if logfire_url:
logged_in = is_logged_in(data, logfire_url)
else:
logged_in = any(is_logged_in(data, url) for url in data['tokens'])
logged_in = tokens_collection.is_logged_in(logfire_url)

if logged_in:
sys.stderr.writelines(
Expand Down Expand Up @@ -255,22 +260,19 @@ def parse_auth(args: argparse.Namespace) -> None:
)
)

data['tokens'][logfire_url] = poll_for_token(args._session, device_code, logfire_url)
tokens_collection.add_token(logfire_url, poll_for_token(args._session, device_code, logfire_url))
sys.stderr.write('Successfully authenticated!\n')

# There's no standard library package to write TOML files, so we'll write it manually.
with DEFAULT_FILE.open('w') as f:
for url, info in data['tokens'].items():
f.write(f'[tokens."{url}"]\n')
f.write(f'token = "{info["token"]}"\n')
f.write(f'expiration = "{info["expiration"]}"\n')

tokens_collection.dump(DEFAULT_FILE)
sys.stderr.write(f'\nYour Logfire credentials are stored in {DEFAULT_FILE}\n')


def parse_list_projects(args: argparse.Namespace) -> None:
"""List user projects."""
projects = LogfireCredentials.get_user_projects(session=args._session, logfire_api_url=args.logfire_url)
logfire_url: str | None = args.logfire_url
client = LogfireClient.from_url(logfire_url)

projects = client.get_user_projects()
if projects:
sys.stderr.write(
_pretty_table(
Expand Down Expand Up @@ -299,42 +301,39 @@ def _write_credentials(project_info: dict[str, Any], data_dir: Path, logfire_api
def parse_create_new_project(args: argparse.Namespace) -> None:
"""Create a new project."""
data_dir = Path(args.data_dir)
logfire_url = args.logfire_url
if logfire_url is None: # pragma: no cover
_, logfire_url = LogfireCredentials._get_user_token_data() # type: ignore
logfire_url: str | None = args.logfire_url
client = LogfireClient.from_url(logfire_url)

project_name = args.project_name
organization = args.org
default_organization = args.default_org
project_info = LogfireCredentials.create_new_project(
session=args._session,
logfire_api_url=logfire_url,
client=client,
organization=organization,
default_organization=default_organization,
project_name=project_name,
)
credentials = _write_credentials(project_info, data_dir, logfire_url)
credentials = _write_credentials(project_info, data_dir, client.base_url)
sys.stderr.write(f'Project created successfully. You will be able to view it at: {credentials.project_url}\n')


def parse_use_project(args: argparse.Namespace) -> None:
"""Use an existing project."""
data_dir = Path(args.data_dir)
logfire_url = args.logfire_url
if logfire_url is None: # pragma: no cover
_, logfire_url = LogfireCredentials._get_user_token_data() # type: ignore
logfire_url: str | None = args.logfire_url
client = LogfireClient.from_url(logfire_url)

project_name = args.project_name
organization = args.org

projects = LogfireCredentials.get_user_projects(session=args._session, logfire_api_url=logfire_url)
projects = client.get_user_projects()
project_info = LogfireCredentials.use_existing_project(
session=args._session,
logfire_api_url=logfire_url,
client=client,
projects=projects,
organization=organization,
project_name=project_name,
)
if project_info:
credentials = _write_credentials(project_info, data_dir, logfire_url)
credentials = _write_credentials(project_info, data_dir, client.base_url)
sys.stderr.write(
f'Project configured successfully. You will be able to view it at: {credentials.project_url}\n'
)
Expand Down
Loading
Loading