Skip to content

Commit

Permalink
Type annotation improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
nsoranzo committed Feb 10, 2025
1 parent e89d747 commit 2837cbe
Show file tree
Hide file tree
Showing 19 changed files with 81 additions and 83 deletions.
10 changes: 7 additions & 3 deletions lib/galaxy/managers/histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)

Expand Down Expand Up @@ -85,6 +86,9 @@
RawTextTerm,
)

if TYPE_CHECKING:
from sqlalchemy.engine import ScalarResult

log = logging.getLogger(__name__)

INDEX_SEARCH_FILTERS = {
Expand All @@ -95,7 +99,7 @@
}


class HistoryManager(sharable.SharableModelManager, deletable.PurgableManagerMixin, SortableManager):
class HistoryManager(sharable.SharableModelManager[model.History], deletable.PurgableManagerMixin, SortableManager):
model_class = model.History
foreign_key_name = "history"
user_share_model = model.HistoryUserShareAssociation
Expand All @@ -120,7 +124,7 @@ def __init__(

def index_query(
self, trans: ProvidesUserContext, payload: HistoryIndexQueryPayload, include_total_count: bool = False
) -> Tuple[List[model.History], int]:
) -> Tuple["ScalarResult"[model.History], Union[int, None]]:
show_deleted = False
show_own = payload.show_own
show_published = payload.show_published
Expand Down Expand Up @@ -234,7 +238,7 @@ def p_tag_filter(term_text: str, quoted: bool):
stmt = stmt.limit(payload.limit)
if payload.offset is not None:
stmt = stmt.offset(payload.offset)
return trans.sa_session.scalars(stmt), total_matches # type:ignore[return-value]
return trans.sa_session.scalars(stmt), total_matches

# .... sharable
# overriding to handle anonymous users' current histories in both cases
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/managers/pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
}


class PageManager(sharable.SharableModelManager, UsesAnnotations):
class PageManager(sharable.SharableModelManager[model.Page], UsesAnnotations):
"""Provides operations for managing a Page."""

model_class = model.Page
Expand Down
5 changes: 4 additions & 1 deletion lib/galaxy/managers/sharable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Optional,
Set,
Type,
TypeVar,
)

from slugify import slugify
Expand Down Expand Up @@ -54,10 +55,12 @@
from galaxy.util.hash_util import md5_hash_str

log = logging.getLogger(__name__)
# Only model classes that have `users_shared_with` field
U = TypeVar("U", model.History, model.Page, model.StoredWorkflow, model.Visualization)


class SharableModelManager(
base.ModelManager,
base.ModelManager[U],
secured.OwnableManagerMixin,
secured.AccessibleManagerMixin,
annotatable.AnnotatableManagerMixin,
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/managers/visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
}


class VisualizationManager(sharable.SharableModelManager):
class VisualizationManager(sharable.SharableModelManager[model.Visualization]):
"""
Handle operations outside and between visualizations and other models.
"""
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/managers/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
}


