diff --git a/examples/hello-numpy-cross-val/custom/np_trainer.py b/examples/hello-numpy-cross-val/custom/np_trainer.py index 2ca54b9d0e..1156d3a326 100755 --- a/examples/hello-numpy-cross-val/custom/np_trainer.py +++ b/examples/hello-numpy-cross-val/custom/np_trainer.py @@ -22,7 +22,7 @@ from nvflare.apis.executor import Executor from nvflare.apis.fl_constant import FLContextKey, ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import Shareable +from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.signal import Signal from nvflare.app_common.abstract.model import ModelLearnable from nvflare.app_common.app_constant import AppConstants @@ -72,103 +72,103 @@ def execute( count, interval = 0, 0.5 while count < self._sleep_time: if abort_signal.triggered: - return self._get_exception_shareable() + return make_reply(ReturnCode.TASK_ABORTED) time.sleep(interval) count += interval - if task_name == self._train_task_name: - # First we extract DXO from the shareable. - try: - incoming_dxo = from_shareable(shareable) - except BaseException as e: - self.system_panic(f"Unable to convert shareable to model definition. Exception {e.__str__()}", fl_ctx) - return self._get_exception_shareable() - - # Information about workflow is retrieved from the shareable header. - current_round = shareable.get_header(AppConstants.CURRENT_ROUND, None) - total_rounds = shareable.get_header(AppConstants.NUM_ROUNDS, None) - - # Ensure that data is of type weights. Extract model data. - if incoming_dxo.data_kind != DataKind.WEIGHTS: - self.system_panic("Model dex should be of kind DataKind.WEIGHTS.", fl_ctx) - return self._get_exception_shareable() - np_data = incoming_dxo.data - - # Display properties. - self.log_info(fl_ctx, f"Incoming data kind: {incoming_dxo.data_kind}") - self.log_info(fl_ctx, f"Model: \n{np_data}") - self.log_info(fl_ctx, f"Current Round: {current_round}") - self.log_info(fl_ctx, f"Total Rounds: {total_rounds}") - self.log_info(fl_ctx, f"Task name: {task_name}") - self.log_info(fl_ctx, f"Client identity: {fl_ctx.get_identity_name()}") - - # Check abort signal - if abort_signal.triggered: - return self._get_exception_shareable() - - # Doing some dummy training. - if np_data: - if NPConstants.NUMPY_KEY in np_data: - np_data[NPConstants.NUMPY_KEY] += self._delta + try: + if task_name == self._train_task_name: + # First we extract DXO from the shareable. + try: + incoming_dxo = from_shareable(shareable) + except BaseException as e: + self.system_panic(f"Unable to convert shareable to model definition. " + f"Exception {e.__str__()}", fl_ctx) + return make_reply(ReturnCode.BAD_TASK_DATA) + + # Ensure that data is of type weights. Extract model data. + if incoming_dxo.data_kind != DataKind.WEIGHTS: + self.system_panic("Model dex should be of kind DataKind.WEIGHTS.", fl_ctx) + return make_reply(ReturnCode.BAD_TASK_DATA) + + # Check contents of data + np_data = incoming_dxo.data + if np_data: + if NPConstants.NUMPY_KEY not in np_data: + self.log_error(fl_ctx, "numpy_key not found in model.") + return make_reply(ReturnCode.BAD_TASK_DATA) else: - self.log_error(fl_ctx, "numpy_key not found in model.") - shareable.set_return_code(ReturnCode.EXECUTION_RESULT_ERROR) - return shareable - else: - self.log_error(fl_ctx, "No model weights found in shareable.") - shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return shareable - - # We check abort_signal regularly to make sure - if abort_signal.triggered: - return self._get_exception_shareable() - - # Save local numpy model - try: - self._save_local_model(fl_ctx, np_data) - except Exception as e: - self.log_error(fl_ctx, f"Exception in saving local model: {e}.") - - self.log_info( - fl_ctx, - f"Model after training: {np_data}", - ) - - # Checking abort signal again. - if abort_signal.triggered: - return self._get_exception_shareable() - - # Prepare a DXO for our updated model. Create shareable and return - outgoing_dxo = DXO(data_kind=incoming_dxo.data_kind, data=np_data, meta={}) - return outgoing_dxo.to_shareable() - elif task_name == self._submit_model_task_name: - # Retrieve the local model saved during training. - np_data = None - try: - np_data = self._load_local_model(fl_ctx) - except Exception as e: - self.log_error(fl_ctx, f"Unable to load model: {e}") - - # Checking abort signal - if abort_signal.triggered: - return self._get_exception_shareable() + self.log_error(fl_ctx, "No model weights found in shareable.") + return make_reply(ReturnCode.BAD_TASK_DATA) + + # Information about workflow is retrieved from the shareable header. + current_round = shareable.get_header(AppConstants.CURRENT_ROUND, None) + total_rounds = shareable.get_header(AppConstants.NUM_ROUNDS, None) + + # Display properties. + self.log_info(fl_ctx, f"Incoming data kind: {incoming_dxo.data_kind}") + self.log_info(fl_ctx, f"Model: \n{np_data}") + self.log_info(fl_ctx, f"Current Round: {current_round}") + self.log_info(fl_ctx, f"Total Rounds: {total_rounds}") + self.log_info(fl_ctx, f"Task name: {task_name}") + self.log_info(fl_ctx, f"Client identity: {fl_ctx.get_identity_name()}") + + # Check abort signal + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + + np_data[NPConstants.NUMPY_KEY] += self._delta + + # We check abort_signal regularly to make sure + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + + # Save local numpy model + try: + self._save_local_model(fl_ctx, np_data) + except Exception as e: + self.log_error(fl_ctx, f"Exception in saving local model: {e}.") + + self.log_info( + fl_ctx, + f"Model after training: {np_data}", + ) + + # Checking abort signal again. + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + + # Prepare a DXO for our updated model. Create shareable and return + outgoing_dxo = DXO(data_kind=incoming_dxo.data_kind, data=np_data, meta={}) + return outgoing_dxo.to_shareable() + elif task_name == self._submit_model_task_name: + # Retrieve the local model saved during training. + try: + np_data = self._load_local_model(fl_ctx) + except Exception as e: + self.log_error(fl_ctx, f"Unable to load model: {e}") + return make_reply(ReturnCode.EXECUTION_RESULT_ERROR) + + # Checking abort signal + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + + # Create DXO and shareable from model data. + if np_data: + outgoing_dxo = DXO(data_kind=DataKind.WEIGHTS, data=np_data) + model_shareable = outgoing_dxo.to_shareable() + else: + # Set return code. + self.log_error(fl_ctx, f"local model not found.") + return make_reply(ReturnCode.EXECUTION_RESULT_ERROR) - # Create DXO and shareable from model data. - model_shareable = Shareable() - if np_data: - outgoing_dxo = DXO(data_kind=DataKind.WEIGHTS, data=np_data) - model_shareable = outgoing_dxo.to_shareable() + return model_shareable else: - # Set return code. - self.log_error(fl_ctx, f"local model not found.") - model_shareable.set_return_code(ReturnCode.EXECUTION_RESULT_ERROR) - - return model_shareable - else: - # If unknown task name, set RC accordingly. - shareable = Shareable() - shareable.set_return_code(ReturnCode.TASK_UNKNOWN) - return shareable + # If unknown task name, set RC accordingly. + return make_reply(ReturnCode.TASK_UNKNOWN) + except: + self.log_exception(fl_ctx, "Exception in NPTrainer execute.") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) def _load_local_model(self, fl_ctx: FLContext): engine = fl_ctx.get_engine() @@ -177,7 +177,6 @@ def _load_local_model(self, fl_ctx: FLContext): model_path = os.path.join(run_dir, self._model_dir) model_load_path = os.path.join(model_path, self._model_name) - np_data = None try: np_data = np.load(model_load_path) except Exception as e: @@ -203,14 +202,3 @@ def _save_local_model(self, fl_ctx: FLContext, model: dict): with open(model_save_path, "wb") as f: np.save(f, model[NPConstants.NUMPY_KEY]) self.log_info(fl_ctx, f"Saved numpy model to: {model_save_path}") - - def _get_exception_shareable(self) -> Shareable: - """Abort execution. This is used if abort_signal is triggered. Users should - make sure they abort any running processes here. - - Returns: - Shareable: Shareable with return_code. - """ - shareable = Shareable() - shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return shareable diff --git a/examples/hello-numpy-cross-val/custom/np_validator.py b/examples/hello-numpy-cross-val/custom/np_validator.py index 4698353a85..eec66dfa3c 100755 --- a/examples/hello-numpy-cross-val/custom/np_validator.py +++ b/examples/hello-numpy-cross-val/custom/np_validator.py @@ -22,7 +22,7 @@ from nvflare.apis.executor import Executor from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import Shareable +from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.signal import Signal from nvflare.app_common.app_constant import AppConstants @@ -63,79 +63,66 @@ def execute( count, interval = 0, 0.5 while count < self._sleep_time: if abort_signal.triggered: - return self._abort_execution() + return make_reply(ReturnCode.TASK_ABORTED) time.sleep(interval) count += interval if task_name == self._validate_task_name: - # First we extract DXO from the shareable. try: - model_dxo = from_shareable(shareable) - except Exception as e: - self.log_error(fl_ctx, f"Unable to extract model dxo from shareable. Exception: {e.__str__()}") - shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return shareable - - # Get model from shareable. data_kind must be WEIGHTS. - if model_dxo.data and model_dxo.data_kind == DataKind.WEIGHTS: - model = model_dxo.data - else: - self.log_error(fl_ctx, f"Model DXO doesn't have data or is not of type DataKind.WEIGHTS. Unable " - "to validate.") - shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return shareable - - # The workflow provides MODEL_OWNER information in the shareable header. - model_name = shareable.get_header(AppConstants.MODEL_OWNER, "?") - - # Print properties. - self.log_info(fl_ctx, f"Model: \n{model}") - self.log_info(fl_ctx, f"Task name: {task_name}") - self.log_info(fl_ctx, f"Client identity: {fl_ctx.get_identity_name()}") - self.log_info(fl_ctx, f"Validating model from {model_name}.") - - # Check abort signal regularly. - if abort_signal.triggered: - return self._abort_execution() - - # Check if key exists in model - if NPConstants.NUMPY_KEY not in model: - self.log_error(fl_ctx, "numpy_key not in model. Unable to validate.") - shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return shareable - - # Do some dummy validation. - random_epsilon = np.random.random() - self.log_info( - fl_ctx, f"Adding random epsilon {random_epsilon} in validation." - ) - val_results = {} - np_data = model[NPConstants.NUMPY_KEY] - np_data = np.sum(np_data / np.max(np_data)) - val_results["accuracy"] = np_data + random_epsilon - - # Check abort signal regularly. - if abort_signal.triggered: - return self._abort_execution() - - self.log_info(fl_ctx, f"Validation result: {val_results}") - - # Create DXO for metrics and return shareable. - metric_dxo = DXO(data_kind=DataKind.METRICS, data=val_results) - return metric_dxo.to_shareable() - + # First we extract DXO from the shareable. + try: + model_dxo = from_shareable(shareable) + except Exception as e: + self.log_error(fl_ctx, f"Unable to extract model dxo from shareable. Exception: {e.__str__()}") + return make_reply(ReturnCode.BAD_TASK_DATA) + + # Get model from shareable. data_kind must be WEIGHTS. + if model_dxo.data and model_dxo.data_kind == DataKind.WEIGHTS: + model = model_dxo.data + else: + self.log_error(fl_ctx, f"Model DXO doesn't have data or is not of type DataKind.WEIGHTS. Unable " + "to validate.") + return make_reply(ReturnCode.BAD_TASK_DATA) + + # Check if key exists in model + if NPConstants.NUMPY_KEY not in model: + self.log_error(fl_ctx, "numpy_key not in model. Unable to validate.") + return make_reply(ReturnCode.BAD_TASK_DATA) + + # The workflow provides MODEL_OWNER information in the shareable header. + model_name = shareable.get_header(AppConstants.MODEL_OWNER, "?") + + # Print properties. + self.log_info(fl_ctx, f"Model: \n{model}") + self.log_info(fl_ctx, f"Task name: {task_name}") + self.log_info(fl_ctx, f"Client identity: {fl_ctx.get_identity_name()}") + self.log_info(fl_ctx, f"Validating model from {model_name}.") + + # Check abort signal regularly. + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + + # Do some dummy validation. + random_epsilon = np.random.random() + self.log_info( + fl_ctx, f"Adding random epsilon {random_epsilon} in validation." + ) + val_results = {} + np_data = model[NPConstants.NUMPY_KEY] + np_data = np.sum(np_data / np.max(np_data)) + val_results["accuracy"] = np_data + random_epsilon + + # Check abort signal regularly. + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + + self.log_info(fl_ctx, f"Validation result: {val_results}") + + # Create DXO for metrics and return shareable. + metric_dxo = DXO(data_kind=DataKind.METRICS, data=val_results) + return metric_dxo.to_shareable() + except: + self.log_exception(fl_ctx, "Exception in NPValidator execute.") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) else: - shareable = Shareable() - shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return shareable - - def _abort_execution(self) -> Shareable: - """Abort execution. This is used if abort_signal is triggered. Users should - make sure they abort any running processes here. - - Returns: - Shareable: Shareable with return_code. - """ - shareable = Shareable() - shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return shareable + return make_reply(ReturnCode.TASK_UNKNOWN) diff --git a/examples/hello-numpy-sag/custom/np_trainer.py b/examples/hello-numpy-sag/custom/np_trainer.py index d65b609acc..aee07fc83e 100755 --- a/examples/hello-numpy-sag/custom/np_trainer.py +++ b/examples/hello-numpy-sag/custom/np_trainer.py @@ -22,7 +22,7 @@ from nvflare.apis.executor import Executor from nvflare.apis.fl_constant import FLContextKey, ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import Shareable +from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.signal import Signal from nvflare.app_common.abstract.model import ModelLearnable from nvflare.app_common.app_constant import AppConstants @@ -72,7 +72,7 @@ def execute( count, interval = 0, 0.5 while count < self._sleep_time: if abort_signal.triggered: - return self._get_exception_shareable() + return make_reply(ReturnCode.TASK_ABORTED) time.sleep(interval) count += interval @@ -82,7 +82,7 @@ def execute( incoming_dxo = from_shareable(shareable) except BaseException as e: self.system_panic(f"Unable to convert shareable to model definition. Exception {e.__str__()}", fl_ctx) - return self._get_exception_shareable() + return make_reply(ReturnCode.BAD_TASK_DATA) # Information about workflow is retrieved from the shareable header. current_round = shareable.get_header(AppConstants.CURRENT_ROUND, None) @@ -91,7 +91,7 @@ def execute( # Ensure that data is of type weights. Extract model data. if incoming_dxo.data_kind != DataKind.WEIGHTS: self.system_panic("Model dex should be of kind DataKind.WEIGHTS.", fl_ctx) - return self._get_exception_shareable() + return make_reply(ReturnCode.BAD_TASK_DATA) np_data = incoming_dxo.data # Display properties. @@ -104,7 +104,7 @@ def execute( # Check abort signal if abort_signal.triggered: - return self._get_exception_shareable() + return make_reply(ReturnCode.TASK_ABORTED) # Doing some dummy training. if np_data: @@ -112,16 +112,14 @@ def execute( np_data[NPConstants.NUMPY_KEY] += self._delta else: self.log_error(fl_ctx, "numpy_key not found in model.") - shareable.set_return_code(ReturnCode.EXECUTION_RESULT_ERROR) - return shareable + return make_reply(ReturnCode.BAD_TASK_DATA) else: self.log_error(fl_ctx, "No model weights found in shareable.") - shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return shareable + return make_reply(ReturnCode.BAD_TASK_DATA) # We check abort_signal regularly to make sure if abort_signal.triggered: - return self._get_exception_shareable() + return make_reply(ReturnCode.TASK_ABORTED) # Save local numpy model try: @@ -136,16 +134,14 @@ def execute( # Checking abort signal again. if abort_signal.triggered: - return self._get_exception_shareable() + return make_reply(ReturnCode.TASK_ABORTED) # Prepare a DXO for our updated model. Create shareable and return outgoing_dxo = DXO(data_kind=incoming_dxo.data_kind, data=np_data, meta={}) return outgoing_dxo.to_shareable() else: # If unknown task name, set RC accordingly. - shareable = Shareable() - shareable.set_return_code(ReturnCode.TASK_UNKNOWN) - return shareable + return make_reply(ReturnCode.TASK_UNKNOWN) def _load_local_model(self, fl_ctx: FLContext): engine = fl_ctx.get_engine() @@ -179,14 +175,3 @@ def _save_local_model(self, fl_ctx: FLContext, model: dict): with open(model_save_path, "wb") as f: np.save(f, model[NPConstants.NUMPY_KEY]) self.log_info(fl_ctx, f"Saved numpy model to: {model_save_path}") - - def _get_exception_shareable(self) -> Shareable: - """Abort execution. This is used if abort_signal is triggered. Users should - make sure they abort any running processes here. - - Returns: - Shareable: Shareable with return_code. - """ - shareable = Shareable() - shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return shareable diff --git a/examples/hello-pt/custom/cifar10trainer.py b/examples/hello-pt/custom/cifar10trainer.py index b25d94a0ee..0df51ceb02 100644 --- a/examples/hello-pt/custom/cifar10trainer.py +++ b/examples/hello-pt/custom/cifar10trainer.py @@ -24,7 +24,7 @@ from nvflare.apis.executor import Executor from nvflare.apis.fl_constant import ReturnCode, ReservedKey from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import Shareable +from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.signal import Signal from nvflare.app_common.abstract.model import make_model_learnable, model_learnable_to_dxo from nvflare.app_common.app_constant import AppConstants @@ -68,7 +68,7 @@ def __init__(self, lr=0.01, epochs=5, train_task_name=AppConstants.TASK_TRAIN, self.persistence_manager = PTModelPersistenceFormatManager(data=self.model.state_dict(), default_train_conf=self.default_train_conf) - def local_train(self, fl_ctx, weights): + def local_train(self, fl_ctx, weights, abort_signal): # Set the model weights self.model.load_state_dict(state_dict=weights) @@ -77,6 +77,11 @@ def local_train(self, fl_ctx, weights): for epoch in range(self.epochs): running_loss = 0 for i, batch in enumerate(self.train_loader): + if abort_signal.triggered: + # If abort_signal is triggered, we simply return. + # The outside function will check it again and decide steps to take. + return + images, labels = batch[0].to(self.device), batch[1].to(self.device) self.optimizer.zero_grad() @@ -95,23 +100,27 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort try: if task_name == self._train_task_name: # Get model weights - dxo = from_shareable(shareable) + try: + dxo = from_shareable(shareable) + except: + self.log_error(fl_ctx, "Unable to extract dxo from shareable.") + return make_reply(ReturnCode.BAD_TASK_DATA) # Check if dxo is valid. if not isinstance(dxo, DXO): self.log_exception(fl_ctx, f"dxo excepted type DXO. Got {type(dxo)} instead.") - shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return shareable + return make_reply(ReturnCode.BAD_TASK_DATA) # Ensure data kind is weights. if not dxo.data_kind == DataKind.WEIGHTS: self.log_exception(fl_ctx, f"data_kind expected WEIGHTS but got {dxo.data_kind} instead.") - shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return shareable + return make_reply(ReturnCode.BAD_TASK_DATA) # Convert weights to tensor. Run training torch_weights = {k: torch.as_tensor(v) for k, v in dxo.data.items()} - self.local_train(fl_ctx, torch_weights) + self.local_train(fl_ctx, torch_weights, abort_signal) + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) self.save_local_model(fl_ctx) @@ -130,14 +139,10 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort dxo = model_learnable_to_dxo(ml) return dxo.to_shareable() else: - shareable = Shareable() - shareable.set_return_code(ReturnCode.TASK_UNKNOWN) - return shareable + return make_reply(ReturnCode.TASK_UNKNOWN) except: self.log_exception(fl_ctx, f"Exception in simple trainer.") - shareable = Shareable() - shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return shareable + return make_reply(ReturnCode.EXECUTION_EXCEPTION) def save_local_model(self, fl_ctx: FLContext): run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_prop(ReservedKey.RUN_NUM)) diff --git a/examples/hello-pt/custom/cifar10validator.py b/examples/hello-pt/custom/cifar10validator.py index 195633184c..b6ab1853c9 100644 --- a/examples/hello-pt/custom/cifar10validator.py +++ b/examples/hello-pt/custom/cifar10validator.py @@ -21,7 +21,7 @@ from nvflare.apis.executor import Executor from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import Shareable +from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.signal import Signal from nvflare.app_common.app_constant import AppConstants @@ -49,20 +49,19 @@ def __init__(self, validate_task_name=AppConstants.TASK_VALIDATION): def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: out_shareable = Shareable() if task_name == self._validate_task_name: + model_owner = "?" try: dxo = from_shareable(shareable) # Check if dxo is valid. if not dxo: self.log_exception(fl_ctx, "DXO invalid") - out_shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return out_shareable + return make_reply(ReturnCode.BAD_TASK_DATA) # Ensure data_kind is weights. if not dxo.data_kind == DataKind.WEIGHTS: self.log_exception(fl_ctx, f"DXO is of type {dxo.data_kind} but expected type WEIGHTS.") - out_shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return out_shareable + return make_reply(ReturnCode.BAD_TASK_DATA) # Extract weights and ensure they are tensor. model_owner = shareable.get_header(AppConstants.MODEL_OWNER, "?") @@ -71,6 +70,9 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort # Get validation accuracy val_accuracy = self.do_validation(weights) + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + self.log_info(fl_ctx, f"Accuracy when validating {model_owner}'s model on" f" {fl_ctx.get_identity_name()}"f's data: {val_accuracy}') @@ -78,13 +80,11 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort return dxo.to_shareable() except: self.log_exception(fl_ctx, f"Exception in validating model from {model_owner}") - out_shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return out_shareable + return make_reply(ReturnCode.EXECUTION_EXCEPTION) else: - out_shareable.set_return_code(ReturnCode.TASK_UNKNOWN) - return out_shareable + return make_reply(ReturnCode.TASK_UNKNOWN) - def do_validation(self, weights): + def do_validation(self, weights, abort_signal): self.model.load_state_dict(weights) self.model.eval() @@ -93,6 +93,9 @@ def do_validation(self, weights): total = 0 with torch.no_grad(): for i, (images, labels) in enumerate(self.test_loader): + if abort_signal.triggered: + return 0 + images, labels = images.to(self.device), labels.to(self.device) output = self.model(images)