Skip to content

Commit

Permalink
refactoring: LlmsManager => ModelsManager
Browse files Browse the repository at this point in the history
  • Loading branch information
pjbedard committed Dec 5, 2024
1 parent 72745c5 commit 1c1b981
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
12 changes: 6 additions & 6 deletions backend/api/LlmsView.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from starlette.responses import JSONResponse
from backend.managers.LlmsManager import LlmsManager
from backend.managers.ModelsManager import ModelsManager
from backend.pagination import parse_pagination_params
from backend.schemas import LlmSchema
from litellm.exceptions import BadRequestError

class LlmsView:
def __init__(self):
self.llmm = LlmsManager()
self.mm = ModelsManager()

async def get(self, id: str):
llm = await self.llmm.get_llm(id)
llm = await self.mm.get_llm(id)
if llm is None:
return JSONResponse(headers={"error": "LLM not found"}, status_code=404)
llm_schema = LlmSchema(id=llm.id, name=llm.name, full_name=f"{llm.provider}/{llm.name}",
Expand All @@ -23,7 +23,7 @@ async def search(self, filter: str = None, range: str = None, sort: str = None):

offset, limit, sort_by, sort_order, filters = result

llms, total_count = await self.llmm.retrieve_llms(limit=limit, offset=offset, sort_by=sort_by, sort_order=sort_order, filters=filters)
llms, total_count = await self.mm.retrieve_llms(limit=limit, offset=offset, sort_by=sort_by, sort_order=sort_order, filters=filters)
results = [LlmSchema(id=llm.id, name=llm.name, full_name=f"{llm.provider}/{llm.name}",
provider=llm.provider, api_base=llm.api_base,
is_active=llm.is_active)
Expand All @@ -36,7 +36,7 @@ async def search(self, filter: str = None, range: str = None, sort: str = None):

async def completion(self, id: str, body: dict):
print("completion. body: {}".format(body))
llm = await self.llmm.get_llm(id)
llm = await self.mm.get_llm(id)
if llm:
messages = []
if 'messages' in body and body['messages']:
Expand All @@ -45,7 +45,7 @@ async def completion(self, id: str, body: dict):
if 'optional_params' in body and body['optional_params']:
opt_params = body['optional_params']
try:
response = self.llmm.completion(llm, messages, **opt_params)
response = self.mm.completion(llm, messages, **opt_params)
return JSONResponse(response.model_dump(), status_code=200)
except BadRequestError as e:
return JSONResponse(status_code=400, content={"message": e.message})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
import logging
logger = logging.getLogger(__name__)

class LlmsManager:
class ModelsManager:
_instance = None
_lock = Lock()

def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super(LlmsManager, cls).__new__(cls, *args, **kwargs)
cls._instance = super(ModelsManager, cls).__new__(cls, *args, **kwargs)
return cls._instance

def __init__(self):
Expand Down Expand Up @@ -80,7 +80,7 @@ def completion(self, llm, messages, **optional_params) -> Union[ModelResponse, C
response = self.router.completion(model=llm.llm_name,
messages=messages,
**optional_params)
# below is the direct way to call the LLM (i.e. not using the router):
# below is the direct way to call the model (i.e. not using the router):
#response = completion(model=llm.llm_name,
# messages=messages,
# **optional_params)
Expand All @@ -98,7 +98,7 @@ async def _init_router(self):
await asyncio.gather(ollama_task,
openai_task,
return_exceptions=True)
# collect the available LLMs
# collect the available models
llms, total_llms = await self.retrieve_llms()
# configure router
model_list = []
Expand Down

0 comments on commit 1c1b981

Please sign in to comment.