Skip to content

Commit

Permalink
TAS: Support Kubeflow. (#3414)
Browse files Browse the repository at this point in the history
* TAS: Support Kubeflow.

* Test ValidateOnCreate and ValidateOnUpdate on BaseWebHook.
  • Loading branch information
mbobrovskyi authored Nov 5, 2024
1 parent 6935a11 commit b3c53d5
Show file tree
Hide file tree
Showing 32 changed files with 2,417 additions and 61 deletions.
14 changes: 8 additions & 6 deletions pkg/controller/jobframework/base_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"context"

"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/klog/v2"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
Expand Down Expand Up @@ -67,11 +66,11 @@ func (w *BaseWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) (a
job := w.FromObject(obj)
log := ctrl.LoggerFrom(ctx)
log.V(5).Info("Validating create", "job", klog.KObj(job.Object()))
return nil, validateCreate(job).ToAggregate()
}

func validateCreate(job GenericJob) field.ErrorList {
return ValidateJobOnCreate(job)
allErrs := ValidateJobOnCreate(job)
if jobWithValidation, ok := job.(JobWithValidation); ok {
allErrs = append(allErrs, jobWithValidation.ValidateOnCreate()...)
}
return nil, allErrs.ToAggregate()
}

// ValidateUpdate implements webhook.CustomValidator so a webhook will be registered for the type
Expand All @@ -81,6 +80,9 @@ func (w *BaseWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime
log := ctrl.LoggerFrom(ctx)
log.Info("Validating update", "job", klog.KObj(newJob.Object()))
allErrs := ValidateJobOnUpdate(oldJob, newJob)
if jobWithValidation, ok := newJob.(JobWithValidation); ok {
allErrs = append(allErrs, jobWithValidation.ValidateOnUpdate(oldJob)...)
}
return nil, allErrs.ToAggregate()
}

Expand Down
189 changes: 184 additions & 5 deletions pkg/controller/jobframework/base_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,29 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
"sigs.k8s.io/kueue/pkg/controller/constants"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
"sigs.k8s.io/kueue/pkg/podset"
)

type testGenericJob batchv1.Job
type testGenericJob struct {
*batchv1.Job

validateOnCreate func() field.ErrorList
validateOnUpdate func(jobframework.GenericJob) field.ErrorList
}

var _ jobframework.GenericJob = (*testGenericJob)(nil)
var _ jobframework.JobWithValidation = (*testGenericJob)(nil)

func (t *testGenericJob) Object() client.Object {
return (*batchv1.Job)(t)
return t.Job
}

