Skip to content

Commit

Permalink
add flake8-comprehentions (#689)
Browse files Browse the repository at this point in the history
  • Loading branch information
iKintosh authored May 26, 2022
1 parent eebf273 commit dd3175c
Show file tree
Hide file tree
Showing 19 changed files with 148 additions and 131 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[flake8]
ignore = F, E203, W605, E501, W503, D100, D104
ignore = F, E203, W605, E501, W503, D100, D104, C408
max-line-length = 121
max-complexity = 18
docstring-convention=numpy
Expand Down
4 changes: 2 additions & 2 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def _select_quantiles(forecast_results: Dict[str, "TSDataset"], quantiles: Optio
intersection_quantiles_set = set.intersection(
*[_get_existing_quantiles(forecast) for forecast in forecast_results.values()]
)
intersection_quantiles = sorted(list(intersection_quantiles_set))
intersection_quantiles = sorted(intersection_quantiles_set)

if quantiles is None:
selected_quantiles = intersection_quantiles
else:
selected_quantiles = sorted(list(set(quantiles) & intersection_quantiles_set))
selected_quantiles = sorted(set(quantiles) & intersection_quantiles_set)
non_existent = set(quantiles) - intersection_quantiles_set
if non_existent:
warnings.warn(f"Quantiles {non_existent} do not exist in each forecast dataset. They will be dropped.")
Expand Down
4 changes: 2 additions & 2 deletions etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def _check_known_future(

if isinstance(known_future, str):
if known_future == "all":
return sorted(list(exog_columns))
return sorted(exog_columns)
else:
raise ValueError("The only possible literal is 'all'")
else:
Expand All @@ -335,7 +335,7 @@ def _check_known_future(
f"{known_future_unique.difference(exog_columns)}"
)
else:
return sorted(list(known_future_unique))
return sorted(known_future_unique)

@staticmethod
def _check_regressors(df: pd.DataFrame, df_regressors: pd.DataFrame):
Expand Down
2 changes: 1 addition & 1 deletion etna/ensembles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _validate_pipeline_number(pipelines: List[BasePipeline]):
@staticmethod
def _get_horizon(pipelines: List[BasePipeline]) -> int:
"""Get ensemble's horizon."""
horizons = set([pipeline.horizon for pipeline in pipelines])
horizons = {pipeline.horizon for pipeline in pipelines}
if len(horizons) > 1:
raise ValueError("All the pipelines should have the same horizon.")
return horizons.pop()
Expand Down
2 changes: 1 addition & 1 deletion etna/transforms/feature_selection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, features_to_use: Union[List[str], Literal["all"]] = "all"):

def _get_features_to_use(self, df: pd.DataFrame) -> List[str]:
"""Get list of features from the dataframe to perform the selection on."""
features = set(df.columns.get_level_values("feature")) - set(["target"])
features = set(df.columns.get_level_values("feature")) - {"target"}
if self.features_to_use != "all":
features = features.intersection(self.features_to_use)
if sorted(features) != sorted(self.features_to_use):
Expand Down
2 changes: 1 addition & 1 deletion etna/transforms/missing_values/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _get_folds(self, df: pd.DataFrame) -> List[int]:
in_column_start_index = in_column_index[0]
left_tie_len = len(df[:in_column_start_index]) - 1
right_tie_len = len(df[in_column_start_index:])
folds_for_left_tie = [fold for fold in range(n_folds_per_gap - left_tie_len, n_folds_per_gap)]
folds_for_left_tie = list(range(n_folds_per_gap - left_tie_len, n_folds_per_gap))
folds_for_right_tie = [fold for _ in range(n_periods) for fold in range(n_folds_per_gap)][:right_tie_len]
return folds_for_left_tie + folds_for_right_tie

Expand Down
2 changes: 1 addition & 1 deletion etna/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
def match_target_quantiles(features: Set[str]) -> Set[str]:
"""Find quantiles in dataframe columns."""
pattern = re.compile("target_\d+\.\d+$")
return set(i for i in list(features) if pattern.match(i) is not None)
return {i for i in list(features) if pattern.match(i) is not None}
230 changes: 123 additions & 107 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ ruptures = "1.1.5"
numba = ">=0.53.1,<0.56.0"
seaborn = "^0.11.1"
statsmodels = ">=0.12,<0.14"
pmdarima = ">=1.8.0"
dill = "^0.3.4"
toml = "^0.10.2"
loguru = "^0.5.3"
Expand Down Expand Up @@ -89,6 +90,7 @@ isort = {version = "^5.8.0", optional = true}
flake8 = {version = "^3.9.2", optional = true}
pep8-naming = {version = "^0.12.1", optional = true}
flake8-bugbear = {version = "^22.4.25", optional = true}
flake8-comprehensions = {version = "^3.9.0", optional = true}
flake8-docstrings = {version = "^1.6.0", optional = true}
mypy = {version = "^0.910", optional = true}
types-PyYAML = {version = "^6.0.0", optional = true}
Expand All @@ -100,7 +102,6 @@ ipywidgets = {version = "^7.6.5", optional = true}

jupyter = {version = "*", optional = true}
nbconvert = {version = "*", optional = true}
pmdarima = ">=1.8.0"


[tool.poetry.extras]
Expand All @@ -113,7 +114,7 @@ release = ["click", "semver"]
docs = ["Sphinx", "numpydoc", "sphinx-rtd-theme", "nbsphinx", "sphinx-mathjax-offline", "myst-parser", "GitPython"]
tests = ["pytest-cov", "coverage", "pytest"]
jupyter = ["jupyter", "nbconvert"]
style = ["black", "isort", "flake8", "pep8-naming", "flake8-docstrings", "mypy", "types-PyYAML", "codespell", "flake8-bugbear"]
style = ["black", "isort", "flake8", "pep8-naming", "flake8-docstrings", "mypy", "types-PyYAML", "codespell", "flake8-bugbear", "flake8-comprehensions"]

all = [
"prophet",
Expand All @@ -128,7 +129,7 @@ all-dev = [
"click", "semver",
"Sphinx", "numpydoc", "sphinx-rtd-theme", "nbsphinx", "sphinx-mathjax-offline", "myst-parser", "GitPython",
"pytest-cov", "coverage", "pytest",
"black", "isort", "flake8", "pep8-naming", "flake8-docstrings", "mypy", "types-PyYAML", "codespell", "flake8-bugbear",
"black", "isort", "flake8", "pep8-naming", "flake8-docstrings", "mypy", "types-PyYAML", "codespell", "flake8-bugbear", "flake8-comprehensions",
"click", "semver",
"jupyter", "nbconvert"
]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_analysis/test_eda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_cross_corr_with_full_nans(a, b, normed, expected_result):
pd.date_range(start="2020-01-03", periods=40, freq="D"),
"month",
["2020-Jan"] * 29 + ["2020-Feb"] * 11,
[i for i in range(3, 32)] + [i for i in range(1, 12)],
list(range(3, 32)) + list(range(1, 12)),
[str(i) for i in range(3, 32)] + [str(i) for i in range(1, 12)],
),
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_get_anomalies_prediction_interval_interface(outliers_tsds, model, in_co
"""Test that `get_anomalies_prediction_interval` produces correct columns."""
anomalies = get_anomalies_prediction_interval(outliers_tsds, model=model, interval_width=0.95, in_column=in_column)
assert isinstance(anomalies, dict)
assert sorted(list(anomalies.keys())) == sorted(outliers_tsds.segments)
assert sorted(anomalies.keys()) == sorted(outliers_tsds.segments)
for segment in anomalies.keys():
assert isinstance(anomalies[segment], list)
for date in anomalies[segment]:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_analysis/test_outliers/test_density_outliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def simple_window() -> np.array:

def test_const_ts(const_ts_anomal):
anomal = get_anomalies_density(const_ts_anomal)
assert set(["segment_0", "segment_1"]) == set(anomal.keys())
assert {"segment_0", "segment_1"} == set(anomal.keys())
for seg in anomal.keys():
assert len(anomal[seg]) == 0

Expand Down
4 changes: 2 additions & 2 deletions tests/test_analysis/test_outliers/test_median_outliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def test_const_ts(const_ts_anomal):
anomal = get_anomalies_median(const_ts_anomal)
assert set(["segment_0", "segment_1"]) == set(anomal.keys())
assert {"segment_0", "segment_1"} == set(anomal.keys())
for seg in anomal.keys():
assert len(anomal[seg]) == 0

Expand Down Expand Up @@ -34,7 +34,7 @@ def test_median_outliers(window_size, alpha, right_anomal, outliers_tsds):
def test_interface_correct_args(true_params, outliers_tsds):
d = get_anomalies_median(ts=outliers_tsds, window_size=10, alpha=2)
assert isinstance(d, dict)
assert sorted(list(d.keys())) == sorted(true_params)
assert sorted(d.keys()) == sorted(true_params)
for i in d.keys():
for j in d[i]:
assert isinstance(j, np.datetime64)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,14 +317,14 @@ def test_get_fold_info_interface_daily(catboost_pipeline: Pipeline, big_daily_ex
"""Check that Pipeline.backtest returns info dataframe in correct format."""
_, _, info_df = catboost_pipeline.backtest(ts=big_daily_example_tsdf, metrics=DEFAULT_METRICS)
expected_columns = ["fold_number", "test_end_time", "test_start_time", "train_end_time", "train_start_time"]
assert expected_columns == list(sorted(info_df.columns))
assert expected_columns == sorted(info_df.columns)


def test_get_fold_info_interface_hours(catboost_pipeline: Pipeline, example_tsdf: TSDataset):
"""Check that Pipeline.backtest returns info dataframe in correct format with non-daily seasonality."""
_, _, info_df = catboost_pipeline.backtest(ts=example_tsdf, metrics=DEFAULT_METRICS)
expected_columns = ["fold_number", "test_end_time", "test_start_time", "train_end_time", "train_start_time"]
assert expected_columns == list(sorted(info_df.columns))
assert expected_columns == sorted(info_df.columns)


@pytest.mark.long
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_naming_ohe_encoder(two_df_with_new_values):
ohe.fit(df1)
segments = ["segment_0", "segment_1"]
target = ["target", "targets_0", "targets_1", "targets_2", "regressor_0"]
assert set([(i, j) for i in segments for j in target]) == set(ohe.transform(df2).columns.values)
assert {(i, j) for i in segments for j in target} == set(ohe.transform(df2).columns.values)


@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,4 +268,4 @@ def test_mrmr_right_regressors(relevance_table, ts_with_regressors):
for column in df_selected.columns.get_level_values("feature"):
if column.startswith("regressor"):
selected_regressors.add(column)
assert set(selected_regressors) == set(["regressor_useful_0", "regressor_useful_1", "regressor_useful_2"])
assert set(selected_regressors) == {"regressor_useful_0", "regressor_useful_1", "regressor_useful_2"}
2 changes: 1 addition & 1 deletion tests/test_transforms/test_missing_values/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def date_range(request) -> pd.DatetimeIndex:
def all_date_present_df(date_range: pd.Series) -> pd.DataFrame:
"""Create pd.DataFrame that contains some target on given range of dates without gaps."""
df = pd.DataFrame({"timestamp": date_range})
df["target"] = [i for i in range(len(df))]
df["target"] = list(range(len(df)))
df.set_index("timestamp", inplace=True)
return df

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_interface_correct_args_out_column(true_params: List[str], train_df: pd.
true_params = [f"{out_column}_{param}" for param in true_params]
for seg in result.columns.get_level_values(0).unique():
tmp_df = result[seg]
assert sorted(list(tmp_df.columns)) == sorted(true_params + ["target"])
assert sorted(tmp_df.columns) == sorted(true_params + ["target"])
for param in true_params:
assert tmp_df[param].dtype == "category"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_interface_out_column(true_params: List[str], train_df: pd.DataFrame):
true_params = [f"{out_column}_{param}" for param in true_params]
for seg in result.columns.get_level_values(0).unique():
tmp_df = result[seg]
assert sorted(list(tmp_df.columns)) == sorted(true_params + ["target"])
assert sorted(tmp_df.columns) == sorted(true_params + ["target"])
for param in true_params:
assert tmp_df[param].dtype == "category"

Expand Down

1 comment on commit dd3175c

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.