|
8 | 8 | from typing import Any |
9 | 9 | from dotenv import load_dotenv |
10 | 10 |
|
11 | | -from langchain_neo4j import Neo4jVector |
12 | | -from langchain_neo4j import Neo4jChatMessageHistory |
13 | | -from langchain_neo4j import GraphCypherQAChain |
| 11 | +from langchain_neo4j import Neo4jVector, Neo4jChatMessageHistory, GraphCypherQAChain |
14 | 12 | from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
15 | 13 | from langchain_core.output_parsers import StrOutputParser |
16 | 14 | from langchain_core.runnables import RunnableBranch |
|
20 | 18 | from langchain_core.messages import HumanMessage, AIMessage |
21 | 19 | from langchain_community.chat_message_histories import ChatMessageHistory |
22 | 20 | from langchain_core.callbacks import StdOutCallbackHandler, BaseCallbackHandler |
23 | | -from src.shared.llm_graph_builder_exception import LLMGraphBuilderException |
| 21 | +# from src.shared.llm_graph_builder_exception import LLMGraphBuilderException |
24 | 22 | # LangChain chat models |
25 | 23 | from langchain_openai import ChatOpenAI, AzureChatOpenAI |
26 | | -from langchain_google_vertexai import ChatVertexAI |
| 24 | +from langchain_google_genai import ChatGoogleGenerativeAI |
27 | 25 | from langchain_groq import ChatGroq |
28 | 26 | from langchain_anthropic import ChatAnthropic |
29 | 27 | from langchain_fireworks import ChatFireworks |
@@ -75,8 +73,12 @@ def get_total_tokens(ai_response, llm): |
75 | 73 | if isinstance(llm, (ChatOpenAI, AzureChatOpenAI, ChatFireworks, ChatGroq)): |
76 | 74 | total_tokens = ai_response.response_metadata.get('token_usage', {}).get('total_tokens', 0) |
77 | 75 |
|
78 | | - elif isinstance(llm, ChatVertexAI): |
79 | | - total_tokens = ai_response.response_metadata.get('usage_metadata', {}).get('prompt_token_count', 0) |
| 76 | + elif isinstance(llm, ChatGoogleGenerativeAI): |
| 77 | + if hasattr(ai_response, 'usage_metadata') and ai_response.usage_metadata: |
| 78 | + total_tokens = ai_response.usage_metadata.get('total_tokens', 0) |
| 79 | + else: |
| 80 | + usage = ai_response.response_metadata.get('token_usage', {}) or ai_response.response_metadata.get('usage_metadata', {}) |
| 81 | + total_tokens = usage.get('total_tokens', 0) |
80 | 82 |
|
81 | 83 | elif isinstance(llm, ChatBedrock): |
82 | 84 | total_tokens = ai_response.response_metadata.get('usage', {}).get('total_tokens', 0) |
@@ -224,6 +226,13 @@ def format_documents(documents, model,chat_mode_settings): |
224 | 226 |
|
225 | 227 | return "\n\n".join(formatted_docs), sources,entities,global_communities |
226 | 228 |
|
| 229 | +def get_clean_text(msg): |
| 230 | + if isinstance(msg.content, str): |
| 231 | + return msg.content |
| 232 | + return msg.additional_kwargs.get("text", "") or "".join( |
| 233 | + [p.get("text", "") for p in msg.content if isinstance(p, dict)] |
| 234 | + ) |
| 235 | + |
227 | 236 | def process_documents(docs, question, messages, llm, model,chat_mode_settings): |
228 | 237 | start_time = time.time() |
229 | 238 |
|
@@ -256,7 +265,7 @@ def process_documents(docs, question, messages, llm, model,chat_mode_settings): |
256 | 265 | result["nodedetails"] = node_details |
257 | 266 | result["entities"] = entities |
258 | 267 |
|
259 | | - content = ai_response.content |
| 268 | + content = get_clean_text(ai_response) |
260 | 269 | total_tokens = get_total_tokens(ai_response, llm) |
261 | 270 |
|
262 | 271 | predict_time = time.time() - start_time |
@@ -508,33 +517,41 @@ def summarize_and_log(history, stored_messages, llm): |
508 | 517 | try: |
509 | 518 | start_time = time.time() |
510 | 519 |
|
511 | | - summarization_prompt = ChatPromptTemplate.from_messages( |
512 | | - [ |
513 | | - MessagesPlaceholder(variable_name="chat_history"), |
514 | | - ( |
515 | | - "human", |
516 | | - "Summarize the above chat messages into a concise message, focusing on key points and relevant details that could be useful for future conversations. Exclude all introductions and extraneous information." |
517 | | - ), |
518 | | - ] |
519 | | - ) |
| 520 | + summarization_prompt = ChatPromptTemplate.from_messages([ |
| 521 | + MessagesPlaceholder(variable_name="chat_history"), |
| 522 | + ("human", "Summarize the above chat messages into a concise message..."), |
| 523 | + ]) |
| 524 | + |
520 | 525 | summarization_chain = summarization_prompt | llm |
521 | 526 |
|
522 | | - summary_message = summarization_chain.invoke({"chat_history": stored_messages}) |
| 527 | + raw_summary = summarization_chain.invoke({"chat_history": stored_messages}) |
| 528 | + |
| 529 | + if hasattr(raw_summary, "content"): |
| 530 | + content = raw_summary.content |
| 531 | + if isinstance(content, list): |
| 532 | + summary_text = "".join([ |
| 533 | + block.get("text", "") if isinstance(block, dict) else str(block) |
| 534 | + for block in content |
| 535 | + ]) |
| 536 | + else: |
| 537 | + summary_text = str(content) |
| 538 | + else: |
| 539 | + summary_text = str(raw_summary) |
| 540 | + |
| 541 | + summary_message_for_db = AIMessage(content=summary_text) |
523 | 542 |
|
524 | 543 | with threading.Lock(): |
525 | 544 | history.clear() |
526 | 545 | history.add_user_message("Our current conversation summary till now") |
527 | | - history.add_message(summary_message) |
528 | | - |
529 | | - history_summarized_time = time.time() - start_time |
530 | | - logging.info(f"Chat History summarized in {history_summarized_time:.2f} seconds") |
| 546 | + history.add_message(summary_message_for_db) |
531 | 547 |
|
| 548 | + logging.info(f"Chat History summarized in {time.time() - start_time:.2f} seconds") |
532 | 549 | return True |
533 | 550 |
|
534 | 551 | except Exception as e: |
535 | 552 | logging.error(f"An error occurred while summarizing messages: {e}", exc_info=True) |
536 | | - return False |
537 | | - |
| 553 | + return False |
| 554 | + |
538 | 555 | def create_graph_chain(model, graph): |
539 | 556 | try: |
540 | 557 | logging.info(f"Graph QA Chain using LLM model: {model}") |
|
0 commit comments