diff --git a/api/batch/resource.go b/api/batch/resource.go index b8b24dc89..2361ccbd6 100644 --- a/api/batch/resource.go +++ b/api/batch/resource.go @@ -16,6 +16,7 @@ package batch import ( "fmt" + "os" "strconv" "time" @@ -150,7 +151,7 @@ func (t *BatchJobTemplater) createDriverSpec(job *models.PredictionJob) (v1beta2 return v1beta2.DriverSpec{}, fmt.Errorf("invalid driver memory request: %s", job.Config.ResourceRequest.DriverMemoryRequest) } - envVars, err := addEnvVars(job) + envVars, err := t.addEnvVars(job) if err != nil { return v1beta2.DriverSpec{}, err } @@ -194,7 +195,7 @@ func (t *BatchJobTemplater) createExecutorSpec(job *models.PredictionJob) (v1bet return v1beta2.ExecutorSpec{}, fmt.Errorf("invalid executor memory request: %s", job.Config.ResourceRequest.ExecutorMemoryRequest) } - envVars, err := addEnvVars(job) + envVars, err := t.addEnvVars(job) if err != nil { return v1beta2.ExecutorSpec{}, err } @@ -263,13 +264,16 @@ func getCoreRequest(cpuRequest resource.Quantity) *int32 { return &core } -func addEnvVars(job *models.PredictionJob) ([]corev1.EnvVar, error) { +func (t *BatchJobTemplater) addEnvVars(job *models.PredictionJob) ([]corev1.EnvVar, error) { envVars := []corev1.EnvVar{ { Name: envServiceAccountPathKey, Value: envServiceAccountPath, }, } + for _, ev := range t.batchConfig.APIServerEnvVars { + envVars = append(envVars, corev1.EnvVar{Name: ev, Value: os.Getenv(ev)}) + } for _, ev := range job.Config.EnvVars.ToKubernetesEnvVars() { if ev.Name == envServiceAccountPathKey { return []corev1.EnvVar{}, fmt.Errorf("environment variable '%s' cannot be changed", ev.Name) diff --git a/api/batch/resource_test.go b/api/batch/resource_test.go index 2fe04c3fb..1adb15e84 100644 --- a/api/batch/resource_test.go +++ b/api/batch/resource_test.go @@ -15,6 +15,7 @@ package batch import ( + "os" "testing" "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" @@ -23,6 +24,7 @@ import ( v12 "k8s.io/api/core/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "github.com/caraml-dev/merlin/config" "github.com/caraml-dev/merlin/mlp" "github.com/caraml-dev/merlin/models" ) @@ -649,3 +651,89 @@ func TestCreateSparkApplicationResource(t *testing.T) { }) } } + +func TestAddEnvVars(t *testing.T) { + predictionJob := &models.PredictionJob{ + Name: jobName, + ID: jobID, + Metadata: models.Metadata{ + App: modelName, + Component: models.ComponentBatchJob, + Stream: streamName, + Team: teamName, + Labels: userLabels, + }, + VersionModelID: modelID, + VersionID: versionID, + Config: &models.Config{ + JobConfig: nil, + ImageRef: imageRef, + ResourceRequest: &models.PredictionJobResourceRequest{ + DriverCPURequest: driverCPURequest, + DriverMemoryRequest: driverMemory, + ExecutorReplica: executorReplica, + ExecutorCPURequest: executorCPURequest, + ExecutorMemoryRequest: executorMemory, + }, + MainAppPath: mainAppPathInput, + }, + } + + tests := []struct { + name string + apiServerEnvVars []string + wantErr bool + wantErrMessage string + want []v12.EnvVar + }{ + { + name: "api server env vars specified", + apiServerEnvVars: []string{"TEST_ENV_VAR_1"}, + want: []v12.EnvVar{ + { + Name: envServiceAccountPathKey, + Value: envServiceAccountPath, + }, + { + Name: "TEST_ENV_VAR_1", + Value: "TEST_VALUE_1", + }, + }, + }, + { + name: "no api server env vars specified", + apiServerEnvVars: []string{}, + want: []v12.EnvVar{ + { + Name: envServiceAccountPathKey, + Value: envServiceAccountPath, + }, + }, + }, + } + + err := os.Setenv("TEST_ENV_VAR_1", "TEST_VALUE_1") + assert.NoError(t, err) + err = os.Setenv("TEST_ENV_VAR_2", "TEST_VALUE_2") + assert.NoError(t, err) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + defaultBatchConfig = config.BatchConfig{ + Tolerations: []v12.Toleration{defaultToleration}, + NodeSelectors: defaultNodeSelector, + } + defaultBatchConfig.APIServerEnvVars = test.apiServerEnvVars + + testBatchJobTemplater := NewBatchJobTemplater(defaultBatchConfig) + + sp, err := testBatchJobTemplater.addEnvVars(predictionJob) + if test.wantErr { + assert.Equal(t, test.wantErrMessage, err.Error()) + return + } + assert.NoError(t, err) + assert.Equal(t, test.want, sp) + }) + } +} diff --git a/api/config/batch.go b/api/config/batch.go index 6200eb09d..5a2de551d 100644 --- a/api/config/batch.go +++ b/api/config/batch.go @@ -23,4 +23,6 @@ type BatchConfig struct { Tolerations []v1.Toleration // Node Selectors for Jobs Specification NodeSelectors map[string]string + // APIServerEnvVars specifies the environment variables that are propagated from the API server to the Spark job + APIServerEnvVars []string } diff --git a/api/config/config_test.go b/api/config/config_test.go index c618e8340..43505c8dc 100644 --- a/api/config/config_test.go +++ b/api/config/config_test.go @@ -463,6 +463,7 @@ func TestLoad(t *testing.T) { NodeSelectors: map[string]string{ "purpose.caraml.com/batch": "true", }, + APIServerEnvVars: []string{"TEST_ENV_VAR_2"}, }, AuthorizationConfig: AuthorizationConfig{ AuthorizationEnabled: true, diff --git a/api/config/testdata/base-configs-1.yaml b/api/config/testdata/base-configs-1.yaml index 3855f823c..5bb8dfa31 100644 --- a/api/config/testdata/base-configs-1.yaml +++ b/api/config/testdata/base-configs-1.yaml @@ -92,6 +92,8 @@ BatchConfig: Value: "true" NodeSelectors: purpose.caraml.com/batch: "true" + APIServerEnvVars: + - TEST_ENV_VAR_2 AuthorizationConfig: AuthorizationEnabled: true KetoRemoteRead: http://mlp-keto-read:80