Skip to content

Commit e4ec43e

Browse files
committed
Inject a session instead of an engine
1 parent 0e578f0 commit e4ec43e

File tree

3 files changed

+137
-128
lines changed

3 files changed

+137
-128
lines changed

fastapi_users_db_sqlmodel/__init__.py

+80-96
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
"""FastAPI Users database adapter for SQLModel."""
22
import uuid
3-
from typing import Callable, Generic, Optional, Type, TypeVar
3+
from typing import Generic, Optional, Type, TypeVar
44

55
from fastapi_users.db.base import BaseUserDatabase
66
from fastapi_users.models import BaseOAuthAccount, BaseUserDB
77
from pydantic import UUID4, EmailStr
8-
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
9-
from sqlalchemy.future import Engine
10-
from sqlalchemy.orm import selectinload, sessionmaker
8+
from sqlalchemy.ext.asyncio import AsyncSession
9+
from sqlalchemy.orm import selectinload
1110
from sqlmodel import Field, Session, SQLModel, func, select
1211

1312
__version__ = "0.0.3"
@@ -48,80 +47,74 @@ class SQLModelUserDatabase(Generic[UD, OA], BaseUserDatabase[UD]):
4847
Database adapter for SQLModel.
4948
5049
:param user_db_model: SQLModel model of a DB representation of a user.
51-
:param engine: SQLAlchemy engine.
50+
:param session: SQLAlchemy session.
5251
"""
5352

54-
engine: Engine
53+
session: Session
5554
oauth_account_model: Optional[Type[OA]]
5655

5756
def __init__(
5857
self,
5958
user_db_model: Type[UD],
60-
engine: Engine,
59+
session: Session,
6160
oauth_account_model: Optional[Type[OA]] = None,
6261
):
6362
super().__init__(user_db_model)
64-
self.engine = engine
63+
self.session = session
6564
self.oauth_account_model = oauth_account_model
6665

6766
async def get(self, id: UUID4) -> Optional[UD]:
6867
"""Get a single user by id."""
69-
with Session(self.engine) as session:
70-
return session.get(self.user_db_model, id)
68+
return self.session.get(self.user_db_model, id)
7169

7270
async def get_by_email(self, email: str) -> Optional[UD]:
7371
"""Get a single user by email."""
74-
with Session(self.engine) as session:
75-
statement = select(self.user_db_model).where(
76-
func.lower(self.user_db_model.email) == func.lower(email)
77-
)
78-
results = session.exec(statement)
79-
return results.first()
72+
statement = select(self.user_db_model).where(
73+
func.lower(self.user_db_model.email) == func.lower(email)
74+
)
75+
results = self.session.exec(statement)
76+
return results.first()
8077

8178
async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]:
8279
"""Get a single user by OAuth account id."""
8380
if not self.oauth_account_model:
8481
raise NotSetOAuthAccountTableError()
85-
with Session(self.engine) as session:
86-
statement = (
87-
select(self.oauth_account_model)
88-
.where(self.oauth_account_model.oauth_name == oauth)
89-
.where(self.oauth_account_model.account_id == account_id)
90-
)
91-
results = session.exec(statement)
92-
oauth_account = results.first()
93-
if oauth_account:
94-
user = oauth_account.user # type: ignore
95-
return user
96-
return None
82+
statement = (
83+
select(self.oauth_account_model)
84+
.where(self.oauth_account_model.oauth_name == oauth)
85+
.where(self.oauth_account_model.account_id == account_id)
86+
)
87+
results = self.session.exec(statement)
88+
oauth_account = results.first()
89+
if oauth_account:
90+
user = oauth_account.user # type: ignore
91+
return user
92+
return None
9793

9894
async def create(self, user: UD) -> UD:
9995
"""Create a user."""
100-
with Session(self.engine) as session:
101-
session.add(user)
102-
if self.oauth_account_model is not None:
103-
for oauth_account in user.oauth_accounts: # type: ignore
104-
session.add(oauth_account)
105-
session.commit()
106-
session.refresh(user)
107-
return user
96+
self.session.add(user)
97+
if self.oauth_account_model is not None:
98+
for oauth_account in user.oauth_accounts: # type: ignore
99+
self.session.add(oauth_account)
100+
self.session.commit()
101+
self.session.refresh(user)
102+
return user
108103

109104
async def update(self, user: UD) -> UD:
110105
"""Update a user."""
111-
with Session(self.engine) as session:
112-
session.add(user)
113-
if self.oauth_account_model is not None:
114-
for oauth_account in user.oauth_accounts: # type: ignore
115-
session.add(oauth_account)
116-
session.commit()
117-
session.refresh(user)
118-
return user
106+
self.session.add(user)
107+
if self.oauth_account_model is not None:
108+
for oauth_account in user.oauth_accounts: # type: ignore
109+
self.session.add(oauth_account)
110+
self.session.commit()
111+
self.session.refresh(user)
112+
return user
119113

120114
async def delete(self, user: UD) -> None:
121115
"""Delete a user."""
122-
with Session(self.engine) as session:
123-
session.delete(user)
124-
session.commit()
116+
self.session.delete(user)
117+
self.session.commit()
125118

126119

127120
class SQLModelUserDatabaseAsync(Generic[UD, OA], BaseUserDatabase[UD]):
@@ -132,81 +125,72 @@ class SQLModelUserDatabaseAsync(Generic[UD, OA], BaseUserDatabase[UD]):
132125
:param engine: SQLAlchemy async engine.
133126
"""
134127

