Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change api model registry to databricks-sdk #1054

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 31 additions & 121 deletions aws-lambda/src/databricks_cdk/resources/mlflow/registered_model.py
Original file line number Diff line number Diff line change
@@ -1,110 +1,27 @@
from enum import Enum
from typing import List, Optional

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.ml import ModelTag
from pydantic import BaseModel

from databricks_cdk.utils import CnfResponse, delete_request, get_request, patch_request, post_request


class RegisteredModelTag(BaseModel):
key: str
value: str


class ModelVersionTag(BaseModel):
key: str
value: str
from databricks_cdk.utils import CnfResponse, get_workspace_client


class RegisteredModelProperties(BaseModel):
name: str
tags: Optional[List[RegisteredModelTag]] = []
tags: Optional[List[ModelTag]] = []
description: Optional[str]
workspace_url: str


class ModelVersionStatus(str, Enum):
PENDING_REGISTRATION = "PENDING_REGISTRATION"
FAILED_REGISTRATION = "FAILED_REGISTRATION"
READY = "READY"


class ModelVersion(BaseModel):
name: str
version: str
creation_timestamp: int
last_updated_timestamp: int
user_id: str
current_stage: str
description: Optional[str]
source: str
run_id: str
status: ModelVersionStatus
status_message: Optional[str]
tags: Optional[List[ModelVersionTag]]
run_link: Optional[str]


class RegisteredModel(BaseModel):
name: str
creation_timestamp: int
last_updated_timestamp: int
description: Optional[str]
latest_versions: Optional[List[ModelVersion]]
tags: Optional[List[RegisteredModelTag]]
user_id: Optional[str] # Currently not returned


class RegisteredModelCreateResponse(CnfResponse):
physical_resource_id: str


def get_registered_model_url(workspace_url: str):
"""Get the mlflow registered-models url"""
return f"{workspace_url}/api/2.0/mlflow/registered-models"


def _create_registered_model(registered_model_url: str, properties: RegisteredModelProperties) -> str:
"""Creates a registered model"""
response = post_request(
f"{registered_model_url}/create",
{
"name": properties.name,
"tags": [{"key": t.key, "value": t.value} for t in properties.tags],
"description": properties.description,
},
)
return response["registered_model"]["name"]


def _get_registered_model(registered_model_url: str, name: str) -> Optional[RegisteredModel]:
"""Gets the registered model"""
response = get_request(f"{registered_model_url}/get?name={name}")
if response:
return RegisteredModel.parse_obj(response["registered_model"])

return None


def _update_registered_model_description(registered_model_url: str, registered_model_name: str, description: str):
"""Updates the registered model description"""
return patch_request(
f"{registered_model_url}/update",
body={"name": registered_model_name, "description": description},
)


def _update_registered_model_name(registered_model_url: str, current_name: str, new_name: str) -> str:
"""Updates the registered model name"""
return post_request(f"{registered_model_url}/rename", {"name": current_name, "new_name": new_name})[
"registered_model"
]["name"]


def _update_registered_model_tags(
registered_model_url: str,
workspace_client: WorkspaceClient,
properties: RegisteredModelProperties,
current_tags: List[RegisteredModelTag],
current_tags: List[ModelTag],
):
"""Updates the registered model tags"""
tags_to_delete = []
Expand All @@ -116,22 +33,15 @@ def _update_registered_model_tags(
tags_to_delete = [t for t in current_tags if t.key not in new_keys]

if tags_to_delete:
[
delete_request(
f"{registered_model_url}/delete-tag",
body={"name": properties.name, "key": t.key},
)
for t in tags_to_delete
]
# delete tags that are not on the cdk object anymore
[workspace_client.model_registry.delete_model_tag(properties.name, t.key) for t in tags_to_delete]

if properties.tags:
# Overwrites / updates existing tags
[
post_request(
f"{registered_model_url}/set-tag",
{"name": properties.name, "key": t.key, "value": t.value},
)
workspace_client.model_registry.set_model_tag(properties.name, t.key, t.value)
for t in properties.tags
if t.key and t.value is not None
]


Expand All @@ -146,27 +56,28 @@ def create_or_update_registered_model(
:param physical_resource_id: CDK Physical Resource Id belonging to the Registered Model (if exists). Defaults to None
:return:physical_resource_id of the Registered Model, which equals the name of the Registered Model
"""
registered_model_url = get_registered_model_url(properties.workspace_url)

if not physical_resource_id:
registered_model_name = _create_registered_model(registered_model_url, properties)
return RegisteredModelCreateResponse(physical_resource_id=registered_model_name)
workspace_client = get_workspace_client(properties.workspace_url)

if physical_resource_id is None:
response = workspace_client.model_registry.create_model(
name=properties.name, description=properties.description, tags=properties.tags
)

registered_model_url = get_registered_model_url(properties.workspace_url)
registered_model = _get_registered_model(registered_model_url, physical_resource_id)
name = response.registered_model.name if response.registered_model else None
if name is not None:
return RegisteredModelCreateResponse(physical_resource_id=name)

if not registered_model:
registered_model = workspace_client.model_registry.get_model(name=physical_resource_id)
if registered_model is None:
raise ValueError(f"Registered model cannot be found but physical_resouce_id {physical_resource_id} is provided")

if properties.name != registered_model.name:
physical_resource_id = _update_registered_model_name(
registered_model_url,
current_name=physical_resource_id,
new_name=properties.name,
)

if properties.description != registered_model.description:
_update_registered_model_description(registered_model_url, physical_resource_id, properties.description)
if registered_model.registered_model_databricks is not None and (
properties.name != registered_model.registered_model_databricks.name
or properties.description != registered_model.registered_model_databricks.description
):
workspace_client.model_registry.update_model(name=properties.name, description=properties.description)
return RegisteredModelCreateResponse(physical_resource_id=physical_resource_id)

new_tags = properties.tags
if properties.tags is not None:
Expand All @@ -177,14 +88,13 @@ def create_or_update_registered_model(
current_tags = sorted(current_tags, key=lambda t: t.key)

if new_tags != current_tags:
_update_registered_model_tags(registered_model_url, properties, registered_model.tags)
_update_registered_model_tags(workspace_client, properties, registered_model.tags)

return RegisteredModelCreateResponse(physical_resource_id=physical_resource_id)


def delete_registered_model(properties: RegisteredModelProperties, physical_resource_id: str):
"""Deletes an existing registered model"""
delete_request(
f"{get_registered_model_url(properties.workspace_url)}/delete",
body={"name": physical_resource_id},
)
workspace_client = get_workspace_client(properties.workspace_url)
workspace_client.model_registry.delete_model(name=physical_resource_id)
return CnfResponse(physical_resource_id=physical_resource_id)
12 changes: 12 additions & 0 deletions aws-lambda/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from unittest.mock import MagicMock

import pytest
from databricks.sdk import ModelRegistryAPI, WorkspaceClient


@pytest.fixture(scope="function", autouse=True)
Expand All @@ -12,3 +14,13 @@ def aws_credentials():
os.environ["AWS_SECURITY_TOKEN"] = "testing"
os.environ["AWS_SESSION_TOKEN"] = "testing"
os.environ["AWS_DEFAULT_REGION"] = "eu-west-1"


@pytest.fixture(scope="function")
def workspace_client():
workspace_client = MagicMock(spec=WorkspaceClient)

# mock all of the underlying service api's
workspace_client.model_registry = MagicMock(spec=ModelRegistryAPI)

return workspace_client
Loading
Loading