From 6454929525db1d8bf530883ae6b15a09c220e992 Mon Sep 17 00:00:00 2001 From: Ana Pacheco Date: Tue, 29 Oct 2024 12:29:46 -0600 Subject: [PATCH] include user_id in conversations and resources --- apis/paios/openapi.yaml | 44 +++++++++++++++++-- backend/api/ConversationsView.py | 4 +- backend/api/ResourcesView.py | 10 ++++- backend/api/SharesView.py | 7 ++- backend/managers/ConversationsManager.py | 26 ++++++----- backend/managers/ResourcesManager.py | 31 ++++++++++--- backend/managers/SharesManager.py | 12 ++++- backend/models.py | 2 + backend/schemas.py | 2 + .../75aaaf2cd1a2_added_resource_table.py | 1 + .../d34acf83524e_added_conversation_table.py | 2 + 11 files changed, 114 insertions(+), 27 deletions(-) diff --git a/apis/paios/openapi.yaml b/apis/paios/openapi.yaml index c2268bd1..3e149144 100644 --- a/apis/paios/openapi.yaml +++ b/apis/paios/openapi.yaml @@ -363,6 +363,26 @@ paths: description: No Content '404': description: Not Found + '/resources/{user_id}/shared': + get: + security: + - jwt: [] + tags: + - Resource Management + summary: Retrieve resource by user id + operationId: backend.api.ResourcesView.shared + description: Retrieve the information of the resource with the specified user ID. + parameters: + - $ref: '#/components/parameters/user_id' + responses: + '200': + description: OK + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/Resource' '/config/{key}': get: security: @@ -1081,7 +1101,7 @@ paths: description: No Content '404': description: Not Found - '/conversations/{resource_id}': + '/conversations/{resource_id}/{user_id}': post: security: - jwt: [] @@ -1090,7 +1110,8 @@ paths: - Conversation Management operationId: backend.api.ConversationsView.post parameters: - - $ref: '#/components/parameters/resource_id' + - $ref: '#/components/parameters/resource_id' + - $ref: '#/components/parameters/user_id' responses: '201': description: Created @@ -1317,6 +1338,13 @@ components: in: path description: this refers to the assitant id for the resource required: true + schema: + $ref: '#/components/schemas/uuid4' + user_id: + name: user_id + in: path + description: this refers to the user id + required: true schema: $ref: '#/components/schemas/uuid4' schemas: @@ -1571,7 +1599,10 @@ components: $ref: '#/components/schemas/uri' active: nullable: true - $ref: '#/components/schemas/boolean_str' + $ref: '#/components/schemas/boolean_str' + user_id: + $ref: '#/components/schemas/uuid4' + nullable: true ResourceCreate: type: object title: ResourceCreate @@ -1606,7 +1637,10 @@ components: $ref: '#/components/schemas/uri' active: nullable: true - $ref: '#/components/schemas/boolean_str' + $ref: '#/components/schemas/boolean_str' + user_id: + $ref: '#/components/schemas/uuid4' + nullable: true required: - name - kind @@ -1958,6 +1992,8 @@ components: type: array items: $ref: '#/components/schemas/Message' + user_id: + $ref: '#/components/schemas/uuid4' required: - name - archive diff --git a/backend/api/ConversationsView.py b/backend/api/ConversationsView.py index 2dcc35c3..407cb3f1 100644 --- a/backend/api/ConversationsView.py +++ b/backend/api/ConversationsView.py @@ -17,8 +17,8 @@ async def get(self, id: str): return NOT_FOUND_RESPONSE return JSONResponse(conversation.dict(), status_code=200) - async def post(self, resource_id: str, body: ConversationCreateSchema): - conversation_id = await self.cm.create_conversation(resource_id, body) + async def post(self, resource_id: str, user_id: str, body: ConversationCreateSchema): + conversation_id = await self.cm.create_conversation(resource_id, user_id, body) conversation = await self.cm.retrieve_conversation(conversation_id) if conversation is None: return JSONResponse({"error": "Assistant not found"}, status_code=404) diff --git a/backend/api/ResourcesView.py b/backend/api/ResourcesView.py index 92a2c0f7..d36d949c 100644 --- a/backend/api/ResourcesView.py +++ b/backend/api/ResourcesView.py @@ -6,7 +6,6 @@ from backend.managers.ConversationsManager import ConversationsManager from backend.pagination import parse_pagination_params from backend.schemas import ResourceCreateSchema -from typing import List class ResourcesView: def __init__(self): @@ -71,3 +70,12 @@ async def search(self, filter: str = None, range: str = None, sort: str = None): 'Content-Range': f'resources {offset}-{offset + len(resources) - 1}/{total_count}' } return JSONResponse([resource.dict() for resource in resources], status_code=200, headers=headers) + + async def shared(self, user_id: str = None): + print("user_id: ", user_id) + resources = await self.rm.retrieve_shared_resources(user_id) + if resources is None: + return JSONResponse({"error": "Shared resources not found"}, status_code=404) + return JSONResponse([resource.dict() for resource in resources], status_code=200) + + \ No newline at end of file diff --git a/backend/api/SharesView.py b/backend/api/SharesView.py index b04f5231..fe462d3d 100644 --- a/backend/api/SharesView.py +++ b/backend/api/SharesView.py @@ -1,12 +1,14 @@ from starlette.responses import JSONResponse, Response from common.paths import api_base_url from backend.managers.SharesManager import SharesManager +from backend.managers.ResourcesManager import ResourcesManager from backend.pagination import parse_pagination_params from datetime import datetime, timezone class SharesView: def __init__(self): self.slm = SharesManager() + self.rm = ResourcesManager() async def get(self, id: str): share = await self.slm.retrieve_share(id) @@ -20,6 +22,9 @@ async def post(self, body: dict): expiration_dt = datetime.fromisoformat(body['expiration_dt']).astimezone(tz=timezone.utc) user_id = None if 'user_id' in body and body['user_id']: + valid = await self.slm.validate_assistant_user_id(body['resource_id'], body['user_id']) + if valid is not None: + return JSONResponse({"error": valid}, status_code=400) user_id = body['user_id'] new_share = await self.slm.create_share(resource_id=body['resource_id'], user_id=user_id, @@ -61,4 +66,4 @@ async def search(self, filter: str = None, range: str = None, sort: str = None): 'X-Total-Count': str(total_count), 'Content-Range': f'shares {offset}-{offset + len(shares) - 1}/{total_count}' } - return JSONResponse([share.model_dump() for share in shares], status_code=200, headers=headers) + return JSONResponse([share.model_dump() for share in shares], status_code=200, headers=headers) \ No newline at end of file diff --git a/backend/managers/ConversationsManager.py b/backend/managers/ConversationsManager.py index 963e1b7c..049495b4 100644 --- a/backend/managers/ConversationsManager.py +++ b/backend/managers/ConversationsManager.py @@ -1,7 +1,7 @@ from uuid import uuid4 from threading import Lock from sqlalchemy import select, update, delete, func -from backend.models import Conversation, Resource, Message +from backend.models import Conversation, Resource, Message, User from backend.db import db_session_context from backend.schemas import ConversationSchema, ConversationCreateSchema, MessageSchema from typing import List, Tuple, Optional, Dict, Any @@ -24,18 +24,18 @@ def __init__(self): if not hasattr(self, '_initialized'): self._initialized = True - async def create_conversation(self, resource_id: str, conversation_data: ConversationCreateSchema) -> Optional[str]: + async def create_conversation(self, resource_id: str, user_id: str, conversation_data: ConversationCreateSchema) -> Optional[str]: async with db_session_context() as session: - if not await self.validate_assistant_id(resource_id): + if not await self.validate_assistant_user_id(resource_id, user_id): return None timestamp = get_current_timestamp() conversation_data['created_timestamp'] = timestamp conversation_data['last_updated_timestamp'] = timestamp conversation_data['archive'] = "False" - conversation_data['assistant_id'] = resource_id - + conversation_data['assistant_id'] = resource_id + conversation_data['user_id'] = user_id new_conversation = Conversation(id=str(uuid4()), **conversation_data) session.add(new_conversation) await session.commit() @@ -113,7 +113,8 @@ async def retrieve_conversation(self, id: str) -> Optional[ConversationSchema]: last_updated_timestamp=conversation.last_updated_timestamp, archive=conversation.archive, assistant_id=conversation.assistant_id, - messages=messages_list + messages=messages_list, + user_id=conversation.user_id ) return None @@ -155,10 +156,13 @@ async def _get_total_count(self, session, filters: Optional[Dict[str, Any]]) -> count_query = self._apply_filters(count_query, filters) total_count = await session.execute(count_query) return total_count.scalar() - - - async def validate_assistant_id(self, assistant_id: str) -> bool: + + async def validate_assistant_user_id(self, assistant_id: str, user_id: str) -> bool: async with db_session_context() as session: - result = await session.execute(select(Resource).filter(Resource.id == assistant_id)) - return result.scalar_one_or_none() is not None \ No newline at end of file + assistant = await session.execute(select(Resource).filter(Resource.id == assistant_id)) + user = await session.execute(select(User).filter(User.id == user_id)) + if not assistant.scalar_one_or_none() or not user.scalar_one_or_none(): + return False + return True + \ No newline at end of file diff --git a/backend/managers/ResourcesManager.py b/backend/managers/ResourcesManager.py index 52b9f2cf..dabc0fc6 100644 --- a/backend/managers/ResourcesManager.py +++ b/backend/managers/ResourcesManager.py @@ -2,10 +2,11 @@ from threading import Lock import httpx from sqlalchemy import select, update, delete, func -from backend.models import Resource, File, Conversation +from backend.models import Resource, File, Conversation, Share from backend.db import db_session_context from backend.schemas import ResourceCreateSchema, ResourceSchema from typing import List, Tuple, Optional, Dict, Any +from backend.managers.UsersManager import UsersManager # This is a mock of the installed models in the system ollama_model=[{'name': 'llama3:latest', 'model': 'llama3:latest', 'modified_at': '2024-08-24T21:57:16.6075173-06:00', 'size': 2176178913, 'digest': '4f222292793889a9a40a020799cfd28d53f3e01af25d48e06c5e708610fc47e9', 'details': {'parent_model': '', 'format': 'gguf', 'family': 'phi3', 'families': ['phi3'], 'parameter_size': '3.8B', 'quantization_level': 'Q4_0'}}] @@ -42,7 +43,8 @@ async def create_resource(self, resource_data: ResourceCreateSchema) -> str: "persona_id": resource_data.get("persona_id"), "status": resource_data.get("status"), "allow_edit": resource_data.get("allow_edit"), - "kind": kind + "kind": kind, + "user_id": resource_data.get("user_id") } else: resource_data_table={ @@ -107,7 +109,8 @@ async def retrieve_resource(self, id: str) -> Optional[ResourceSchema]: persona_id=resource.persona_id, status=resource.status, allow_edit=resource.allow_edit, - kind=resource.kind) + kind=resource.kind, + user_id=resource.user_id) else : return ResourceSchema( id=resource.id, @@ -160,7 +163,7 @@ async def retrieve_resources(self, offset: int = 0, limit: int = 100, sort_by: O else: query = query.filter(getattr(Resource, key) == value) - if sort_by and sort_by in ['id', 'name', 'uri','status','allow_edit','kind','active']: + if sort_by and sort_by in ['id', 'name', 'uri','status','allow_edit','kind','active','user_id']: order_column = getattr(Resource, sort_by) query = query.order_by(order_column.desc() if sort_order.lower() == 'desc' else order_column) @@ -189,13 +192,16 @@ async def retrieve_resources(self, offset: int = 0, limit: int = 100, sort_by: O async def validate_resource_data(self, resource_data: ResourceCreateSchema ) -> Optional[str]: kind = resource_data["kind"] - if kind not in ["llm", "assistant"]: + if kind not in ["llm", "assistant"]: return "Not a valid kind" - if kind == 'assistant': + um = UsersManager() + if kind == 'assistant': if not resource_data.get("resource_llm_id"): return "It is mandatory to provide a resource_llm_id for an assistant" if not await self.retrieve_resource(resource_data.get("resource_llm_id")): return "Not a valid resource_llm_id" + if not await um.retrieve_user(resource_data.get("user_id")): + return "Not a valid user_id" return None @@ -248,4 +254,15 @@ async def retrieve_resource_conversations(self, resource_id: str) -> bool: conversations_ids = [conversation.id for conversation in conversations] print("conversations_ids: ", conversations_ids) return conversations_ids - \ No newline at end of file + + async def retrieve_shared_resources(self, user_id: str) -> List[ResourceSchema]: + async with db_session_context() as session: + result = await session.execute(select(Share).filter(Share.user_id == user_id)) + shares = result.scalars().all() + resource_ids = [share.resource_id for share in shares] + query = select(Resource).filter(Resource.id.in_(resource_ids)) + result = await session.execute(query) + resources = result.scalars().all() + + resources = [ResourceSchema.from_orm(resource) for resource in resources] + return resources \ No newline at end of file diff --git a/backend/managers/SharesManager.py b/backend/managers/SharesManager.py index a729678a..d47d738c 100644 --- a/backend/managers/SharesManager.py +++ b/backend/managers/SharesManager.py @@ -2,7 +2,7 @@ import string from threading import Lock from sqlalchemy import select, insert, update, delete, func -from backend.models import Share +from backend.models import Share, Resource, User from backend.db import db_session_context from backend.schemas import ShareCreateSchema, ShareSchema from typing import List, Tuple, Optional, Dict, Any @@ -114,3 +114,13 @@ async def retrieve_shares(self, offset: int = 0, limit: int = 100, sort_by: Opti total_count = total_count.scalar() return shares, total_count + + async def validate_assistant_user_id(self, assistant_id: str, user_id: str) -> Optional[str]: + async with db_session_context() as session: + assistant = await session.execute(select(Resource).filter(Resource.id == assistant_id)) + user = await session.execute(select(User).filter(User.id == user_id)) + if not assistant.scalar_one_or_none(): + return "Not a valid resource_id" + if not user.scalar_one_or_none(): + return "Not a valid user_id" + return None \ No newline at end of file diff --git a/backend/models.py b/backend/models.py index 03c74fe2..385453b1 100644 --- a/backend/models.py +++ b/backend/models.py @@ -25,6 +25,7 @@ class Resource(SQLModelBase, table=True): kind: str | None = Field(default=None) icon: str | None = Field(default=None) active: str | None = Field(default=None) + user_id: str | None = Field(default=None) class User(SQLModelBase, table=True): id: str = Field(primary_key=True, default_factory=lambda: str(uuid4())) @@ -106,6 +107,7 @@ class Conversation(SQLModelBase, table=True): last_updated_timestamp: str = Field() archive: str = Field() assistant_id: str | None = Field(default=None, foreign_key="resource.id") + user_id: str | None = Field(default=None, foreign_key="user.id") class Voice(SQLModelBase, table=True): id: str = Field(primary_key=True, default_factory=lambda: str(uuid4())) diff --git a/backend/schemas.py b/backend/schemas.py index e5b1b4c4..fb395ca0 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -25,6 +25,7 @@ class ResourceBaseSchema(BaseModel): kind : str icon : Optional[str] = None active : Optional[str] = None + user_id: Optional[str] = None class Config: orm_mode = True from_attributes = True @@ -183,6 +184,7 @@ class ConversationBaseSchema(BaseModel): archive: str assistant_id: str messages: Optional[List[MessageSchema]] = None + user_id: str class Config: orm_mode = True diff --git a/migrations/versions/75aaaf2cd1a2_added_resource_table.py b/migrations/versions/75aaaf2cd1a2_added_resource_table.py index f6c86d92..bc062509 100644 --- a/migrations/versions/75aaaf2cd1a2_added_resource_table.py +++ b/migrations/versions/75aaaf2cd1a2_added_resource_table.py @@ -32,6 +32,7 @@ def upgrade() -> None: sa.Column('kind', sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.Column('icon', sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.Column('active', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('user_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.PrimaryKeyConstraint('id') ) diff --git a/migrations/versions/d34acf83524e_added_conversation_table.py b/migrations/versions/d34acf83524e_added_conversation_table.py index 0c4e5f5e..92cfd320 100644 --- a/migrations/versions/d34acf83524e_added_conversation_table.py +++ b/migrations/versions/d34acf83524e_added_conversation_table.py @@ -28,6 +28,8 @@ def upgrade() -> None: sa.Column('last_updated_timestamp', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('archive', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('assistant_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('user_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], ), sa.PrimaryKeyConstraint('id') ) # ### end Alembic commands ###