Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Training image needed for train api #1963

Merged
merged 14 commits into from
Jan 11, 2024
14 changes: 11 additions & 3 deletions .github/workflows/publish-core-images.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ jobs:
uses: ./.github/workflows/build-and-publish-images.yaml
with:
component-name: ${{ matrix.component-name }}
platforms: linux/amd64,linux/arm64,linux/ppc64le
platforms: ${{ matrix.platforms }}
dockerfile: ${{ matrix.dockerfile }}
context: ${{ matrix.context }}
secrets:
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
Expand All @@ -22,8 +23,15 @@ jobs:
include:
- component-name: training-operator
dockerfile: build/images/training-operator/Dockerfile
platforms: linux/amd64,linux/arm64,linux/ppc64le
- component-name: kubectl-delivery
dockerfile: build/images/kubectl-delivery/Dockerfile
platforms: linux/amd64,linux/arm64,linux/ppc64le
- component-name: storage-initializer
dockerfile: sdk/python/kubeflow/storage_initializer/Dockerfile
context: sdk/python/kubeflow/storage_initializer
dockerfile: sdk/python/kubeflow/storage_initializer/Dockerfile
deepanker13 marked this conversation as resolved.
Show resolved Hide resolved
context: sdk/python/kubeflow/storage_initializer
platforms: linux/amd64,linux/arm64
- component-name: trainer-huggingface
dockerfile: sdk/python/kubeflow/trainer/hf_dockerfile
deepanker13 marked this conversation as resolved.
Show resolved Hide resolved
context: sdk/python/kubeflow/trainer
platforms: linux/amd64,linux/arm64
1 change: 0 additions & 1 deletion .github/workflows/publish-example-images.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ jobs:
- component-name: mxnet-auto-tuning
dockerfile: examples/mxnet/tune/Dockerfile
context: examples/mxnet/tune

# TODO (tenzen-y): Fix the below broken Dockerfiles
# - component-name: pytorch-dist-mnist-mpi
# dockerfile: examples/pytorch/mnist/Dockerfile-mpi
Expand Down
57 changes: 57 additions & 0 deletions examples/sdk/train_api.py
deepanker13 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from kubeflow.training.api.training_client import TrainingClient
from kubeflow.storage_initializer.hugging_face import (
HuggingFaceModelParams,
HuggingFaceTrainParams,
HfDatasetParams,
)
from peft import LoraConfig
import transformers
from transformers import TrainingArguments

client = TrainingClient()

client.train(
name="hf-test",
num_workers=2,
num_procs_per_worker=0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this value is 0 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for cpu only training

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, but can torchrun be used with CPUs ?
E.g. maybe I want to run torchrun --nproc-per-node=2 where I use 2 CPU per node.
cc @johnugeorge

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It can run on cpus.

model_provider_parameters=HuggingFaceModelParams(
model_uri="hf://Jedalc/codeparrot-gp2-finetune",
transformer_type=transformers.AutoModelForCausalLM,
),
dataset_provider_parameters=HfDatasetParams(repo_id="imdatta0/ultrachat_10k"),
train_parameters=HuggingFaceTrainParams(
lora_config=LoraConfig(
r=8,
lora_alpha=8,
target_modules=["c_attn", "c_proj", "w1", "w2"],
layers_to_transform=list(range(30, 40)),
# layers_pattern=['lm_head'],
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
),
training_parameters=TrainingArguments(
num_train_epochs=2,
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
gradient_checkpointing=True,
warmup_steps=0.01,
# max_steps=50, #20,
learning_rate=1,
lr_scheduler_type="cosine",
bf16=False,
logging_steps=0.01,
output_dir="",
optim=f"paged_adamw_32bit",
save_steps=0.01,
save_total_limit=3,
disable_tqdm=False,
resume_from_checkpoint=True,
remove_unused_columns=True,
evaluation_strategy="steps",
eval_steps=0.01,
per_device_eval_batch_size=1,
),
),
resources_per_worker={"gpu": 0, "cpu": 8, "memory": "8Gi"},
)
18 changes: 18 additions & 0 deletions sdk/python/kubeflow/trainer/hf_dockerfile
deepanker13 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Use an official Pytorch runtime as a parent image
FROM nvcr.io/nvidia/pytorch:23.12-py3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to use PyTorch image from NVIDIA for this trainer ?
Would it be better to take official PyTorch image similar to what we use in SDK ?
docker.io/pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as suggested by @tenzen-y
#1963 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. @tenzen-y Do you know if PyTorch has any official image that we can use that is supported on all platforms ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andreyvelich As I remember correctly, the PyTorch doesn't provide images with multiple architecture platforms with GPU. So, we need to use the NVIDIA official images.


