Skip to content

Commit

Permalink
gpu training example fix
Browse files Browse the repository at this point in the history
  • Loading branch information
deepanker13 committed Jan 11, 2024
1 parent f520329 commit 95b3e2b
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions examples/sdk/train_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,12 @@
" ),\n",
" training_parameters=TrainingArguments(\n",
" num_train_epochs=1,\n",
" per_device_train_batch_size=4,\n",
" gradient_accumulation_steps=4,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=1,\n",
" gradient_checkpointing=True,\n",
" gradient_checkpointing_kwargs={\n",
" \"use_reentrant\": False\n",
" }, # this is mandatory if checkpointng is enabled\n",
" warmup_steps=0.02,\n",
" learning_rate=1,\n",
" lr_scheduler_type=\"cosine\",\n",
Expand All @@ -93,7 +96,7 @@
" resources_per_worker={\n",
" \"gpu\": 1,\n",
" \"cpu\": 8,\n",
" \"memory\": \"16Gi\",\n",
" \"memory\": \"8Gi\",\n",
" }, # remove the gpu key if you don't want to attach gpus to the pods\n",
")"
]
Expand Down

0 comments on commit 95b3e2b

Please sign in to comment.