Skip to content

Commit

Permalink
Pass results directory to the jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
Wiebke committed Mar 13, 2024
1 parent 431ddd6 commit 7ab859c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
31 changes: 25 additions & 6 deletions callbacks/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def run_train(
)
model_parameters["num_classes"] = num_classes
model_parameters["network"] = model_name

if mask_uri is None:
notification = generate_notification(
"Mask Export", "red", ANNOT_ICONS["export"], mask_error_message
Expand All @@ -165,6 +166,7 @@ def run_train(
data_uri = tiled_datasets.get_data_uri_by_name(project_name)
io_parameters = assemble_io_parameters_from_uris(data_uri, mask_uri)
io_parameters["uid_retrieve"] = ""
io_parameters["models_dir"] = RESULTS_DIR

TRAIN_PARAMS_EXAMPLE["params_list"][0]["params"][
"io_parameters"
Expand Down Expand Up @@ -252,10 +254,12 @@ def run_inference(
return notification, no_update
model_parameters["num_classes"] = len(all_annotations)
model_parameters["network"] = model_name

# Set io_parameters for inference, there will be no mask
data_uri = tiled_datasets.get_data_uri_by_name(project_name)
io_parameters = assemble_io_parameters_from_uris(data_uri, "")
io_parameters["uid_retrieve"] = ""
io_parameters["models_dir"] = RESULTS_DIR

INFERENCE_PARAMS_EXAMPLE["params_list"][0]["params"][
"io_parameters"
Expand Down Expand Up @@ -420,7 +424,7 @@ def populate_segmentation_results(
ANNOT_ICONS["results"],
f"Could not retrieve result from {job_type} job!",
)
return notification, None
return notification, None, job_id
result_metadata = result_container.metadata
if result_metadata["data_uri"] == data_uri:
result_store = {
Expand All @@ -434,15 +438,16 @@ def populate_segmentation_results(
ANNOT_ICONS["results"],
f"Retrieved result from {job_type} job!",
)
return notification, result_store
return notification, result_store, job_id
else:
return no_update, None
return no_update, no_update
return no_update, None, None
return no_update, no_update, None


@callback(
Output("notifications-container", "children", allow_duplicate=True),
Output("seg-results-train-store", "data"),
Output("dvc-training-stats-link", "href"),
Input("train-job-selector", "value"),
Input("project-name-src", "value"),
prevent_initial_call=True,
Expand All @@ -452,7 +457,15 @@ def populate_segmentation_results_train(train_job_id, project_name):
This callback populates the segmentation results store based on the uids
if the training job and the inference job.
"""
return populate_segmentation_results(train_job_id, project_name, "training")
notification, result_store, segment_job_id = populate_segmentation_results(
train_job_id, project_name, "training"
)
if segment_job_id is not None:
results_link = os.path.join(RESULTS_DIR, segment_job_id, "results.html")
else:
results_link = no_update

return notification, result_store, results_link


@callback(
Expand All @@ -467,4 +480,10 @@ def populate_segmentation_results_inference(inference_job_id, project_name):
This callback populates the segmentation results store based on the uids
if the training job and the inference job.
"""
return populate_segmentation_results(inference_job_id, project_name, "inference")
notification, result_store, _ = populate_segmentation_results(
inference_job_id, project_name, "inference"
)
return (
notification,
result_store,
)
4 changes: 2 additions & 2 deletions components/control_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,8 +640,8 @@ def layout():
"dvc-training-stats",
dmc.Anchor(
dmc.Text("Open in new window"),
# href=RESULTS_DIR + uid from store "report.html",
href="assets/report.html",
id="dvc-training-stats-link",
href="",
target="_blank",
size="sm",
),
Expand Down

0 comments on commit 7ab859c

Please sign in to comment.