Skip to content

Commit

Permalink
feat(batch-jobs): Add step to pass API server env vars to spark jobs (#…
Browse files Browse the repository at this point in the history
…624)

# Description
This simple PR simply adds a tiny feature to allow the API server to
pass certain pre-configured environment variable values from its
environment to the Spark drivers and executors (for batch prediction)
that it spins up.

# Modifications
- `api/batch/resource.go` - Addition of a step to add pre-configured API
server environment variables to the Spark driver/executor manifest
- `api/config/batch.go` - Addition of a new config field to specify the
environment values mentioned above

# Tests
<!-- Besides the existing / updated automated tests, what specific
scenarios should be tested? Consider the backward compatibility of the
changes, whether corner cases are covered, etc. Please describe the
tests and check the ones that have been completed. Eg:
- [x] Deploying new and existing standard models
- [ ] Deploying PyFunc models
-->

# Checklist
- [x] Added PR label
- [ ] Added unit test, integration, and/or e2e tests
- [x] Tested locally
- [ ] Updated documentation
- [ ] Update Swagger spec if the PR introduce API changes
- [ ] Regenerated Golang and Python client if the PR introduces API
changes

# Release Notes
```release-note
NONE
```
  • Loading branch information
deadlycoconuts authored Jan 7, 2025
1 parent 6b36c4c commit 967788e
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 3 deletions.
10 changes: 7 additions & 3 deletions api/batch/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package batch

import (
"fmt"
"os"
"strconv"
"time"

Expand Down Expand Up @@ -150,7 +151,7 @@ func (t *BatchJobTemplater) createDriverSpec(job *models.PredictionJob) (v1beta2
return v1beta2.DriverSpec{}, fmt.Errorf("invalid driver memory request: %s", job.Config.ResourceRequest.DriverMemoryRequest)
}

envVars, err := addEnvVars(job)
envVars, err := t.addEnvVars(job)
if err != nil {
return v1beta2.DriverSpec{}, err
}
Expand Down Expand Up @@ -194,7 +195,7 @@ func (t *BatchJobTemplater) createExecutorSpec(job *models.PredictionJob) (v1bet
return v1beta2.ExecutorSpec{}, fmt.Errorf("invalid executor memory request: %s", job.Config.ResourceRequest.ExecutorMemoryRequest)
}

envVars, err := addEnvVars(job)
envVars, err := t.addEnvVars(job)
if err != nil {
return v1beta2.ExecutorSpec{}, err
}
Expand Down Expand Up @@ -263,13 +264,16 @@ func getCoreRequest(cpuRequest resource.Quantity) *int32 {
return &core
}

func addEnvVars(job *models.PredictionJob) ([]corev1.EnvVar, error) {
func (t *BatchJobTemplater) addEnvVars(job *models.PredictionJob) ([]corev1.EnvVar, error) {
envVars := []corev1.EnvVar{
{
Name: envServiceAccountPathKey,
Value: envServiceAccountPath,
},
}
for _, ev := range t.batchConfig.APIServerEnvVars {
envVars = append(envVars, corev1.EnvVar{Name: ev, Value: os.Getenv(ev)})
}
for _, ev := range job.Config.EnvVars.ToKubernetesEnvVars() {
if ev.Name == envServiceAccountPathKey {
return []corev1.EnvVar{}, fmt.Errorf("environment variable '%s' cannot be changed", ev.Name)
Expand Down
88 changes: 88 additions & 0 deletions api/batch/resource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package batch

import (
"os"
"testing"

"github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2"
Expand All @@ -23,6 +24,7 @@ import (
v12 "k8s.io/api/core/v1"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/caraml-dev/merlin/config"
"github.com/caraml-dev/merlin/mlp"
"github.com/caraml-dev/merlin/models"
)
Expand Down Expand Up @@ -649,3 +651,89 @@ func TestCreateSparkApplicationResource(t *testing.T) {
})
}
}

func TestAddEnvVars(t *testing.T) {
predictionJob := &models.PredictionJob{
Name: jobName,
ID: jobID,
Metadata: models.Metadata{
App: modelName,
Component: models.ComponentBatchJob,
Stream: streamName,
Team: teamName,
Labels: userLabels,
},
VersionModelID: modelID,
VersionID: versionID,
Config: &models.Config{
JobConfig: nil,
ImageRef: imageRef,
ResourceRequest: &models.PredictionJobResourceRequest{
DriverCPURequest: driverCPURequest,
DriverMemoryRequest: driverMemory,
ExecutorReplica: executorReplica,
ExecutorCPURequest: executorCPURequest,
ExecutorMemoryRequest: executorMemory,
},
MainAppPath: mainAppPathInput,
},
}

tests := []struct {
name string
apiServerEnvVars []string
wantErr bool
wantErrMessage string
want []v12.EnvVar
}{
{
name: "api server env vars specified",
apiServerEnvVars: []string{"TEST_ENV_VAR_1"},
want: []v12.EnvVar{
{
Name: envServiceAccountPathKey,
Value: envServiceAccountPath,
},
{
Name: "TEST_ENV_VAR_1",
Value: "TEST_VALUE_1",
},
},
},
{
name: "no api server env vars specified",
apiServerEnvVars: []string{},
want: []v12.EnvVar{
{
Name: envServiceAccountPathKey,
Value: envServiceAccountPath,
},
},
},
}

err := os.Setenv("TEST_ENV_VAR_1", "TEST_VALUE_1")
assert.NoError(t, err)
err = os.Setenv("TEST_ENV_VAR_2", "TEST_VALUE_2")
assert.NoError(t, err)

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
defaultBatchConfig = config.BatchConfig{
Tolerations: []v12.Toleration{defaultToleration},
NodeSelectors: defaultNodeSelector,
}
defaultBatchConfig.APIServerEnvVars = test.apiServerEnvVars

testBatchJobTemplater := NewBatchJobTemplater(defaultBatchConfig)

sp, err := testBatchJobTemplater.addEnvVars(predictionJob)
if test.wantErr {
assert.Equal(t, test.wantErrMessage, err.Error())
return
}
assert.NoError(t, err)
assert.Equal(t, test.want, sp)
})
}
}
2 changes: 2 additions & 0 deletions api/config/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ type BatchConfig struct {
Tolerations []v1.Toleration
// Node Selectors for Jobs Specification
NodeSelectors map[string]string
// APIServerEnvVars specifies the environment variables that are propagated from the API server to the Spark job
APIServerEnvVars []string
}
1 change: 1 addition & 0 deletions api/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ func TestLoad(t *testing.T) {
NodeSelectors: map[string]string{
"purpose.caraml.com/batch": "true",
},
APIServerEnvVars: []string{"TEST_ENV_VAR_2"},
},
AuthorizationConfig: AuthorizationConfig{
AuthorizationEnabled: true,
Expand Down
2 changes: 2 additions & 0 deletions api/config/testdata/base-configs-1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ BatchConfig:
Value: "true"
NodeSelectors:
purpose.caraml.com/batch: "true"
APIServerEnvVars:
- TEST_ENV_VAR_2
AuthorizationConfig:
AuthorizationEnabled: true
KetoRemoteRead: http://mlp-keto-read:80
Expand Down

0 comments on commit 967788e

Please sign in to comment.