diff --git a/swanlab/cli/commands/launcher/__init__.py b/swanlab/cli/commands/launcher/__init__.py index e0972116..f648bc6a 100644 --- a/swanlab/cli/commands/launcher/__init__.py +++ b/swanlab/cli/commands/launcher/__init__.py @@ -34,7 +34,7 @@ def launch(file: str, dry_run: bool): Launch a task """ file = os.path.abspath(file) - config = yaml.safe_load(open(file, 'r')) + config = yaml.safe_load(open(file, 'r', encoding='utf-8')) if not isinstance(config, dict): raise click.FileError(file, hint='Invalid configuration file') p = parse(config, file) diff --git a/swanlab/cli/commands/launcher/parser/v1/folder.py b/swanlab/cli/commands/launcher/parser/v1/folder.py index 5b623a40..13153286 100644 --- a/swanlab/cli/commands/launcher/parser/v1/folder.py +++ b/swanlab/cli/commands/launcher/parser/v1/folder.py @@ -8,7 +8,9 @@ 文件夹上传模型 """ from typing import List, Tuple +import click from ..model import LaunchParser +from swanlab.error import ApiError from swanlab.cli.utils import login_init_sid, UseTaskHttp, CosUploader, UploadBytesIO import zipfile from rich.progress import ( @@ -97,22 +99,31 @@ def parse_spec(self, spec: dict): self.spec['volumes'] = volumes self.spec['exclude'] = exclude - def walk(self) -> Tuple[List[str], List[str]]: + def walk(self, path: str = None) -> Tuple[List[str], List[str]]: """ 遍历path,生成文件列表,注意排除exclude中的文件 + 此函数为递归调用函数 + 返回所有命中的文件列表和排除的文件列表 """ - files = glob.glob(os.path.join(self.dirpath, '**/*'), recursive=True) + path = path or self.dirpath + all_files = glob.glob(os.path.join(path, '**')) exclude_files = [] - split_len = len(self.dirpath) - - def match(f, fs): - return any([f[split_len:] == fs[i][split_len:] for i in range(len(fs))]) - for g in self.spec['exclude']: - efs = glob.glob(os.path.join(self.dirpath, g), recursive=True) - files = [f for f in files if not match(f, efs)] + efs = glob.glob(os.path.join(path, g)) exclude_files.extend(efs) exclude_files = list(set(exclude_files)) + files = [] + for f in all_files: + if os.path.isdir(f): + if f in exclude_files: + continue + fs, efs = self.walk(f) + files.extend(fs) + exclude_files.extend(efs) + else: + if f in exclude_files: + continue + files.append(f) return files, exclude_files def zip(self, files: List[str]) -> io.BytesIO: @@ -153,15 +164,25 @@ def upload(self, memory_file: io.BytesIO): ) def run(self): + # 剔除、压缩、上传、发布任务 + files, _ = self.walk() + if len(files) == 0: + raise click.BadParameter(self.dirpath + " is empty") login_info = login_init_sid() print(FONT.swanlab("Login successfully. Hi, " + FONT.bold(FONT.default(login_info.username))) + "!") self.api_key = login_info.api_key - # 剔除、压缩、上传、发布任务 - files, _ = self.walk() memory_file = self.zip(files) self.upload(memory_file) with UseTaskHttp() as http: - http.post("/task", data=self.__dict__()) + try: + http.post("/task", data=self.__dict__()) + except ApiError as e: + if e.resp.status_code not in [404, 401]: + raise e + elif e.resp.status_code == 404: + raise click.BadParameter("The dataset does not exist") + else: + raise click.BadParameter("The combo does not exist") def dry_run(self): # 剔除、显示即将发布的任务的相关信息 diff --git a/swanlab/data/formater.py b/swanlab/data/formater.py index 07e5280f..a9304a14 100644 --- a/swanlab/data/formater.py +++ b/swanlab/data/formater.py @@ -42,9 +42,7 @@ def check_load_json_yaml(file_path: str, param_name): if not file_path.endswith((".json", ".yaml", ".yml")): raise ValueError( "{} must be a json or yaml file ('.json', '.yaml', '.yml'), " - "but got {}, please check if the content of config_file is correct.".format( - param_name, path_suffix - ) + "but got {}, please check if the content of config_file is correct.".format(param_name, path_suffix) ) # 转换为绝对路径 file_path = os.path.abspath(file_path) @@ -57,11 +55,9 @@ def check_load_json_yaml(file_path: str, param_name): raise ValueError("{} is empty, please check if the content of config_file is correct.".format(param_name)) # 无权限读取 if not os.access(file_path, os.R_OK): - raise PermissionError( - "No permission to read {}, please check if you have the permission.".format(param_name) - ) + raise PermissionError("No permission to read {}, please check if you have the permission.".format(param_name)) load = json.load if path_suffix == "json" else yaml.safe_load - with open(file_path, "r") as f: + with open(file_path, "r", encoding='utf-8') as f: # 读取配置文件的内容 file_data = load(f) # 如果读取的内容不是字典类型,则报错