Skip to content

Commit

Permalink
include user_id in conversations and resources
Browse files Browse the repository at this point in the history
  • Loading branch information
AnniePacheco committed Oct 29, 2024
1 parent 875fde1 commit 6454929
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 27 deletions.
44 changes: 40 additions & 4 deletions apis/paios/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,26 @@ paths:
description: No Content
'404':
description: Not Found
'/resources/{user_id}/shared':
get:
security:
- jwt: []
tags:
- Resource Management
summary: Retrieve resource by user id
operationId: backend.api.ResourcesView.shared
description: Retrieve the information of the resource with the specified user ID.
parameters:
- $ref: '#/components/parameters/user_id'
responses:
'200':
description: OK
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/Resource'
'/config/{key}':
get:
security:
Expand Down Expand Up @@ -1081,7 +1101,7 @@ paths:
description: No Content
'404':
description: Not Found
'/conversations/{resource_id}':
'/conversations/{resource_id}/{user_id}':
post:
security:
- jwt: []
Expand All @@ -1090,7 +1110,8 @@ paths:
- Conversation Management
operationId: backend.api.ConversationsView.post
parameters:
- $ref: '#/components/parameters/resource_id'
- $ref: '#/components/parameters/resource_id'
- $ref: '#/components/parameters/user_id'
responses:
'201':
description: Created
Expand Down Expand Up @@ -1317,6 +1338,13 @@ components:
in: path
description: this refers to the assitant id for the resource
required: true
schema:
$ref: '#/components/schemas/uuid4'
user_id:
name: user_id
in: path
description: this refers to the user id
required: true
schema:
$ref: '#/components/schemas/uuid4'
schemas:
Expand Down Expand Up @@ -1571,7 +1599,10 @@ components:
$ref: '#/components/schemas/uri'
active:
nullable: true
$ref: '#/components/schemas/boolean_str'
$ref: '#/components/schemas/boolean_str'
user_id:
$ref: '#/components/schemas/uuid4'
nullable: true
ResourceCreate:
type: object
title: ResourceCreate
Expand Down Expand Up @@ -1606,7 +1637,10 @@ components:
$ref: '#/components/schemas/uri'
active:
nullable: true
$ref: '#/components/schemas/boolean_str'
$ref: '#/components/schemas/boolean_str'
user_id:
$ref: '#/components/schemas/uuid4'
nullable: true
required:
- name
- kind
Expand Down Expand Up @@ -1958,6 +1992,8 @@ components:
type: array
items:
$ref: '#/components/schemas/Message'
user_id:
$ref: '#/components/schemas/uuid4'
required:
- name
- archive
Expand Down
4 changes: 2 additions & 2 deletions backend/api/ConversationsView.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ async def get(self, id: str):
return NOT_FOUND_RESPONSE
return JSONResponse(conversation.dict(), status_code=200)

async def post(self, resource_id: str, body: ConversationCreateSchema):
conversation_id = await self.cm.create_conversation(resource_id, body)
async def post(self, resource_id: str, user_id: str, body: ConversationCreateSchema):
conversation_id = await self.cm.create_conversation(resource_id, user_id, body)
conversation = await self.cm.retrieve_conversation(conversation_id)
if conversation is None:
return JSONResponse({"error": "Assistant not found"}, status_code=404)
Expand Down
10 changes: 9 additions & 1 deletion backend/api/ResourcesView.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from backend.managers.ConversationsManager import ConversationsManager
from backend.pagination import parse_pagination_params
from backend.schemas import ResourceCreateSchema
from typing import List

class ResourcesView:
def __init__(self):
Expand Down Expand Up @@ -71,3 +70,12 @@ async def search(self, filter: str = None, range: str = None, sort: str = None):
'Content-Range': f'resources {offset}-{offset + len(resources) - 1}/{total_count}'
}
return JSONResponse([resource.dict() for resource in resources], status_code=200, headers=headers)

async def shared(self, user_id: str = None):
print("user_id: ", user_id)
resources = await self.rm.retrieve_shared_resources(user_id)
if resources is None:
return JSONResponse({"error": "Shared resources not found"}, status_code=404)
return JSONResponse([resource.dict() for resource in resources], status_code=200)


7 changes: 6 additions & 1 deletion backend/api/SharesView.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from starlette.responses import JSONResponse, Response
from common.paths import api_base_url
from backend.managers.SharesManager import SharesManager
from backend.managers.ResourcesManager import ResourcesManager
from backend.pagination import parse_pagination_params
from datetime import datetime, timezone

class SharesView:
def __init__(self):
self.slm = SharesManager()
self.rm = ResourcesManager()

