Skip to content

Commit

Permalink
fix: 🐛 issue #90
Browse files Browse the repository at this point in the history
Thanks to @yuqinie98 for providing a clue on how to reproduce this
error.
  • Loading branch information
zezhishao committed Jan 25, 2024
1 parent 7b80981 commit 0af86ed
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 4 deletions.
2 changes: 1 addition & 1 deletion basicts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .launcher import launch_training, launch_runner
from .runners import BaseRunner

__version__ = "0.3.10"
__version__ = "0.3.11"

__all__ = ["__version__", "launch_training", "launch_runner", "BaseRunner"]
5 changes: 2 additions & 3 deletions basicts/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

from easytorch.utils.registry import scan_modules

from ..utils.misc import scan_modules
from .registry import SCALER_REGISTRY
from .dataset_zoo.simple_tsf_dataset import TimeSeriesForecastingDataset
from .dataset_zoo.m4_dataset import M4ForecastingDataset
Expand All @@ -10,4 +9,4 @@

# fix bugs on Windows systems and on jupyter
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
scan_modules(project_dir, __file__, ["__init__.py", "registry.py"])
scan_modules(project_dir, __file__, ["__init__.py", "registry.py"], ["dataset_zoo/", ".ipynb_checkpoints/"])
31 changes: 31 additions & 0 deletions basicts/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,38 @@
import os
import time
import importlib
from typing import List
from functools import partial

import torch
from easytorch.utils.misc import scan_dir


def scan_modules(work_dir: str, file_dir: str, exclude_files: List[str] = None, exclude_dirs: List[str] = None):
"""
overwrite easytorch.utils.scan_modeuls: automatically scan and import modules for registry, and exclude some files and dirs.
"""
module_dir = os.path.dirname(os.path.abspath(file_dir))
import_prefix = module_dir[module_dir.find(work_dir) + len(work_dir) + 1:].replace('/', '.').replace('\\', '.')

if exclude_files is None:
exclude_files = []
if exclude_dirs is None:
exclude_dirs = []

# get all file names, and remove the files in exclude_files
model_file_names = [
v[:v.find('.py')].replace('/', '.').replace('\\', '.') \
for v in scan_dir(module_dir, suffix='py', recursive=True) if v not in exclude_files
]

# remove the files in exclude_dirs. TODO: use os.path to check
for exclude_dir in exclude_dirs:
exclude_dir = exclude_dir.replace('/', '.').replace('\\', '.')
model_file_names = [file_name for file_name in model_file_names if exclude_dir not in file_name]

# import all modules
return [importlib.import_module(f'{import_prefix}.{file_name}') for file_name in model_file_names]


class partial_func(partial):
Expand Down

0 comments on commit 0af86ed

Please sign in to comment.