Skip to content

Commit

Permalink
fix:encoding (#687)
Browse files Browse the repository at this point in the history
* fix:encoding

* fix:exclude

* fix: utf8

* fix:error-code

---------

Co-authored-by: Zirui Cai <[email protected]>
  • Loading branch information
SAKURA-CAT and Feudalman committed Aug 30, 2024
1 parent ce7e5a3 commit f02755b
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 20 deletions.
2 changes: 1 addition & 1 deletion swanlab/cli/commands/launcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def launch(file: str, dry_run: bool):
Launch a task
"""
file = os.path.abspath(file)
config = yaml.safe_load(open(file, 'r'))
config = yaml.safe_load(open(file, 'r', encoding='utf-8'))
if not isinstance(config, dict):
raise click.FileError(file, hint='Invalid configuration file')
p = parse(config, file)
Expand Down
45 changes: 33 additions & 12 deletions swanlab/cli/commands/launcher/parser/v1/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
文件夹上传模型
"""
from typing import List, Tuple
import click
from ..model import LaunchParser
from swanlab.error import ApiError
from swanlab.cli.utils import login_init_sid, UseTaskHttp, CosUploader, UploadBytesIO
import zipfile
from rich.progress import (
Expand Down Expand Up @@ -97,22 +99,31 @@ def parse_spec(self, spec: dict):
self.spec['volumes'] = volumes
self.spec['exclude'] = exclude

def walk(self) -> Tuple[List[str], List[str]]:
def walk(self, path: str = None) -> Tuple[List[str], List[str]]:
"""
遍历path,生成文件列表,注意排除exclude中的文件
此函数为递归调用函数
返回所有命中的文件列表和排除的文件列表
"""
files = glob.glob(os.path.join(self.dirpath, '**/*'), recursive=True)
path = path or self.dirpath
all_files = glob.glob(os.path.join(path, '**'))
exclude_files = []
split_len = len(self.dirpath)

def match(f, fs):
return any([f[split_len:] == fs[i][split_len:] for i in range(len(fs))])

for g in self.spec['exclude']:
efs = glob.glob(os.path.join(self.dirpath, g), recursive=True)
files = [f for f in files if not match(f, efs)]
efs = glob.glob(os.path.join(path, g))
exclude_files.extend(efs)
exclude_files = list(set(exclude_files))
files = []
for f in all_files:
if os.path.isdir(f):
if f in exclude_files:
continue
fs, efs = self.walk(f)
files.extend(fs)
exclude_files.extend(efs)
else:
if f in exclude_files:
continue
files.append(f)
return files, exclude_files

def zip(self, files: List[str]) -> io.BytesIO:
Expand Down Expand Up @@ -153,15 +164,25 @@ def upload(self, memory_file: io.BytesIO):
)

def run(self):
# 剔除、压缩、上传、发布任务
files, _ = self.walk()
if len(files) == 0:
raise click.BadParameter(self.dirpath + " is empty")
login_info = login_init_sid()
print(FONT.swanlab("Login successfully. Hi, " + FONT.bold(FONT.default(login_info.username))) + "!")
self.api_key = login_info.api_key
# 剔除、压缩、上传、发布任务
files, _ = self.walk()
memory_file = self.zip(files)
self.upload(memory_file)
with UseTaskHttp() as http:
http.post("/task", data=self.__dict__())
try:
http.post("/task", data=self.__dict__())
except ApiError as e:
if e.resp.status_code not in [404, 401]:
raise e
elif e.resp.status_code == 404:
raise click.BadParameter("The dataset does not exist")
else:
raise click.BadParameter("The combo does not exist")

def dry_run(self):
# 剔除、显示即将发布的任务的相关信息
Expand Down
10 changes: 3 additions & 7 deletions swanlab/data/formater.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def check_load_json_yaml(file_path: str, param_name):
if not file_path.endswith((".json", ".yaml", ".yml")):
raise ValueError(
"{} must be a json or yaml file ('.json', '.yaml', '.yml'), "
"but got {}, please check if the content of config_file is correct.".format(
param_name, path_suffix
)
"but got {}, please check if the content of config_file is correct.".format(param_name, path_suffix)
)
# 转换为绝对路径
file_path = os.path.abspath(file_path)
Expand All @@ -57,11 +55,9 @@ def check_load_json_yaml(file_path: str, param_name):
raise ValueError("{} is empty, please check if the content of config_file is correct.".format(param_name))
# 无权限读取
if not os.access(file_path, os.R_OK):
raise PermissionError(
"No permission to read {}, please check if you have the permission.".format(param_name)
)
raise PermissionError("No permission to read {}, please check if you have the permission.".format(param_name))
load = json.load if path_suffix == "json" else yaml.safe_load
with open(file_path, "r") as f:
with open(file_path, "r", encoding='utf-8') as f:
# 读取配置文件的内容
file_data = load(f)
# 如果读取的内容不是字典类型,则报错
Expand Down

0 comments on commit f02755b

Please sign in to comment.