async def get(self, id: str):
share = await self.slm.retrieve_share(id)
Expand All @@ -20,6 +22,9 @@ async def post(self, body: dict):
expiration_dt = datetime.fromisoformat(body['expiration_dt']).astimezone(tz=timezone.utc)
user_id = None
if 'user_id' in body and body['user_id']:
valid = await self.slm.validate_assistant_user_id(body['resource_id'], body['user_id'])
if valid is not None:
return JSONResponse({"error": valid}, status_code=400)
user_id = body['user_id']
new_share = await self.slm.create_share(resource_id=body['resource_id'],
user_id=user_id,
Expand Down Expand Up @@ -61,4 +66,4 @@ async def search(self, filter: str = None, range: str = None, sort: str = None):
'X-Total-Count': str(total_count),
'Content-Range': f'shares {offset}-{offset + len(shares) - 1}/{total_count}'
}
return JSONResponse([share.model_dump() for share in shares], status_code=200, headers=headers)
return JSONResponse([share.model_dump() for share in shares], status_code=200, headers=headers)
26 changes: 15 additions & 11 deletions backend/managers/ConversationsManager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from uuid import uuid4
from threading import Lock
from sqlalchemy import select, update, delete, func
from backend.models import Conversation, Resource, Message
from backend.models import Conversation, Resource, Message, User
from backend.db import db_session_context
from backend.schemas import ConversationSchema, ConversationCreateSchema, MessageSchema
from typing import List, Tuple, Optional, Dict, Any
Expand All @@ -24,18 +24,18 @@ def __init__(self):
if not hasattr(self, '_initialized'):
self._initialized = True

async def create_conversation(self, resource_id: str, conversation_data: ConversationCreateSchema) -> Optional[str]:
async def create_conversation(self, resource_id: str, user_id: str, conversation_data: ConversationCreateSchema) -> Optional[str]:
async with db_session_context() as session:

if not await self.validate_assistant_id(resource_id):
if not await self.validate_assistant_user_id(resource_id, user_id):
return None

timestamp = get_current_timestamp()
conversation_data['created_timestamp'] = timestamp
conversation_data['last_updated_timestamp'] = timestamp
conversation_data['archive'] = "False"
conversation_data['assistant_id'] = resource_id

conversation_data['assistant_id'] = resource_id
conversation_data['user_id'] = user_id
new_conversation = Conversation(id=str(uuid4()), **conversation_data)
session.add(new_conversation)
await session.commit()
Expand Down Expand Up @@ -113,7 +113,8 @@ async def retrieve_conversation(self, id: str) -> Optional[ConversationSchema]:
last_updated_timestamp=conversation.last_updated_timestamp,
archive=conversation.archive,
assistant_id=conversation.assistant_id,
messages=messages_list
messages=messages_list,
user_id=conversation.user_id
)
return None

Expand Down Expand Up @@ -155,10 +156,13 @@ async def _get_total_count(self, session, filters: Optional[Dict[str, Any]]) ->
count_query = self._apply_filters(count_query, filters)
total_count = await session.execute(count_query)
return total_count.scalar()


async def validate_assistant_id(self, assistant_id: str) -> bool:

async def validate_assistant_user_id(self, assistant_id: str, user_id: str) -> bool:
async with db_session_context() as session:
result = await session.execute(select(Resource).filter(Resource.id == assistant_id))
return result.scalar_one_or_none() is not None
assistant = await session.execute(select(Resource).filter(Resource.id == assistant_id))
user = await session.execute(select(User).filter(User.id == user_id))
if not assistant.scalar_one_or_none() or not user.scalar_one_or_none():
return False
return True

31 changes: 24 additions & 7 deletions backend/managers/ResourcesManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from threading import Lock
import httpx
from sqlalchemy import select, update, delete, func
from backend.models import Resource, File, Conversation
from backend.models import Resource, File, Conversation, Share
from backend.db import db_session_context
from backend.schemas import ResourceCreateSchema, ResourceSchema
from typing import List, Tuple, Optional, Dict, Any
from backend.managers.UsersManager import UsersManager

