diff --git a/bin/chat-chainlit.py b/bin/chat-chainlit.py index 183292a..9d96a04 100644 --- a/bin/chat-chainlit.py +++ b/bin/chat-chainlit.py @@ -7,6 +7,7 @@ import chainlit as cl from dotenv import load_dotenv +from conversational_chain.graph import RAGGraphWithMemory from retreival_chain import initialize_retrieval_chain from util.embedding_environment import EmbeddingEnvironment @@ -57,14 +58,14 @@ async def start() -> None: chat_profile = cl.user_session.get("chat_profile") embeddings_directory = EmbeddingEnvironment.get_dir(env) - llm_chain = initialize_retrieval_chain( + llm_graph = initialize_retrieval_chain( env, embeddings_directory, False, False, hf_model=EmbeddingEnvironment.get_model(env), ) - cl.user_session.set("llm_chain", llm_chain) + cl.user_session.set("llm_graph", llm_graph) initial_message: str = f"""Welcome to {chat_profile} your interactive chatbot for exploring Reactome! Ask me about biological pathways and processes""" @@ -73,13 +74,15 @@ async def start() -> None: @cl.on_message async def main(message: cl.Message) -> None: - llm_chain = cl.user_session.get("llm_chain") + llm_graph: RAGGraphWithMemory = cl.user_session.get("llm_graph") cb = cl.AsyncLangchainCallbackHandler( stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"] ) - cb.answer_reached = True - - res = await llm_chain.ainvoke(message.content, callbacks=[cb]) + res = await llm_graph.ainvoke( + message.content, + callbacks = [cb], + configurable = {"thread_id": "0"} # single thread + ) if cb.has_streamed_final_answer and cb.final_stream is not None: await cb.final_stream.update() else: diff --git a/poetry.lock b/poetry.lock index 7ab245a..f6de25e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -418,17 +418,17 @@ uvloop = ["uvloop (>=0.15.2)"] [[package]] name = "boto3" -version = "1.35.50" +version = "1.35.51" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.35.50-py3-none-any.whl", hash = "sha256:14724b905fd13f26d9d8f7cdcea0fa65a9acad79f60f41f7662667f4e233d97c"}, - {file = "boto3-1.35.50.tar.gz", hash = "sha256:4f15d1ccb481d66f6925b8c91c970ce41b956b6ecf7c479f23e2159531b37eec"}, + {file = "boto3-1.35.51-py3-none-any.whl", hash = "sha256:c922f6a18958af9d8af0489d6d8503b517029d8159b26aa4859a8294561c72e9"}, + {file = "boto3-1.35.51.tar.gz", hash = "sha256:a57c6c7012ecb40c43e565a6f7a891f39efa990ff933eab63cd456f7501c2731"}, ] [package.dependencies] -botocore = ">=1.35.50,<1.36.0" +botocore = ">=1.35.51,<1.36.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -437,13 +437,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.35.50" +version = "1.35.51" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.35.50-py3-none-any.whl", hash = "sha256:965d3b99179ac04aa98e4c4baf4a970ebce77a5e02bb2a0a21cb6304e2bc0955"}, - {file = "botocore-1.35.50.tar.gz", hash = "sha256:136ecef8d5a1088f1ba485c0bbfca40abd42b9f9fe9e11d8cde4e53b4c05b188"}, + {file = "botocore-1.35.51-py3-none-any.whl", hash = "sha256:4d65b00111bd12b98e9f920ecab602cf619cc6a6d0be6e5dd53f517e4b92901c"}, + {file = "botocore-1.35.51.tar.gz", hash = "sha256:a9b3d1da76b3e896ad74605c01d88f596324a3337393d4bfbfa0d6c35822ca9c"}, ] [package.dependencies] @@ -1368,6 +1368,17 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "httpx-sse" +version = "0.4.0" +description = "Consume Server-Sent Event (SSE) messages with HTTPX." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"}, + {file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"}, +] + [[package]] name = "huggingface-hub" version = "0.26.2" @@ -1668,19 +1679,19 @@ adal = ["adal (>=1.0.2)"] [[package]] name = "langchain" -version = "0.3.4" +version = "0.3.5" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "langchain-0.3.4-py3-none-any.whl", hash = "sha256:7a1241d9429510d2083c62df0da998a7b2b05c730cd4255b89da9d47c57f48fd"}, - {file = "langchain-0.3.4.tar.gz", hash = "sha256:3596515fcd0157dece6ec96e0240d29f4cf542d91ecffc815d32e35198dfff37"}, + {file = "langchain-0.3.5-py3-none-any.whl", hash = "sha256:f62ea88bb12e40b1b210cf059a5e564bfb17ee3194651944d5a013abb88d30ee"}, + {file = "langchain-0.3.5.tar.gz", hash = "sha256:ff8f138e7e75c34ab155397853964494b212abda856b9c983dedc535249f29e8"}, ] [package.dependencies] aiohttp = ">=3.8.3,<4.0.0" async-timeout = {version = ">=4.0.0,<5.0.0", markers = "python_version < \"3.11\""} -langchain-core = ">=0.3.12,<0.4.0" +langchain-core = ">=0.3.13,<0.4.0" langchain-text-splitters = ">=0.3.0,<0.4.0" langsmith = ">=0.1.17,<0.2.0" numpy = [ @@ -1811,17 +1822,64 @@ tiktoken = ">=0.7,<1" [[package]] name = "langchain-text-splitters" -version = "0.3.0" +version = "0.3.1" description = "LangChain text splitting utilities" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_text_splitters-0.3.0-py3-none-any.whl", hash = "sha256:e84243e45eaff16e5b776cd9c81b6d07c55c010ebcb1965deb3d1792b7358e83"}, - {file = "langchain_text_splitters-0.3.0.tar.gz", hash = "sha256:f9fe0b4d244db1d6de211e7343d4abc4aa90295aa22e1f0c89e51f33c55cd7ce"}, + {file = "langchain_text_splitters-0.3.1-py3-none-any.whl", hash = "sha256:842c9b342dc4b1a810f911a262f0daa772beb83d6e4f13a87ae7a3de015cb2a9"}, + {file = "langchain_text_splitters-0.3.1.tar.gz", hash = "sha256:5d2a749d0e70c7c6a3e0a0a3996648fd06cb9fe8e316c3040ca6541811d8ecaf"}, ] [package.dependencies] -langchain-core = ">=0.3.0,<0.4.0" +langchain-core = ">=0.3.13,<0.4.0" + +[[package]] +name = "langgraph" +version = "0.2.39" +description = "Building stateful, multi-actor applications with LLMs" +optional = false +python-versions = "<4.0,>=3.9.0" +files = [ + {file = "langgraph-0.2.39-py3-none-any.whl", hash = "sha256:5dfbdeefbf599f16d245799609f2b43c1ec7a7e8ed6e1d7981b1a7979a4ad7fe"}, + {file = "langgraph-0.2.39.tar.gz", hash = "sha256:32af60291f9260c3acb8a3d4bec99e32abd89ddb6b4a10a79aa3dbc90fa920ac"}, +] + +[package.dependencies] +langchain-core = ">=0.2.39,<0.4" +langgraph-checkpoint = ">=2.0.0,<3.0.0" +langgraph-sdk = ">=0.1.32,<0.2.0" + +[[package]] +name = "langgraph-checkpoint" +version = "2.0.2" +description = "Library with base interfaces for LangGraph checkpoint savers." +optional = false +python-versions = "<4.0.0,>=3.9.0" +files = [ + {file = "langgraph_checkpoint-2.0.2-py3-none-any.whl", hash = "sha256:6e5dfd90e1fc71b91ccff75939ada1114e5d7f824df5f24c62d39bed69039ee2"}, + {file = "langgraph_checkpoint-2.0.2.tar.gz", hash = "sha256:c1d033e4e4855f580fa56830327eb86513b64ab5be527245363498e76b19a0b9"}, +] + +[package.dependencies] +langchain-core = ">=0.2.38,<0.4" +msgpack = ">=1.1.0,<2.0.0" + +[[package]] +name = "langgraph-sdk" +version = "0.1.35" +description = "SDK for interacting with LangGraph API" +optional = false +python-versions = "<4.0.0,>=3.9.0" +files = [ + {file = "langgraph_sdk-0.1.35-py3-none-any.whl", hash = "sha256:b137c324fbce96afe39cc6a189c61fc042164068f0f6f02ac8de864d8ece6e05"}, + {file = "langgraph_sdk-0.1.35.tar.gz", hash = "sha256:414cfbc172b883446197763f3645d86bbc6a5b8ce9693c1df3fb6ce7d854a994"}, +] + +[package.dependencies] +httpx = ">=0.25.2" +httpx-sse = ">=0.4.0" +orjson = ">=3.10.1" [[package]] name = "langsmith" @@ -2150,6 +2208,79 @@ docs = ["sphinx"] gmpy = ["gmpy2 (>=2.1.0a4)"] tests = ["pytest (>=4.6)"] +[[package]] +name = "msgpack" +version = "1.1.0" +description = "MessagePack serializer" +optional = false +python-versions = ">=3.8" +files = [ + {file = "msgpack-1.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7ad442d527a7e358a469faf43fda45aaf4ac3249c8310a82f0ccff9164e5dccd"}, + {file = "msgpack-1.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:74bed8f63f8f14d75eec75cf3d04ad581da6b914001b474a5d3cd3372c8cc27d"}, + {file = "msgpack-1.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:914571a2a5b4e7606997e169f64ce53a8b1e06f2cf2c3a7273aa106236d43dd5"}, + {file = "msgpack-1.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c921af52214dcbb75e6bdf6a661b23c3e6417f00c603dd2070bccb5c3ef499f5"}, + {file = "msgpack-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8ce0b22b890be5d252de90d0e0d119f363012027cf256185fc3d474c44b1b9e"}, + {file = "msgpack-1.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:73322a6cc57fcee3c0c57c4463d828e9428275fb85a27aa2aa1a92fdc42afd7b"}, + {file = "msgpack-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e1f3c3d21f7cf67bcf2da8e494d30a75e4cf60041d98b3f79875afb5b96f3a3f"}, + {file = "msgpack-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:64fc9068d701233effd61b19efb1485587560b66fe57b3e50d29c5d78e7fef68"}, + {file = "msgpack-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:42f754515e0f683f9c79210a5d1cad631ec3d06cea5172214d2176a42e67e19b"}, + {file = "msgpack-1.1.0-cp310-cp310-win32.whl", hash = "sha256:3df7e6b05571b3814361e8464f9304c42d2196808e0119f55d0d3e62cd5ea044"}, + {file = "msgpack-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:685ec345eefc757a7c8af44a3032734a739f8c45d1b0ac45efc5d8977aa4720f"}, + {file = "msgpack-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3d364a55082fb2a7416f6c63ae383fbd903adb5a6cf78c5b96cc6316dc1cedc7"}, + {file = "msgpack-1.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:79ec007767b9b56860e0372085f8504db5d06bd6a327a335449508bbee9648fa"}, + {file = "msgpack-1.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6ad622bf7756d5a497d5b6836e7fc3752e2dd6f4c648e24b1803f6048596f701"}, + {file = "msgpack-1.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e59bca908d9ca0de3dc8684f21ebf9a690fe47b6be93236eb40b99af28b6ea6"}, + {file = "msgpack-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e1da8f11a3dd397f0a32c76165cf0c4eb95b31013a94f6ecc0b280c05c91b59"}, + {file = "msgpack-1.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:452aff037287acb1d70a804ffd022b21fa2bb7c46bee884dbc864cc9024128a0"}, + {file = "msgpack-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8da4bf6d54ceed70e8861f833f83ce0814a2b72102e890cbdfe4b34764cdd66e"}, + {file = "msgpack-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:41c991beebf175faf352fb940bf2af9ad1fb77fd25f38d9142053914947cdbf6"}, + {file = "msgpack-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a52a1f3a5af7ba1c9ace055b659189f6c669cf3657095b50f9602af3a3ba0fe5"}, + {file = "msgpack-1.1.0-cp311-cp311-win32.whl", hash = "sha256:58638690ebd0a06427c5fe1a227bb6b8b9fdc2bd07701bec13c2335c82131a88"}, + {file = "msgpack-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fd2906780f25c8ed5d7b323379f6138524ba793428db5d0e9d226d3fa6aa1788"}, + {file = "msgpack-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d46cf9e3705ea9485687aa4001a76e44748b609d260af21c4ceea7f2212a501d"}, + {file = "msgpack-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5dbad74103df937e1325cc4bfeaf57713be0b4f15e1c2da43ccdd836393e2ea2"}, + {file = "msgpack-1.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:58dfc47f8b102da61e8949708b3eafc3504509a5728f8b4ddef84bd9e16ad420"}, + {file = "msgpack-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4676e5be1b472909b2ee6356ff425ebedf5142427842aa06b4dfd5117d1ca8a2"}, + {file = "msgpack-1.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:17fb65dd0bec285907f68b15734a993ad3fc94332b5bb21b0435846228de1f39"}, + {file = "msgpack-1.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a51abd48c6d8ac89e0cfd4fe177c61481aca2d5e7ba42044fd218cfd8ea9899f"}, + {file = "msgpack-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2137773500afa5494a61b1208619e3871f75f27b03bcfca7b3a7023284140247"}, + {file = "msgpack-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:398b713459fea610861c8a7b62a6fec1882759f308ae0795b5413ff6a160cf3c"}, + {file = "msgpack-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:06f5fd2f6bb2a7914922d935d3b8bb4a7fff3a9a91cfce6d06c13bc42bec975b"}, + {file = "msgpack-1.1.0-cp312-cp312-win32.whl", hash = "sha256:ad33e8400e4ec17ba782f7b9cf868977d867ed784a1f5f2ab46e7ba53b6e1e1b"}, + {file = "msgpack-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:115a7af8ee9e8cddc10f87636767857e7e3717b7a2e97379dc2054712693e90f"}, + {file = "msgpack-1.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:071603e2f0771c45ad9bc65719291c568d4edf120b44eb36324dcb02a13bfddf"}, + {file = "msgpack-1.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0f92a83b84e7c0749e3f12821949d79485971f087604178026085f60ce109330"}, + {file = "msgpack-1.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4a1964df7b81285d00a84da4e70cb1383f2e665e0f1f2a7027e683956d04b734"}, + {file = "msgpack-1.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59caf6a4ed0d164055ccff8fe31eddc0ebc07cf7326a2aaa0dbf7a4001cd823e"}, + {file = "msgpack-1.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0907e1a7119b337971a689153665764adc34e89175f9a34793307d9def08e6ca"}, + {file = "msgpack-1.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:65553c9b6da8166e819a6aa90ad15288599b340f91d18f60b2061f402b9a4915"}, + {file = "msgpack-1.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7a946a8992941fea80ed4beae6bff74ffd7ee129a90b4dd5cf9c476a30e9708d"}, + {file = "msgpack-1.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:4b51405e36e075193bc051315dbf29168d6141ae2500ba8cd80a522964e31434"}, + {file = "msgpack-1.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:b4c01941fd2ff87c2a934ee6055bda4ed353a7846b8d4f341c428109e9fcde8c"}, + {file = "msgpack-1.1.0-cp313-cp313-win32.whl", hash = "sha256:7c9a35ce2c2573bada929e0b7b3576de647b0defbd25f5139dcdaba0ae35a4cc"}, + {file = "msgpack-1.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:bce7d9e614a04d0883af0b3d4d501171fbfca038f12c77fa838d9f198147a23f"}, + {file = "msgpack-1.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c40ffa9a15d74e05ba1fe2681ea33b9caffd886675412612d93ab17b58ea2fec"}, + {file = "msgpack-1.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1ba6136e650898082d9d5a5217d5906d1e138024f836ff48691784bbe1adf96"}, + {file = "msgpack-1.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e0856a2b7e8dcb874be44fea031d22e5b3a19121be92a1e098f46068a11b0870"}, + {file = "msgpack-1.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:471e27a5787a2e3f974ba023f9e265a8c7cfd373632247deb225617e3100a3c7"}, + {file = "msgpack-1.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:646afc8102935a388ffc3914b336d22d1c2d6209c773f3eb5dd4d6d3b6f8c1cb"}, + {file = "msgpack-1.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:13599f8829cfbe0158f6456374e9eea9f44eee08076291771d8ae93eda56607f"}, + {file = "msgpack-1.1.0-cp38-cp38-win32.whl", hash = "sha256:8a84efb768fb968381e525eeeb3d92857e4985aacc39f3c47ffd00eb4509315b"}, + {file = "msgpack-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:879a7b7b0ad82481c52d3c7eb99bf6f0645dbdec5134a4bddbd16f3506947feb"}, + {file = "msgpack-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:53258eeb7a80fc46f62fd59c876957a2d0e15e6449a9e71842b6d24419d88ca1"}, + {file = "msgpack-1.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7e7b853bbc44fb03fbdba34feb4bd414322180135e2cb5164f20ce1c9795ee48"}, + {file = "msgpack-1.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f3e9b4936df53b970513eac1758f3882c88658a220b58dcc1e39606dccaaf01c"}, + {file = "msgpack-1.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46c34e99110762a76e3911fc923222472c9d681f1094096ac4102c18319e6468"}, + {file = "msgpack-1.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a706d1e74dd3dea05cb54580d9bd8b2880e9264856ce5068027eed09680aa74"}, + {file = "msgpack-1.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:534480ee5690ab3cbed89d4c8971a5c631b69a8c0883ecfea96c19118510c846"}, + {file = "msgpack-1.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8cf9e8c3a2153934a23ac160cc4cba0ec035f6867c8013cc6077a79823370346"}, + {file = "msgpack-1.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:3180065ec2abbe13a4ad37688b61b99d7f9e012a535b930e0e683ad6bc30155b"}, + {file = "msgpack-1.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:c5a91481a3cc573ac8c0d9aace09345d989dc4a0202b7fcb312c88c26d4e71a8"}, + {file = "msgpack-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f80bc7d47f76089633763f952e67f8214cb7b3ee6bfa489b3cb6a84cfac114cd"}, + {file = "msgpack-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:4d1b7ff2d6146e16e8bd665ac726a89c74163ef8cd39fa8c1087d4e52d3a2325"}, + {file = "msgpack-1.1.0.tar.gz", hash = "sha256:dd432ccc2c72b914e4cb77afce64aab761c1137cc698be3984eee260bcb2896e"}, +] + [[package]] name = "multidict" version = "6.1.0" @@ -4554,13 +4685,13 @@ telegram = ["requests"] [[package]] name = "transformers" -version = "4.46.0" +version = "4.46.1" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.46.0-py3-none-any.whl", hash = "sha256:e161268ae8bee315eb9e9b4c0b27f1bd6980f91e0fc292d75249193d339704c0"}, - {file = "transformers-4.46.0.tar.gz", hash = "sha256:3a9e2eb537094db11c3652334d281afa4766c0e5091c4dcdb454e9921bb0d2b7"}, + {file = "transformers-4.46.1-py3-none-any.whl", hash = "sha256:f77b251a648fd32e3d14b5e7e27c913b7c29154940f519e4c8c3aa6061df0f05"}, + {file = "transformers-4.46.1.tar.gz", hash = "sha256:16d79927d772edaf218820a96f9254e2211f9cd7fb3c308562d2d636c964a68c"}, ] [package.dependencies] @@ -5167,4 +5298,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.10, <=3.12.3" -content-hash = "2e6037c49af340757ba136053d907feb6d4a65a2c9176019990d72f2667a87d8" +content-hash = "72b2db3ea198663339ad58aaad9b9722c1360f2510d6b745b8c5e4f537ad0a99" diff --git a/pyproject.toml b/pyproject.toml index bc9fe19..e0eddee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ torch = {version = "^2.4.0+cpu", source = "pytorch_cpu"} langchain-chroma = "^0.1.4" langchain-ollama = "^0.2.0" lark = "^1.2.2" +langgraph = "^0.2.39" [tool.poetry.group.dev.dependencies] ruff = "^0.7.1" diff --git a/src/conversational_chain/chain.py b/src/conversational_chain/chain.py index 7fff9a2..c966353 100644 --- a/src/conversational_chain/chain.py +++ b/src/conversational_chain/chain.py @@ -1,12 +1,17 @@ -from langchain.chains import (create_history_aware_retriever, - create_retrieval_chain) from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain.chains.history_aware_retriever import \ + create_history_aware_retriever +from langchain.chains.retrieval import create_retrieval_chain +from langchain_core.language_models import LanguageModelLike +from langchain_core.retrievers import RetrieverLike from system_prompt.reactome_prompt import contextualize_q_prompt, qa_prompt class RAGChainWithMemory: - def __init__(self, memory, retriever, llm): + def __init__( + self, memory, retriever: RetrieverLike, llm: LanguageModelLike + ): """ Initializes the Retrieval-Augmented Generation (RAG) chain with memory. """ diff --git a/src/conversational_chain/graph.py b/src/conversational_chain/graph.py new file mode 100644 index 0000000..98b7db3 --- /dev/null +++ b/src/conversational_chain/graph.py @@ -0,0 +1,60 @@ +from typing import Annotated, Any, Sequence, TypedDict + +from langchain_core.callbacks.base import Callbacks +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph.message import add_messages +from langgraph.graph.state import CompiledStateGraph, StateGraph + +from conversational_chain.chain import RAGChainWithMemory + + +class ChatResponse(TypedDict): + chat_history: Annotated[Sequence[BaseMessage], add_messages] + context: str + answer: str + + +class ChatState(ChatResponse): + input: str + + +class RAGGraphWithMemory(RAGChainWithMemory): + def __init__(self, **chain_kwargs): + super().__init__(**chain_kwargs) + + # Single-node graph (for now) + graph: StateGraph = StateGraph(ChatState) + graph.add_node("model", self.call_model) + graph.set_entry_point("model") + graph.set_finish_point("model") + + memory = MemorySaver() + self.graph: CompiledStateGraph = graph.compile(checkpointer=memory) + + async def call_model( + self, state: ChatState, config: RunnableConfig + ) -> ChatResponse: + response = await self.rag_chain.ainvoke(state, config) + return { + "chat_history": [ + HumanMessage(state["input"]), + AIMessage(response["answer"]), + ], + "context": response["context"], + "answer": response["answer"], + } + + async def ainvoke( + self, user_input: str, callbacks: Callbacks, + configurable: dict[str, Any] + ) -> str: + response: dict[str, Any] = await self.graph.ainvoke( + {"input": user_input}, + config = RunnableConfig( + callbacks = callbacks, + configurable = configurable, + ) + ) + return response["answer"] diff --git a/src/retreival_chain.py b/src/retreival_chain.py index f4bfe43..371ff7b 100644 --- a/src/retreival_chain.py +++ b/src/retreival_chain.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import AsyncGenerator, Callable, Optional +from typing import Callable, Optional from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler @@ -14,7 +14,7 @@ from langchain_ollama.chat_models import ChatOllama from langchain_openai import ChatOpenAI, OpenAIEmbeddings -from conversational_chain.chain import RAGChainWithMemory +from conversational_chain.graph import RAGGraphWithMemory from conversational_chain.memory import ChatHistoryMemory from reactome.metadata_info import descriptions_info, field_info @@ -26,11 +26,6 @@ def list_chroma_subdirectories(directory: Path) -> list[str]: return subdirectories -async def invoke(self, query: str) -> AsyncGenerator[str, None]: - async for message in self.astream(query): - yield message - - def get_embedding( hf_model: Optional[str] = None, device: str = "cpu" ) -> Callable[[], Embeddings]: @@ -60,7 +55,7 @@ def initialize_retrieval_chain( ollama_url: str = "http://localhost:11434", hf_model: Optional[str] = None, device: str = "cpu", -) -> RAGChainWithMemory: +) -> RAGGraphWithMemory: memory = ChatHistoryMemory() callbacks: list[BaseCallbackHandler] = [] @@ -73,7 +68,7 @@ def initialize_retrieval_chain( if ollama_model is None: # Use OpenAI when Ollama not specified llm = ChatOpenAI( temperature=0.0, - streaming=commandline, + streaming=True, callbacks=callbacks, verbose=verbose, model="gpt-4o-mini", @@ -112,8 +107,7 @@ def initialize_retrieval_chain( retrievers=retriever_list, weights=[0.25] * len(retriever_list) ) - RAGChainWithMemory.invoke = invoke - qa = RAGChainWithMemory( + qa = RAGGraphWithMemory( memory=memory, retriever=reactome_retriever, llm=llm,