Skip to content

Commit

Permalink
Integrate aisuite in ModelsManager to replace LiteLLM
Browse files Browse the repository at this point in the history
  • Loading branch information
pjbedard committed Jan 9, 2025
1 parent 1c1b981 commit cda38e6
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 28 deletions.
4 changes: 2 additions & 2 deletions apis/paios/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ paths:
content:
application/json:
schema:
type: object
type: string
'400':
description: Completion failed
'404':
Expand Down Expand Up @@ -1173,7 +1173,7 @@ components:
pattern: ^[a-z]{4}-[a-z]{4}-[a-z]{4}$
messagesList:
type: array
example: [{"role": "user", "content": "What is Kwaai.ai?"}]
example: [{"role": "user", "content": "What is Personal AI?"}]
items:
type: object
properties:
Expand Down
12 changes: 6 additions & 6 deletions backend/api/LlmsView.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ async def get(self, id: str):
llm = await self.mm.get_llm(id)
if llm is None:
return JSONResponse(headers={"error": "LLM not found"}, status_code=404)
llm_schema = LlmSchema(id=llm.id, name=llm.name, full_name=f"{llm.provider}/{llm.name}",
provider=llm.provider, api_base=llm.api_base, is_active=llm.is_active)
llm_schema = LlmSchema(id=llm.id, name=llm.name, provider=llm.provider, full_name=llm.aisuite_name,
api_base=llm.api_base, is_active=llm.is_active)
return JSONResponse(llm_schema.model_dump(), status_code=200)

async def search(self, filter: str = None, range: str = None, sort: str = None):
Expand All @@ -24,9 +24,8 @@ async def search(self, filter: str = None, range: str = None, sort: str = None):
offset, limit, sort_by, sort_order, filters = result

