Skip to content

Commit

Permalink
Update database SQLAlchemy ORM declarations
Browse files Browse the repository at this point in the history
  • Loading branch information
vjf committed Aug 30, 2024
1 parent cd08668 commit d8d813d
Showing 1 changed file with 45 additions and 41 deletions.
86 changes: 45 additions & 41 deletions scripts/lib/db/db_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
Uses 'sqlalchemy' library to create a simple 'sqlite' db to hold query results for models
"""
from sqlalchemy import create_engine
from sqlalchemy import Column, Integer, String, Boolean
from sqlalchemy import Integer
from sqlalchemy import select, func
from sqlalchemy.orm import sessionmaker, relationship, scoped_session
from sqlalchemy.orm import declarative_base
from sqlalchemy.schema import ForeignKey, PrimaryKeyConstraint, MetaData
from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped
from sqlalchemy.schema import ForeignKey, MetaData, PrimaryKeyConstraint
from sqlalchemy.exc import DatabaseError

LOGGER = logging.getLogger(__name__)
Expand All @@ -23,7 +24,9 @@

QUERY_DB_FILE = 'query_data.db'

Base = declarative_base()
# Declarative Base class
class Base(DeclarativeBase):
pass

# pylint: disable=R0903
class Query(Base):
Expand All @@ -41,20 +44,22 @@ class Query(Base):
'''
__tablename__ = "query"

model_name = Column(String)
label = Column(String)
segment_info_id = Column(Integer, ForeignKey("segment_info.id"))
part_info_id = Column(Integer, ForeignKey("part_info.id"))
model_info_id = Column(Integer, ForeignKey("model_info.id"))
user_info_id = Column(Integer, ForeignKey("user_info.id"))
model_name: Mapped[str]
label: Mapped[str]

segment_info = relationship('SegmentInfo')
part_info = relationship('PartInfo')
model_info = relationship('ModelInfo')
user_info = relationship('UserInfo')
segment_info_id = mapped_column(Integer, ForeignKey("segment_info.id"))
part_info_id = mapped_column(Integer, ForeignKey("part_info.id"))
model_info_id = mapped_column(Integer, ForeignKey("model_info.id"))
user_info_id = mapped_column(Integer, ForeignKey("user_info.id"))

segment_info = relationship("SegmentInfo", foreign_keys=[segment_info_id])
part_info = relationship('PartInfo', foreign_keys=[part_info_id])
model_info = relationship('ModelInfo', foreign_keys=[model_info_id])
user_info = relationship('UserInfo', foreign_keys=[user_info_id])

__table_args__ = (PrimaryKeyConstraint('model_name', 'label', name='_query_uc'),)


def __repr__(self):
result = "Query:" + \
"\n model_name={0}".format(self.model_name) + \
Expand All @@ -74,8 +79,8 @@ class SegmentInfo(Base):
'''
__tablename__ = "segment_info"

id = Column(Integer, primary_key=True, autoincrement=True)
json = Column(String, unique=True)
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
json: Mapped[str] = mapped_column(unique=True)

def __repr__(self):
return "{1}: json={0}\n".format(self.json, self.__class__.__name__)
Expand All @@ -90,8 +95,8 @@ class PartInfo(Base):
'''
__tablename__ = "part_info"

id = Column(Integer, primary_key=True, autoincrement=True)
json = Column(String, unique=True)
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
json: Mapped[str] = mapped_column(unique=True)

def __repr__(self):
return "{1}: json={0}\n".format(self.json, self.__class__.__name__)
Expand All @@ -105,8 +110,8 @@ class ModelInfo(Base):
'''
__tablename__ = "model_info"

id = Column(Integer, primary_key=True, autoincrement=True)
json = Column(String, unique=True)
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
json: Mapped[str] = mapped_column(unique=True)

def __repr__(self):
return "{1}: json={0}\n".format(self.json, self.__class__.__name__)
Expand All @@ -120,8 +125,8 @@ class UserInfo(Base):
'''
__tablename__ = "user_info"

id = Column(Integer, primary_key=True, autoincrement=True)
json = Column(String, unique=True)
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
json: Mapped[str] = mapped_column(unique=True)

def __repr__(self):
return "{1}: json={0}\n".format(self.json, self.__class__.__name__)
Expand All @@ -134,10 +139,10 @@ class KeyValuePairs(Base):
'''
__tablename__ = "keyvaluepairs"

id = Column(Integer, primary_key=True, autoincrement=True)
key = Column(String)
value = Column(String, nullable=False)
is_url = Column(Boolean)
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
key: Mapped[str]
value: Mapped[str] = mapped_column(nullable=False)
is_url: Mapped[bool]

#
class QueryDB():
Expand Down Expand Up @@ -184,7 +189,7 @@ def add_segment(self, json_str):
self.ses.add(seginfo_obj)
self.ses.commit()
return True, seginfo_obj
seginfo_obj = self.ses.query(SegmentInfo).filter_by(json=json_str).first()
seginfo_obj = self.ses.scalars(select(SegmentInfo).filter_by(json=json_str).limit(1)).first()
if seginfo_obj is None:
seginfo_obj = SegmentInfo(json=json_str)
self.ses.add(seginfo_obj)
Expand All @@ -207,7 +212,7 @@ def add_part(self, json_str):
self.ses.add(part_obj)
self.ses.commit()
return True, part_obj
part_obj = self.ses.query(PartInfo).filter_by(json=json_str).first()
part_obj = self.ses.scalars(select(PartInfo).filter_by(json=json_str).limit(1)).first()
if part_obj is None:
part_obj = PartInfo(json=json_str)
self.ses.add(part_obj)
Expand All @@ -230,7 +235,7 @@ def add_model(self, json_str):
self.ses.add(model_obj)
self.ses.commit()
return True, model_obj
model_obj = self.ses.query(ModelInfo).filter_by(json=json_str).first()
model_obj = self.ses.scalars(select(ModelInfo).filter_by(json=json_str).limit(1)).first()
if model_obj is None:
model_obj = ModelInfo(json=json_str)
self.ses.add(model_obj)
Expand All @@ -253,7 +258,7 @@ def add_user(self, json_str):
self.ses.add(userinfo_obj)
self.ses.commit()
return True, userinfo_obj
userinfo_obj = self.ses.query(UserInfo).filter_by(json=json_str).first()
userinfo_obj = self.ses.scalars(select(UserInfo).filter_by(json=json_str).limit(1)).first()
if userinfo_obj is None:
userinfo_obj = UserInfo(json=json_str)
self.ses.add(userinfo_obj)
Expand Down Expand Up @@ -290,15 +295,15 @@ def query(self, label, model_name):
else (False, exception string)
"""
try:
result = self.ses.query(Query).filter_by(label=label) \
.filter_by(model_name=model_name).first()
result = self.ses.scalars(select(Query).filter_by(label=label) \
.filter_by(model_name=model_name).limit(1)).first()
except DatabaseError as db_exc:
return False, str(db_exc)
if result is None:
filter_str = label.rpartition('_')[0]
try:
result = self.ses.query(Query).filter_by(model_name=model_name) \
.filter_by(label=filter_str).first()
result = self.ses.scalars(select(Query).filter_by(model_name=model_name) \
.filter_by(label=filter_str).limit(1)).first()
except DatabaseError as db_exc:
return False, str(db_exc)
if result is None:
Expand Down Expand Up @@ -333,8 +338,7 @@ def __del__(self):
assert S2 is not None

