Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
added unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Rammer <[email protected]>
  • Loading branch information
hamersaw committed Sep 8, 2023
1 parent fb4c4af commit 27cff08
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 6 deletions.
23 changes: 21 additions & 2 deletions go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ var (
"test-args",
}

dummyAnnotations = map[string]string{
"annotation-key": "annotation-value",
}
dummyLabels = map[string]string{
"label-key": "label-value",
}

resourceRequirements = &corev1.ResourceRequirements{
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1000m"),
Expand Down Expand Up @@ -150,8 +157,8 @@ func dummyMPITaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskExecut
taskExecutionMetadata := &mocks.TaskExecutionMetadata{}
taskExecutionMetadata.OnGetTaskExecutionID().Return(tID)
taskExecutionMetadata.OnGetNamespace().Return("test-namespace")
taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"annotation-1": "val1"})
taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"})
taskExecutionMetadata.OnGetAnnotations().Return(dummyAnnotations)
taskExecutionMetadata.OnGetLabels().Return(dummyLabels)
taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{
Kind: "node",
Name: "blah",
Expand Down Expand Up @@ -304,6 +311,18 @@ func TestBuildResourceMPI(t *testing.T) {
assert.Equal(t, int32(100), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas)
assert.Equal(t, int32(1), *mpiJob.Spec.SlotsPerWorker)

// verify TaskExecutionMetadata labels and annotations are copied to the MPIJob
for k, v := range dummyAnnotations {
for _, replicaSpec := range mpiJob.Spec.MPIReplicaSpecs {
assert.Equal(t, v, replicaSpec.Template.ObjectMeta.Annotations[k])
}
}
for k, v := range dummyLabels {
for _, replicaSpec := range mpiJob.Spec.MPIReplicaSpecs {
assert.Equal(t, v, replicaSpec.Template.ObjectMeta.Labels[k])
}
}

for _, replicaSpec := range mpiJob.Spec.MPIReplicaSpecs {
for _, container := range replicaSpec.Template.Spec.Containers {
assert.Equal(t, resourceRequirements.Requests, container.Resources.Requests)
Expand Down
35 changes: 33 additions & 2 deletions go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ var (
"test-args",
}

dummyAnnotations = map[string]string{
"annotation-key": "annotation-value",
}
dummyLabels = map[string]string{
"label-key": "label-value",
}

resourceRequirements = &corev1.ResourceRequirements{
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1000m"),
Expand Down Expand Up @@ -170,8 +177,8 @@ func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskEx
taskExecutionMetadata := &mocks.TaskExecutionMetadata{}
taskExecutionMetadata.OnGetTaskExecutionID().Return(tID)
taskExecutionMetadata.OnGetNamespace().Return("test-namespace")
taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"annotation-1": "val1"})
taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"})
taskExecutionMetadata.OnGetAnnotations().Return(dummyAnnotations)
taskExecutionMetadata.OnGetLabels().Return(dummyLabels)
taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{
Kind: "node",
Name: "blah",
Expand Down Expand Up @@ -339,6 +346,18 @@ func TestBuildResourcePytorchElastic(t *testing.T) {
}

assert.True(t, hasContainerWithDefaultPytorchName)

// verify TaskExecutionMetadata labels and annotations are copied to the PyTorchJob
for k, v := range dummyAnnotations {
for _, replicaSpec := range pytorchJob.Spec.PyTorchReplicaSpecs {
assert.Equal(t, v, replicaSpec.Template.ObjectMeta.Annotations[k])
}
}
for k, v := range dummyLabels {
for _, replicaSpec := range pytorchJob.Spec.PyTorchReplicaSpecs {
assert.Equal(t, v, replicaSpec.Template.ObjectMeta.Labels[k])
}
}
}

func TestBuildResourcePytorch(t *testing.T) {
Expand All @@ -356,6 +375,18 @@ func TestBuildResourcePytorch(t *testing.T) {
assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas)
assert.Nil(t, pytorchJob.Spec.ElasticPolicy)

// verify TaskExecutionMetadata labels and annotations are copied to the TensorFlowJob
for k, v := range dummyAnnotations {
for _, replicaSpec := range pytorchJob.Spec.PyTorchReplicaSpecs {
assert.Equal(t, v, replicaSpec.Template.ObjectMeta.Annotations[k])
}
}
for k, v := range dummyLabels {
for _, replicaSpec := range pytorchJob.Spec.PyTorchReplicaSpecs {
assert.Equal(t, v, replicaSpec.Template.ObjectMeta.Labels[k])
}
}

for _, replicaSpec := range pytorchJob.Spec.PyTorchReplicaSpecs {
var hasContainerWithDefaultPytorchName = false

Expand Down
23 changes: 21 additions & 2 deletions go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ var (
"test-args",
}

dummyAnnotations = map[string]string{
"annotation-key": "annotation-value",
}
dummyLabels = map[string]string{
"label-key": "label-value",
}

resourceRequirements = &corev1.ResourceRequirements{
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1000m"),
Expand Down Expand Up @@ -152,8 +159,8 @@ func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.Tas
taskExecutionMetadata := &mocks.TaskExecutionMetadata{}
taskExecutionMetadata.OnGetTaskExecutionID().Return(tID)
taskExecutionMetadata.OnGetNamespace().Return("test-namespace")
taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"annotation-1": "val1"})
taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"})
taskExecutionMetadata.OnGetAnnotations().Return(dummyAnnotations)
taskExecutionMetadata.OnGetLabels().Return(dummyLabels)
taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{
Kind: "node",
Name: "blah",
Expand Down Expand Up @@ -306,6 +313,18 @@ func TestBuildResourceTensorFlow(t *testing.T) {
assert.Equal(t, int32(50), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS].Replicas)
assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Replicas)

// verify TaskExecutionMetadata labels and annotations are copied to the TensorFlowJob
for k, v := range dummyAnnotations {
for _, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs {
assert.Equal(t, v, replicaSpec.Template.ObjectMeta.Annotations[k])
}
}
for k, v := range dummyLabels {
for _, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs {
assert.Equal(t, v, replicaSpec.Template.ObjectMeta.Labels[k])
}
}

for _, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs {
var hasContainerWithDefaultTensorFlowName = false

Expand Down

0 comments on commit 27cff08

Please sign in to comment.