mirror of
https://github.com/Zippland/NanoComic.git
synced 2026-03-03 08:34:55 +08:00
init
This commit is contained in:
293
backend/src/agent/graph.py
Normal file
293
backend/src/agent/graph.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import os
|
||||
|
||||
from agent.tools_and_schemas import SearchQueryList, Reflection
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.types import Send
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.graph import START, END
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from google.genai import Client
|
||||
|
||||
from agent.state import (
|
||||
OverallState,
|
||||
QueryGenerationState,
|
||||
ReflectionState,
|
||||
WebSearchState,
|
||||
)
|
||||
from agent.configuration import Configuration
|
||||
from agent.prompts import (
|
||||
get_current_date,
|
||||
query_writer_instructions,
|
||||
web_searcher_instructions,
|
||||
reflection_instructions,
|
||||
answer_instructions,
|
||||
)
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from agent.utils import (
|
||||
get_citations,
|
||||
get_research_topic,
|
||||
insert_citation_markers,
|
||||
resolve_urls,
|
||||
)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
if os.getenv("GEMINI_API_KEY") is None:
|
||||
raise ValueError("GEMINI_API_KEY is not set")
|
||||
|
||||
# Used for Google Search API
|
||||
genai_client = Client(api_key=os.getenv("GEMINI_API_KEY"))
|
||||
|
||||
|
||||
# Nodes
|
||||
def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerationState:
|
||||
"""LangGraph node that generates a search queries based on the User's question.
|
||||
|
||||
Uses Gemini 2.0 Flash to create an optimized search query for web research based on
|
||||
the User's question.
|
||||
|
||||
Args:
|
||||
state: Current graph state containing the User's question
|
||||
config: Configuration for the runnable, including LLM provider settings
|
||||
|
||||
Returns:
|
||||
Dictionary with state update, including search_query key containing the generated query
|
||||
"""
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
|
||||
# check for custom initial search query count
|
||||
if state.get("initial_search_query_count") is None:
|
||||
state["initial_search_query_count"] = configurable.number_of_initial_queries
|
||||
|
||||
# init Gemini 2.0 Flash
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model=configurable.query_generator_model,
|
||||
temperature=1.0,
|
||||
max_retries=2,
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
)
|
||||
structured_llm = llm.with_structured_output(SearchQueryList)
|
||||
|
||||
# Format the prompt
|
||||
current_date = get_current_date()
|
||||
formatted_prompt = query_writer_instructions.format(
|
||||
current_date=current_date,
|
||||
research_topic=get_research_topic(state["messages"]),
|
||||
number_queries=state["initial_search_query_count"],
|
||||
)
|
||||
# Generate the search queries
|
||||
result = structured_llm.invoke(formatted_prompt)
|
||||
return {"query_list": result.query}
|
||||
|
||||
|
||||
def continue_to_web_research(state: QueryGenerationState):
|
||||
"""LangGraph node that sends the search queries to the web research node.
|
||||
|
||||
This is used to spawn n number of web research nodes, one for each search query.
|
||||
"""
|
||||
return [
|
||||
Send("web_research", {"search_query": search_query, "id": int(idx)})
|
||||
for idx, search_query in enumerate(state["query_list"])
|
||||
]
|
||||
|
||||
|
||||
def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState:
|
||||
"""LangGraph node that performs web research using the native Google Search API tool.
|
||||
|
||||
Executes a web search using the native Google Search API tool in combination with Gemini 2.0 Flash.
|
||||
|
||||
Args:
|
||||
state: Current graph state containing the search query and research loop count
|
||||
config: Configuration for the runnable, including search API settings
|
||||
|
||||
Returns:
|
||||
Dictionary with state update, including sources_gathered, research_loop_count, and web_research_results
|
||||
"""
|
||||
# Configure
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
formatted_prompt = web_searcher_instructions.format(
|
||||
current_date=get_current_date(),
|
||||
research_topic=state["search_query"],
|
||||
)
|
||||
|
||||
# Uses the google genai client as the langchain client doesn't return grounding metadata
|
||||
response = genai_client.models.generate_content(
|
||||
model=configurable.query_generator_model,
|
||||
contents=formatted_prompt,
|
||||
config={
|
||||
"tools": [{"google_search": {}}],
|
||||
"temperature": 0,
|
||||
},
|
||||
)
|
||||
# resolve the urls to short urls for saving tokens and time
|
||||
resolved_urls = resolve_urls(
|
||||
response.candidates[0].grounding_metadata.grounding_chunks, state["id"]
|
||||
)
|
||||
# Gets the citations and adds them to the generated text
|
||||
citations = get_citations(response, resolved_urls)
|
||||
modified_text = insert_citation_markers(response.text, citations)
|
||||
sources_gathered = [item for citation in citations for item in citation["segments"]]
|
||||
|
||||
return {
|
||||
"sources_gathered": sources_gathered,
|
||||
"search_query": [state["search_query"]],
|
||||
"web_research_result": [modified_text],
|
||||
}
|
||||
|
||||
|
||||
def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState:
|
||||
"""LangGraph node that identifies knowledge gaps and generates potential follow-up queries.
|
||||
|
||||
Analyzes the current summary to identify areas for further research and generates
|
||||
potential follow-up queries. Uses structured output to extract
|
||||
the follow-up query in JSON format.
|
||||
|
||||
Args:
|
||||
state: Current graph state containing the running summary and research topic
|
||||
config: Configuration for the runnable, including LLM provider settings
|
||||
|
||||
Returns:
|
||||
Dictionary with state update, including search_query key containing the generated follow-up query
|
||||
"""
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
# Increment the research loop count and get the reasoning model
|
||||
state["research_loop_count"] = state.get("research_loop_count", 0) + 1
|
||||
reasoning_model = state.get("reasoning_model") or configurable.reasoning_model
|
||||
|
||||
# Format the prompt
|
||||
current_date = get_current_date()
|
||||
formatted_prompt = reflection_instructions.format(
|
||||
current_date=current_date,
|
||||
research_topic=get_research_topic(state["messages"]),
|
||||
summaries="\n\n---\n\n".join(state["web_research_result"]),
|
||||
)
|
||||
# init Reasoning Model
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model=reasoning_model,
|
||||
temperature=1.0,
|
||||
max_retries=2,
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
)
|
||||
result = llm.with_structured_output(Reflection).invoke(formatted_prompt)
|
||||
|
||||
return {
|
||||
"is_sufficient": result.is_sufficient,
|
||||
"knowledge_gap": result.knowledge_gap,
|
||||
"follow_up_queries": result.follow_up_queries,
|
||||
"research_loop_count": state["research_loop_count"],
|
||||
"number_of_ran_queries": len(state["search_query"]),
|
||||
}
|
||||
|
||||
|
||||
def evaluate_research(
|
||||
state: ReflectionState,
|
||||
config: RunnableConfig,
|
||||
) -> OverallState:
|
||||
"""LangGraph routing function that determines the next step in the research flow.
|
||||
|
||||
Controls the research loop by deciding whether to continue gathering information
|
||||
or to finalize the summary based on the configured maximum number of research loops.
|
||||
|
||||
Args:
|
||||
state: Current graph state containing the research loop count
|
||||
config: Configuration for the runnable, including max_research_loops setting
|
||||
|
||||
Returns:
|
||||
String literal indicating the next node to visit ("web_research" or "finalize_summary")
|
||||
"""
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
max_research_loops = (
|
||||
state.get("max_research_loops")
|
||||
if state.get("max_research_loops") is not None
|
||||
else configurable.max_research_loops
|
||||
)
|
||||
if state["is_sufficient"] or state["research_loop_count"] >= max_research_loops:
|
||||
return "finalize_answer"
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"web_research",
|
||||
{
|
||||
"search_query": follow_up_query,
|
||||
"id": state["number_of_ran_queries"] + int(idx),
|
||||
},
|
||||
)
|
||||
for idx, follow_up_query in enumerate(state["follow_up_queries"])
|
||||
]
|
||||
|
||||
|
||||
def finalize_answer(state: OverallState, config: RunnableConfig):
|
||||
"""LangGraph node that finalizes the research summary.
|
||||
|
||||
Prepares the final output by deduplicating and formatting sources, then
|
||||
combining them with the running summary to create a well-structured
|
||||
research report with proper citations.
|
||||
|
||||
Args:
|
||||
state: Current graph state containing the running summary and sources gathered
|
||||
|
||||
Returns:
|
||||
Dictionary with state update, including running_summary key containing the formatted final summary with sources
|
||||
"""
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
reasoning_model = state.get("reasoning_model") or configurable.reasoning_model
|
||||
|
||||
# Format the prompt
|
||||
current_date = get_current_date()
|
||||
formatted_prompt = answer_instructions.format(
|
||||
current_date=current_date,
|
||||
research_topic=get_research_topic(state["messages"]),
|
||||
summaries="\n---\n\n".join(state["web_research_result"]),
|
||||
)
|
||||
|
||||
# init Reasoning Model, default to Gemini 2.5 Flash
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model=reasoning_model,
|
||||
temperature=0,
|
||||
max_retries=2,
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
)
|
||||
result = llm.invoke(formatted_prompt)
|
||||
|
||||
# Replace the short urls with the original urls and add all used urls to the sources_gathered
|
||||
unique_sources = []
|
||||
for source in state["sources_gathered"]:
|
||||
if source["short_url"] in result.content:
|
||||
result.content = result.content.replace(
|
||||
source["short_url"], source["value"]
|
||||
)
|
||||
unique_sources.append(source)
|
||||
|
||||
return {
|
||||
"messages": [AIMessage(content=result.content)],
|
||||
"sources_gathered": unique_sources,
|
||||
}
|
||||
|
||||
|
||||
# Create our Agent Graph
|
||||
builder = StateGraph(OverallState, config_schema=Configuration)
|
||||
|
||||
# Define the nodes we will cycle between
|
||||
builder.add_node("generate_query", generate_query)
|
||||
builder.add_node("web_research", web_research)
|
||||
builder.add_node("reflection", reflection)
|
||||
builder.add_node("finalize_answer", finalize_answer)
|
||||
|
||||
# Set the entrypoint as `generate_query`
|
||||
# This means that this node is the first one called
|
||||
builder.add_edge(START, "generate_query")
|
||||
# Add conditional edge to continue with search queries in a parallel branch
|
||||
builder.add_conditional_edges(
|
||||
"generate_query", continue_to_web_research, ["web_research"]
|
||||
)
|
||||
# Reflect on the web research
|
||||
builder.add_edge("web_research", "reflection")
|
||||
# Evaluate the research
|
||||
builder.add_conditional_edges(
|
||||
"reflection", evaluate_research, ["web_research", "finalize_answer"]
|
||||
)
|
||||
# Finalize the answer
|
||||
builder.add_edge("finalize_answer", END)
|
||||
|
||||
graph = builder.compile(name="pro-search-agent")
|
||||
Reference in New Issue
Block a user