Skip to content

Commit

Permalink
Cleanup examples (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
madil90 authored Nov 25, 2021
1 parent a116f04 commit 7bb5aef
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 224 deletions.
196 changes: 92 additions & 104 deletions examples/hello-numpy-cross-val/custom/np_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.model import ModelLearnable
from nvflare.app_common.app_constant import AppConstants
Expand Down Expand Up @@ -72,103 +72,103 @@ def execute(
count, interval = 0, 0.5
while count < self._sleep_time:
if abort_signal.triggered:
return self._get_exception_shareable()
return make_reply(ReturnCode.TASK_ABORTED)
time.sleep(interval)
count += interval

if task_name == self._train_task_name:
# First we extract DXO from the shareable.
try:
incoming_dxo = from_shareable(shareable)
except BaseException as e:
self.system_panic(f"Unable to convert shareable to model definition. Exception {e.__str__()}", fl_ctx)
return self._get_exception_shareable()

# Information about workflow is retrieved from the shareable header.
current_round = shareable.get_header(AppConstants.CURRENT_ROUND, None)
total_rounds = shareable.get_header(AppConstants.NUM_ROUNDS, None)

# Ensure that data is of type weights. Extract model data.
if incoming_dxo.data_kind != DataKind.WEIGHTS:
self.system_panic("Model dex should be of kind DataKind.WEIGHTS.", fl_ctx)
return self._get_exception_shareable()
np_data = incoming_dxo.data

# Display properties.
self.log_info(fl_ctx, f"Incoming data kind: {incoming_dxo.data_kind}")
self.log_info(fl_ctx, f"Model: \n{np_data}")
self.log_info(fl_ctx, f"Current Round: {current_round}")
self.log_info(fl_ctx, f"Total Rounds: {total_rounds}")
self.log_info(fl_ctx, f"Task name: {task_name}")
self.log_info(fl_ctx, f"Client identity: {fl_ctx.get_identity_name()}")

# Check abort signal
if abort_signal.triggered:
return self._get_exception_shareable()

# Doing some dummy training.
if np_data:
if NPConstants.NUMPY_KEY in np_data:
np_data[NPConstants.NUMPY_KEY] += self._delta
try:
if task_name == self._train_task_name:
# First we extract DXO from the shareable.
try:
incoming_dxo = from_shareable(shareable)
except BaseException as e:
self.system_panic(f"Unable to convert shareable to model definition. "
f"Exception {e.__str__()}", fl_ctx)
return make_reply(ReturnCode.BAD_TASK_DATA)

# Ensure that data is of type weights. Extract model data.
if incoming_dxo.data_kind != DataKind.WEIGHTS:
self.system_panic("Model dex should be of kind DataKind.WEIGHTS.", fl_ctx)
return make_reply(ReturnCode.BAD_TASK_DATA)

# Check contents of data
np_data = incoming_dxo.data
if np_data:
if NPConstants.NUMPY_KEY not in np_data:
self.log_error(fl_ctx, "numpy_key not found in model.")
return make_reply(ReturnCode.BAD_TASK_DATA)
else:
self.log_error(fl_ctx, "numpy_key not found in model.")
shareable.set_return_code(ReturnCode.EXECUTION_RESULT_ERROR)
return shareable
else:
self.log_error(fl_ctx, "No model weights found in shareable.")
shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION)
return shareable

# We check abort_signal regularly to make sure
if abort_signal.triggered:
return self._get_exception_shareable()

# Save local numpy model
try:
self._save_local_model(fl_ctx, np_data)
except Exception as e:
self.log_error(fl_ctx, f"Exception in saving local model: {e}.")

self.log_info(
fl_ctx,
f"Model after training: {np_data}",
)

# Checking abort signal again.
if abort_signal.triggered:
return self._get_exception_shareable()

# Prepare a DXO for our updated model. Create shareable and return
outgoing_dxo = DXO(data_kind=incoming_dxo.data_kind, data=np_data, meta={})
return outgoing_dxo.to_shareable()
elif task_name == self._submit_model_task_name:
# Retrieve the local model saved during training.
np_data = None
try:
np_data = self._load_local_model(fl_ctx)
except Exception as e:
self.log_error(fl_ctx, f"Unable to load model: {e}")

# Checking abort signal
if abort_signal.triggered:
return self._get_exception_shareable()
self.log_error(fl_ctx, "No model weights found in shareable.")
return make_reply(ReturnCode.BAD_TASK_DATA)

# Information about workflow is retrieved from the shareable header.
current_round = shareable.get_header(AppConstants.CURRENT_ROUND, None)
total_rounds = shareable.get_header(AppConstants.NUM_ROUNDS, None)

# Display properties.
self.log_info(fl_ctx, f"Incoming data kind: {incoming_dxo.data_kind}")
self.log_info(fl_ctx, f"Model: \n{np_data}")
self.log_info(fl_ctx, f"Current Round: {current_round}")
self.log_info(fl_ctx, f"Total Rounds: {total_rounds}")
self.log_info(fl_ctx, f"Task name: {task_name}")
self.log_info(fl_ctx, f"Client identity: {fl_ctx.get_identity_name()}")

# Check abort signal
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)

