diff --git a/scripts/lib/db/db_tables.py b/scripts/lib/db/db_tables.py index f3ee885..43e800e 100755 --- a/scripts/lib/db/db_tables.py +++ b/scripts/lib/db/db_tables.py @@ -8,10 +8,11 @@ Uses 'sqlalchemy' library to create a simple 'sqlite' db to hold query results for models """ from sqlalchemy import create_engine -from sqlalchemy import Column, Integer, String, Boolean +from sqlalchemy import Integer +from sqlalchemy import select, func from sqlalchemy.orm import sessionmaker, relationship, scoped_session -from sqlalchemy.orm import declarative_base -from sqlalchemy.schema import ForeignKey, PrimaryKeyConstraint, MetaData +from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped +from sqlalchemy.schema import ForeignKey, MetaData, PrimaryKeyConstraint from sqlalchemy.exc import DatabaseError LOGGER = logging.getLogger(__name__) @@ -23,7 +24,9 @@ QUERY_DB_FILE = 'query_data.db' -Base = declarative_base() +# Declarative Base class +class Base(DeclarativeBase): + pass # pylint: disable=R0903 class Query(Base): @@ -41,20 +44,22 @@ class Query(Base): ''' __tablename__ = "query" - model_name = Column(String) - label = Column(String) - segment_info_id = Column(Integer, ForeignKey("segment_info.id")) - part_info_id = Column(Integer, ForeignKey("part_info.id")) - model_info_id = Column(Integer, ForeignKey("model_info.id")) - user_info_id = Column(Integer, ForeignKey("user_info.id")) + model_name: Mapped[str] + label: Mapped[str] - segment_info = relationship('SegmentInfo') - part_info = relationship('PartInfo') - model_info = relationship('ModelInfo') - user_info = relationship('UserInfo') + segment_info_id = mapped_column(Integer, ForeignKey("segment_info.id")) + part_info_id = mapped_column(Integer, ForeignKey("part_info.id")) + model_info_id = mapped_column(Integer, ForeignKey("model_info.id")) + user_info_id = mapped_column(Integer, ForeignKey("user_info.id")) + + segment_info = relationship("SegmentInfo", foreign_keys=[segment_info_id]) + part_info = relationship('PartInfo', foreign_keys=[part_info_id]) + model_info = relationship('ModelInfo', foreign_keys=[model_info_id]) + user_info = relationship('UserInfo', foreign_keys=[user_info_id]) __table_args__ = (PrimaryKeyConstraint('model_name', 'label', name='_query_uc'),) + def __repr__(self): result = "Query:" + \ "\n model_name={0}".format(self.model_name) + \ @@ -74,8 +79,8 @@ class SegmentInfo(Base): ''' __tablename__ = "segment_info" - id = Column(Integer, primary_key=True, autoincrement=True) - json = Column(String, unique=True) + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + json: Mapped[str] = mapped_column(unique=True) def __repr__(self): return "{1}: json={0}\n".format(self.json, self.__class__.__name__) @@ -90,8 +95,8 @@ class PartInfo(Base): ''' __tablename__ = "part_info" - id = Column(Integer, primary_key=True, autoincrement=True) - json = Column(String, unique=True) + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + json: Mapped[str] = mapped_column(unique=True) def __repr__(self): return "{1}: json={0}\n".format(self.json, self.__class__.__name__) @@ -105,8 +110,8 @@ class ModelInfo(Base): ''' __tablename__ = "model_info" - id = Column(Integer, primary_key=True, autoincrement=True) - json = Column(String, unique=True) + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + json: Mapped[str] = mapped_column(unique=True) def __repr__(self): return "{1}: json={0}\n".format(self.json, self.__class__.__name__) @@ -120,8 +125,8 @@ class UserInfo(Base): ''' __tablename__ = "user_info" - id = Column(Integer, primary_key=True, autoincrement=True) - json = Column(String, unique=True) + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + json: Mapped[str] = mapped_column(unique=True) def __repr__(self): return "{1}: json={0}\n".format(self.json, self.__class__.__name__) @@ -134,10 +139,10 @@ class KeyValuePairs(Base): ''' __tablename__ = "keyvaluepairs" - id = Column(Integer, primary_key=True, autoincrement=True) - key = Column(String) - value = Column(String, nullable=False) - is_url = Column(Boolean) + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + key: Mapped[str] + value: Mapped[str] = mapped_column(nullable=False) + is_url: Mapped[bool] # class QueryDB(): @@ -184,7 +189,7 @@ def add_segment(self, json_str): self.ses.add(seginfo_obj) self.ses.commit() return True, seginfo_obj - seginfo_obj = self.ses.query(SegmentInfo).filter_by(json=json_str).first() + seginfo_obj = self.ses.scalars(select(SegmentInfo).filter_by(json=json_str).limit(1)).first() if seginfo_obj is None: seginfo_obj = SegmentInfo(json=json_str) self.ses.add(seginfo_obj) @@ -207,7 +212,7 @@ def add_part(self, json_str): self.ses.add(part_obj) self.ses.commit() return True, part_obj - part_obj = self.ses.query(PartInfo).filter_by(json=json_str).first() + part_obj = self.ses.scalars(select(PartInfo).filter_by(json=json_str).limit(1)).first() if part_obj is None: part_obj = PartInfo(json=json_str) self.ses.add(part_obj) @@ -230,7 +235,7 @@ def add_model(self, json_str): self.ses.add(model_obj) self.ses.commit() return True, model_obj - model_obj = self.ses.query(ModelInfo).filter_by(json=json_str).first() + model_obj = self.ses.scalars(select(ModelInfo).filter_by(json=json_str).limit(1)).first() if model_obj is None: model_obj = ModelInfo(json=json_str) self.ses.add(model_obj) @@ -253,7 +258,7 @@ def add_user(self, json_str): self.ses.add(userinfo_obj) self.ses.commit() return True, userinfo_obj - userinfo_obj = self.ses.query(UserInfo).filter_by(json=json_str).first() + userinfo_obj = self.ses.scalars(select(UserInfo).filter_by(json=json_str).limit(1)).first() if userinfo_obj is None: userinfo_obj = UserInfo(json=json_str) self.ses.add(userinfo_obj) @@ -290,15 +295,15 @@ def query(self, label, model_name): else (False, exception string) """ try: - result = self.ses.query(Query).filter_by(label=label) \ - .filter_by(model_name=model_name).first() + result = self.ses.scalars(select(Query).filter_by(label=label) \ + .filter_by(model_name=model_name).limit(1)).first() except DatabaseError as db_exc: return False, str(db_exc) if result is None: filter_str = label.rpartition('_')[0] try: - result = self.ses.query(Query).filter_by(model_name=model_name) \ - .filter_by(label=filter_str).first() + result = self.ses.scalars(select(Query).filter_by(model_name=model_name) \ + .filter_by(label=filter_str).limit(1)).first() except DatabaseError as db_exc: return False, str(db_exc) if result is None: @@ -333,8 +338,7 @@ def __del__(self): assert S2 is not None # Test for no duplicates - Q = QUERY_DB.ses.query(SegmentInfo) - assert Q.count() == 1 + assert QUERY_DB.ses.scalar(select(func.count(SegmentInfo.id))) == 1 OK, S3 = QUERY_DB.add_segment('seg3') assert OK @@ -356,11 +360,11 @@ def __del__(self): assert OK # Have added three 'Query' objs? two 'Segment_Info' objs ? etc. - assert QUERY_DB.ses.query(Query).count() == 3 - assert QUERY_DB.ses.query(SegmentInfo).count() == 2 - assert QUERY_DB.ses.query(PartInfo).count() == 1 - assert QUERY_DB.ses.query(ModelInfo).count() == 1 - assert QUERY_DB.ses.query(UserInfo).count() == 1 + assert QUERY_DB.ses.scalar(select(func.count(Query.model_name))) == 3 + assert QUERY_DB.ses.scalar(select(func.count(SegmentInfo.id))) == 2 + assert QUERY_DB.ses.scalar(select(func.count(PartInfo.id))) == 1 + assert QUERY_DB.ses.scalar(select(func.count(ModelInfo.id))) == 1 + assert QUERY_DB.ses.scalar(select(func.count(UserInfo.id))) == 1 # Look for a 'Query' with all info tables OK, Q1 = QUERY_DB.query('label2', 'model_name2')