From f1c3e9cae84c95ecf33199f98cca5d36fcef59f0 Mon Sep 17 00:00:00 2001 From: Annie Pacheco <161406069+AnniePacheco@users.noreply.github.com> Date: Tue, 5 Nov 2024 09:19:28 -0700 Subject: [PATCH] Update SharesView.py --- backend/api/SharesView.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/backend/api/SharesView.py b/backend/api/SharesView.py index fe462d3d..8d665d07 100644 --- a/backend/api/SharesView.py +++ b/backend/api/SharesView.py @@ -4,6 +4,7 @@ from backend.managers.ResourcesManager import ResourcesManager from backend.pagination import parse_pagination_params from datetime import datetime, timezone +from backend.schemas import ShareCreateSchema class SharesView: def __init__(self): @@ -13,7 +14,7 @@ def __init__(self): async def get(self, id: str): share = await self.slm.retrieve_share(id) if share is None: - return JSONResponse(headers={"error": "Share not found"}, status_code=404) + return JSONResponse({"error": "Share not found"}, status_code=404) return JSONResponse(share.model_dump(), status_code=200) async def post(self, body: dict): @@ -32,21 +33,19 @@ async def post(self, body: dict): is_revoked=False) return JSONResponse(new_share.model_dump(), status_code=201, headers={'Location': f'{api_base_url}/shares/{new_share.id}'}) - async def put(self, id: str, body: dict): - expiration_dt = None + async def put(self, id: str, body: ShareCreateSchema): if 'expiration_dt' in body and body['expiration_dt'] is not None: expiration_dt = datetime.fromisoformat(body['expiration_dt']).astimezone(tz=timezone.utc) - user_id = None + body['expiration_dt'] = expiration_dt if 'user_id' in body and body['user_id']: - user_id = body['user_id'] - updated_share = await self.slm.update_share(id, - resource_id=body['resource_id'], - user_id=user_id, - expiration_dt=expiration_dt, - is_revoked=body['is_revoked']) + valid = await self.slm.validate_assistant_user_id(body['resource_id'], body['user_id']) + if valid is not None: + return JSONResponse({"error": valid}, status_code=400) + await self.slm.update_share(id, body) + updated_share = await self.slm.retrieve_share(id) if updated_share is None: return JSONResponse({"error": "Share not found"}, status_code=404) - return JSONResponse(updated_share.model_dump(), status_code=200) + return JSONResponse(updated_share.dict(), status_code=200) async def delete(self, id: str): success = await self.slm.delete_share(id) @@ -61,9 +60,14 @@ async def search(self, filter: str = None, range: str = None, sort: str = None): offset, limit, sort_by, sort_order, filters = result - shares, total_count = await self.slm.retrieve_shares(limit=limit, offset=offset, sort_by=sort_by, sort_order=sort_order, filters=filters) + shares, total_count = await self.slm.retrieve_shares(limit=limit, + offset=offset, + sort_by=sort_by, + sort_order=sort_order, + filters=filters + ) headers = { 'X-Total-Count': str(total_count), 'Content-Range': f'shares {offset}-{offset + len(shares) - 1}/{total_count}' } - return JSONResponse([share.model_dump() for share in shares], status_code=200, headers=headers) \ No newline at end of file + return JSONResponse([share.model_dump() for share in shares], status_code=200, headers=headers)