class WorkflowsManager(sharable.SharableModelManager, deletable.DeletableManagerMixin):
class WorkflowsManager(sharable.SharableModelManager[model.StoredWorkflow], deletable.DeletableManagerMixin):
"""Handle CRUD type operations related to workflows. More interesting
stuff regarding workflow execution, step sorting, etc... can be found in
the galaxy.workflow module.
Expand Down
3 changes: 0 additions & 3 deletions lib/galaxy/model/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from threading import local
from typing import (
Optional,
Type,
TYPE_CHECKING,
)

Expand All @@ -28,8 +27,6 @@
class GalaxyModelMapping(SharedModelMapping):
security_agent: GalaxyRBACAgent
thread_local_log: Optional[local]
User: Type
GalaxySession: Type


def init(
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/model/store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def flush(self) -> None:
def add(self, obj: model.RepresentById) -> None:
self.objects[obj.__class__][obj.id] = obj

def query(self, model_class: model.RepresentById) -> Bunch:
def query(self, model_class: Type[model.RepresentById]) -> Bunch:
def find(obj_id):
return self.objects.get(model_class, {}).get(obj_id) or None

Expand All @@ -243,7 +243,7 @@ def filter_by(*args, **kwargs):

return Bunch(find=find, get=find, filter_by=filter_by)

def get(self, model_class: model.RepresentById, primary_key: Any): # patch for SQLAlchemy 2.0 compatibility
def get(self, model_class: Type[model.RepresentById], primary_key: Any): # patch for SQLAlchemy 2.0 compatibility
return self.query(model_class).get(primary_key)


Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/webapps/galaxy/services/histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def index_query(
payload: HistoryIndexQueryPayload,
serialization_params: SerializationParams,
include_total_count: bool = False,
) -> Tuple[List[AnyHistoryView], int]:
) -> Tuple[List[AnyHistoryView], Union[int, None]]:
"""Return a list of History accessible by the user
:rtype: list
Expand Down
8 changes: 4 additions & 4 deletions scripts/grt/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import galaxy
import galaxy.app
import galaxy.config
from galaxy import model
from galaxy.model.mapping import init_models_from_config
from galaxy.objectstore import build_object_store_from_config
from galaxy.util import (
Expand All @@ -42,9 +43,9 @@ def _init(args):
"The database connection is empty. If you are using the default value, please uncomment that in your galaxy.yml"
)

model = init_models_from_config(config, object_store=object_store)
sa_session = init_models_from_config(config, object_store=object_store).context
return (
model,
sa_session,
object_store,
config,
)
Expand Down Expand Up @@ -120,11 +121,10 @@ def annotate(label, human_label=None):
last_job_sent = -1

annotate("galaxy_init", "Loading Galaxy...")
model, object_store, gxconfig = _init(args)
sa_session, object_store, gxconfig = _init(args)

# Galaxy overrides our logging level.
logging.getLogger().setLevel(getattr(logging, args.loglevel.upper()))
sa_session = model.context
annotate("galaxy_end")

# Fetch jobs COMPLETED with status OK that have not yet been sent.
Expand Down
6 changes: 3 additions & 3 deletions scripts/set_user_disk_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
sys.path.insert(1, os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, "lib")))

import galaxy.config
from galaxy import model
from galaxy.model.mapping import init_models_from_config
from galaxy.objectstore import build_object_store_from_config
from galaxy.util import nice_size
Expand Down Expand Up @@ -41,7 +42,7 @@ def init():
config = galaxy.config.Configuration(**app_properties)
object_store = build_object_store_from_config(config)
engine = config.database_connection.split(":")[0]
return init_models_from_config(config, object_store=object_store), object_store, engine
return init_models_from_config(config, object_store=object_store).context, object_store, engine


def quotacheck(sa_session, users, engine, object_store):
Expand Down Expand Up @@ -69,8 +70,7 @@ def quotacheck(sa_session, users, engine, object_store):

if __name__ == "__main__":
print("Loading Galaxy model...")
model, object_store, engine = init()
sa_session = model.context
sa_session, object_store, engine = init()

if not args.username and not args.email:
user_count = sa_session.query(model.User).count()
Expand Down
7 changes: 3 additions & 4 deletions test/integration/oidc/test_auth_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import ClassVar
from urllib import parse

from galaxy import model
from galaxy.util import requests
from galaxy_test.base.api import ApiTestInteractor
from galaxy_test.driver import integration_util
Expand Down Expand Up @@ -205,8 +206,7 @@ def test_oidc_login_new_user(self):
def test_oidc_login_existing_user(self):
# pre-create a user account manually
sa_session = self._app.model.session
User = self._app.model.User
user = User(email="[email protected]", username="precreated_user")
user = model.User(email="[email protected]", username="precreated_user")
user.set_password_cleartext("test123")
sa_session.add(user)
try:
Expand All @@ -230,8 +230,7 @@ def test_oidc_login_existing_user(self):
def test_oidc_login_account_linkup(self):
# pre-create a user account manually
sa_session = self._app.model.session
User = self._app.model.User
user = User(email="[email protected]", username="precreated_user")
user = model.User(email="[email protected]", username="precreated_user")
user.set_password_cleartext("test123")
sa_session.add(user)
try:
Expand Down
2 changes: 2 additions & 0 deletions test/integration/test_page_revision_json_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_page_encoding(self, history_id: str):
api_asserts.assert_status_code_is_ok(page_response)
sa_session = self._app.model.session
page_revision = sa_session.scalars(select(model.PageRevision).filter_by(content_format="html")).all()[0]
assert page_revision.content is not None
assert history_num_re.search(page_revision.content), page_revision.content
assert f'''id="History-{history_id}"''' not in page_revision.content, page_revision.content

