mirror of
https://github.com/Zippland/NanoComic.git
synced 2026-01-19 17:51:07 +08:00
302 lines
10 KiB
Python
302 lines
10 KiB
Python
import os
|
|
import json
|
|
import re
|
|
|
|
from agent.tools_and_schemas import SearchQueryList, Reflection
|
|
from dotenv import load_dotenv
|
|
from langchain_core.messages import AIMessage
|
|
from langgraph.types import Send
|
|
from langgraph.graph import StateGraph
|
|
from langgraph.graph import START, END
|
|
from langchain_core.runnables import RunnableConfig
|
|
from google.genai import Client
|
|
|
|
from agent.state import (
|
|
OverallState,
|
|
QueryGenerationState,
|
|
ReflectionState,
|
|
WebSearchState,
|
|
)
|
|
from agent.configuration import Configuration
|
|
from agent.prompts import (
|
|
get_current_date,
|
|
query_writer_instructions,
|
|
web_searcher_instructions,
|
|
reflection_instructions,
|
|
answer_instructions,
|
|
)
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
from agent.utils import get_research_topic
|
|
|
|
load_dotenv()
|
|
|
|
if os.getenv("GEMINI_API_KEY") is None:
|
|
raise ValueError("GEMINI_API_KEY is not set")
|
|
|
|
# Used for Google Search API
|
|
genai_client = Client(api_key=os.getenv("GEMINI_API_KEY"))
|
|
|
|
|
|
# Nodes
|
|
def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerationState:
|
|
"""LangGraph node that generates search queries based on the User's question.
|
|
|
|
Uses Gemini 2.0 Flash to create an optimized search queries for web research based on
|
|
the User's question.
|
|
|
|
Args:
|
|
state: Current graph state containing the User's question
|
|
config: Configuration for the runnable, including LLM provider settings
|
|
|
|
Returns:
|
|
Dictionary with state update, including search_query key containing the generated queries
|
|
"""
|
|
configurable = Configuration.from_runnable_config(config)
|
|
language = state.get("language") or "English"
|
|
|
|
# check for custom initial search query count
|
|
if state.get("initial_search_query_count") is None:
|
|
state["initial_search_query_count"] = configurable.number_of_initial_queries
|
|
|
|
# init Gemini 2.0 Flash
|
|
llm = ChatGoogleGenerativeAI(
|
|
model=configurable.query_generator_model,
|
|
temperature=1.0,
|
|
max_retries=2,
|
|
api_key=os.getenv("GEMINI_API_KEY"),
|
|
)
|
|
structured_llm = llm.with_structured_output(SearchQueryList)
|
|
|
|
# Format the prompt
|
|
current_date = get_current_date()
|
|
formatted_prompt = query_writer_instructions.format(
|
|
current_date=current_date,
|
|
research_topic=get_research_topic(state["messages"]),
|
|
number_queries=state["initial_search_query_count"],
|
|
language=language,
|
|
)
|
|
# Generate the search queries
|
|
result = structured_llm.invoke(formatted_prompt)
|
|
return {"search_query": result.query}
|
|
|
|
|
|
def continue_to_web_research(state: QueryGenerationState):
|
|
"""LangGraph node that sends the search queries to the web research node.
|
|
|
|
This is used to spawn n number of web research nodes, one for each search query.
|
|
"""
|
|
return [
|
|
Send("web_research", {"search_query": search_query, "id": int(idx)})
|
|
for idx, search_query in enumerate(state["search_query"])
|
|
]
|
|
|
|
|
|
def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState:
|
|
"""LangGraph node that performs web research using the native Google Search API tool.
|
|
|
|
Executes a web search using the native Google Search API tool in combination with Gemini 2.0 Flash.
|
|
|
|
Args:
|
|
state: Current graph state containing the search query and research loop count
|
|
config: Configuration for the runnable, including search API settings
|
|
|
|
Returns:
|
|
Dictionary with state update, including sources_gathered, research_loop_count, and web_research_results
|
|
"""
|
|
# Configure
|
|
configurable = Configuration.from_runnable_config(config)
|
|
language = state.get("language") or "English"
|
|
formatted_prompt = web_searcher_instructions.format(
|
|
current_date=get_current_date(),
|
|
research_topic=state["search_query"],
|
|
language=language,
|
|
)
|
|
|
|
# Uses the google genai client as the langchain client doesn't return grounding metadata
|
|
response = genai_client.models.generate_content(
|
|
model=configurable.query_generator_model,
|
|
contents=formatted_prompt,
|
|
config={
|
|
"tools": [{"google_search": {}}],
|
|
"temperature": 0,
|
|
},
|
|
)
|
|
base_text = response.text or ""
|
|
|
|
return {
|
|
"sources_gathered": [],
|
|
"search_query": [state["search_query"]],
|
|
"web_research_result": [base_text],
|
|
}
|
|
|
|
|
|
def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState:
|
|
"""LangGraph node that identifies knowledge gaps and generates potential follow-up queries.
|
|
|
|
Analyzes the current summary to identify areas for further research and generates
|
|
potential follow-up queries. Uses structured output to extract
|
|
the follow-up query in JSON format.
|
|
|
|
Args:
|
|
state: Current graph state containing the running summary and research topic
|
|
config: Configuration for the runnable, including LLM provider settings
|
|
|
|
Returns:
|
|
Dictionary with state update, including search_query key containing the generated follow-up query
|
|
"""
|
|
configurable = Configuration.from_runnable_config(config)
|
|
language = state.get("language") or "English"
|
|
# Increment the research loop count and get the reasoning model
|
|
state["research_loop_count"] = state.get("research_loop_count", 0) + 1
|
|
reasoning_model = state.get("reasoning_model", configurable.reflection_model)
|
|
|
|
# Format the prompt
|
|
current_date = get_current_date()
|
|
formatted_prompt = reflection_instructions.format(
|
|
current_date=current_date,
|
|
research_topic=get_research_topic(state["messages"]),
|
|
summaries="\n\n---\n\n".join(state["web_research_result"]),
|
|
language=language,
|
|
)
|
|
# init Reasoning Model
|
|
llm = ChatGoogleGenerativeAI(
|
|
model=reasoning_model,
|
|
temperature=1.0,
|
|
max_retries=2,
|
|
api_key=os.getenv("GEMINI_API_KEY"),
|
|
)
|
|
result = llm.with_structured_output(Reflection).invoke(formatted_prompt)
|
|
|
|
return {
|
|
"is_sufficient": result.is_sufficient,
|
|
"knowledge_gap": result.knowledge_gap,
|
|
"follow_up_queries": result.follow_up_queries,
|
|
"research_loop_count": state["research_loop_count"],
|
|
"number_of_ran_queries": len(state["search_query"]),
|
|
}
|
|
|
|
|
|
def evaluate_research(
|
|
state: ReflectionState,
|
|
config: RunnableConfig,
|
|
) -> OverallState:
|
|
"""LangGraph routing function that determines the next step in the research flow.
|
|
|
|
Controls the research loop by deciding whether to continue gathering information
|
|
or to finalize the summary based on the configured maximum number of research loops.
|
|
|
|
Args:
|
|
state: Current graph state containing the research loop count
|
|
config: Configuration for the runnable, including max_research_loops setting
|
|
|
|
Returns:
|
|
String literal indicating the next node to visit ("web_research" or "finalize_summary")
|
|
"""
|
|
configurable = Configuration.from_runnable_config(config)
|
|
max_research_loops = (
|
|
state.get("max_research_loops")
|
|
if state.get("max_research_loops") is not None
|
|
else configurable.max_research_loops
|
|
)
|
|
if state["is_sufficient"] or state["research_loop_count"] >= max_research_loops:
|
|
return "finalize_answer"
|
|
else:
|
|
return [
|
|
Send(
|
|
"web_research",
|
|
{
|
|
"search_query": follow_up_query,
|
|
"id": state["number_of_ran_queries"] + int(idx),
|
|
},
|
|
)
|
|
for idx, follow_up_query in enumerate(state["follow_up_queries"])
|
|
]
|
|
|
|
|
|
def finalize_answer(state: OverallState, config: RunnableConfig):
|
|
"""LangGraph node that finalizes the research summary.
|
|
|
|
Prepares the final output by deduplicating and formatting sources, then
|
|
combining them with the running summary to create a well-structured
|
|
research report with proper citations.
|
|
|
|
Args:
|
|
state: Current graph state containing the running summary and sources gathered
|
|
|
|
Returns:
|
|
Dictionary with state update, including running_summary key containing the formatted final summary with sources
|
|
"""
|
|
configurable = Configuration.from_runnable_config(config)
|
|
reasoning_model = state.get("reasoning_model") or configurable.answer_model
|
|
language = state.get("language") or "English"
|
|
|
|
# Format the prompt
|
|
current_date = get_current_date()
|
|
# Escape braces in user content to avoid str.format KeyErrors when summaries contain JSON-like text
|
|
safe_topic = get_research_topic(state["messages"]).replace("{", "{{").replace(
|
|
"}", "}}"
|
|
)
|
|
summaries_text = "\n---\n\n".join(state["web_research_result"])
|
|
safe_summaries = summaries_text.replace("{", "{{").replace("}", "}}")
|
|
formatted_prompt = answer_instructions.format(
|
|
current_date=current_date,
|
|
research_topic=safe_topic,
|
|
summaries=safe_summaries,
|
|
language=language,
|
|
)
|
|
|
|
# init Reasoning Model, default to Gemini 2.5 Flash
|
|
llm = ChatGoogleGenerativeAI(
|
|
model=reasoning_model,
|
|
temperature=0,
|
|
max_retries=2,
|
|
api_key=os.getenv("GEMINI_API_KEY"),
|
|
)
|
|
result = llm.invoke(formatted_prompt)
|
|
|
|
# Clean potential markdown fences and parse JSON so we return structured content
|
|
content = result.content
|
|
if isinstance(content, str):
|
|
# Strip markdown fences ```json ... ```
|
|
cleaned = re.sub(r"^```[a-zA-Z]*\s*|\s*```$", "", content.strip())
|
|
try:
|
|
parsed = json.loads(cleaned)
|
|
except Exception:
|
|
parsed = cleaned
|
|
content_payload = parsed
|
|
else:
|
|
content_payload = content
|
|
|
|
return {
|
|
"messages": [AIMessage(content=content_payload)],
|
|
"sources_gathered": [],
|
|
}
|
|
|
|
|
|
# Create our Agent Graph
|
|
builder = StateGraph(OverallState, config_schema=Configuration)
|
|
|
|
# Define the nodes we will cycle between
|
|
builder.add_node("generate_query", generate_query)
|
|
builder.add_node("web_research", web_research)
|
|
builder.add_node("reflection", reflection)
|
|
builder.add_node("finalize_answer", finalize_answer)
|
|
|
|
# Set the entrypoint as `generate_query`
|
|
# This means that this node is the first one called
|
|
builder.add_edge(START, "generate_query")
|
|
# Add conditional edge to continue with search queries in a parallel branch
|
|
builder.add_conditional_edges(
|
|
"generate_query", continue_to_web_research, ["web_research"]
|
|
)
|
|
# Reflect on the web research
|
|
builder.add_edge("web_research", "reflection")
|
|
# Evaluate the research
|
|
builder.add_conditional_edges(
|
|
"reflection", evaluate_research, ["web_research", "finalize_answer"]
|
|
)
|
|
# Finalize the answer
|
|
builder.add_edge("finalize_answer", END)
|
|
|
|
graph = builder.compile(name="pro-search-agent")
|