-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Embedding] Support Ollama embedding (#81)
* [emb] Support ollama as embedding provider * A new embedding_utils for easier filtering/sorting relevant candidates according to the metric type * Fix new ollama emb collections cleanup * Enhance logging * update README
- Loading branch information
Showing
9 changed files
with
235 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import os | ||
import json | ||
import time | ||
|
||
import numpy as np | ||
|
||
from embedding import Embedding | ||
from langchain_community.embeddings import OllamaEmbeddings | ||
import utils | ||
|
||
|
||
class EmbeddingOllama(Embedding): | ||
""" | ||
Embedding via Ollama | ||
""" | ||
def __init__(self, model_name="nomic-embed-text", base_url=""): | ||
super().__init__(model_name) | ||
|
||
self.base_url = base_url or os.getenv("OLLAMA_URL") | ||
self.dimensions = -1 | ||
|
||
self.client = OllamaEmbeddings( | ||
base_url=self.base_url, | ||
model=self.model_name, | ||
) | ||
|
||
print(f"Initialized EmbeddingOllama: model_name: {self.model_name}, base_url: {self.base_url}") | ||
|
||
def dim(self): | ||
if self.dimensions > 0: | ||
return self.dimensions | ||
|
||
text = "This is a test query" | ||
query_result = self.client.embed_query(text) | ||
self.dimensions = len(query_result) | ||
return self.dimensions | ||
|
||
def getname(self, start_date, prefix="ollama"): | ||
""" | ||
Get a embedding collection name of milvus | ||
""" | ||
return f"embedding__{prefix}__ollama_{self.model_name}__{start_date}".replace("-", "_") | ||
|
||
def create( | ||
self, | ||
text: str, | ||
num_retries=3, | ||
retry_wait_time=0.5, | ||
error_wait_time=0.5, | ||
|
||
# ollama embedding query result is not normalized, for most | ||
# of the vector database would suggest us do the normalization | ||
# first before inserting into the vector database | ||
# here, we can apply a post-step for the normalization | ||
normalize=True, | ||
): | ||
emb = None | ||
|
||
for i in range(1, num_retries + 1): | ||
try: | ||
emb = self.client.embed_query(text) | ||
|
||
if normalize: | ||
emb = (np.array(emb) / np.linalg.norm(emb)).tolist() | ||
|
||
break | ||
|
||
except Exception as e: | ||
print(f"[ERROR] APIError during embedding ({i}/{num_retries}): {e}") | ||
|
||
if i == num_retries: | ||
raise | ||
|
||
time.sleep(error_wait_time) | ||
|
||
return emb | ||
|
||
def get_or_create( | ||
self, | ||
text: str, | ||
source="", | ||
page_id="", | ||
db_client=None, | ||
key_ttl=86400 * 30 | ||
): | ||
""" | ||
Get embedding from cache (or create if not exist) | ||
""" | ||
client = db_client | ||
embedding = None | ||
|
||
if client: | ||
# Tips: the quickest way to get rid of all previous | ||
# cache, change the provider (1st arg) | ||
embedding = client.get_milvus_embedding_item_id( | ||
"ollama-norm", | ||
self.model_name, | ||
source, | ||
page_id) | ||
|
||
if embedding: | ||
print("[EmbeddingOllama] Embedding got from cache") | ||
return utils.fix_and_parse_json(embedding) | ||
|
||
# Not found in cache, generate one | ||
print("[EmbeddingOllama] Embedding not found, create a new one and cache it") | ||
|
||
# Most of the emb models have 8k tokens, exceed it will | ||
# throw exceptions. Here we simply limited it <= 5000 chars | ||
# for the input | ||
|
||
EMBEDDING_MAX_LENGTH = int(os.getenv("EMBEDDING_MAX_LENGTH", 5000)) | ||
embedding = self.create(text[:EMBEDDING_MAX_LENGTH]) | ||
|
||
# store embedding into redis (ttl = 1 month) | ||
if client: | ||
client.set_milvus_embedding_item_id( | ||
"ollama-norm", | ||
self.model_name, | ||
source, | ||
page_id, | ||
json.dumps(embedding), | ||
expired_time=key_ttl) | ||
|
||
return embedding |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
############################################################################### | ||
# Embedding Utils | ||
############################################################################### | ||
|
||
def similarity_topk(embedding_items: list, metric_type, threshold=None, k=3): | ||
""" | ||
@param embedding_items [{item_id, distance}, ...] | ||
@param metric_type L2, IP, COSINE | ||
@threshold to filter the result | ||
@k max number of returns | ||
""" | ||
if metric_type == "L2": | ||
return similarity_topk_l2(embedding_items, threshold, k) | ||
elif metric_type in ("IP", "COSINE"): | ||
# assume IP type all embeddings has been normalized | ||
return similarity_topk_cosine(embedding_items, threshold, k) | ||
else: | ||
raise Exception(f"Unknown metric_type: {metric_type}") | ||
|
||
|
||
def similarity_topk_l2(items: list, threshold, k): | ||
""" | ||
metric_type L2, the value range [0, +inf) | ||
* The smaller (Close to 0), the more similiar | ||
* The larger, the less similar | ||
so, we will filter in distance <= threshold first, then get top-k | ||
""" | ||
valid_items = items | ||
|
||
if threshold is not None: | ||
valid_items = [x for x in items if x["distance"] <= threshold] | ||
|
||
# sort in ASC | ||
sorted_items = sorted( | ||
valid_items, | ||
key=lambda item: item["distance"], | ||
) | ||
|
||
# The returned value is sorted by most similar -> least similar | ||
return sorted_items[:k] | ||
|
||
|
||
def similarity_topk_cosine(items: list, threshold, k): | ||
""" | ||
metric_type IP (normalized) or COSINE, the value range [-1, 1] | ||
* 1 indicates that the vectors are identical in direction. | ||
* 0 indicates orthogonality (no similarity in direction). | ||
* -1 indicates that the vectors are opposite in direction. | ||
so, we will filter in distance >= threshold first, then get top-k | ||
""" | ||
valid_items = items | ||
|
||
if threshold is not None: | ||
valid_items = [x for x in items if x["distance"] >= threshold] | ||
|
||
# sort in DESC | ||
sorted_items = sorted( | ||
valid_items, | ||
key=lambda item: item["distance"], | ||
reverse=True, | ||
) | ||
|
||
# The returned value is sorted by most similar -> least similar | ||
return sorted_items[:k] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters