Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add --exclude-artifacts option to exclude artifacts based on glob expression #12

Merged
merged 5 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions aimlflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,19 @@ def cli_entry_point():
writable=True))
@click.option('--mlflow-tracking-uri', required=False, default=None)
@click.option('--experiment', '-e', required=False, default=None)
def sync(aim_repo, mlflow_tracking_uri, experiment):
@click.option('--excluded-artifacts', required=False, default=None)
def sync(aim_repo, mlflow_tracking_uri, experiment, excluded_artifacts):
mihran113 marked this conversation as resolved.
Show resolved Hide resolved
repo_path = clean_repo_path(aim_repo) or Repo.default_repo_path()
repo_inst = Repo.from_path(repo_path)

mlflow_tracking_uri = mlflow_tracking_uri or os.environ.get('MLFLOW_TRACKING_URI')
if not mlflow_tracking_uri:
raise ClickException('MLFlow tracking URI must be provided either through ENV or CLI.')

watcher = MLFlowWatcher(repo_inst, mlflow_tracking_uri, experiment)
watcher = MLFlowWatcher(repo_inst, mlflow_tracking_uri, experiment, excluded_artifacts)

click.echo('Converting existing MLflow logs.')
convert_existing_logs(repo_inst, mlflow_tracking_uri, experiment)
convert_existing_logs(repo_inst, mlflow_tracking_uri, experiment, excluded_artifacts)

click.echo(f'Starting watcher on {mlflow_tracking_uri}.')
watcher.start()
Expand Down
14 changes: 11 additions & 3 deletions aimlflow/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import fnmatch
import click
import mlflow
import json
Expand Down Expand Up @@ -114,7 +115,10 @@ def collect_run_params(aim_run, mlflow_run):
}


def collect_artifacts(aim_run, mlflow_run, mlflow_client):
def collect_artifacts(aim_run, mlflow_run, mlflow_client, excluded_artifacts):
if excluded_artifacts == '*':
return

run_id = mlflow_run.info.run_id

artifacts_cache_key = '_mlflow_artifacts_cache'
Expand All @@ -136,6 +140,10 @@ def collect_artifacts(aim_run, mlflow_run, mlflow_client):
continue
else:
artifacts_cache.append(file_info.path)

if fnmatch.fnmatch(file_info.path, excluded_artifacts):
continue

downloaded_path = mlflow_client.download_artifacts(run_id, file_info.path, dst_path=temp_path)
if file_info.path.endswith(HTML_EXTENSIONS):
if not __html_warning_issued:
Expand Down Expand Up @@ -196,7 +204,7 @@ def collect_metrics(aim_run, mlflow_run, mlflow_client, timestamp=None):
aim_run.track(m.value, step=m.step, name=m.key)


def convert_existing_logs(repo_inst, tracking_uri, experiment=None, no_cache=False):
def convert_existing_logs(repo_inst, tracking_uri, experiment=None, excluded_artifacts=None, no_cache=False):
client = mlflow.tracking.client.MlflowClient(tracking_uri=tracking_uri)

experiments = get_mlflow_experiments(client, experiment)
Expand All @@ -215,7 +223,7 @@ def convert_existing_logs(repo_inst, tracking_uri, experiment=None, no_cache=Fal
collect_metrics(aim_run, run, client)

# Collect artifacts
collect_artifacts(aim_run, run, client)
collect_artifacts(aim_run, run, client, excluded_artifacts)

run_cache.refresh()

Expand Down
4 changes: 3 additions & 1 deletion aimlflow/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self,
repo: 'Repo',
tracking_uri: str,
experiment: str = None,
excluded_artifacts: str = None,
interval: Union[int, float] = WATCH_INTERVAL_DEFAULT,
):

Expand All @@ -38,6 +39,7 @@ def __init__(self,

self._client = MlflowClient(tracking_uri)
self._experiments = get_mlflow_experiments(self._client, experiment)
self._excluded_artifacts = excluded_artifacts
self._repo = repo

self._th_collector = Thread(target=self._watch, daemon=True)
Expand Down Expand Up @@ -77,7 +79,7 @@ def _process_single_run(self, aim_run, mlflow_run):
collect_metrics(aim_run, mlflow_run, self._client, timestamp=self._last_watch_time)

# Collect artifacts
collect_artifacts(aim_run, mlflow_run, self._client)
collect_artifacts(aim_run, mlflow_run, self._client, self._excluded_artifacts)

def _process_runs(self):
watch_started_time = time.time()
Expand Down