# This is a mock of the installed models in the system
ollama_model=[{'name': 'llama3:latest', 'model': 'llama3:latest', 'modified_at': '2024-08-24T21:57:16.6075173-06:00', 'size': 2176178913, 'digest': '4f222292793889a9a40a020799cfd28d53f3e01af25d48e06c5e708610fc47e9', 'details': {'parent_model': '', 'format': 'gguf', 'family': 'phi3', 'families': ['phi3'], 'parameter_size': '3.8B', 'quantization_level': 'Q4_0'}}]
Expand Down Expand Up @@ -42,7 +43,8 @@ async def create_resource(self, resource_data: ResourceCreateSchema) -> str:
"persona_id": resource_data.get("persona_id"),
"status": resource_data.get("status"),
"allow_edit": resource_data.get("allow_edit"),
"kind": kind
"kind": kind,
"user_id": resource_data.get("user_id")
}
else:
resource_data_table={
Expand Down Expand Up @@ -107,7 +109,8 @@ async def retrieve_resource(self, id: str) -> Optional[ResourceSchema]:
persona_id=resource.persona_id,
status=resource.status,
allow_edit=resource.allow_edit,
kind=resource.kind)
kind=resource.kind,
user_id=resource.user_id)
else :
return ResourceSchema(
id=resource.id,
Expand Down Expand Up @@ -160,7 +163,7 @@ async def retrieve_resources(self, offset: int = 0, limit: int = 100, sort_by: O
else:
query = query.filter(getattr(Resource, key) == value)

if sort_by and sort_by in ['id', 'name', 'uri','status','allow_edit','kind','active']:
if sort_by and sort_by in ['id', 'name', 'uri','status','allow_edit','kind','active','user_id']:
order_column = getattr(Resource, sort_by)
query = query.order_by(order_column.desc() if sort_order.lower() == 'desc' else order_column)

Expand Down Expand Up @@ -189,13 +192,16 @@ async def retrieve_resources(self, offset: int = 0, limit: int = 100, sort_by: O

async def validate_resource_data(self, resource_data: ResourceCreateSchema ) -> Optional[str]:
kind = resource_data["kind"]
if kind not in ["llm", "assistant"]:
if kind not in ["llm", "assistant"]:
return "Not a valid kind"
if kind == 'assistant':
um = UsersManager()
if kind == 'assistant':
if not resource_data.get("resource_llm_id"):
return "It is mandatory to provide a resource_llm_id for an assistant"
if not await self.retrieve_resource(resource_data.get("resource_llm_id")):
return "Not a valid resource_llm_id"
if not await um.retrieve_user(resource_data.get("user_id")):
return "Not a valid user_id"
return None


Expand Down Expand Up @@ -248,4 +254,15 @@ async def retrieve_resource_conversations(self, resource_id: str) -> bool:
conversations_ids = [conversation.id for conversation in conversations]
print("conversations_ids: ", conversations_ids)
return conversations_ids


async def retrieve_shared_resources(self, user_id: str) -> List[ResourceSchema]:
async with db_session_context() as session:
result = await session.execute(select(Share).filter(Share.user_id == user_id))
shares = result.scalars().all()
resource_ids = [share.resource_id for share in shares]
query = select(Resource).filter(Resource.id.in_(resource_ids))
result = await session.execute(query)
resources = result.scalars().all()

resources = [ResourceSchema.from_orm(resource) for resource in resources]
return resources
12 changes: 11 additions & 1 deletion backend/managers/SharesManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import string
from threading import Lock
from sqlalchemy import select, insert, update, delete, func
from backend.models import Share
from backend.models import Share, Resource, User
from backend.db import db_session_context
from backend.schemas import ShareCreateSchema, ShareSchema
from typing import List, Tuple, Optional, Dict, Any
Expand Down Expand Up @@ -114,3 +114,13 @@ async def retrieve_shares(self, offset: int = 0, limit: int = 100, sort_by: Opti
total_count = total_count.scalar()

return shares, total_count

async def validate_assistant_user_id(self, assistant_id: str, user_id: str) -> Optional[str]:
async with db_session_context() as session:
assistant = await session.execute(select(Resource).filter(Resource.id == assistant_id))
user = await session.execute(select(User).filter(User.id == user_id))
if not assistant.scalar_one_or_none():
return "Not a valid resource_id"
if not user.scalar_one_or_none():
return "Not a valid user_id"
return None
2 changes: 2 additions & 0 deletions backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Resource(SQLModelBase, table=True):
kind: str | None = Field(default=None)
icon: str | None = Field(default=None)
active: str | None = Field(default=None)
user_id: str | None = Field(default=None)

class User(SQLModelBase, table=True):
id: str = Field(primary_key=True, default_factory=lambda: str(uuid4()))
Expand Down Expand Up @@ -106,6 +107,7 @@ class Conversation(SQLModelBase, table=True):
last_updated_timestamp: str = Field()
archive: str = Field()
assistant_id: str | None = Field(default=None, foreign_key="resource.id")
user_id: str | None = Field(default=None, foreign_key="user.id")

class Voice(SQLModelBase, table=True):
id: str = Field(primary_key=True, default_factory=lambda: str(uuid4()))
Expand Down
2 changes: 2 additions & 0 deletions backend/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ResourceBaseSchema(BaseModel):
kind : str
icon : Optional[str] = None
active : Optional[str] = None
user_id: Optional[str] = None
class Config:
orm_mode = True
from_attributes = True
Expand Down Expand Up @@ -183,6 +184,7 @@ class ConversationBaseSchema(BaseModel):
archive: str
assistant_id: str
messages: Optional[List[MessageSchema]] = None
user_id: str

class Config:
orm_mode = True
Expand Down
1 change: 1 addition & 0 deletions migrations/versions/75aaaf2cd1a2_added_resource_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def upgrade() -> None:
sa.Column('kind', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('icon', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('active', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('user_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.PrimaryKeyConstraint('id')
)

Expand Down
2 changes: 2 additions & 0 deletions migrations/versions/d34acf83524e_added_conversation_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def upgrade() -> None:
sa.Column('last_updated_timestamp', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('archive', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('assistant_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('user_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
Expand Down

0 comments on commit 6454929

Please sign in to comment.