135-
engine: AsyncEngine
128+
session: AsyncSession
136129
oauth_account_model: Optional[Type[OA]]
137130

138131
def __init__(
139132
self,
140133
user_db_model: Type[UD],
141-
engine: AsyncEngine,
134+
session: AsyncSession,
142135
oauth_account_model: Optional[Type[OA]] = None,
143136
):
144137
super().__init__(user_db_model)
145-
self.engine = engine
138+
self.session = session
146139
self.oauth_account_model = oauth_account_model
147-
self.session_maker: Callable[[], AsyncSession] = sessionmaker(
148-
self.engine, class_=AsyncSession, expire_on_commit=False
149-
)
150140

151141
async def get(self, id: UUID4) -> Optional[UD]:
152142
"""Get a single user by id."""
153-
async with self.session_maker() as session:
154-
return await session.get(self.user_db_model, id)
143+
return await self.session.get(self.user_db_model, id)
155144

156145
async def get_by_email(self, email: str) -> Optional[UD]:
157146
"""Get a single user by email."""
158-
async with self.session_maker() as session:
159-
statement = select(self.user_db_model).where(
160-
func.lower(self.user_db_model.email) == func.lower(email)
161-
)
162-
results = await session.execute(statement)
163-
object = results.first()
164-
if object is None:
165-
return None
166-
return object[0]
147+
statement = select(self.user_db_model).where(
148+
func.lower(self.user_db_model.email) == func.lower(email)
149+
)
150+
results = await self.session.execute(statement)
151+
object = results.first()
152+
if object is None:
153+
return None
154+
return object[0]
167155

168156
async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]:
169157
"""Get a single user by OAuth account id."""
170158
if not self.oauth_account_model:
171159
raise NotSetOAuthAccountTableError()
172-
async with self.session_maker() as session:
173-
statement = (
174-
select(self.oauth_account_model)
175-
.where(self.oauth_account_model.oauth_name == oauth)
176-
.where(self.oauth_account_model.account_id == account_id)
177-
.options(selectinload(self.oauth_account_model.user)) # type: ignore
178-
)
179-
results = await session.execute(statement)
180-
oauth_account = results.first()
181-
if oauth_account:
182-
user = oauth_account[0].user
183-
return user
184-
return None
160+
statement = (
161+
select(self.oauth_account_model)
162+
.where(self.oauth_account_model.oauth_name == oauth)
163+
.where(self.oauth_account_model.account_id == account_id)
164+
.options(selectinload(self.oauth_account_model.user)) # type: ignore
165+
)
166+
results = await self.session.execute(statement)
167+
oauth_account = results.first()
168+
if oauth_account:
169+
user = oauth_account[0].user
170+
return user
171+
return None
185172

186173
async def create(self, user: UD) -> UD:
187174
"""Create a user."""
188-
async with self.session_maker() as session:
189-
session.add(user)
190-
if self.oauth_account_model is not None:
191-
for oauth_account in user.oauth_accounts: # type: ignore
192-
session.add(oauth_account)
193-
await session.commit()
194-
await session.refresh(user)
195-
return user
175+
self.session.add(user)
176+
if self.oauth_account_model is not None:
177+
for oauth_account in user.oauth_accounts: # type: ignore
178+
self.session.add(oauth_account)
179+
await self.session.commit()
180+
await self.session.refresh(user)
181+
return user
196182

197183
async def update(self, user: UD) -> UD:
198184
"""Update a user."""
199-
async with self.session_maker() as session:
200-
session.add(user)
201-
if self.oauth_account_model is not None:
202-
for oauth_account in user.oauth_accounts: # type: ignore
203-
session.add(oauth_account)
204-
await session.commit()
205-
await session.refresh(user)
206-
return user
185+
self.session.add(user)
186+
if self.oauth_account_model is not None:
187+
for oauth_account in user.oauth_accounts: # type: ignore
188+
self.session.add(oauth_account)
189+
await self.session.commit()
190+
await self.session.refresh(user)
191+
return user
207192

208193
async def delete(self, user: UD) -> None:
209194
"""Delete a user."""
210-
async with self.session_maker() as session:
211-
await session.delete(user)
212-
await session.commit()
195+
await self.session.delete(user)
196+
await self.session.commit()

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
aiosqlite >= 0.17.0
2-
fastapi-users >= 6.1.2
2+
fastapi-users >= 8.0.0b3
33
sqlmodel >=0.0.4,<0.1.0

0 commit comments

Comments
 (0)