-
Notifications
You must be signed in to change notification settings - Fork 52
/
Copy pathcli.py
32 lines (27 loc) · 1.28 KB
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch.cuda
import torch.backends
from typing import Any, List, Dict, Union, Mapping, Optional
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from models.custom_llm import CustomLLM
from models.custom_agent import DeepAgent
from models.util import LocalDocQA
from models.config import *
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
deep_agent = DeepAgent()
embeddings = HuggingFaceEmbeddings(model_name="GanymedeNil/text2vec-base-chinese",
model_kwargs={'device':EMBEDDING_DEVICE})
qa_doc = LocalDocQA(filepath=LOCAL_CONTENT,
vs_path=VS_PATH,
embeddings=embeddings,
init=True)
def answer(query: str = ""):
question = query
related_content = qa_doc.query_knowledge(query=question)
formed_related_content = "\n" + related_content
result = deep_agent.query(related_content=formed_related_content, query=question)
return result
if __name__ == "__main__":
question = "携程最近有什么大新闻?"
related_content = qa_doc.query_knowledge(query=question)
formed_related_content = "\n" + related_content
print(deep_agent.query(related_content=formed_related_content, query=question))