Skip to content

Commit

Permalink
remove finetuning models limitation. (#573)
Browse files Browse the repository at this point in the history
* remove finetuning models limitation.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add ut.

* update ut and add dashboard.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update ut port.

* update finetuning params for customization.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change name.

---------

Co-authored-by: root <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 4, 2024
1 parent 445c9b1 commit a924579
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 150 deletions.
5 changes: 3 additions & 2 deletions comps/finetuning/docker/Dockerfile_cpu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ WORKDIR /home/user/comps/finetuning
RUN echo PKGPATH=$(python3 -c "import pkg_resources; print(pkg_resources.get_distribution('oneccl-bind-pt').location)") >> run.sh && \
echo 'export LD_LIBRARY_PATH=$PKGPATH/oneccl_bindings_for_pytorch/opt/mpi/lib/:$LD_LIBRARY_PATH' >> run.sh && \
echo 'source $PKGPATH/oneccl_bindings_for_pytorch/env/setvars.sh' >> run.sh && \
echo ray start --head >> run.sh && \
echo ray start --head --dashboard-host=0.0.0.0 >> run.sh && \
echo export RAY_ADDRESS=http://localhost:8265 >> run.sh && \
echo python finetuning_service.py >> run.sh

CMD bash run.sh
CMD bash run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from pydantic import BaseModel, validator

from comps.cores.proto.api_protocol import FineTuningJobsRequest

PRECISION_BF16 = "bf16"
PRECISION_FP16 = "fp16"
PRECISION_NO = "no"
Expand All @@ -20,30 +22,31 @@
ACCELERATE_STRATEGY_DEEPSPEED = "DEEPSPEED"


class GeneralConfig(BaseModel):
trust_remote_code: bool
use_auth_token: Optional[str]
class LoadConfig(BaseModel):
trust_remote_code: bool = False
# set Huggingface token to access dataset/model
token: Optional[str] = None


class LoraConfig(BaseModel):
task_type: str
r: int
lora_alpha: int
lora_dropout: float
task_type: str = "CAUSAL_LM"
r: int = 8
lora_alpha: int = 32
lora_dropout: float = 0.1
target_modules: Optional[List[str]] = None


class General(BaseModel):
base_model: str
class GeneralConfig(BaseModel):
base_model: str = None
tokenizer_name: Optional[str] = None
gaudi_config_name: Optional[str] = None
gpt_base_model: bool
output_dir: str
gpt_base_model: bool = False
output_dir: str = "./tmp"
report_to: str = "none"
resume_from_checkpoint: Optional[str] = None
save_strategy: str = "no"
config: GeneralConfig
lora_config: Optional[LoraConfig] = None
config: LoadConfig = LoadConfig()
lora_config: Optional[LoraConfig] = LoraConfig()
enable_gradient_checkpointing: bool = False

@validator("report_to")
Expand All @@ -52,10 +55,10 @@ def check_report_to(cls, v: str):
return v


class Dataset(BaseModel):
train_file: str
validation_file: Optional[str]
validation_split_percentage: int
class DatasetConfig(BaseModel):
train_file: str = None
validation_file: Optional[str] = None
validation_split_percentage: int = 5
max_length: int = 512
group: bool = True
block_size: int = 512
Expand All @@ -74,23 +77,23 @@ class Dataset(BaseModel):


class RayResourceConfig(BaseModel):
CPU: int
CPU: int = 32
GPU: int = 0
HPU: int = 0


class Training(BaseModel):
optimizer: str
batch_size: int
epochs: int
class TrainingConfig(BaseModel):
optimizer: str = "adamw_torch"
batch_size: int = 2
epochs: int = 1
max_train_steps: Optional[int] = None
learning_rate: float
lr_scheduler: str
weight_decay: float
learning_rate: float = 5.0e-5
lr_scheduler: str = "linear"
weight_decay: float = 0.0
device: str = DEVICE_CPU
hpu_execution_mode: str = "lazy"
num_training_workers: int
resources_per_worker: RayResourceConfig
num_training_workers: int = 1
resources_per_worker: RayResourceConfig = RayResourceConfig()
accelerate_mode: str = ACCELERATE_STRATEGY_DDP
mixed_precision: str = PRECISION_NO
gradient_accumulation_steps: int = 1
Expand Down Expand Up @@ -151,6 +154,13 @@ def check_logging_steps(cls, v: int):


class FinetuneConfig(BaseModel):
General: General
Dataset: Dataset
Training: Training
General: GeneralConfig = GeneralConfig()
Dataset: DatasetConfig = DatasetConfig()
Training: TrainingConfig = TrainingConfig()


class FineTuningParams(FineTuningJobsRequest):
# priority use FineTuningJobsRequest params
General: GeneralConfig = GeneralConfig()
Dataset: DatasetConfig = DatasetConfig()
Training: TrainingConfig = TrainingConfig()
2 changes: 1 addition & 1 deletion comps/finetuning/finetune_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic_yaml import parse_yaml_raw_as
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments

from comps.finetuning.llm_on_ray.finetune.finetune_config import FinetuneConfig
from comps.finetuning.finetune_config import FinetuneConfig


class FineTuneCallback(TrainerCallback):
Expand Down
5 changes: 3 additions & 2 deletions comps/finetuning/finetuning_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from fastapi import BackgroundTasks, File, UploadFile

from comps import opea_microservices, register_microservice
from comps.cores.proto.api_protocol import FineTuningJobIDRequest, FineTuningJobsRequest
from comps.cores.proto.api_protocol import FineTuningJobIDRequest
from comps.finetuning.finetune_config import FineTuningParams
from comps.finetuning.handlers import (
DATASET_BASE_PATH,
handle_cancel_finetuning_job,
Expand All @@ -21,7 +22,7 @@


@register_microservice(name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs", host="0.0.0.0", port=8015)
def create_finetuning_jobs(request: FineTuningJobsRequest, background_tasks: BackgroundTasks):
def create_finetuning_jobs(request: FineTuningParams, background_tasks: BackgroundTasks):
return handle_create_finetuning_jobs(request, background_tasks)


Expand Down
41 changes: 16 additions & 25 deletions comps/finetuning/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,21 @@
from ray.job_submission import JobSubmissionClient

from comps import CustomLogger
from comps.cores.proto.api_protocol import (
FineTuningJob,
FineTuningJobIDRequest,
FineTuningJobList,
FineTuningJobsRequest,
)
from comps.finetuning.llm_on_ray.finetune.finetune_config import FinetuneConfig
from comps.cores.proto.api_protocol import FineTuningJob, FineTuningJobIDRequest, FineTuningJobList
from comps.finetuning.finetune_config import FinetuneConfig, FineTuningParams

logger = CustomLogger("finetuning_handlers")

MODEL_CONFIG_FILE_MAP = {
"meta-llama/Llama-2-7b-chat-hf": "./models/llama-2-7b-chat-hf.yaml",
"mistralai/Mistral-7B-v0.1": "./models/mistral-7b-v0.1.yaml",
}

DATASET_BASE_PATH = "datasets"
JOBS_PATH = "jobs"
OUTPUT_DIR = "output"

if not os.path.exists(DATASET_BASE_PATH):
os.mkdir(DATASET_BASE_PATH)

if not os.path.exists(JOBS_PATH):
os.mkdir(JOBS_PATH)
if not os.path.exists(OUTPUT_DIR):
os.mkdir(OUTPUT_DIR)

FineTuningJobID = str
CHECK_JOB_STATUS_INTERVAL = 5 # Check every 5 secs
Expand Down Expand Up @@ -62,23 +55,17 @@ def update_job_status(job_id: FineTuningJobID):
time.sleep(CHECK_JOB_STATUS_INTERVAL)


def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tasks: BackgroundTasks):
def handle_create_finetuning_jobs(request: FineTuningParams, background_tasks: BackgroundTasks):
base_model = request.model
train_file = request.training_file
train_file_path = os.path.join(DATASET_BASE_PATH, train_file)

model_config_file = MODEL_CONFIG_FILE_MAP.get(base_model)
if not model_config_file:
raise HTTPException(status_code=404, detail=f"Base model '{base_model}' not supported!")

if not os.path.exists(train_file_path):
raise HTTPException(status_code=404, detail=f"Training file '{train_file}' not found!")

with open(model_config_file) as f:
finetune_config = parse_yaml_raw_as(FinetuneConfig, f)

finetune_config = FinetuneConfig(General=request.General, Dataset=request.Dataset, Training=request.Training)
finetune_config.General.base_model = base_model
finetune_config.Dataset.train_file = train_file_path

if request.hyperparameters is not None:
if request.hyperparameters.epochs != "auto":
finetune_config.Training.epochs = request.hyperparameters.epochs
Expand All @@ -90,7 +77,7 @@ def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tas
finetune_config.Training.learning_rate = request.hyperparameters.learning_rate_multiplier

if os.getenv("HF_TOKEN", None):
finetune_config.General.config.use_auth_token = os.getenv("HF_TOKEN", None)
finetune_config.General.config.token = os.getenv("HF_TOKEN", None)

job = FineTuningJob(
id=f"ft-job-{uuid.uuid4()}",
Expand All @@ -105,12 +92,16 @@ def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tas
status="running",
seed=random.randint(0, 1000) if request.seed is None else request.seed,
)
finetune_config.General.output_dir = os.path.join(JOBS_PATH, job.id)
finetune_config.General.output_dir = os.path.join(OUTPUT_DIR, job.id)
if os.getenv("DEVICE", ""):

logger.info(f"specific device: {os.getenv('DEVICE')}")

finetune_config.Training.device = os.getenv("DEVICE")
if finetune_config.Training.device == "hpu":
if finetune_config.Training.resources_per_worker.HPU == 0:
# set 1
finetune_config.Training.resources_per_worker.HPU = 1

finetune_config_file = f"{JOBS_PATH}/{job.id}.yaml"
to_yaml_file(finetune_config_file, finetune_config)
Expand All @@ -122,7 +113,7 @@ def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tas
# Entrypoint shell command to execute
entrypoint=f"python finetune_runner.py --config_file {finetune_config_file}",
# Path to the local directory that contains the script.py file
runtime_env={"working_dir": "./"},
runtime_env={"working_dir": "./", "excludes": [f"{OUTPUT_DIR}"]},
)

logger.info(f"Submitted Ray job: {ray_job_id} ...")
Expand Down
6 changes: 3 additions & 3 deletions comps/finetuning/launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
# SPDX-License-Identifier: Apache-2.0

if [[ -n "$RAY_PORT" ]];then
ray start --head --port $RAY_PORT
ray start --head --port $RAY_PORT --dashboard-host=0.0.0.0
else
ray start --head
ray start --head --dashboard-host=0.0.0.0
export RAY_PORT=8265
fi

export RAY_ADDRESS=http://127.0.0.1:$RAY_PORT
export RAY_ADDRESS=http://localhost:$RAY_PORT
python finetuning_service.py
6 changes: 3 additions & 3 deletions comps/finetuning/llm_on_ray/finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from ray.train.torch import TorchTrainer

from comps import CustomLogger
from comps.finetuning.finetune_config import FinetuneConfig
from comps.finetuning.llm_on_ray import common
from comps.finetuning.llm_on_ray.finetune.data_process import DataProcessor
from comps.finetuning.llm_on_ray.finetune.finetune_config import FinetuneConfig

logger = CustomLogger("llm_on_ray/finetune")

Expand Down Expand Up @@ -171,8 +171,8 @@ def local_load(name, **load_config):
else:
# try to download and load dataset from huggingface.co
load_config = config["General"].get("config", {})
use_auth_token = load_config.get("use_auth_token", None)
raw_dataset = datasets.load_dataset(dataset_file, use_auth_token=use_auth_token)
use_auth_token = load_config.get("token", None)
raw_dataset = datasets.load_dataset(dataset_file, token=use_auth_token)

validation_split_percentage = config["Dataset"].get("validation_split_percentage", 0)
if "validation" not in raw_dataset.keys() and (
Expand Down
39 changes: 0 additions & 39 deletions comps/finetuning/models/llama-2-7b-chat-hf.yaml

This file was deleted.

45 changes: 0 additions & 45 deletions comps/finetuning/models/mistral-7b-v0.1.yaml

This file was deleted.

Loading

0 comments on commit a924579

Please sign in to comment.