np_data[NPConstants.NUMPY_KEY] += self._delta

# We check abort_signal regularly to make sure
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)

# Save local numpy model
try:
self._save_local_model(fl_ctx, np_data)
except Exception as e:
self.log_error(fl_ctx, f"Exception in saving local model: {e}.")

self.log_info(
fl_ctx,
f"Model after training: {np_data}",
)

# Checking abort signal again.
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)

# Prepare a DXO for our updated model. Create shareable and return
outgoing_dxo = DXO(data_kind=incoming_dxo.data_kind, data=np_data, meta={})
return outgoing_dxo.to_shareable()
elif task_name == self._submit_model_task_name:
# Retrieve the local model saved during training.
try:
np_data = self._load_local_model(fl_ctx)
except Exception as e:
self.log_error(fl_ctx, f"Unable to load model: {e}")
return make_reply(ReturnCode.EXECUTION_RESULT_ERROR)

# Checking abort signal
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)

# Create DXO and shareable from model data.
if np_data:
outgoing_dxo = DXO(data_kind=DataKind.WEIGHTS, data=np_data)
model_shareable = outgoing_dxo.to_shareable()
else:
# Set return code.
self.log_error(fl_ctx, f"local model not found.")
return make_reply(ReturnCode.EXECUTION_RESULT_ERROR)

# Create DXO and shareable from model data.
model_shareable = Shareable()
if np_data:
outgoing_dxo = DXO(data_kind=DataKind.WEIGHTS, data=np_data)
model_shareable = outgoing_dxo.to_shareable()
return model_shareable
else:
# Set return code.
self.log_error(fl_ctx, f"local model not found.")
model_shareable.set_return_code(ReturnCode.EXECUTION_RESULT_ERROR)

return model_shareable
else:
# If unknown task name, set RC accordingly.
shareable = Shareable()
shareable.set_return_code(ReturnCode.TASK_UNKNOWN)
return shareable
# If unknown task name, set RC accordingly.
return make_reply(ReturnCode.TASK_UNKNOWN)
except:
self.log_exception(fl_ctx, "Exception in NPTrainer execute.")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

def _load_local_model(self, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
Expand All @@ -177,7 +177,6 @@ def _load_local_model(self, fl_ctx: FLContext):
model_path = os.path.join(run_dir, self._model_dir)

model_load_path = os.path.join(model_path, self._model_name)
np_data = None
try:
np_data = np.load(model_load_path)
except Exception as e:
Expand All @@ -203,14 +202,3 @@ def _save_local_model(self, fl_ctx: FLContext, model: dict):
with open(model_save_path, "wb") as f:
np.save(f, model[NPConstants.NUMPY_KEY])
self.log_info(fl_ctx, f"Saved numpy model to: {model_save_path}")

def _get_exception_shareable(self) -> Shareable:
"""Abort execution. This is used if abort_signal is triggered. Users should
make sure they abort any running processes here.
Returns:
Shareable: Shareable with return_code.
"""
shareable = Shareable()
shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION)
return shareable
129 changes: 58 additions & 71 deletions examples/hello-numpy-cross-val/custom/np_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.app_constant import AppConstants

Expand Down Expand Up @@ -63,79 +63,66 @@ def execute(
count, interval = 0, 0.5
while count < self._sleep_time:
if abort_signal.triggered:
return self._abort_execution()
return make_reply(ReturnCode.TASK_ABORTED)
time.sleep(interval)
count += interval

if task_name == self._validate_task_name:
# First we extract DXO from the shareable.
try:
model_dxo = from_shareable(shareable)
except Exception as e:
self.log_error(fl_ctx, f"Unable to extract model dxo from shareable. Exception: {e.__str__()}")
shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION)
return shareable

