diff --git a/aws-lambda/src/databricks_cdk/resources/mlflow/registered_model.py b/aws-lambda/src/databricks_cdk/resources/mlflow/registered_model.py index 70add702..49ce1537 100644 --- a/aws-lambda/src/databricks_cdk/resources/mlflow/registered_model.py +++ b/aws-lambda/src/databricks_cdk/resources/mlflow/registered_model.py @@ -1,110 +1,27 @@ -from enum import Enum from typing import List, Optional +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.ml import ModelTag from pydantic import BaseModel -from databricks_cdk.utils import CnfResponse, delete_request, get_request, patch_request, post_request - - -class RegisteredModelTag(BaseModel): - key: str - value: str - - -class ModelVersionTag(BaseModel): - key: str - value: str +from databricks_cdk.utils import CnfResponse, get_workspace_client class RegisteredModelProperties(BaseModel): name: str - tags: Optional[List[RegisteredModelTag]] = [] + tags: Optional[List[ModelTag]] = [] description: Optional[str] workspace_url: str -class ModelVersionStatus(str, Enum): - PENDING_REGISTRATION = "PENDING_REGISTRATION" - FAILED_REGISTRATION = "FAILED_REGISTRATION" - READY = "READY" - - -class ModelVersion(BaseModel): - name: str - version: str - creation_timestamp: int - last_updated_timestamp: int - user_id: str - current_stage: str - description: Optional[str] - source: str - run_id: str - status: ModelVersionStatus - status_message: Optional[str] - tags: Optional[List[ModelVersionTag]] - run_link: Optional[str] - - -class RegisteredModel(BaseModel): - name: str - creation_timestamp: int - last_updated_timestamp: int - description: Optional[str] - latest_versions: Optional[List[ModelVersion]] - tags: Optional[List[RegisteredModelTag]] - user_id: Optional[str] # Currently not returned - - class RegisteredModelCreateResponse(CnfResponse): physical_resource_id: str -def get_registered_model_url(workspace_url: str): - """Get the mlflow registered-models url""" - return f"{workspace_url}/api/2.0/mlflow/registered-models" - - -def _create_registered_model(registered_model_url: str, properties: RegisteredModelProperties) -> str: - """Creates a registered model""" - response = post_request( - f"{registered_model_url}/create", - { - "name": properties.name, - "tags": [{"key": t.key, "value": t.value} for t in properties.tags], - "description": properties.description, - }, - ) - return response["registered_model"]["name"] - - -def _get_registered_model(registered_model_url: str, name: str) -> Optional[RegisteredModel]: - """Gets the registered model""" - response = get_request(f"{registered_model_url}/get?name={name}") - if response: - return RegisteredModel.parse_obj(response["registered_model"]) - - return None - - -def _update_registered_model_description(registered_model_url: str, registered_model_name: str, description: str): - """Updates the registered model description""" - return patch_request( - f"{registered_model_url}/update", - body={"name": registered_model_name, "description": description}, - ) - - -def _update_registered_model_name(registered_model_url: str, current_name: str, new_name: str) -> str: - """Updates the registered model name""" - return post_request(f"{registered_model_url}/rename", {"name": current_name, "new_name": new_name})[ - "registered_model" - ]["name"] - - def _update_registered_model_tags( - registered_model_url: str, + workspace_client: WorkspaceClient, properties: RegisteredModelProperties, - current_tags: List[RegisteredModelTag], + current_tags: List[ModelTag], ): """Updates the registered model tags""" tags_to_delete = [] @@ -116,22 +33,15 @@ def _update_registered_model_tags( tags_to_delete = [t for t in current_tags if t.key not in new_keys] if tags_to_delete: - [ - delete_request( - f"{registered_model_url}/delete-tag", - body={"name": properties.name, "key": t.key}, - ) - for t in tags_to_delete - ] + # delete tags that are not on the cdk object anymore + [workspace_client.model_registry.delete_model_tag(properties.name, t.key) for t in tags_to_delete] if properties.tags: # Overwrites / updates existing tags [ - post_request( - f"{registered_model_url}/set-tag", - {"name": properties.name, "key": t.key, "value": t.value}, - ) + workspace_client.model_registry.set_model_tag(properties.name, t.key, t.value) for t in properties.tags + if t.key and t.value is not None ] @@ -146,27 +56,28 @@ def create_or_update_registered_model( :param physical_resource_id: CDK Physical Resource Id belonging to the Registered Model (if exists). Defaults to None :return:physical_resource_id of the Registered Model, which equals the name of the Registered Model """ - registered_model_url = get_registered_model_url(properties.workspace_url) - if not physical_resource_id: - registered_model_name = _create_registered_model(registered_model_url, properties) - return RegisteredModelCreateResponse(physical_resource_id=registered_model_name) + workspace_client = get_workspace_client(properties.workspace_url) + + if physical_resource_id is None: + response = workspace_client.model_registry.create_model( + name=properties.name, description=properties.description, tags=properties.tags + ) - registered_model_url = get_registered_model_url(properties.workspace_url) - registered_model = _get_registered_model(registered_model_url, physical_resource_id) + name = response.registered_model.name if response.registered_model else None + if name is not None: + return RegisteredModelCreateResponse(physical_resource_id=name) - if not registered_model: + registered_model = workspace_client.model_registry.get_model(name=physical_resource_id) + if registered_model is None: raise ValueError(f"Registered model cannot be found but physical_resouce_id {physical_resource_id} is provided") - if properties.name != registered_model.name: - physical_resource_id = _update_registered_model_name( - registered_model_url, - current_name=physical_resource_id, - new_name=properties.name, - ) - - if properties.description != registered_model.description: - _update_registered_model_description(registered_model_url, physical_resource_id, properties.description) + if registered_model.registered_model_databricks is not None and ( + properties.name != registered_model.registered_model_databricks.name + or properties.description != registered_model.registered_model_databricks.description + ): + workspace_client.model_registry.update_model(name=properties.name, description=properties.description) + return RegisteredModelCreateResponse(physical_resource_id=physical_resource_id) new_tags = properties.tags if properties.tags is not None: @@ -177,14 +88,13 @@ def create_or_update_registered_model( current_tags = sorted(current_tags, key=lambda t: t.key) if new_tags != current_tags: - _update_registered_model_tags(registered_model_url, properties, registered_model.tags) + _update_registered_model_tags(workspace_client, properties, registered_model.tags) + return RegisteredModelCreateResponse(physical_resource_id=physical_resource_id) def delete_registered_model(properties: RegisteredModelProperties, physical_resource_id: str): """Deletes an existing registered model""" - delete_request( - f"{get_registered_model_url(properties.workspace_url)}/delete", - body={"name": physical_resource_id}, - ) + workspace_client = get_workspace_client(properties.workspace_url) + workspace_client.model_registry.delete_model(name=physical_resource_id) return CnfResponse(physical_resource_id=physical_resource_id) diff --git a/aws-lambda/tests/conftest.py b/aws-lambda/tests/conftest.py index c552e8eb..588f7492 100644 --- a/aws-lambda/tests/conftest.py +++ b/aws-lambda/tests/conftest.py @@ -1,6 +1,8 @@ import os +from unittest.mock import MagicMock import pytest +from databricks.sdk import ModelRegistryAPI, WorkspaceClient @pytest.fixture(scope="function", autouse=True) @@ -12,3 +14,13 @@ def aws_credentials(): os.environ["AWS_SECURITY_TOKEN"] = "testing" os.environ["AWS_SESSION_TOKEN"] = "testing" os.environ["AWS_DEFAULT_REGION"] = "eu-west-1" + + +@pytest.fixture(scope="function") +def workspace_client(): + workspace_client = MagicMock(spec=WorkspaceClient) + + # mock all of the underlying service api's + workspace_client.model_registry = MagicMock(spec=ModelRegistryAPI) + + return workspace_client diff --git a/aws-lambda/tests/resources/mlflow/test_registered_model.py b/aws-lambda/tests/resources/mlflow/test_registered_model.py index 04661a1d..de9bc68e 100644 --- a/aws-lambda/tests/resources/mlflow/test_registered_model.py +++ b/aws-lambda/tests/resources/mlflow/test_registered_model.py @@ -1,226 +1,117 @@ from unittest.mock import patch import pytest +from databricks.sdk.service.ml import CreateModelResponse, GetModelResponse, Model, ModelDatabricks, ModelTag from databricks_cdk.resources.mlflow.registered_model import ( - RegisteredModel, RegisteredModelCreateResponse, RegisteredModelProperties, - RegisteredModelTag, - _create_registered_model, - _get_registered_model, - _update_registered_model_description, - _update_registered_model_name, _update_registered_model_tags, create_or_update_registered_model, delete_registered_model, - get_registered_model_url, ) -def test_get_registered_model_url(): - workspace_url = "https://test.cloud.databricks.com" - assert ( - get_registered_model_url(workspace_url) == "https://test.cloud.databricks.com/api/2.0/mlflow/registered-models" - ) - - -@patch("databricks_cdk.resources.mlflow.registered_model.post_request") -def test__create_registered_model(patched_post_request): - props = RegisteredModelProperties( - name="test", - workspace_url="https://test.cloud.databricks.com", - description="some description", - tags=[RegisteredModelTag(key="test", value="test")], - ) - patched_post_request.return_value = {"registered_model": {"name": "test"}} - name = _create_registered_model("https://test.cloud.databricks.com/api/2.0/mlflow/registered-models", props) - - assert name == "test" - assert patched_post_request.call_count == 1 - - assert patched_post_request.call_args.args == ( - "https://test.cloud.databricks.com/api/2.0/mlflow/registered-models/create", - { - "name": "test", - "description": "some description", - "tags": [{"key": "test", "value": "test"}], - }, - ) - - -@patch("databricks_cdk.resources.mlflow.registered_model.get_request") -def test__get_registered_model(patched__get_request): - patched__get_request.return_value = { - "registered_model": { - "name": "same_name", - "creation_timestamp": 1, - "last_updated_timestamp": 1, - "description": "same description", - } - } - - registered_model = _get_registered_model( - registered_model_url="https://test.cloud.databricks.com/api/2.0/mlflow/registered-models", - name="same_name", - ) - - assert isinstance(registered_model, RegisteredModel) - - patched__get_request.return_value = None - assert not _get_registered_model( - registered_model_url="https://test.cloud.databricks.com/api/2.0/mlflow/registered-models", - name="same_name", - ) - - -@patch("databricks_cdk.resources.mlflow.registered_model.post_request") -def test__update_registered_model_name(patched_post_request): - _update_registered_model_name( - registered_model_url="https://test.cloud.databricks.com/api/2.0/mlflow/registered-models", - current_name="name", - new_name="new_name", - ) - - assert patched_post_request.call_args.args == ( - "https://test.cloud.databricks.com/api/2.0/mlflow/registered-models/rename", - {"name": "name", "new_name": "new_name"}, - ) - - -@patch("databricks_cdk.resources.mlflow.registered_model.patch_request") -def test__update_registered_model_description(patched_patch_request): - _update_registered_model_description( - registered_model_url="https://test.cloud.databricks.com/api/2.0/mlflow/registered-models", - registered_model_name="test", - description="new_description", - ) - - assert ( - patched_patch_request.call_args.args[0] - == "https://test.cloud.databricks.com/api/2.0/mlflow/registered-models/update" - ) - - assert patched_patch_request.call_args.kwargs == {"body": {"description": "new_description", "name": "test"}} - - -@patch("databricks_cdk.resources.mlflow.registered_model.post_request") -def test__update_registered_model_tags_add(patched_post_request): +def test__update_registered_model_tags_add(workspace_client): props = RegisteredModelProperties( name="test-model", - tags=[RegisteredModelTag(key="test", value="test-value")], + tags=[ModelTag(key="test", value="test-value")], workspace_url="https://test.cloud.databricks.com", ) _update_registered_model_tags( - registered_model_url="https://test.cloud.databricks.com/api/2.0/mlflow/registered-models", + workspace_client=workspace_client, properties=props, current_tags=[], ) - assert ( - patched_post_request.call_args.args[0] - == "https://test.cloud.databricks.com/api/2.0/mlflow/registered-models/set-tag" - ) - assert patched_post_request.call_args.args[1] == { - "name": "test-model", - "key": "test", - "value": "test-value", - } + workspace_client.model_registry.set_model_tag.assert_called_once_with("test-model", "test", "test-value") -@patch("databricks_cdk.resources.mlflow.registered_model.delete_request") -@patch("databricks_cdk.resources.mlflow.registered_model.post_request") -def test__update_registered_model_tags_update(patched_post_request, patched_delete_request): +def test__update_registered_model_tags_update(workspace_client): props = RegisteredModelProperties( name="test-model", - tags=[RegisteredModelTag(key="test", value="new-test-value")], + tags=[ModelTag(key="test", value="new-test-value")], workspace_url="https://test.cloud.databricks.com", ) _update_registered_model_tags( - registered_model_url="https://test.cloud.databricks.com/api/2.0/mlflow/registered-models", + workspace_client=workspace_client, properties=props, current_tags=[ - RegisteredModelTag(key="to-delete", value="test-delete"), - RegisteredModelTag(key="test", value="test-value"), + ModelTag(key="to-delete", value="test-delete"), + ModelTag(key="test", value="test-value"), ], ) # Make sure tag test-model is updated - assert ( - patched_post_request.call_args.args[0] - == "https://test.cloud.databricks.com/api/2.0/mlflow/registered-models/set-tag" - ) - assert patched_post_request.call_args.args[1] == { - "name": "test-model", - "key": "test", - "value": "new-test-value", - } + workspace_client.model_registry.set_model_tag.assert_called_once_with("test-model", "test", "new-test-value") # Make sure removed to-delete tag is deleted - assert ( - patched_delete_request.call_args.args[0] - == "https://test.cloud.databricks.com/api/2.0/mlflow/registered-models/delete-tag" - ) - assert patched_delete_request.call_args.kwargs == {"body": {"name": "test-model", "key": "to-delete"}} + workspace_client.model_registry.delete_model_tag.assert_called_once_with("test-model", "to-delete") -@patch("databricks_cdk.resources.mlflow.registered_model._create_registered_model") -def test_create_or_update_registered_model_new(patched__create_registered_model): +@patch("databricks_cdk.resources.mlflow.registered_model.get_workspace_client") +def test_create_or_update_registered_model_new(patched_get_workspace_client, workspace_client): + patched_get_workspace_client.return_value = workspace_client + + workspace_client.model_registry.create_model.return_value = CreateModelResponse( + registered_model=Model(name="new-model-name") + ) props = RegisteredModelProperties( name="new-model-name", workspace_url="https://test.cloud.databricks.com", description="same description", ) - patched__create_registered_model.return_value = "new-model-name" - # completely new experiment response = create_or_update_registered_model(props, physical_resource_id=None) assert response == RegisteredModelCreateResponse(physical_resource_id="new-model-name") -@patch("databricks_cdk.resources.mlflow.registered_model._get_registered_model") -def test_create_or_update_registered_model_existing( - patched__get_existing_registered_model, -): +@patch("databricks_cdk.resources.mlflow.registered_model.get_workspace_client") +def test_create_or_update_registered_model_existing(patched_get_workspace_client, workspace_client): + patched_get_workspace_client.return_value = workspace_client + + workspace_client.model_registry.get_model.return_value = GetModelResponse( + registered_model_databricks=ModelDatabricks(name="model-name", description="same description") + ) + props = RegisteredModelProperties( - name="model-name", + name="new-model-name", workspace_url="https://test.cloud.databricks.com", description="same description", ) - patched__get_existing_registered_model.return_value = RegisteredModel( - name="model-name", - last_updated_timestamp=1234, - creation_timestamp=1234, - description="same description", + response = create_or_update_registered_model(props, physical_resource_id="new-model-name") + + assert response == RegisteredModelCreateResponse(physical_resource_id="new-model-name") + workspace_client.model_registry.update_model.assert_called_once_with( + name="new-model-name", description="same description" ) - response = create_or_update_registered_model(props, physical_resource_id="model-name") - assert patched__get_existing_registered_model.call_args.args == ( - "https://test.cloud.databricks.com/api/2.0/mlflow/registered-models", - "model-name", +@patch("databricks_cdk.resources.mlflow.registered_model.get_workspace_client") +def test_create_or_update_registered_model_invalid(patched_get_workspace_client, workspace_client): + props = RegisteredModelProperties( + name="new-model-name", + workspace_url="https://test.cloud.databricks.com", + description="same description", ) - - assert response == RegisteredModelCreateResponse(physical_resource_id="model-name") + patched_get_workspace_client.return_value = workspace_client + workspace_client.model_registry.get_model.return_value = None # this is invalid and should raise error - patched__get_existing_registered_model.return_value = None + # patched__get_existing_registered_model.return_value = None with pytest.raises(ValueError): create_or_update_registered_model(props, physical_resource_id="model-name") -@patch("databricks_cdk.resources.mlflow.registered_model.delete_request") -def test_delete_experiment(patched_delete_request): +@patch("databricks_cdk.resources.mlflow.registered_model.get_workspace_client") +def test_delete_experiment(patched_get_workspace_client, workspace_client): + patched_get_workspace_client.return_value = workspace_client props = RegisteredModelProperties( name="name", workspace_url="https://test.cloud.databricks.com", description="same description", ) delete_registered_model(props, "name") - assert patched_delete_request.call_args.args == ( - "https://test.cloud.databricks.com/api/2.0/mlflow/registered-models/delete", - ) - - assert patched_delete_request.call_args.kwargs == {"body": {"name": "name"}} + workspace_client.model_registry.delete_model.assert_called_once_with(name="name")