-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Implement cache for predictions #1334
Draft
DRMPN
wants to merge
28
commits into
master
Choose a base branch
from
DRMPN-better-caching
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 6 commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
4abcae2
refactor: rename pipeline cache to operation cache
DRMPN e2f76d8
chore: Add test-cache.py for benchmarking and debugging purposes
DRMPN fcab118
chore: add TODOs to insert the data_cache functionality
DRMPN 5248af3
feat: Add DataCache class for storing and loading predictions
DRMPN ef90a68
feat: Add DataCacheDB class for caching predicted output using a rela…
DRMPN 77e47ca
chore: add TODO to save the predictions
DRMPN cfaef4f
feat: change the logic to save the entire OutputData instead
DRMPN 612b40e
feat: get/put pickled OutputData into SQL table
DRMPN 8b68240
chore: modify test script to use generated dataset
DRMPN e86ec6a
chore: modify error message
DRMPN 064ab27
feat: pass data_cache parameter down to store a prediction in DB
DRMPN 6edef1f
feat: test access to the stored data
DRMPN 5328fe9
chore: remove old .pyc files
DRMPN a9624d1
Merge remote-tracking branch 'origin/master' into DRMPN-better-caching
DRMPN 20a401d
chore: add comment to remove redundant param
DRMPN 0ed1309
fix: take blob column instead of str
DRMPN 64734b3
fix: generate better dataset
DRMPN 662c248
feat: load predicted data from cache to calculate loss function
DRMPN 8a5b424
chore: decrease timeout for test script
DRMPN 9fe700e
feat: add cache for pipeline metrics
DRMPN 12b96ad
feat: add intermediate metrics' cache
DRMPN d75b8d9
feat: add fit/predict cache for a single node
DRMPN 1f4980b
feat: add cache effectiveness metric
DRMPN a4718a4
feat: save cache effectiveness to csv file
DRMPN 33e3234
feat: extract metrics cache to dictionary
DRMPN fa08705
fix: turn on prediction cache
DRMPN bab641f
chore: turn off fit cache
DRMPN 54f2232
fix: check and grab metric's cache before fit
DRMPN File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import sqlite3 | ||
from typing import TYPE_CHECKING, List, Optional, Union | ||
|
||
import numpy as np | ||
|
||
from fedot.core.caching.base_cache import BaseCache | ||
from fedot.core.caching.data_cache_db import DataCacheDB | ||
from fedot.core.data.data import OutputData | ||
|
||
if TYPE_CHECKING: | ||
from fedot.core.pipelines.pipeline import Pipeline | ||
|
||
|
||
class DataCache(BaseCache): | ||
""" | ||
Stores/loads predictions to increase performance of calculations. | ||
|
||
:param cache_dir: path to the place where cache files should be stored. | ||
""" | ||
|
||
def __init__(self, cache_dir: Optional[str] = None, custom_pid=None): | ||
super().__init__(DataCacheDB(cache_dir, custom_pid)) | ||
|
||
def save_prediction(self, prediction: np.ndarray, uid: str): | ||
""" | ||
Save the prediction for a given UID. | ||
|
||
:param prediction (np.ndarray): The prediction to be saved. | ||
:param uid (str): The unique identifier for the prediction. | ||
""" | ||
try: | ||
self._db.add_prediction([(uid, prediction)]) | ||
except Exception as ex: | ||
unexpected_exc = not ( | ||
isinstance(ex, sqlite3.DatabaseError) and "disk is full" in str(ex) | ||
) | ||
self.log.warning( | ||
f"Predictions can not be saved: {ex}. Continue", | ||
exc=ex, | ||
raise_if_test=unexpected_exc, | ||
) | ||
|
||
def load_prediction(self, uid: str) -> np.ndarray: | ||
""" | ||
Load the prediction data for the given unique identifier. | ||
:param uid (str): The unique identifier of the prediction data. | ||
:return np.ndarray: The loaded prediction data. | ||
""" | ||
predict = self._db.get_prediction(uid) | ||
# TODO: restore OutputData from predict | ||
return predict | ||
|
||
def save_data( | ||
self, | ||
pipeline: "Pipeline", | ||
outputData: OutputData, | ||
fold_id: Optional[int] = None, | ||
): | ||
""" | ||
Save the pipeline data to the cache. | ||
|
||
:param pipeline: The pipeline data to be cached. | ||
:type pipeline: Pipeline | ||
:param outputData: The output data to be saved. | ||
:type outputData: OutputData | ||
:param fold_id: Optional part of the cache item UID (can be used to specify the number of CV fold). | ||
:type fold_id: Optional[int] | ||
""" | ||
uid = self._create_uid(pipeline, fold_id) | ||
# TODO: save OutputData as a whole to the cache | ||
self.save_prediction(outputData.predict, uid) | ||
|
||
def try_load_data( | ||
self, pipeline: "Pipeline", fold_id: Optional[int] = None | ||
) -> OutputData: | ||
# create parameter dosctring | ||
""" | ||
Try to load data for the given pipeline and fold ID. | ||
|
||
:param pipeline (Pipeline): The pipeline for which to load the data. | ||
:param fold_id (Optional[int]): The fold ID for which to load the data. Defaults to None. | ||
:return OutputData: The loaded data. | ||
""" | ||
# TODO: implement loading of pipeline data | ||
uid = self._create_uid(pipeline, fold_id) | ||
self.load_prediction(uid) | ||
|
||
def _create_uid( | ||
self, | ||
pipeline: "Pipeline", | ||
fold_id: Optional[int] = None, | ||
) -> str: | ||
""" | ||
Generate a unique identifier for a pipeline. | ||
|
||
:param pipeline (Pipeline): The pipeline for which the unique identifier is generated. | ||
:param fold_id (Optional[int]): The fold ID (default: None). | ||
:return str: The unique identifier generated for the pipeline. | ||
""" | ||
base_uid = "" | ||
for node in pipeline.nodes: | ||
base_uid += f"{node.descriptive_id}_" | ||
if fold_id is not None: | ||
base_uid += f"{fold_id}" | ||
return base_uid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import pickle | ||
import sqlite3 | ||
from contextlib import closing | ||
from os import getpid | ||
from typing import List, Optional, Tuple, TypeVar | ||
|
||
import numpy as np | ||
|
||
from fedot.core.caching.base_cache_db import BaseCacheDB | ||
|
||
|
||
class DataCacheDB(BaseCacheDB): | ||
""" | ||
Database for `DataCache` class. | ||
Includes low-level idea of caching predicted output using relational database. | ||
|
||
:param cache_dir: path to the place where cache files should be stored. | ||
""" | ||
|
||
def __init__(self, cache_dir: Optional[str] = None, custom_pid=None): | ||
super().__init__("prediction", cache_dir) | ||
self._init_db() | ||
|
||
def add_prediction(self, uid_val_lst: List[Tuple[str, np.ndarray]]): | ||
""" | ||
Adds operation score to DB table via its uid | ||
|
||
:param uid_val_lst: list of pairs (uid -> prediction) to be saved | ||
""" | ||
try: | ||
with closing(sqlite3.connect(self.db_path)) as conn: | ||
with conn: | ||
cur = conn.cursor() | ||
pickled = [ | ||
( | ||
uid, | ||
sqlite3.Binary(pickle.dumps(val, pickle.HIGHEST_PROTOCOL)), | ||
) | ||
for uid, val in uid_val_lst | ||
] | ||
cur.executemany( | ||
f"INSERT OR IGNORE INTO {self._main_table} VALUES (?, ?);", | ||
pickled, | ||
) | ||
except sqlite3.Error as e: | ||
print(f"SQLite error: {e}") | ||
|
||
def get_prediction(self, uids: List[str]) -> List[Optional[np.ndarray]]: | ||
""" | ||
Maps given uids to operations from DB and puts None if is not present. | ||
|
||
:param uids: list of operations uids to be mapped | ||
|
||
:return retrieved: list of operations taken from DB table with None where it wasn't present | ||
""" | ||
try: | ||
with closing(sqlite3.connect(self.db_path)) as conn: | ||
with conn: | ||
cur = conn.cursor() | ||
placeholders = ",".join("?" for _ in uids) | ||
query = ( | ||
f"SELECT id, prediction FROM {self._main_table} " | ||
f"WHERE id IN ({placeholders})" | ||
) | ||
cur.execute(query, uids) | ||
results = {row[0]: pickle.loads(row[1]) for row in cur.fetchall()} | ||
retrieved = [results.get(uid) for uid in uids] | ||
return retrieved | ||
except sqlite3.Error as e: | ||
print(f"SQLite error: {e}") | ||
return [None] * len(uids) | ||
|
||
def _init_db(self): | ||
""" | ||
Initializes DB working table. | ||
""" | ||
try: | ||
with closing(sqlite3.connect(self.db_path)) as conn: | ||
with conn: | ||
cur = conn.cursor() | ||
cur.execute( | ||
( | ||
f"CREATE TABLE IF NOT EXISTS {self._main_table} (" | ||
"id TEXT PRIMARY KEY," | ||
"prediction BLOB" | ||
");" | ||
) | ||
) | ||
except sqlite3.Error as e: | ||
print(f"SQLite error: {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Почему бы не реализовать
DataCacheDB
как синглтон, подключаясь к БД один раз при инициализации (реинициализации в новых инстансах питона в многопотоке)?