Skip to content

Commit

Permalink
fix: Fix for SQL registry initialization fails #4543
Browse files Browse the repository at this point in the history
Signed-off-by: Bhargav Dodla <[email protected]>
  • Loading branch information
Bhargav Dodla committed Sep 19, 2024
1 parent 1b92803 commit cf8387d
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 17 deletions.
38 changes: 23 additions & 15 deletions sdk/python/feast/infra/registry/caching_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from feast.permissions.permission import Permission
from feast.project import Project
from feast.project_metadata import ProjectMetadata
from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto
from feast.saved_dataset import SavedDataset, ValidationReference
from feast.stream_feature_view import StreamFeatureView
from feast.utils import _utc_now
Expand All @@ -28,13 +29,14 @@

class CachingRegistry(BaseRegistry):
def __init__(self, project: str, cache_ttl_seconds: int, cache_mode: str):
self.cached_registry_proto = self.proto()
self.cached_registry_proto_created = _utc_now()
self.cache_mode = cache_mode
self.cached_registry_proto = RegistryProto()
self._refresh_lock = Lock()
self.cached_registry_proto_ttl = timedelta(
seconds=cache_ttl_seconds if cache_ttl_seconds is not None else 0
)
self.cache_mode = cache_mode
self.cached_registry_proto = self.proto()
self.cached_registry_proto_created = _utc_now()
if cache_mode == "thread":
self._start_thread_async_refresh(cache_ttl_seconds)
atexit.register(self._exit_handler)
Expand Down Expand Up @@ -429,20 +431,26 @@ def refresh(self, project: Optional[str] = None):
def _refresh_cached_registry_if_necessary(self):
if self.cache_mode == "sync":
with self._refresh_lock:
expired = (
self.cached_registry_proto is None
or self.cached_registry_proto_created is None
) or (
self.cached_registry_proto_ttl.total_seconds()
> 0 # 0 ttl means infinity
and (
_utc_now()
> (
self.cached_registry_proto_created
+ self.cached_registry_proto_ttl
if self.cached_registry_proto == RegistryProto():
# Avoids the need to refresh the registry when cache is not populated yet
# Specially during the __init__ phase
# proto() will populate the cache with project metadata if no objects are registered
expired = False
else:
expired = (
self.cached_registry_proto is None
or self.cached_registry_proto_created is None
) or (
self.cached_registry_proto_ttl.total_seconds()
> 0 # 0 ttl means infinity
and (
_utc_now()
> (
self.cached_registry_proto_created
+ self.cached_registry_proto_ttl
)
)
)
)
if expired:
logger.info("Registry cache expired, so refreshing")
self.refresh()
Expand Down
6 changes: 4 additions & 2 deletions sdk/python/feast/infra/registry/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ def __init__(
registry_config, SqlRegistryConfig
), "SqlRegistry needs a valid registry_config"

self.registry_config = registry_config

self.write_engine: Engine = create_engine(
registry_config.path, **registry_config.sqlalchemy_config_kwargs
)
Expand Down Expand Up @@ -281,7 +283,7 @@ def __init__(
def _sync_feast_metadata_to_projects_table(self):
feast_metadata_projects: set = []
projects_set: set = []
with self.write_engine.begin() as conn:
with self.read_engine.begin() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.metadata_key == FeastMetadataKeys.PROJECT_UUID.value
)
Expand All @@ -290,7 +292,7 @@ def _sync_feast_metadata_to_projects_table(self):
feast_metadata_projects.append(row._mapping["project_id"])

if len(feast_metadata_projects) > 0:
with self.write_engine.begin() as conn:
with self.read_engine.begin() as conn:
stmt = select(projects)
rows = conn.execute(stmt).all()
for row in rows:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1767,3 +1767,92 @@ def test_apply_entity_success_with_purge_feast_metadata(test_registry):
assert len(entities) == 0

test_registry.teardown()


combined_sql_fixtures = [
pytest.param(
lazy_fixture("pg_registry"), marks=pytest.mark.xdist_group(name="pg_registry")
),
pytest.param(
lazy_fixture("mysql_registry"),
marks=pytest.mark.xdist_group(name="mysql_registry"),
),
lazy_fixture("sqlite_registry"),
pytest.param(
lazy_fixture("pg_registry_async"),
marks=pytest.mark.xdist_group(name="pg_registry"),
),
pytest.param(
lazy_fixture("mysql_registry_async"),
marks=pytest.mark.xdist_group(name="mysql_registry"),
),
pytest.param(
lazy_fixture("pg_registry_purge_feast_metadata"),
marks=pytest.mark.xdist_group(name="pg_registry"),
),
pytest.param(
lazy_fixture("mysql_registry_purge_feast_metadata"),
marks=pytest.mark.xdist_group(name="mysql_registry"),
),
]


@pytest.mark.integration
@pytest.mark.parametrize(
"test_registry",
combined_sql_fixtures,
)
def test_apply_entity_to_sql_registry_and_reinitialize_sql_registry(test_registry):
entity = Entity(
name="driver_car_id",
description="Car driver id",
tags={"team": "matchmaking"},
)

project = "project"

# Register Entity
test_registry.apply_entity(entity, project)
assert_project(project, test_registry)

entities = test_registry.list_entities(project, tags=entity.tags)
assert_project(project, test_registry)

entity = entities[0]
assert (
len(entities) == 1
and entity.name == "driver_car_id"
and entity.description == "Car driver id"
and "team" in entity.tags
and entity.tags["team"] == "matchmaking"
)

entity = test_registry.get_entity("driver_car_id", project)
assert (
entity.name == "driver_car_id"
and entity.description == "Car driver id"
and "team" in entity.tags
and entity.tags["team"] == "matchmaking"
)

# After the first apply, the created_timestamp should be the same as the last_update_timestamp.
assert entity.created_timestamp == entity.last_updated_timestamp
updated_test_registry = SqlRegistry(test_registry.registry_config, "project", None)

# Update entity
updated_entity = Entity(
name="driver_car_id",
description="Car driver Id",
tags={"team": "matchmaking"},
)
updated_test_registry.apply_entity(updated_entity, project)

updated_entity = updated_test_registry.get_entity("driver_car_id", project)
updated_test_registry.delete_entity("driver_car_id", project)
assert_project(project, updated_test_registry)
entities = updated_test_registry.list_entities(project)
assert_project(project, updated_test_registry)
assert len(entities) == 0

updated_test_registry.teardown()
test_registry.teardown()

0 comments on commit cf8387d

Please sign in to comment.