diff --git a/python/federatedml/util/data_transform.py b/python/federatedml/util/data_transform.py index 3ce3674d35..b0c549a322 100644 --- a/python/federatedml/util/data_transform.py +++ b/python/federatedml/util/data_transform.py @@ -80,6 +80,8 @@ def __init__(self, data_transform_param): else: self.exclusive_data_type = None + self.col_missing_fill_method = data_transform_param.col_missing_fill_method + def _update_param(self, schema): meta = schema["meta"] self.delimitor = meta.get("delimiter", ",") @@ -229,7 +231,8 @@ def fill_missing_value(self, input_data_features, mode="fit"): if mode == "fit": input_data_features, self.default_value = imputer_processor.fit(input_data_features, replace_method=self.missing_fill_method, - replace_value=self.default_value) + replace_value=self.default_value, + col_replace_method=self.col_missing_fill_method) if self.missing_impute is None: self.missing_impute = imputer_processor.get_missing_value_list() else: @@ -693,6 +696,7 @@ def __init__(self, data_transform_param): self.missing_impute = None self.anonymous_generator = None self.anonymous_header = None + self.col_missing_fill_method = data_transform_param.col_missing_fill_method def _update_param(self, schema): meta = schema["meta"] @@ -806,7 +810,8 @@ def fill_missing_value(self, input_data, tags_dict, schema, mode="fit"): if mode == "fit": data, self.default_value = imputer_processor.fit(input_data, replace_method=self.missing_fill_method, - replace_value=self.default_value) + replace_value=self.default_value, + col_replace_method=self.col_missing_fill_method) LOGGER.debug("self.default_value is {}".format(self.default_value)) else: data = imputer_processor.transform(input_data,