diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index e8c2823f3793c7..ac875e0570cc67 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -976,12 +976,12 @@ class TrainingArguments: ) }, ) - fsdp_config: Optional[str] = field( + fsdp_config: Optional[Union[str, Dict]] = field( default=None, metadata={ "help": ( - "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a" - "fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`." + "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a" + "fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`." ) }, ) @@ -994,11 +994,11 @@ class TrainingArguments: ) }, ) - deepspeed: Optional[str] = field( + deepspeed: Optional[Union[str, Dict]] = field( default=None, metadata={ "help": ( - "Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json) or an already" + "Enable deepspeed and pass the path to deepspeed json config file (e.g. `ds_config.json`) or an already" " loaded json file as a dict" ) },