Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added processors, and used adapter to use them in collectionsManager #76

Open
wants to merge 1 commit into
base: feat/collections
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions backend/managers/CollectionsManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@
from threading import Lock
from sqlalchemy import select, insert, update, delete, func, or_
from backend.models import Collection
from backend.processors.embedders.EmbeddersAdapter import EmbeddersAdapter
from backend.processors.text_splitters.TextSplitterAdapter import TextSplitterAdapter
from backend.processors.vector_stores.VectorStoresAdapter import VectorStoresAdapter
from backend.schemas import CollectionCreate # Import from schemas.py
from backend.db import db_session_context
from backend.processors.BaseProcessor import BaseProcessor
from backend.processors.SimpleTextSplitter import SimpleTextSplitter
from backend.processors.SimpleEmbedder import SimpleEmbedder
from backend.processors.SimpleVectorStore import SimpleVectorStore

from backend.processors.embedders import *
from backend.processors.text_splitters import *
from backend.processors.vector_stores import *
from typing import List, Tuple, Optional, Dict, Any


class CollectionsManager:
_instance = None
_lock = Lock()
Expand Down Expand Up @@ -71,16 +79,23 @@ async def process_collection(self, collection_id: str, text_splitter: str, embed

if not all([text_splitter, embedder, vector_store]):
raise ValueError("Invalid processor selection")

await self._process_collection(collection, text_splitter, embedder, vector_store)

async def _process_collection(self, collection: Collection, text_splitter: BaseProcessor, embedder: BaseProcessor, vector_store: BaseProcessor):
async def _process_collection(self, collection: Collection, text_splitter: BaseProcessor, embedder: BaseProcessor,
vector_store: BaseProcessor):
# Implement a method to process collections
# This could involve streaming data from an external source, like an email server
text_splitter_instance = TextSplitterAdapter(text_splitter, process=text_splitter.split_text)
embedder_instance = EmbeddersAdapter(embedder, process=embedder.embed_documents)
vector_store_instance = VectorStoresAdapter(vector_store, process=vector_store.from_texts)
for batch in self._collection_batch_generator(collection):
chunks = await text_splitter.process(batch)
embeddings = await embedder.process(chunks)
await vector_store.process(f"{collection.id}_{uuid4()}", chunks, embeddings)
chunks = await text_splitter_instance.process(batch)
embeddings = await embedder_instance.process(chunks)
# specify param values as required by each vector store type (VStores methods may have different signatures)
if vector_store_instance.type_instance.__name__ == "Chroma":
await vector_store_instance.process(chunks, embeddings, collection_name="collection_name_to_create")
elif vector_store_instance.type_instance.__name__ == "AWS":
await vector_store_instance.process(chunks, embeddings)

def _collection_batch_generator(self, collection: Collection):
# This is a dummy generator. In a real scenario, this would fetch data from the actual source
Expand All @@ -95,7 +110,8 @@ async def retrieve_collection(self, id: str) -> Optional[Collection]:

async def update_collection(self, id: str, collection_data: CollectionCreate) -> Optional[Collection]:
async with db_session_context() as session:
stmt = update(Collection).where(Collection.id == id).values(**collection_data.model_dump(exclude_unset=True))
stmt = update(Collection).where(Collection.id == id).values(
**collection_data.model_dump(exclude_unset=True))
result = await session.execute(stmt)
if result.rowcount > 0:
await session.commit()
Expand All @@ -110,8 +126,8 @@ async def delete_collection(self, id: str) -> bool:
await session.commit()
return result.rowcount > 0

async def list_collections(self, offset: int = 0, limit: int = 100, sort_by: Optional[str] = None,
sort_order: str = 'asc', filters: Optional[Dict[str, Any]] = None,
async def list_collections(self, offset: int = 0, limit: int = 100, sort_by: Optional[str] = None,
sort_order: str = 'asc', filters: Optional[Dict[str, Any]] = None,
query: Optional[str] = None) -> Tuple[List[Collection], int]:
async with db_session_context() as session:
stmt = select(Collection)
Expand Down
18 changes: 18 additions & 0 deletions backend/processors/embedders/BaseProcessorEmbeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Any, List
from abc import ABC, abstractmethod

class BaseProcessorEmbeddings(ABC):
@abstractmethod
def is_available(self) -> bool:
"""Check if the processor is available (dependencies installed)."""
pass

@abstractmethod
def embed_query(self, text: str) -> List[float]:
pass

@abstractmethod
def embed_documents(
self, texts: List[str], chunk_size: int
) -> List[List[float]]:
pass
17 changes: 17 additions & 0 deletions backend/processors/embedders/EmbeddersAdapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Callable, TypeVar

T = TypeVar("T")

class EmbeddersAdapter:

def __init__(self, obj: T, **adapted_methods: Callable):
self.obj = obj
self.__dict__.update(adapted_methods)

def __getattr__(self, attr):
"""All non-adapted calls are passed to the object."""
return getattr(self.obj, attr)

def original_dict(self):
"""Print original object dict."""
return self.obj.__dict__
25 changes: 25 additions & 0 deletions backend/processors/embedders/OllamaEmbeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import List

from backend.processors.embedders import BaseProcessorEmbeddings
from langchain_ollama import OllamaEmbeddings


class OllamaEmbeddings(BaseProcessorEmbeddings):
def __init__(self, model_name="llama3"):
self.embed = OllamaEmbeddings(
model=model_name
)

def is_available(self) -> bool:
return True


def embed_query(self, text: str) -> List[float]:
return self.embed.embed_query(text)

def embed_documents(
self, texts: List[str]
) -> List[List[float]]:
return self.embed.embed_documents(texts)

# TODO implement other methods
29 changes: 29 additions & 0 deletions backend/processors/embedders/OpenAIEmbeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import random
from typing import List
from backend.processors.embedders import BaseProcessorEmbeddings
from langchain_openai import OpenAIEmbeddings

class OpenAIEmbeddings(BaseProcessorEmbeddings):
def __init__(self,model_name="text-embedding-3-large"):
self.embed = OpenAIEmbeddings(
model=model_name
# With the `text-embedding-3` class
# of models, you can specify the size
# of the embeddings you want returned.
# dimensions=1024
)

def is_available(self) -> bool:
return True


def embed_query(self, text: str) -> List[float]:
return self.embed.embed_query(text)

def embed_documents(
self, texts: List[str], chunk_size: int
) -> List[List[float]]:
return self.embed.embed_documents(texts,chunk_size)

# TODO implement other methods

18 changes: 18 additions & 0 deletions backend/processors/text_splitters/BaseProcessorTextSplitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Any, List, Optional
from abc import ABC, abstractmethod
from langchain_core.documents import BaseDocumentTransformer, Document
class BaseProcessorTextSplitter(ABC):
@abstractmethod
def is_available(self) -> bool:
"""Check if the processor is available (dependencies installed)."""
pass

@abstractmethod
def split_documents(
self, texts: List[str], metadatas: Optional[List[dict]]
) -> List[Document]:
pass

@abstractmethod
def split_text(self, text: str) -> List[str]:
pass
32 changes: 32 additions & 0 deletions backend/processors/text_splitters/CharacterTextSplitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import List, Optional
from backend.processors.text_splitters import BaseProcessorTextSplitter
from langchain_text_splitters import CharacterTextSplitter
from langchain_core.documents import BaseDocumentTransformer, Document


class CharacterTextSplitter(BaseProcessorTextSplitter):

def __init__(self):
self.text_splitter = CharacterTextSplitter(
separator="\n\n",
chunk_size=1000,
chunk_overlap=200,
length_function=len,
is_separator_regex=False,
)

def is_available(self) -> bool:
return True


def split_documents(
self, texts: List[str], metadatas: Optional[List[dict]]
) -> List[Document]:
return self.text_splitter.create_documents(
texts, metadatas=metadatas
)

def split_text(self, text: str) -> List[str]:
return self.text_splitter.split_text(text)

# TODO implement other methods
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from langchain_text_splitters import RecursiveCharacterTextSplitter

from typing import List, Optional
from backend.processors.text_splitters import BaseProcessorTextSplitter
from langchain_core.documents import BaseDocumentTransformer, Document


class RecursiveCharacterTextSplitter(BaseProcessorTextSplitter):

def __init__(self):
self.text_splitter = RecursiveCharacterTextSplitter(
# Set a really small chunk size, just to show.
chunk_size=100,
chunk_overlap=20,
length_function=len,
is_separator_regex=False,
)

def is_available(self) -> bool:
return True


def split_documents(
self, texts: List[str]
) -> List[Document]:
return self.text_splitter.create_documents(
texts
)

def split_text(self, text: str) -> List[str]:
return self.text_splitter.split_text(text)

# TODO implement other methods
17 changes: 17 additions & 0 deletions backend/processors/text_splitters/TextSplitterAdapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Callable, TypeVar

T = TypeVar("T")

class TextSplitterAdapter:

def __init__(self, obj: T, **adapted_methods: Callable):
self.obj = obj
self.__dict__.update(adapted_methods)

def __getattr__(self, attr):
"""All non-adapted calls are passed to the object."""
return getattr(self.obj, attr)

def original_dict(self):
"""Print original object dict."""
return self.obj.__dict__
56 changes: 56 additions & 0 deletions backend/processors/vector_stores/AWS.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import List, Optional, Dict, Union, Type

import numpy as np

from backend.processors.vector_stores import BaseProcessorVectorStore
from langchain_core.documents import Document

from langchain_aws.vectorstores import InMemoryVectorStore

from langchain_aws.vectorstores.inmemorydb.filters import InMemoryDBFilterExpression
from langchain_aws.utilities.math import cosine_similarity
from langchain_core.embeddings.embeddings import Embeddings


class AWS(BaseProcessorVectorStore):
def __init__(self, documents, embeddings, redis="redis://cluster_endpoint:6379"):
self.aws = InMemoryVectorStore.from_documents(
documents, # a list of Document objects from loaders or created
embeddings, # an Embeddings object
redis_url=redis,
)

def is_available(self) -> bool:
return True

# returns list of ids of added docs
def add_document(self, documents):
return self.aws.add_documents(documents=documents)

def search(
self,
query: str,
k: int = 4,
filter: Optional[InMemoryDBFilterExpression] = None,
return_metadata: bool = True,
distance_threshold: Optional[float] = None
) -> List[Document]:
"""Run similarity search with AWS."""
return self.aws.similarity_search(query, k)

def from_texts(self,
#cls: Type[InMemoryVectorStore],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
index_name: Optional[str] = None,
# index_schema: Optional[Union[Dict[str, ListOfDict], str, os.PathLike]] = None,
# vector_schema: Optional[Dict[str, Union[str, int]]] = None,
# **kwargs: Any,
) -> InMemoryVectorStore:
return self.aws.from_texts(texts, embedding)

def cosine_similarity(self, b: np.array(), a: np.array()):
return cosine_similarity(a, b)

# TODO implement other methods
33 changes: 33 additions & 0 deletions backend/processors/vector_stores/BaseProcessorVectorStore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Any
from abc import ABC, abstractmethod
from langchain_core.documents import Document

from langchain_aws.vectorstores import InMemoryVectorStore

from langchain_aws.vectorstores.inmemorydb.filters import InMemoryDBFilterExpression
class BaseProcessorVectorStore(ABC):
@abstractmethod
def is_available(self) -> bool:
"""Check if the processor is available (dependencies installed)."""
pass

@abstractmethod
# returns list of ids of added docs
def add_document(self, documents):
pass

@abstractmethod
def similarity_search(
self,
query: str,
k: int = 4,
filter: Optional[InMemoryDBFilterExpression] = None,
return_metadata: bool = True,
distance_threshold: Optional[float] = None
) -> List[Document]:
"""Run similarity search with AWS."""
pass

@abstractmethod
def cosine_similarity(self, b: np.array(), a: np.array()):
pass
Loading