Skip to content

Commit

Permalink
Merge pull request #2140 from solliancenet/cp-dalle-tool-fix
Browse files Browse the repository at this point in the history
Fix built-in DALL-E tool content artifacts
  • Loading branch information
ciprianjichici authored Jan 18, 2025
2 parents 1edb34a + ad35180 commit a85055d
Showing 1 changed file with 21 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from foundationallm.langchain.common import FoundationaLLMToolBase
from foundationallm.config import Configuration, UserIdentity
from foundationallm.models.agents import AgentTool
from foundationallm.models.orchestration import ContentArtifact
from foundationallm.models.resource_providers.ai_models import AIModelBase
from foundationallm.models.resource_providers.configuration import APIEndpointConfiguration
from foundationallm.utils import ObjectUtils
Expand Down Expand Up @@ -52,7 +53,7 @@ class DALLEImageGenerationTool(FoundationaLLMToolBase):
Supports only Azure Identity authentication.
"""
args_schema: Type[BaseModel] = DALLEImageGenerationToolInput

def __init__(self, tool_config: AgentTool, objects: dict, user_identity:UserIdentity, config: Configuration):
""" Initializes the DALLEImageGenerationTool class with the tool configuration,
exploded objects collection, user identity, and platform configuration. """
Expand All @@ -73,20 +74,20 @@ def __init__(self, tool_config: AgentTool, objects: dict, user_identity:UserIden

def _run(self,
prompt: str,
n: int,
quality: DALLEImageGenerationToolQualityEnum,
style: DALLEImageGenerationToolStyleEnum,
size: DALLEImageGenerationToolSizeEnum,
n: int = 1,
quality: DALLEImageGenerationToolQualityEnum = DALLEImageGenerationToolQualityEnum.hd,
style: DALLEImageGenerationToolStyleEnum = DALLEImageGenerationToolStyleEnum.natural,
size: DALLEImageGenerationToolSizeEnum = DALLEImageGenerationToolSizeEnum.size1024x1024,
run_manager: Optional[CallbackManagerForToolRun] = None
) -> str:
raise ToolException("This tool does not support synchronous execution. Please use the async version of the tool.")

async def _arun(self,
prompt: str,
n: int,
quality: DALLEImageGenerationToolQualityEnum,
style: DALLEImageGenerationToolStyleEnum,
size: DALLEImageGenerationToolSizeEnum,
n: int = 1,
quality: DALLEImageGenerationToolQualityEnum = DALLEImageGenerationToolQualityEnum.hd,
style: DALLEImageGenerationToolStyleEnum = DALLEImageGenerationToolStyleEnum.natural,
size: DALLEImageGenerationToolSizeEnum = DALLEImageGenerationToolSizeEnum.size1024x1024,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None
) -> str:
"""
Expand All @@ -102,7 +103,17 @@ async def _arun(self,
style = style,
size = size
)
return json.loads(result.model_dump_json())
content_artifacts = [
ContentArtifact(
id=image_data.url,
title=image_data.revised_prompt,
filepath=image_data.url,
type='image'
)
for image_data in result.data
if image_data.revised_prompt and image_data.url
]
return json.loads(result.model_dump_json()), content_artifacts
except Exception as e:
print(f'Image generation error code and message: {e.code}; {e}')
# Specifically handle content policy violation errors.
Expand Down

0 comments on commit a85055d

Please sign in to comment.