# Get model from shareable. data_kind must be WEIGHTS.
if model_dxo.data and model_dxo.data_kind == DataKind.WEIGHTS:
model = model_dxo.data
else:
self.log_error(fl_ctx, f"Model DXO doesn't have data or is not of type DataKind.WEIGHTS. Unable "
"to validate.")
shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION)
return shareable

# The workflow provides MODEL_OWNER information in the shareable header.
model_name = shareable.get_header(AppConstants.MODEL_OWNER, "?")

# Print properties.
self.log_info(fl_ctx, f"Model: \n{model}")
self.log_info(fl_ctx, f"Task name: {task_name}")
self.log_info(fl_ctx, f"Client identity: {fl_ctx.get_identity_name()}")
self.log_info(fl_ctx, f"Validating model from {model_name}.")

# Check abort signal regularly.
if abort_signal.triggered:
return self._abort_execution()

# Check if key exists in model
if NPConstants.NUMPY_KEY not in model:
self.log_error(fl_ctx, "numpy_key not in model. Unable to validate.")
shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION)
return shareable

# Do some dummy validation.
random_epsilon = np.random.random()
self.log_info(
fl_ctx, f"Adding random epsilon {random_epsilon} in validation."
)
val_results = {}
np_data = model[NPConstants.NUMPY_KEY]
np_data = np.sum(np_data / np.max(np_data))
val_results["accuracy"] = np_data + random_epsilon

# Check abort signal regularly.
if abort_signal.triggered:
return self._abort_execution()

self.log_info(fl_ctx, f"Validation result: {val_results}")

# Create DXO for metrics and return shareable.
metric_dxo = DXO(data_kind=DataKind.METRICS, data=val_results)
return metric_dxo.to_shareable()

# First we extract DXO from the shareable.
try:
model_dxo = from_shareable(shareable)
except Exception as e:
self.log_error(fl_ctx, f"Unable to extract model dxo from shareable. Exception: {e.__str__()}")
return make_reply(ReturnCode.BAD_TASK_DATA)

# Get model from shareable. data_kind must be WEIGHTS.
if model_dxo.data and model_dxo.data_kind == DataKind.WEIGHTS:
model = model_dxo.data
else:
self.log_error(fl_ctx, f"Model DXO doesn't have data or is not of type DataKind.WEIGHTS. Unable "
"to validate.")
return make_reply(ReturnCode.BAD_TASK_DATA)

# Check if key exists in model
if NPConstants.NUMPY_KEY not in model:
self.log_error(fl_ctx, "numpy_key not in model. Unable to validate.")
return make_reply(ReturnCode.BAD_TASK_DATA)

# The workflow provides MODEL_OWNER information in the shareable header.
model_name = shareable.get_header(AppConstants.MODEL_OWNER, "?")

# Print properties.
self.log_info(fl_ctx, f"Model: \n{model}")
self.log_info(fl_ctx, f"Task name: {task_name}")
self.log_info(fl_ctx, f"Client identity: {fl_ctx.get_identity_name()}")
self.log_info(fl_ctx, f"Validating model from {model_name}.")

# Check abort signal regularly.
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)

# Do some dummy validation.
random_epsilon = np.random.random()
self.log_info(
fl_ctx, f"Adding random epsilon {random_epsilon} in validation."
)
val_results = {}
np_data = model[NPConstants.NUMPY_KEY]
np_data = np.sum(np_data / np.max(np_data))
val_results["accuracy"] = np_data + random_epsilon

# Check abort signal regularly.
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)

self.log_info(fl_ctx, f"Validation result: {val_results}")

# Create DXO for metrics and return shareable.
metric_dxo = DXO(data_kind=DataKind.METRICS, data=val_results)
return metric_dxo.to_shareable()
except:
self.log_exception(fl_ctx, "Exception in NPValidator execute.")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
else:
shareable = Shareable()
shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION)
return shareable

def _abort_execution(self) -> Shareable:
"""Abort execution. This is used if abort_signal is triggered. Users should
make sure they abort any running processes here.
Returns:
Shareable: Shareable with return_code.
"""
shareable = Shareable()
shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION)
return shareable
return make_reply(ReturnCode.TASK_UNKNOWN)
Loading

0 comments on commit 7bb5aef

Please sign in to comment.