llms, total_count = await self.mm.retrieve_llms(limit=limit, offset=offset, sort_by=sort_by, sort_order=sort_order, filters=filters)
results = [LlmSchema(id=llm.id, name=llm.name, full_name=f"{llm.provider}/{llm.name}",
provider=llm.provider, api_base=llm.api_base,
is_active=llm.is_active)
results = [LlmSchema(id=llm.id, name=llm.name, provider=llm.provider, full_name=llm.aisuite_name,
api_base=llm.api_base, is_active=llm.is_active)
for llm in llms]
headers = {
'X-Total-Count': str(total_count),
Expand All @@ -46,7 +45,8 @@ async def completion(self, id: str, body: dict):
opt_params = body['optional_params']
try:
response = self.mm.completion(llm, messages, **opt_params)
return JSONResponse(response.model_dump(), status_code=200)
#return JSONResponse(response.model_dump(), status_code=200) # LiteLLM response handling
return JSONResponse(response.choices[0].message.content, status_code=200) # aisuite response handling
except BadRequestError as e:
return JSONResponse(status_code=400, content={"message": e.message})
except Exception as e:
Expand Down
45 changes: 28 additions & 17 deletions backend/managers/ModelsManager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import httpx
import aisuite as ai
from threading import Lock
from sqlalchemy import select, insert, update, delete, func
from backend.models import Llm
Expand Down Expand Up @@ -27,9 +28,10 @@ def __init__(self):
if not hasattr(self, '_initialized'): # Ensure initialization happens only once
with self._lock:
if not hasattr(self, '_initialized'):
self.ai_client = ai.Client()
self.router = None
router_init_task = asyncio.create_task(self._init_router())
asyncio.gather(router_init_task, return_exceptions=True)
model_load_task = asyncio.create_task(self._load_models())
asyncio.gather(model_load_task, return_exceptions=True)
self._initialized = True

async def get_llm(self, id: str) -> Optional[Llm]:
Expand Down Expand Up @@ -77,20 +79,25 @@ async def retrieve_llms(self, offset: int = 0, limit: int = 100, sort_by: Option

def completion(self, llm, messages, **optional_params) -> Union[ModelResponse, CustomStreamWrapper]:
try:
response = self.router.completion(model=llm.llm_name,
messages=messages,
**optional_params)
# below is the direct way to call the model (i.e. not using the router):
response = self.ai_client.chat.completions.create(model=llm.aisuite_name,
messages=messages,
**optional_params)
# below is the way to call the model using the LiteLLM router
#response = self.router.completion(model=llm.llm_name,
# messages=messages,
# **optional_params)
# below is the direct way to call the model using LiteLLM (i.e. not using the router):
#response = completion(model=llm.llm_name,
# messages=messages,
# **optional_params)
print("completion response: {}".format(response))
#print("completion response content: {}".format(response.choices[0].message.content))
return response
except Exception as e:
logger.info(f"completion failed with error: {e.message}")
raise

async def _init_router(self):
async def _load_models(self):
try:
# load models
ollama_task = asyncio.create_task(self._load_ollama_models())
Expand Down Expand Up @@ -123,13 +130,13 @@ async def _init_router(self):

async def _load_ollama_models(self):
try:
ollama_urlroot = get_env_key("OLLAMA_URLROOT")
ollama_api_url = get_env_key("OLLAMA_API_URL")
except ValueError:
print("No Ollama server specified. Skipping.")
return # no Ollama server specified, skip
# retrieve list of installed models
async with httpx.AsyncClient() as client:
response = await client.get("{}/api/tags".format(ollama_urlroot))
response = await client.get("{}/api/tags".format(ollama_api_url))
if response.status_code == 200:
data = response.json()
available_models = [model_data['model'] for model_data in data.get("models", [])]
Expand All @@ -140,24 +147,25 @@ async def _load_ollama_models(self):
models = {}
for model in available_models:
name = model.removesuffix(":latest")
aisuite_name = "{}:{}".format(provider,name) # what aisuite expects
llm_name = "{}/{}".format(provider,name) # what LiteLLM expects
safe_name = llm_name.replace("/", "-").replace(":", "-") # URL-friendly ID
models[model] = {"id": safe_name, "name": name, "llm_name": llm_name, "provider": provider, "api_base": ollama_urlroot}
models[model] = {"id": safe_name, "name": name, "provider": provider, "aisuite_name": aisuite_name, "llm_name": llm_name, "api_base": ollama_api_url}
await self._persist_models(provider=provider, models=models)

async def _load_openai_models(self):
try:
openai_key = get_env_key("OPENAI_API_KEY")
openai_api_key = get_env_key("OPENAI_API_KEY")
except ValueError:
print("No OpenAI API key specified. Skipping.")
return # no OpenAI API key specified, skip
# retrieve list of installed models
async with httpx.AsyncClient() as client:
openai_urlroot = get_env_key("OPENAI_URLROOT", "https://api.openai.com")
openai_api_url = get_env_key("OPENAI_API_URL", "https://api.openai.com")
headers = {
"Authorization": f"Bearer {openai_key}"
"Authorization": f"Bearer {openai_api_key}"
}
response = await client.get(f"{openai_urlroot}/v1/models", headers=headers)
response = await client.get(f"{openai_api_url}/v1/models", headers=headers)
if response.status_code == 200:
data = response.json()
available_models = [model_data['id'] for model_data in data.get("data", [])]
Expand All @@ -174,9 +182,10 @@ async def _load_openai_models(self):
llm_provider = "text-completion-openai"
if llm_provider:
name = model
aisuite_name = "{}:{}".format(provider,name) # what aisuite expects
llm_name = "{}/{}".format(llm_provider,name) # what LiteLLM expects
safe_name = f"{provider}/{name}".replace("/", "-").replace(":", "-") # URL-friendly ID
models[model] = {"id": safe_name, "name": name, "llm_name": llm_name, "provider": provider, "api_base": openai_urlroot}
models[model] = {"id": safe_name, "name": name, "provider": provider, "aisuite_name": aisuite_name, "llm_name": llm_name, "api_base": openai_api_url}
await self._persist_models(provider=provider, models=models)

async def _persist_models(self, provider, models):
Expand All @@ -193,8 +202,9 @@ async def _persist_models(self, provider, models):
llm = await self.get_llm(model_id)
if llm:
stmt = update(Llm).where(Llm.id == model_id).values(name=parameters["name"],
llm_name=parameters["llm_name"],
provider=parameters["provider"],
aisuite_name=parameters["aisuite_name"],
llm_name=parameters["llm_name"],
api_base=parameters["api_base"],
is_active=True)
result = await session.execute(stmt)
Expand All @@ -203,8 +213,9 @@ async def _persist_models(self, provider, models):
else:
new_llm = Llm(id=model_id,
name=parameters["name"],
llm_name=parameters["llm_name"],
provider=parameters["provider"],
aisuite_name=parameters["aisuite_name"],
llm_name=parameters["llm_name"],
api_base=parameters["api_base"],
is_active=True)
session.add(new_llm)
Expand Down
3 changes: 2 additions & 1 deletion backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ class Share(SQLModelBase, table=True):
class Llm(SQLModelBase, table=True):
id: str = Field(primary_key=True) # the model's unique, URL-friendly name
name: str = Field()
llm_name: str = Field() # the model name known to LiteLLM
provider: str = Field() # model provider, eg "ollama"
aisuite_name: str = Field() # the model name known to aisuite
llm_name: str = Field() # the model name known to LiteLLM
api_base: str | None = Field(default=None)
is_active: bool = Field() # is the model installed / available?

Expand Down
1 change: 1 addition & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ webauthn
greenlet
pyjwt
litellm
aisuite[ollama,openai]
2 changes: 1 addition & 1 deletion backend/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ class ShareSchema(ShareBaseSchema):
class LlmSchema(BaseModel):
id: str
name: str
full_name: str
provider: str
full_name: str
api_base: Optional[str] = None
is_active: bool

Expand Down
3 changes: 2 additions & 1 deletion migrations/versions/73d50424c826_added_llm_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ def upgrade() -> None:
op.create_table('llm',
sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('llm_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('aisuite_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('llm_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('api_base', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint('id')
Expand Down

0 comments on commit cda38e6

Please sign in to comment.