1
1
"""FastAPI Users database adapter for SQLModel."""
2
2
import uuid
3
- from typing import Callable , Generic , Optional , Type , TypeVar
3
+ from typing import Generic , Optional , Type , TypeVar
4
4
5
5
from fastapi_users .db .base import BaseUserDatabase
6
6
from fastapi_users .models import BaseOAuthAccount , BaseUserDB
7
7
from pydantic import UUID4 , EmailStr
8
- from sqlalchemy .ext .asyncio import AsyncEngine , AsyncSession
9
- from sqlalchemy .future import Engine
10
- from sqlalchemy .orm import selectinload , sessionmaker
8
+ from sqlalchemy .ext .asyncio import AsyncSession
9
+ from sqlalchemy .orm import selectinload
11
10
from sqlmodel import Field , Session , SQLModel , func , select
12
11
13
12
__version__ = "0.0.3"
@@ -48,80 +47,74 @@ class SQLModelUserDatabase(Generic[UD, OA], BaseUserDatabase[UD]):
48
47
Database adapter for SQLModel.
49
48
50
49
:param user_db_model: SQLModel model of a DB representation of a user.
51
- :param engine : SQLAlchemy engine .
50
+ :param session : SQLAlchemy session .
52
51
"""
53
52
54
- engine : Engine
53
+ session : Session
55
54
oauth_account_model : Optional [Type [OA ]]
56
55
57
56
def __init__ (
58
57
self ,
59
58
user_db_model : Type [UD ],
60
- engine : Engine ,
59
+ session : Session ,
61
60
oauth_account_model : Optional [Type [OA ]] = None ,
62
61
):
63
62
super ().__init__ (user_db_model )
64
- self .engine = engine
63
+ self .session = session
65
64
self .oauth_account_model = oauth_account_model
66
65
67
66
async def get (self , id : UUID4 ) -> Optional [UD ]:
68
67
"""Get a single user by id."""
69
- with Session (self .engine ) as session :
70
- return session .get (self .user_db_model , id )
68
+ return self .session .get (self .user_db_model , id )
71
69
72
70
async def get_by_email (self , email : str ) -> Optional [UD ]:
73
71
"""Get a single user by email."""
74
- with Session (self .engine ) as session :
75
- statement = select (self .user_db_model ).where (
76
- func .lower (self .user_db_model .email ) == func .lower (email )
77
- )
78
- results = session .exec (statement )
79
- return results .first ()
72
+ statement = select (self .user_db_model ).where (
73
+ func .lower (self .user_db_model .email ) == func .lower (email )
74
+ )
75
+ results = self .session .exec (statement )
76
+ return results .first ()
80
77
81
78
async def get_by_oauth_account (self , oauth : str , account_id : str ) -> Optional [UD ]:
82
79
"""Get a single user by OAuth account id."""
83
80
if not self .oauth_account_model :
84
81
raise NotSetOAuthAccountTableError ()
85
- with Session (self .engine ) as session :
86
- statement = (
87
- select (self .oauth_account_model )
88
- .where (self .oauth_account_model .oauth_name == oauth )
89
- .where (self .oauth_account_model .account_id == account_id )
90
- )
91
- results = session .exec (statement )
92
- oauth_account = results .first ()
93
- if oauth_account :
94
- user = oauth_account .user # type: ignore
95
- return user
96
- return None
82
+ statement = (
83
+ select (self .oauth_account_model )
84
+ .where (self .oauth_account_model .oauth_name == oauth )
85
+ .where (self .oauth_account_model .account_id == account_id )
86
+ )
87
+ results = self .session .exec (statement )
88
+ oauth_account = results .first ()
89
+ if oauth_account :
90
+ user = oauth_account .user # type: ignore
91
+ return user
92
+ return None
97
93
98
94
async def create (self , user : UD ) -> UD :
99
95
"""Create a user."""
100
- with Session (self .engine ) as session :
101
- session .add (user )
102
- if self .oauth_account_model is not None :
103
- for oauth_account in user .oauth_accounts : # type: ignore
104
- session .add (oauth_account )
105
- session .commit ()
106
- session .refresh (user )
107
- return user
96
+ self .session .add (user )
97
+ if self .oauth_account_model is not None :
98
+ for oauth_account in user .oauth_accounts : # type: ignore
99
+ self .session .add (oauth_account )
100
+ self .session .commit ()
101
+ self .session .refresh (user )
102
+ return user
108
103
109
104
async def update (self , user : UD ) -> UD :
110
105
"""Update a user."""
111
- with Session (self .engine ) as session :
112
- session .add (user )
113
- if self .oauth_account_model is not None :
114
- for oauth_account in user .oauth_accounts : # type: ignore
115
- session .add (oauth_account )
116
- session .commit ()
117
- session .refresh (user )
118
- return user
106
+ self .session .add (user )
107
+ if self .oauth_account_model is not None :
108
+ for oauth_account in user .oauth_accounts : # type: ignore
109
+ self .session .add (oauth_account )
110
+ self .session .commit ()
111
+ self .session .refresh (user )
112
+ return user
119
113
120
114
async def delete (self , user : UD ) -> None :
121
115
"""Delete a user."""
122
- with Session (self .engine ) as session :
123
- session .delete (user )
124
- session .commit ()
116
+ self .session .delete (user )
117
+ self .session .commit ()
125
118
126
119
127
120
class SQLModelUserDatabaseAsync (Generic [UD , OA ], BaseUserDatabase [UD ]):
@@ -132,81 +125,72 @@ class SQLModelUserDatabaseAsync(Generic[UD, OA], BaseUserDatabase[UD]):
132
125
:param engine: SQLAlchemy async engine.
133
126
"""
134
127
135
- engine : AsyncEngine
128
+ session : AsyncSession
136
129
oauth_account_model : Optional [Type [OA ]]
137
130
138
131
def __init__ (
139
132
self ,
140
133
user_db_model : Type [UD ],
141
- engine : AsyncEngine ,
134
+ session : AsyncSession ,
142
135
oauth_account_model : Optional [Type [OA ]] = None ,
143
136
):
144
137
super ().__init__ (user_db_model )
145
- self .engine = engine
138
+ self .session = session
146
139
self .oauth_account_model = oauth_account_model
147
- self .session_maker : Callable [[], AsyncSession ] = sessionmaker (
148
- self .engine , class_ = AsyncSession , expire_on_commit = False
149
- )
150
140
151
141
async def get (self , id : UUID4 ) -> Optional [UD ]:
152
142
"""Get a single user by id."""
153
- async with self .session_maker () as session :
154
- return await session .get (self .user_db_model , id )
143
+ return await self .session .get (self .user_db_model , id )
155
144
156
145
async def get_by_email (self , email : str ) -> Optional [UD ]:
157
146
"""Get a single user by email."""
158
- async with self .session_maker () as session :
159
- statement = select (self .user_db_model ).where (
160
- func .lower (self .user_db_model .email ) == func .lower (email )
161
- )
162
- results = await session .execute (statement )
163
- object = results .first ()
164
- if object is None :
165
- return None
166
- return object [0 ]
147
+ statement = select (self .user_db_model ).where (
148
+ func .lower (self .user_db_model .email ) == func .lower (email )
149
+ )
150
+ results = await self .session .execute (statement )
151
+ object = results .first ()
152
+ if object is None :
153
+ return None
154
+ return object [0 ]
167
155
168
156
async def get_by_oauth_account (self , oauth : str , account_id : str ) -> Optional [UD ]:
169
157
"""Get a single user by OAuth account id."""
170
158
if not self .oauth_account_model :
171
159
raise NotSetOAuthAccountTableError ()
172
- async with self .session_maker () as session :
173
- statement = (
174
- select (self .oauth_account_model )
175
- .where (self .oauth_account_model .oauth_name == oauth )
176
- .where (self .oauth_account_model .account_id == account_id )
177
- .options (selectinload (self .oauth_account_model .user )) # type: ignore
178
- )
179
- results = await session .execute (statement )
180
- oauth_account = results .first ()
181
- if oauth_account :
182
- user = oauth_account [0 ].user
183
- return user
184
- return None
160
+ statement = (
161
+ select (self .oauth_account_model )
162
+ .where (self .oauth_account_model .oauth_name == oauth )
163
+ .where (self .oauth_account_model .account_id == account_id )
164
+ .options (selectinload (self .oauth_account_model .user )) # type: ignore
165
+ )
166
+ results = await self .session .execute (statement )
167
+ oauth_account = results .first ()
168
+ if oauth_account :
169
+ user = oauth_account [0 ].user
170
+ return user
171
+ return None
185
172
186
173
async def create (self , user : UD ) -> UD :
187
174
"""Create a user."""
188
- async with self .session_maker () as session :
189
- session .add (user )
190
- if self .oauth_account_model is not None :
191
- for oauth_account in user .oauth_accounts : # type: ignore
192
- session .add (oauth_account )
193
- await session .commit ()
194
- await session .refresh (user )
195
- return user
175
+ self .session .add (user )
176
+ if self .oauth_account_model is not None :
177
+ for oauth_account in user .oauth_accounts : # type: ignore
178
+ self .session .add (oauth_account )
179
+ await self .session .commit ()
180
+ await self .session .refresh (user )
181
+ return user
196
182
197
183
async def update (self , user : UD ) -> UD :
198
184
"""Update a user."""
199
- async with self .session_maker () as session :
200
- session .add (user )
201
- if self .oauth_account_model is not None :
202
- for oauth_account in user .oauth_accounts : # type: ignore
203
- session .add (oauth_account )
204
- await session .commit ()
205
- await session .refresh (user )
206
- return user
185
+ self .session .add (user )
186
+ if self .oauth_account_model is not None :
187
+ for oauth_account in user .oauth_accounts : # type: ignore
188
+ self .session .add (oauth_account )
189
+ await self .session .commit ()
190
+ await self .session .refresh (user )
191
+ return user
207
192
208
193
async def delete (self , user : UD ) -> None :
209
194
"""Delete a user."""
210
- async with self .session_maker () as session :
211
- await session .delete (user )
212
- await session .commit ()
195
+ await self .session .delete (user )
196
+ await self .session .commit ()
0 commit comments