diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b01a1f80..8b7482f69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - Fix `BaseReconciliator` to work on `pandas==1.1.5` ([#1229](https://github.com/tinkoff-ai/etna/pull/1229)) - -- +- Fix warning during creation of `ResampleWithDistributionTransform` ([#1230](https://github.com/tinkoff-ai/etna/pull/1230)) ## [2.0.0] - 2023-04-11 ### Added diff --git a/etna/transforms/missing_values/resample.py b/etna/transforms/missing_values/resample.py index 48b55bb5d..c81c8121c 100644 --- a/etna/transforms/missing_values/resample.py +++ b/etna/transforms/missing_values/resample.py @@ -140,26 +140,30 @@ def __init__( self.in_column = in_column self.distribution_column = distribution_column self.inplace = inplace - self.out_column = self._get_out_column(out_column) + self.out_column = out_column self.in_column_regressor: Optional[bool] = None + + if self.inplace and out_column: + warnings.warn("Transformation will be applied inplace, out_column param will be ignored") + super().__init__( transform=_OneSegmentResampleWithDistributionTransform( in_column=in_column, distribution_column=distribution_column, inplace=inplace, - out_column=self.out_column, + out_column=self._get_column_name(), ), required_features=[in_column, distribution_column], ) - def _get_out_column(self, out_column: Optional[str]) -> str: + def _get_column_name( + self, + ) -> str: """Get the `out_column` depending on the transform's parameters.""" - if self.inplace and out_column: - warnings.warn("Transformation will be applied inplace, out_column param will be ignored") if self.inplace: return self.in_column - if out_column: - return out_column + if self.out_column: + return self.out_column return self.__repr__() def get_regressors_info(self) -> List[str]: @@ -168,7 +172,7 @@ def get_regressors_info(self) -> List[str]: raise ValueError("Fit the transform to get the correct regressors info!") if self.inplace: return [] - return [self.out_column] if self.in_column_regressor else [] + return [self._get_column_name()] if self.in_column_regressor else [] def fit(self, ts: TSDataset) -> "ResampleWithDistributionTransform": """Fit the transform."""