Skip to content

Commit

Permalink
Implement fit, transform
Browse files Browse the repository at this point in the history
Signed-off-by: Sangmin Yoon <[email protected]>
  • Loading branch information
sanspareilsmyn committed Dec 30, 2024
1 parent cd7e213 commit 7152cb9
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions aif360/algorithms/preprocessing/disparate_impact_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,49 @@ def __init__(self, repair_level=1.0, sensitive_attribute=''):

self.sensitive_attribute = sensitive_attribute

def fit(self, dataset):
"""Fit the model to the dataset.
Args:
dataset (BinaryLabelDataset): Dataset containing true labels and protected attributes.
Returns:
DisparateImpactRemover: Returns self after fitting the model.
Note:
This method sets the sensitive attribute if it is not already specified.
"""
if not self.sensitive_attribute:
self.sensitive_attribute = dataset.protected_attribute_names[0]

return self

def transform(self, dataset):
"""Transform the dataset using the fitted model.
Args:
dataset (BinaryLabelDataset): Dataset containing labels that needs to be transformed.
Returns:
BinaryLabelDataset: Transformed Dataset with adjusted feature values.
Note:
The transformation preserves the rank-ordering of features while modifying them
to reduce disparate impact based on the specified sensitive attribute.
"""
features = dataset.features.tolist()
index = dataset.feature_names.index(self.sensitive_attribute)
repairer = self.Repairer(features, index, self.repair_level, False)

transformed_features = repairer.repair(features)
transformed_dataset = dataset.copy()
transformed_dataset.features = np.array(transformed_features, dtype=np.float64)

# Ensure protected attributes remain unchanged
transformed_dataset.features[:, index] = transformed_dataset.protected_attributes[:,
transformed_dataset.protected_attribute_names.index(
self.sensitive_attribute)]

return transformed_dataset

def fit_transform(self, dataset):
"""Run a repairer on the non-protected features and return the
transformed dataset.
Expand Down

0 comments on commit 7152cb9

Please sign in to comment.