# Set the working directory in the container
WORKDIR /app

# Copy the Python package and its source code into the container
COPY . /app

# Copy the requirements.txt file into the container
COPY requirements.txt /app/requirements.txt

# Install any needed packages specified in requirements.txt
RUN pip install --no-cache-dir -r requirements.txt

# Run storage.py when the container launches
ENTRYPOINT ["torchrun", "hf_llm_training.py"]

118 changes: 118 additions & 0 deletions sdk/python/kubeflow/trainer/hf_llm_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import argparse
import transformers
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoConfig,
TrainingArguments,
DataCollatorForLanguageModeling,
Trainer,
)
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from urllib.parse import urlparse
import os
import json


def setup_model_and_tokenizer(model_uri, transformer_type, model_dir):
# Set up the model and tokenizer

parsed_uri = urlparse(model_uri)
model_name = parsed_uri.netloc + parsed_uri.path
transformer_type_class = getattr(transformers, transformer_type)

model = transformer_type_class.from_pretrained(
pretrained_model_name_or_path=model_name,
cache_dir=model_dir,
local_files_only=True,
device_map="auto",
)

tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=model_name,
cache_dir=model_dir,
local_files_only=True,
device_map="auto",
)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_pad_token = True

# Freeze model parameters
for param in model.parameters():
param.requires_grad = False

return model, tokenizer


def load_and_preprocess_data(dataset_name, dataset_dir):
# Load and preprocess the dataset
print("loading dataset")
dataset = load_dataset(dataset_name, cache_dir=dataset_dir)
train_data = dataset["train"]

try:
eval_data = dataset["eval"]
except Exception as err:
eval_data = None
deepanker13 marked this conversation as resolved.
Show resolved Hide resolved

return train_data, eval_data


def setup_peft_model(model, lora_config):
# Set up the PEFT model
lora_config = LoraConfig(**json.loads(lora_config))
print(lora_config)
model = get_peft_model(model, lora_config)
return model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we are going to have PEFT config always for this trainer ?
@johnugeorge @deepanker13

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loraconfig can be omitted by user, it is handled by setting empty loraconfig as default value in the data class

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, @deepanker13 Should we verify if lora_config is set ?



def train_model(model, train_data, eval_data, tokenizer, train_params):
# Train the model
trainer = Trainer(
model=model,
train_dataset=train_data,
eval_dataset=eval_data,
tokenizer=tokenizer,
args=TrainingArguments(
**train_params,
data_collator=DataCollatorForLanguageModeling(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", mlm=False
deepanker13 marked this conversation as resolved.
Show resolved Hide resolved
)
),
)

trainer.train()
print("training done")


def parse_arguments():
parser = argparse.ArgumentParser(
description="Script for training a model with PEFT configuration."
)

parser.add_argument("--model_uri", help="model uri")
parser.add_argument("--transformer_type", help="model transformer type")
parser.add_argument("--model_dir", help="directory containing model")
parser.add_argument("--dataset_dir", help="directory contaning dataset")
parser.add_argument("--dataset_name", help="dataset name")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We add dataset_name argument for users who want to use this Trainer without SDK client ?
I am asking because in SDK client we always download dataset in storage initializer and store it in Trainer volume.
So we don't need to provide name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the same dataset_dir there can be multiple datasets, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But can we use train API to download more than one dataset ?
E.g. in your example, you just download ultrachat_10k dataset.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, if I run with a different datasetname, it will work fine.
@andreyvelich

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but for every API execution you create a new PyTorchJob and a new Trainer image will be spin up.
So dataset is always represent single name, isn't ?

parser.add_argument("--lora_config", help="lora_config")
parser.add_argument(
"--training_parameters", help="hugging face training parameters"
)

return parser.parse_args()


if __name__ == "__main__":
args = parse_arguments()
model, tokenizer = setup_model_and_tokenizer(
args.model_uri, args.transformer_type, args.model_dir
)
train_data, eval_data = load_and_preprocess_data(
args.dataset_name, args.dataset_dir
)
model = setup_peft_model(model, args.lora_config)
train_model(model, train_data, eval_data, tokenizer, args.training_parameters)
3 changes: 3 additions & 0 deletions sdk/python/kubeflow/trainer/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
peft==0.7.0
datasets==2.15.0
transformers==4.35.2
Loading