From 7152cb9ebc48f179834546433bce5fe8df88e07a Mon Sep 17 00:00:00 2001 From: Sangmin Yoon Date: Mon, 30 Dec 2024 17:41:37 +0900 Subject: [PATCH] Implement fit, transform Signed-off-by: Sangmin Yoon --- .../preprocessing/disparate_impact_remover.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/aif360/algorithms/preprocessing/disparate_impact_remover.py b/aif360/algorithms/preprocessing/disparate_impact_remover.py index ed88c8c9..1b701ead 100644 --- a/aif360/algorithms/preprocessing/disparate_impact_remover.py +++ b/aif360/algorithms/preprocessing/disparate_impact_remover.py @@ -35,6 +35,49 @@ def __init__(self, repair_level=1.0, sensitive_attribute=''): self.sensitive_attribute = sensitive_attribute + def fit(self, dataset): + """Fit the model to the dataset. + + Args: + dataset (BinaryLabelDataset): Dataset containing true labels and protected attributes. + Returns: + DisparateImpactRemover: Returns self after fitting the model. + + Note: + This method sets the sensitive attribute if it is not already specified. + """ + if not self.sensitive_attribute: + self.sensitive_attribute = dataset.protected_attribute_names[0] + + return self + + def transform(self, dataset): + """Transform the dataset using the fitted model. + + Args: + dataset (BinaryLabelDataset): Dataset containing labels that needs to be transformed. + Returns: + BinaryLabelDataset: Transformed Dataset with adjusted feature values. + + Note: + The transformation preserves the rank-ordering of features while modifying them + to reduce disparate impact based on the specified sensitive attribute. + """ + features = dataset.features.tolist() + index = dataset.feature_names.index(self.sensitive_attribute) + repairer = self.Repairer(features, index, self.repair_level, False) + + transformed_features = repairer.repair(features) + transformed_dataset = dataset.copy() + transformed_dataset.features = np.array(transformed_features, dtype=np.float64) + + # Ensure protected attributes remain unchanged + transformed_dataset.features[:, index] = transformed_dataset.protected_attributes[:, + transformed_dataset.protected_attribute_names.index( + self.sensitive_attribute)] + + return transformed_dataset + def fit_transform(self, dataset): """Run a repairer on the non-protected features and return the transformed dataset.