Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix:encoding #687

Merged
merged 4 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion swanlab/cli/commands/launcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def launch(file: str, dry_run: bool):
Launch a task
"""
file = os.path.abspath(file)
config = yaml.safe_load(open(file, 'r'))
config = yaml.safe_load(open(file, 'r', encoding='utf-8'))
if not isinstance(config, dict):
raise click.FileError(file, hint='Invalid configuration file')
p = parse(config, file)
Expand Down
45 changes: 33 additions & 12 deletions swanlab/cli/commands/launcher/parser/v1/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
文件夹上传模型
"""
from typing import List, Tuple
import click
from ..model import LaunchParser
from swanlab.error import ApiError
from swanlab.cli.utils import login_init_sid, UseTaskHttp, CosUploader, UploadBytesIO
import zipfile
from rich.progress import (
Expand Down Expand Up @@ -97,22 +99,31 @@ def parse_spec(self, spec: dict):
self.spec['volumes'] = volumes
self.spec['exclude'] = exclude

def walk(self) -> Tuple[List[str], List[str]]:
def walk(self, path: str = None) -> Tuple[List[str], List[str]]:
"""
遍历path,生成文件列表,注意排除exclude中的文件
此函数为递归调用函数
返回所有命中的文件列表和排除的文件列表
"""
files = glob.glob(os.path.join(self.dirpath, '**/*'), recursive=True)
path = path or self.dirpath
all_files = glob.glob(os.path.join(path, '**'))
exclude_files = []
split_len = len(self.dirpath)

def match(f, fs):
return any([f[split_len:] == fs[i][split_len:] for i in range(len(fs))])

for g in self.spec['exclude']:
efs = glob.glob(os.path.join(self.dirpath, g), recursive=True)
files = [f for f in files if not match(f, efs)]
efs = glob.glob(os.path.join(path, g))
exclude_files.extend(efs)
exclude_files = list(set(exclude_files))
files = []
for f in all_files:
if os.path.isdir(f):
if f in exclude_files:
continue
fs, efs = self.walk(f)
files.extend(fs)
exclude_files.extend(efs)
else:
if f in exclude_files:
continue
files.append(f)
return files, exclude_files

def zip(self, files: List[str]) -> io.BytesIO:
Expand Down Expand Up @@ -153,15 +164,25 @@ def upload(self, memory_file: io.BytesIO):
)

def run(self):
# 剔除、压缩、上传、发布任务
files, _ = self.walk()
if len(files) == 0:
raise click.BadParameter(self.dirpath + " is empty")
login_info = login_init_sid()
print(FONT.swanlab("Login successfully. Hi, " + FONT.bold(FONT.default(login_info.username))) + "!")
self.api_key = login_info.api_key
# 剔除、压缩、上传、发布任务
files, _ = self.walk()
memory_file = self.zip(files)
self.upload(memory_file)
with UseTaskHttp() as http:
http.post("/task", data=self.__dict__())
try:
http.post("/task", data=self.__dict__())
except ApiError as e:
if e.resp.status_code not in [404, 401]:
raise e
elif e.resp.status_code == 404:
raise click.BadParameter("The dataset does not exist")
else:
raise click.BadParameter("The combo does not exist")

def dry_run(self):
# 剔除、显示即将发布的任务的相关信息
Expand Down
10 changes: 3 additions & 7 deletions swanlab/data/formater.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def check_load_json_yaml(file_path: str, param_name):
if not file_path.endswith((".json", ".yaml", ".yml")):
raise ValueError(
"{} must be a json or yaml file ('.json', '.yaml', '.yml'), "
"but got {}, please check if the content of config_file is correct.".format(
param_name, path_suffix
)
"but got {}, please check if the content of config_file is correct.".format(param_name, path_suffix)
)
# 转换为绝对路径
file_path = os.path.abspath(file_path)
Expand All @@ -57,11 +55,9 @@ def check_load_json_yaml(file_path: str, param_name):
raise ValueError("{} is empty, please check if the content of config_file is correct.".format(param_name))
# 无权限读取
if not os.access(file_path, os.R_OK):
raise PermissionError(
"No permission to read {}, please check if you have the permission.".format(param_name)
)
raise PermissionError("No permission to read {}, please check if you have the permission.".format(param_name))
load = json.load if path_suffix == "json" else yaml.safe_load
with open(file_path, "r") as f:
with open(file_path, "r", encoding='utf-8') as f:
# 读取配置文件的内容
file_data = load(f)
# 如果读取的内容不是字典类型,则报错
Expand Down