Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --tpu-topology flag for specifying custom topology types #57

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,10 @@ workload.

# More advanced facts:

* Workload create accepts a --env-file flag to allow specifying the container's
* `xpk cluster create` accepts a `--tpu-topology` flag to allow for custom tpu topologies.
See https://cloud.google.com/kubernetes-engine/docs/concepts/tpus#topology for more details.

* `xpk workload create` accepts a `--env-file`` flag to allow specifying the container's
environment from a file. Usage is the same as Docker's
[--env-file flag](https://docs.docker.com/engine/reference/commandline/run/#env)

Expand All @@ -383,7 +386,7 @@ environment from a file. Usage is the same as Docker's
MY_ENV_VAR=hello
```

* Workload create accepts a --debug-dump-gcs flag which is a path to GCS bucket.
* `xpk workload create` accepts a --debug-dump-gcs flag which is a path to GCS bucket.
Passing this flag sets the XLA_FLAGS='--xla_dump_to=/tmp/xla_dump/' and uploads
hlo dumps to the specified GCS bucket for each worker.

Expand Down
57 changes: 56 additions & 1 deletion xpk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,7 +1319,11 @@ def run_gke_node_pool_create_command(args, system) -> int:
f' {args.custom_tpu_nodepool_arguments}'
)
if system.accelerator_type == AcceleratorType['TPU']:
command += (f' --tpu-topology={system.topology}')
tpu_topology, return_code = get_tpu_topology(system, args)
if return_code > 0:
xpk_print('Parsing tpu_topology failed!')
return return_code
command += (f' --tpu-topology={tpu_topology}')
elif system.accelerator_type == AcceleratorType['GPU']:
command += f' --accelerator type={system.gke_accelerator},count={str(system.chips_per_vm)}'
task = f'NodepoolCreate-{node_pool_name}'
Expand Down Expand Up @@ -1557,6 +1561,35 @@ def default_subcommand_function(_args) -> int: # args is unused, so pylint: dis
return 0


def get_tpu_topology(system: SystemCharacteristics, args) -> tuple[str, int]:
"""Function around parsing tpu-topology argument and obtaining the tpu-topology.

Args:
system: SystemCharacteristics for the device type.
args: User provided arguments for running commands.

Returns:
A tuple of:
str: topology type to use
int: 0 if successful and 1 otherwise.
"""
if system.accelerator_type != AcceleratorType['TPU']:
xpk_print(
'tpu_topology argument is only supported when the AcceleratorType is TPU.'
f' The AcceleratorType you are using is: {system.accelerator_type}'
)
return None, 1
tpu_topology = system.topology
if args.tpu_topology is not None:
tpu_topology = args.tpu_topology
xpk_print(
f'Using custom tpu topology of {args.tpu_topology} for {system.device_type}'
' in node pool creation.'
)

return tpu_topology, 0


def cluster_create(args) -> int:
"""Function around cluster creation.

Expand Down Expand Up @@ -2519,6 +2552,18 @@ def directory_path_type(value):
return value


def tpu_topology_type(value, pat=re.compile(r'^[\d]+(x[\d.*]+){1,2}$')):
match = pat.fullmatch(value)
if not match:
raise argparse.ArgumentTypeError(
f'Custom TPU Topology must match the pattern `{pat.pattern}` such as 1x2x3'
f' or 10x10. TPU Topology set through `--tpu-topology` is currently {value}.'
' See https://cloud.google.com/kubernetes-engine/docs/concepts/tpus#topology'
' for more details.'
)
return value


#### "cluster" command parser. ####
cluster_parser = xpk_subcommands.add_parser(
'cluster',
Expand Down Expand Up @@ -2605,6 +2650,16 @@ def directory_path_type(value):
)

### Optional Arguments
cluster_create_optional_arguments.add_argument(
'--tpu-topology',
type=tpu_topology_type,
default=None,
help=(
'The slice topology to create the TPU slice with. This only supports TPUs.'
'By default, tpu node pool creation will use the tpu-topology defined in'
' the SystemCharacteristics within xpk code.'
)
)
cluster_create_optional_arguments.add_argument(
'--host-maintenance-interval',
type=str,
Expand Down