Expand All @@ -59,6 +60,7 @@ def test_page_encoding_markdown(self, history_id: str):
api_asserts.assert_status_code_is_ok(page_response)
sa_session = self._app.model.session
page_revision = sa_session.scalars(select(model.PageRevision).filter_by(content_format="markdown")).all()[0]
assert page_revision.content is not None
assert (
"""```galaxy
history_dataset_display(history_dataset_id=1)
Expand Down
1 change: 1 addition & 0 deletions test/integration/test_remote_files_posix.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def test_links_by_default(self):
assert content == "a\n", content
stmt = select(Dataset).order_by(Dataset.create_time.desc()).limit(1)
dataset = self._app.model.session.execute(stmt).unique().scalar_one()
assert dataset.external_filename is not None
assert dataset.external_filename.endswith("/root/a")
assert os.path.exists(dataset.external_filename)
assert open(dataset.external_filename).read() == "a\n"
Expand Down
4 changes: 3 additions & 1 deletion test/integration/test_workflow_handler_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time
from json import dumps

from galaxy import model
from galaxy_test.base.populators import (
DatasetPopulator,
WorkflowPopulator,
Expand Down Expand Up @@ -127,7 +128,8 @@ def _get_workflow_invocations(self, history_id: str):
# into Galaxy's internal state.
app = self._app
history_id = app.security.decode_id(history_id)
history = app.model.session.get(app.model.History, history_id)
history = app.model.session.get(model.History, history_id)
assert history is not None
workflow_invocations = history.workflow_invocations
return workflow_invocations

Expand Down
2 changes: 1 addition & 1 deletion test/unit/app/managers/test_user_file_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ def _assert_secret_absent(self, user_file_source: UserFileSourceModel, secret_na
assert sec_val in ["", None]

def _assert_modify_throws_exception(
self, user_file_source: UserFileSourceModel, modify: ModifyInstancePayload, exception_type: Type
self, user_file_source: UserFileSourceModel, modify: ModifyInstancePayload, exception_type: Type[Exception]
):
exception_thrown = False
try:
Expand Down
5 changes: 4 additions & 1 deletion test/unit/app/managers/test_user_object_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,10 @@ def _init_managers(self, tmp_path, config_dict=None):
self.manager = manager

def _assert_modify_throws_exception(
self, user_object_store: UserConcreteObjectStoreModel, modify: ModifyInstancePayload, exception_type: Type
self,
user_object_store: UserConcreteObjectStoreModel,
modify: ModifyInstancePayload,
exception_type: Type[Exception],
):
exception_thrown = False
try:
Expand Down
2 changes: 2 additions & 0 deletions test/unit/data/model/test_model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from galaxy import model
from galaxy.model import store
from galaxy.model.metadata import MetadataTempFile
from galaxy.model.store import SessionlessContext
from galaxy.model.unittest_utils import GalaxyDataTestApp
from galaxy.model.unittest_utils.store_fixtures import (
deferred_hda_model_store_dict,
Expand Down Expand Up @@ -922,6 +923,7 @@ def test_sessionless_import_edit_datasets():
import_model_store.perform_import()
# Not using app.sa_session but a session mock that has a query/find pattern emulating usage
# of real sa_session.
assert isinstance(import_model_store.sa_session, SessionlessContext)
d1 = import_model_store.sa_session.query(model.HistoryDatasetAssociation).find(h.datasets[0].id)
d2 = import_model_store.sa_session.query(model.HistoryDatasetAssociation).find(h.datasets[1].id)
assert d1 is not None
Expand Down
Loading

0 comments on commit 2837cbe

Please sign in to comment.