# Test for no duplicates
Q = QUERY_DB.ses.query(SegmentInfo)
assert Q.count() == 1
assert QUERY_DB.ses.scalar(select(func.count(SegmentInfo.id))) == 1

OK, S3 = QUERY_DB.add_segment('seg3')
assert OK
Expand All @@ -356,11 +360,11 @@ def __del__(self):
assert OK

# Have added three 'Query' objs? two 'Segment_Info' objs ? etc.
assert QUERY_DB.ses.query(Query).count() == 3
assert QUERY_DB.ses.query(SegmentInfo).count() == 2
assert QUERY_DB.ses.query(PartInfo).count() == 1
assert QUERY_DB.ses.query(ModelInfo).count() == 1
assert QUERY_DB.ses.query(UserInfo).count() == 1
assert QUERY_DB.ses.scalar(select(func.count(Query.model_name))) == 3
assert QUERY_DB.ses.scalar(select(func.count(SegmentInfo.id))) == 2
assert QUERY_DB.ses.scalar(select(func.count(PartInfo.id))) == 1
assert QUERY_DB.ses.scalar(select(func.count(ModelInfo.id))) == 1
assert QUERY_DB.ses.scalar(select(func.count(UserInfo.id))) == 1

# Look for a 'Query' with all info tables
OK, Q1 = QUERY_DB.query('label2', 'model_name2')
Expand Down

0 comments on commit d8d813d

Please sign in to comment.