Skip to content

Commit

Permalink
Merge pull request #5271 from FederatedAI/feature-1.11.5-stats
Browse files Browse the repository at this point in the history
edit feature imputation
  • Loading branch information
mgqa34 authored Nov 21, 2023
2 parents 7b4374e + ec38c52 commit 921d32d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
9 changes: 5 additions & 4 deletions python/federatedml/feature/feature_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ def get_summary(self):
def load_model(self, model_dict):
param_obj = list(model_dict.get('model').values())[0].get(self.model_param_name)
meta_obj = list(model_dict.get('model').values())[0].get(self.model_meta_name)
self.header = param_obj.header
self.header = list(param_obj.header)
self.missing_fill, self.missing_fill_method, \
self.missing_impute, self.fill_value, self.skip_cols = load_feature_imputer_model(self.header,
"Imputer",
meta_obj.imputer_meta,
param_obj.imputer_param)

def save_model(self):
meta_obj, param_obj = save_feature_imputer_model(missing_fill=True,
meta_obj, param_obj = save_feature_imputer_model(missing_fill=self.model_param.need_run,
missing_replace_method=self.missing_fill_method,
cols_replace_method=self.cols_replace_method,
missing_impute=self.missing_impute,
Expand Down Expand Up @@ -159,10 +159,11 @@ def save_feature_imputer_model(missing_fill=False,

if missing_fill_value is not None and header is not None:
fill_header = [col for col in header if col not in skip_cols]
feature_value_dict = dict(zip(fill_header, map(str, missing_fill_value)))
fill_value = [missing_fill_value[header.index(col)] for col in fill_header]
feature_value_dict = dict(zip(fill_header, map(str, fill_value)))

model_param.missing_replace_value.update(feature_value_dict)
missing_fill_value_type = [type(v).__name__ for v in missing_fill_value]
missing_fill_value_type = [type(v).__name__ for v in fill_value]
feature_value_type_dict = dict(zip(fill_header, missing_fill_value_type))
model_param.missing_replace_value_type.update(feature_value_type_dict)

Expand Down
22 changes: 11 additions & 11 deletions python/federatedml/feature/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,19 +234,19 @@ def __get_cols_transform_value(self, data, replace_method, replace_value=None, c
raise ValueError("Unknown replace method:{}".format(replace_method))
cols_transform_value[feature] = transform_value

# LOGGER.debug(f"cols_transform value is: {cols_transform_value}")
cols_transform_value = [cols_transform_value[key] for key in header]
# cols_transform_value = {i: round(cols_transform_value[key], 6) for i, key in enumerate(header)}
# LOGGER.debug(f"cols_transform value is: {cols_transform_value}")
return cols_transform_value

@staticmethod
def _transform_nan(instance):
def _transform_nan(instance, skip_idx):
feature_shape = instance.features.shape[0]
new_features = []
contains_na = False
skip_idx = set(skip_idx)
for i in range(feature_shape):
if instance.features[i] != instance.features[i]:
if instance.features[i] != instance.features[i] and i not in skip_idx:
new_features.append(NoneType())
contains_na = True
else:
Expand All @@ -264,23 +264,23 @@ def __fit_replace(self, data, replace_method, replace_value=None, output_format=
replace_method_per_col, skip_cols = self.__get_cols_transform_method(data, replace_method, col_replace_method)

schema = data.schema
skip_idx = [get_header(data).index(v) for v in skip_cols]
if isinstance(data.first()[1], Instance):
data = data.mapValues(lambda v: Imputer._transform_nan(v))
data = data.mapValues(lambda v: Imputer._transform_nan(v, skip_idx))
data.schema = schema
cols_transform_value = self.__get_cols_transform_value(data, replace_method_per_col,
replace_value=replace_value,
col_replace_value=col_replace_value,
error=error, multi_mode=multi_mode)
self.skip_cols = skip_cols
skip_cols = [get_header(data).index(v) for v in skip_cols]
if output_format is not None:
f = functools.partial(Imputer.replace_missing_value_with_cols_transform_value_format,
transform_list=cols_transform_value, missing_value_list=self.abnormal_value_set,
output_format=output_format, skip_cols=set(skip_cols))
output_format=output_format, skip_cols=set(skip_idx))
else:
f = functools.partial(Imputer.replace_missing_value_with_cols_transform_value,
transform_list=cols_transform_value, missing_value_list=self.abnormal_value_set,
skip_cols=set(skip_cols))
skip_cols=set(skip_idx))

transform_data = data.mapValues(f)
self.cols_replace_method = replace_method_per_col
Expand All @@ -289,11 +289,11 @@ def __fit_replace(self, data, replace_method, replace_value=None, output_format=
return transform_data, cols_transform_value

def __transform_replace(self, data, transform_value, replace_area, output_format, skip_cols):
skip_cols = [get_header(data).index(v) for v in skip_cols]
skip_idx = [get_header(data).index(v) for v in skip_cols]

schema = data.schema
if isinstance(data.first()[1], Instance):
data = data.mapValues(lambda v: Imputer._transform_nan(v))
data = data.mapValues(lambda v: Imputer._transform_nan(v, skip_idx))
data.schema = schema

if replace_area == 'all':
Expand All @@ -309,11 +309,11 @@ def __transform_replace(self, data, transform_value, replace_area, output_format
f = functools.partial(Imputer.replace_missing_value_with_cols_transform_value_format,
transform_list=transform_value, missing_value_list=self.abnormal_value_set,
output_format=output_format,
skip_cols=set(skip_cols))
skip_cols=set(skip_idx))
else:
f = functools.partial(Imputer.replace_missing_value_with_cols_transform_value,
transform_list=transform_value, missing_value_list=self.abnormal_value_set,
skip_cols=set(skip_cols))
skip_cols=set(skip_idx))
else:
raise ValueError("Unknown replace area {} in Imputer".format(replace_area))

Expand Down

0 comments on commit 921d32d

Please sign in to comment.