func (t *testGenericJob) IsSuspended() bool {
Expand Down Expand Up @@ -78,8 +86,40 @@ func (t *testGenericJob) GVK() schema.GroupVersionKind {
panic("not implemented")
}

func fromObject(o runtime.Object) jobframework.GenericJob {
return (*testGenericJob)(o.(*batchv1.Job))
func (t *testGenericJob) ValidateOnCreate() field.ErrorList {
if t.validateOnCreate != nil {
return t.validateOnCreate()
}
return nil
}

func (t *testGenericJob) ValidateOnUpdate(oldJob jobframework.GenericJob) field.ErrorList {
if t.validateOnUpdate != nil {
return t.validateOnUpdate(oldJob)
}
return nil
}

func (t *testGenericJob) withValidateOnCreate(validateOnCreate func() field.ErrorList) *testGenericJob {
t.validateOnCreate = validateOnCreate
return t
}

func (t *testGenericJob) withValidateOnUpdate(validateOnUpdate func(jobframework.GenericJob) field.ErrorList) *testGenericJob {
t.validateOnUpdate = validateOnUpdate
return t
}

func (t *testGenericJob) fromObject(o runtime.Object) jobframework.GenericJob {
if o == nil {
return nil
}
t.Job = o.(*batchv1.Job)
return t
}

func makeTestGenericJob() *testGenericJob {
return &testGenericJob{}
}

func TestBaseWebhookDefault(t *testing.T) {
Expand Down Expand Up @@ -129,7 +169,7 @@ func TestBaseWebhookDefault(t *testing.T) {
t.Run(name, func(t *testing.T) {
w := &jobframework.BaseWebhook{
ManageJobsWithoutQueueName: tc.manageJobsWithoutQueueName,
FromObject: fromObject,
FromObject: makeTestGenericJob().fromObject,
}
if err := w.Default(context.Background(), tc.job); err != nil {
t.Errorf("set defaults to a kubeflow/mpijob by a Defaulter")
Expand All @@ -140,3 +180,142 @@ func TestBaseWebhookDefault(t *testing.T) {
})
}
}

func TestValidateOnCreate(t *testing.T) {
testcases := []struct {
name string
job *batchv1.Job
validateOnCreate func() field.ErrorList
wantErr error
wantWarn admission.Warnings
}{
{
name: "valid request",
job: &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{
Name: "job",
Namespace: "default",
Labels: map[string]string{constants.QueueLabel: "queue"},
},
},
},
{
name: "invalid request with validate on create",
job: &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{
Name: "job",
Namespace: "default",
Labels: map[string]string{constants.QueueLabel: "queue"},
},
},
validateOnCreate: func() field.ErrorList {
return field.ErrorList{
field.Invalid(
field.NewPath("metadata.annotations"),
field.OmitValueType{},
`invalid annotation`,
),
}
},
wantErr: field.ErrorList{
field.Invalid(
field.NewPath("metadata.annotations"),
field.OmitValueType{},
`invalid annotation`,
),
}.ToAggregate(),
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
w := &jobframework.BaseWebhook{
FromObject: makeTestGenericJob().withValidateOnCreate(tc.validateOnCreate).fromObject,
}
gotWarn, gotErr := w.ValidateCreate(context.Background(), tc.job)
if diff := cmp.Diff(tc.wantErr, gotErr); diff != "" {
t.Errorf("validate create err mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tc.wantWarn, gotWarn); diff != "" {
t.Errorf("validate create warn mismatch (-want +got):\n%s", diff)
}
})
}
}

func TestValidateOnUpdate(t *testing.T) {
testcases := []struct {
name string
oldJob *batchv1.Job
job *batchv1.Job
validateOnUpdate func(jobframework.GenericJob) field.ErrorList
wantErr error
wantWarn admission.Warnings
}{
{
name: "valid request",
oldJob: &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{
Name: "job",
Namespace: "default",
Labels: map[string]string{constants.QueueLabel: "queue"},
},
},
job: &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{
Name: "job",
Namespace: "default",
Labels: map[string]string{constants.QueueLabel: "queue"},
},
},
},
{
name: "invalid request with validate on update",
oldJob: &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{
Name: "job",
Namespace: "default",
Labels: map[string]string{constants.QueueLabel: "queue"},
},
},
job: &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{
Name: "job",
Namespace: "default",
Labels: map[string]string{constants.QueueLabel: "queue"},
},
},
validateOnUpdate: func(jobframework.GenericJob) field.ErrorList {
return field.ErrorList{
field.Invalid(
field.NewPath("metadata.annotations"),
field.OmitValueType{},
`invalid annotation`,
),
}
},
wantErr: field.ErrorList{
field.Invalid(
field.NewPath("metadata.annotations"),
field.OmitValueType{},
`invalid annotation`,
),
}.ToAggregate(),
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
w := &jobframework.BaseWebhook{
FromObject: makeTestGenericJob().withValidateOnUpdate(tc.validateOnUpdate).fromObject,
}
gotWarn, gotErr := w.ValidateUpdate(context.Background(), tc.oldJob, tc.job)
if diff := cmp.Diff(tc.wantErr, gotErr); diff != "" {
t.Errorf("validate create err mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tc.wantWarn, gotWarn); diff != "" {
t.Errorf("validate create warn mismatch (-want +got):\n%s", diff)
}
})
}
}
10 changes: 10 additions & 0 deletions pkg/controller/jobframework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/client-go/tools/record"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -107,6 +108,15 @@ type JobWithPriorityClass interface {
PriorityClass() string
}

// JobWithValidation optional interface that allows custom webhook validation
// for Jobs that use BaseWebhook.
type JobWithValidation interface {
// ValidateOnCreate returns list of webhook create validation errors.
ValidateOnCreate() field.ErrorList
// ValidateOnUpdate returns list of webhook update validation errors.
ValidateOnUpdate(oldJob GenericJob) field.ErrorList
}

// ComposableJob interface should be implemented by generic jobs that
// are composed out of multiple API objects.
type ComposableJob interface {
Expand Down
4 changes: 4 additions & 0 deletions pkg/controller/jobs/kubeflow/jobs/mxjob/mxjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ func (j *JobControl) ReplicaSpecs() map[kftraining.ReplicaType]*kftraining.Repli
return j.Spec.MXReplicaSpecs
}

func (j *JobControl) ReplicaSpecsFieldName() string {
return "mxReplicaSpecs"
}

func (j *JobControl) JobStatus() *kftraining.JobStatus {
return &j.Status
}
Expand Down
Loading

0 comments on commit b3c53d5

Please sign in to comment.