refactor(agent): simplify web research and citation handling

feat(frontend): enhance image generation with e
This commit is contained in:
zihanjian
2025-12-02 17:13:07 +08:00
parent 61890d9f06
commit 8d4223fbdb
6 changed files with 401 additions and 380 deletions

View File

@@ -3,6 +3,7 @@ import base64
import io import io
import os import os
import pathlib import pathlib
import logging
from fastapi import FastAPI, Response, HTTPException from fastapi import FastAPI, Response, HTTPException
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware

View File

@@ -26,12 +26,7 @@ from agent.prompts import (
answer_instructions, answer_instructions,
) )
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai import ChatGoogleGenerativeAI
from agent.utils import ( from agent.utils import get_research_topic
get_citations,
get_research_topic,
insert_citation_markers,
resolve_urls,
)
load_dotenv() load_dotenv()
@@ -126,22 +121,12 @@ def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState:
"temperature": 0, "temperature": 0,
}, },
) )
# resolve the urls to short urls for saving tokens and time
candidate = response.candidates[0] if response and response.candidates else None
grounding_chunks = None
if candidate and getattr(candidate, "grounding_metadata", None):
grounding_chunks = getattr(candidate.grounding_metadata, "grounding_chunks", None)
resolved_urls = resolve_urls(grounding_chunks, state["id"])
# Gets the citations and adds them to the generated text
citations = get_citations(response, resolved_urls)
base_text = response.text or "" base_text = response.text or ""
modified_text = insert_citation_markers(base_text, citations)
sources_gathered = [item for citation in citations for item in citation["segments"]]
return { return {
"sources_gathered": sources_gathered, "sources_gathered": [],
"search_query": [state["search_query"]], "search_query": [state["search_query"]],
"web_research_result": [modified_text], "web_research_result": [base_text],
} }
@@ -282,26 +267,9 @@ def finalize_answer(state: OverallState, config: RunnableConfig):
else: else:
content_payload = content content_payload = content
# Replace the short urls with the original urls and add all used urls to the sources_gathered
unique_sources = []
for source in state["sources_gathered"]:
if isinstance(content_payload, str) and source["short_url"] in content_payload:
content_payload = content_payload.replace(source["short_url"], source["value"])
unique_sources.append(source)
elif isinstance(content_payload, list):
# if list of page dicts, replace inside detail strings
updated_pages = []
for page in content_payload:
if isinstance(page, dict) and isinstance(page.get("detail"), str):
page_detail = page["detail"].replace(source["short_url"], source["value"])
page = {**page, "detail": page_detail}
updated_pages.append(page)
content_payload = updated_pages
unique_sources.append(source)
return { return {
"messages": [AIMessage(content=content_payload)], "messages": [AIMessage(content=content_payload)],
"sources_gathered": unique_sources, "sources_gathered": [],
} }

View File

@@ -6,109 +6,113 @@ def get_current_date():
return datetime.now().strftime("%B %d, %Y") return datetime.now().strftime("%B %d, %Y")
query_writer_instructions = """Your goal is to generate sophisticated and diverse web search queries that gather everything needed to turn the user's idea into a detailed comic storyboard. These queries will fuel an advanced automated web research tool capable of analyzing complex results, following links, and synthesizing information. query_writer_instructions = """你的目标是生成精细且多样的网页搜索查询,收集将用户想法转化为详细漫画分镜所需的一切信息。这些查询会驱动一个先进的自动化网络研究工具,它能分析复杂结果、跟进链接并综合信息。
Instructions: 指引:
- Target the details required for comics: characters (personality, appearance, speech style), settings (era, location, atmosphere), and key objects or events (what they are and how they look). - 瞄准漫画需要的细节:角色(性格、外貌、说话风格)、场景(时代、地点、氛围)以及关键物件或事件(是什么、长什么样)。
- Always prefer a single search query, only add another query if the original question requests multiple aspects or elements and one query is not enough. - 优先只用一个搜索查询;只有当原始问题包含多个要点且一个查询不够时才再添加查询。
- Each query should focus on one specific aspect of the original question. - 每个查询应聚焦原始问题的一个具体方面。
- Don't produce more than {number_queries} queries. - 不要生成超过 {number_queries} 个查询。
- Queries should be diverse, if the topic is broad, generate more than 1 query. - 查询要多样;如果主题较广,生成多于 1 个查询。
- Don't generate multiple similar queries, 1 is enough. - 不要生成多个相似查询1 个就够。
- Query should ensure that the most current information is gathered. The current date is {current_date}. - 查询应确保获取最新信息。当前日期是 {current_date}
- Always responed in {language}. - 始终用 {language} 回答。
Format: 格式:
- Format your response as a JSON object with ALL two of these exact keys: - 将响应格式化为一个 JSON 对象,且只包含以下两个键:
- "rationale": Brief explanation of why these queries are relevant - "rationale":简要说明这些查询为何相关
- "query": A list of search queries - "query":搜索查询列表
Example: 示例:
Topic: What revenue grew more last year apple stock or the number of people buying an iphone 主题: 去年苹果股票的收入增长更多,还是购买 iPhone 的人数增长更多
```json ```json
{{ {{
"rationale": "To answer this comparative growth question accurately, we need specific data points on Apple's stock performance and iPhone sales metrics. These queries target the precise financial information needed: company revenue trends, product-specific unit sales figures, and stock price movement over the same fiscal period for direct comparison.", "rationale": "要准确回答这一比较性的增长问题,需要苹果股票表现和 iPhone 销售指标的具体数据。以下查询聚焦所需的精确信息:公司收入趋势、单品销量数字以及同一财年的股价走势,便于直接比较。",
"query": ["Apple total revenue growth fiscal year 2024", "iPhone unit sales growth fiscal year 2024", "Apple stock price growth fiscal year 2024"], "query": ["Apple total revenue growth fiscal year 2024", "iPhone unit sales growth fiscal year 2024", "Apple stock price growth fiscal year 2024"],
}} }}
``` ```
Context: {research_topic}""" 上下文: {research_topic}"""
web_searcher_instructions = """Conduct targeted Google Searches to gather the most recent, credible information on "{research_topic}" and synthesize it into a verifiable text artifact. web_searcher_instructions = """进行有针对性的 Google 搜索,收集关于“{research_topic}”的最新可信信息,并综合成可验证的文本材料。
Instructions: 指引:
- Query should ensure that the most current information is gathered. The current date is {current_date}. - 查询要确保获取最新信息。当前日期是 {current_date}
- Conduct multiple, diverse searches to gather comprehensive information for building a comic storyboard: character traits (personality, visual appearance, clothing, speech patterns), definitions of any mentioned objects or terms, and setting details (time period, geography, mood, visual cues). - 进行多样化的多次搜索,收集构建漫画分镜所需的全面信息:角色特征(性格、视觉外观、服饰、说话方式)、提到的物体或术语定义,以及场景细节(时代、地理、氛围、视觉线索)。
- Consolidate key findings while meticulously tracking the source(s) for each specific piece of information. - 整理关键信息时要细致记录每条信息对应的来源。
- The output should be concise research notes oriented toward comic creation, not a narrative report. Capture factual details that help draw scenes and characters. - 输出应是面向漫画创作的简明研究笔记,而非叙事报告。只捕捉有助于绘制场景与角色的事实细节。
- Only include the information found in the search results, don't make up any information. - 只包含搜索结果中的信息,不要编造。
- Always responed in {language}. - 始终用 {language} 回答。
Research Topic: 研究主题:
{research_topic} {research_topic}
""" """
reflection_instructions = """You are an expert research assistant analyzing summaries about "{research_topic}" to support a comic storyboard. reflection_instructions = """你是一名资深研究助理,正在分析关于“{research_topic}”的摘要,以支持漫画分镜创作。
Instructions: 指引:
- Identify knowledge gaps that block a vivid comic storyboard: missing character personality or appearance, unclear speech style, undefined objects/terms, or incomplete setting/era/mood. Generate a follow-up query (1 or multiple) to fill these gaps. - 找出阻碍生动分镜的知识缺口:缺失的角色性格或外貌、不明确的说话风格、未定义的物体/术语、或不完整的场景时代/氛围。生成 1 条或多条后续查询来补齐这些缺口。
- If provided summaries are sufficient to answer the user's question, don't generate a follow-up query. - 如果给定摘要已足够回答用户问题,不要生成后续查询。
- If there is a knowledge gap, generate a follow-up query that would help expand your understanding. - 若存在知识缺口,生成能扩展理解的后续查询。
- Focus on technical details, implementation specifics, or emerging trends that weren't fully covered. - 关注摘要未充分覆盖的技术细节、实现细节或新兴趋势。
- Always responed in {language}. - 始终用 {language} 回答。
Requirements: 要求:
- Ensure the follow-up query is self-contained and includes necessary context for web search. - 确保后续查询是自包含的,并包含网页搜索所需的上下文。
Output Format: 输出格式:
- Format your response as a JSON object with these exact keys: - 将响应格式化为包含以下精确键的 JSON 对象:
- "is_sufficient": true or false - "is_sufficient": true false
- "knowledge_gap": Describe what information is missing or needs clarification - "knowledge_gap": 描述缺失或需要澄清的信息
- "follow_up_queries": Write a specific question to address this gap - "follow_up_queries": 编写针对该缺口的具体问题
Example: 示例:
```json ```json
{{ {{
"is_sufficient": true, // or false "is_sufficient": true, // false
"knowledge_gap": "The summary lacks information about performance metrics and benchmarks", // "" if is_sufficient is true "knowledge_gap": "摘要缺少关于性能指标和基准的描述", // is_sufficient true 时填 ""
"follow_up_queries": ["What are typical performance benchmarks and metrics used to evaluate [specific technology]?"] // [] if is_sufficient is true "follow_up_queries": ["评估 [特定技术] 常用的性能基准和指标是什么?"] // is_sufficient true 时填 []
}} }}
``` ```
Reflect carefully on the Summaries to identify knowledge gaps and produce a follow-up query. Then, produce your output following this JSON format: <SUMMARIES>
# 仔细审视 Summaries找出知识缺口并生成后续查询。然后按上述 JSON 格式输出。
Summaries:
{summaries} {summaries}
</SUMMARIES>
""" """
answer_instructions = """Create a detailed comic storyboard based on the user's request and the provided research summaries. answer_instructions = """你是一名漫画脚本师,正在创作关于“{research_topic}”的详细的漫画分镜脚本。
Strict Requirements: 严格要求:
- Output ONLY valid JSON array. No prose, no markdown fences, no comments. - 只输出有效的 JSON 数组。不要有正文、Markdown 代码块或注释。
- The JSON must be an array of page objects. Each page object MUST have EXACTLY two keys: - JSON 必须是页面对象的数组。每个页面对象必须且仅有两个键:
- "id": integer, the 1-based page identifier (e.g., 1, 2, 3, ...) - "id":整数,基于 1 的页面编号(如 1, 2, 3, ...
- "detail": string, a thorough page description that fine-grains every panel: characters' actions, attire, environment, camera/framing, dialogue with tone, props, transitions. - "detail":字符串,对每个分镜的详尽描述:角色动作、服装、环境、镜头/构图、带语气的对话、道具、转场。
- Do NOT invent facts. Ground all details in the provided summaries. - 不要编造事实。所有细节都要基于提供的摘要。
Example JSON (structure only): 示例 JSON仅示意结构
[ [
{{ "id": 1, "detail": "..." }}, {{ "id": 1, "detail": "..." }},
{{ "id": 2, "detail": "..." }} {{ "id": 2, "detail": "..." }}
] ]
Instructions: 指引:
- The current date is {current_date}. - 当前日期是 {current_date}
- You are the final step of a multi-step research process; don't mention that you are the final step. - 你是多步研究流程的最后一步;不要提及这一点。
- Use the user's request and all research summaries to build the storyboard. - 使用用户请求和全部研究摘要来构建分镜,每一页都是单独生成的,所以应该包含所有描述性信息。
- If the topic includes people, capture personality, visual appearance (hair, clothing, accessories), and speech style. If it includes objects, explain what they are and notable visual traits. If it includes locations or events, capture time period, atmosphere, and visual cues. - 如果主题包含人物,每一页都需要捕捉性格、外貌(发型、服装、配饰)和说话风格;如果包含物体,每一页都需要说明它们是什么以及显著外观特征;如果包含地点或事件,每一页都需要捕捉时代、氛围和视觉线索。
- Output must be a page-by-page JSON where each page is an object with "id" and a single "detail" string that thoroughly covers all panels and specifics. - 输出必须是逐页的 JSON每一页是含有 "id" 和单个 "detail" 字符串的对象,详尽覆盖所有分镜和细节。
- Always responed in {language}. - 始终用 {language} 回答。
User Context: 用户上下文:
- {research_topic} - {research_topic}
Summaries: <SUMMARIES>
# Summaries
{summaries} {summaries}
</SUMMARIES>
""" """

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List from typing import List
from langchain_core.messages import AnyMessage, AIMessage, HumanMessage from langchain_core.messages import AnyMessage, AIMessage, HumanMessage
@@ -17,152 +17,3 @@ def get_research_topic(messages: List[AnyMessage]) -> str:
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
research_topic += f"Assistant: {message.content}\n" research_topic += f"Assistant: {message.content}\n"
return research_topic return research_topic
def resolve_urls(urls_to_resolve: List[Any], id: int) -> Dict[str, str]:
"""
Create a map of the vertex ai search urls (very long) to a short url with a unique id for each url.
Ensures each original URL gets a consistent shortened form while maintaining uniqueness.
"""
if not urls_to_resolve:
return {}
prefix = f"https://vertexaisearch.cloud.google.com/id/"
urls = [site.web.uri for site in urls_to_resolve]
# Create a dictionary that maps each unique URL to its first occurrence index
resolved_map = {}
for idx, url in enumerate(urls):
if url not in resolved_map:
resolved_map[url] = f"{prefix}{id}-{idx}"
return resolved_map
def insert_citation_markers(text, citations_list):
"""
Inserts citation markers into a text string based on start and end indices.
Args:
text (str): The original text string.
citations_list (list): A list of dictionaries, where each dictionary
contains 'start_index', 'end_index', and
'segment_string' (the marker to insert).
Indices are assumed to be for the original text.
Returns:
str: The text with citation markers inserted.
"""
# Sort citations by end_index in descending order.
# If end_index is the same, secondary sort by start_index descending.
# This ensures that insertions at the end of the string don't affect
# the indices of earlier parts of the string that still need to be processed.
sorted_citations = sorted(
citations_list, key=lambda c: (c["end_index"], c["start_index"]), reverse=True
)
modified_text = text
for citation_info in sorted_citations:
# These indices refer to positions in the *original* text,
# but since we iterate from the end, they remain valid for insertion
# relative to the parts of the string already processed.
end_idx = citation_info["end_index"]
marker_to_insert = ""
for segment in citation_info["segments"]:
marker_to_insert += f" [{segment['label']}]({segment['short_url']})"
# Insert the citation marker at the original end_idx position
modified_text = (
modified_text[:end_idx] + marker_to_insert + modified_text[end_idx:]
)
return modified_text
def get_citations(response, resolved_urls_map):
"""
Extracts and formats citation information from a Gemini model's response.
This function processes the grounding metadata provided in the response to
construct a list of citation objects. Each citation object includes the
start and end indices of the text segment it refers to, and a string
containing formatted markdown links to the supporting web chunks.
Args:
response: The response object from the Gemini model, expected to have
a structure including `candidates[0].grounding_metadata`.
It also relies on a `resolved_map` being available in its
scope to map chunk URIs to resolved URLs.
Returns:
list: A list of dictionaries, where each dictionary represents a citation
and has the following keys:
- "start_index" (int): The starting character index of the cited
segment in the original text. Defaults to 0
if not specified.
- "end_index" (int): The character index immediately after the
end of the cited segment (exclusive).
- "segments" (list[str]): A list of individual markdown-formatted
links for each grounding chunk.
- "segment_string" (str): A concatenated string of all markdown-
formatted links for the citation.
Returns an empty list if no valid candidates or grounding supports
are found, or if essential data is missing.
"""
citations = []
# Ensure response and necessary nested structures are present
if not response or not response.candidates:
return citations
candidate = response.candidates[0]
if (
not hasattr(candidate, "grounding_metadata")
or not candidate.grounding_metadata
or not hasattr(candidate.grounding_metadata, "grounding_supports")
):
return citations
for support in candidate.grounding_metadata.grounding_supports:
citation = {}
# Ensure segment information is present
if not hasattr(support, "segment") or support.segment is None:
continue # Skip this support if segment info is missing
start_index = (
support.segment.start_index
if support.segment.start_index is not None
else 0
)
# Ensure end_index is present to form a valid segment
if support.segment.end_index is None:
continue # Skip if end_index is missing, as it's crucial
# Add 1 to end_index to make it an exclusive end for slicing/range purposes
# (assuming the API provides an inclusive end_index)
citation["start_index"] = start_index
citation["end_index"] = support.segment.end_index
citation["segments"] = []
if (
hasattr(support, "grounding_chunk_indices")
and support.grounding_chunk_indices
):
for ind in support.grounding_chunk_indices:
try:
chunk = candidate.grounding_metadata.grounding_chunks[ind]
resolved_url = resolved_urls_map.get(chunk.web.uri, None)
citation["segments"].append(
{
"label": chunk.web.title.split(".")[:-1][0],
"short_url": resolved_url,
"value": chunk.web.uri,
}
)
except (IndexError, AttributeError, NameError):
# Handle cases where chunk, web, uri, or resolved_map might be problematic
# For simplicity, we'll just skip adding this particular segment link
# In a production system, you might want to log this.
pass
citations.append(citation)
return citations

View File

@@ -59,7 +59,7 @@ export default function App() {
}; };
} else if (event.finalize_answer) { } else if (event.finalize_answer) {
processedEvent = { processedEvent = {
title: "Finalizing Answer", title: "Generate Scripts",
data: "Composing and presenting the final answer.", data: "Composing and presenting the final answer.",
}; };
hasFinalizeEventOccurredRef.current = true; hasFinalizeEventOccurredRef.current = true;

View File

@@ -1,10 +1,10 @@
import type React from "react"; import type React from "react";
import type { Message } from "@langchain/langgraph-sdk"; import type { Message } from "@langchain/langgraph-sdk";
import { ScrollArea } from "@/components/ui/scroll-area"; import { ScrollArea } from "@/components/ui/scroll-area";
import { Loader2, Copy, CopyCheck } from "lucide-react"; import { Loader2, Pencil, ArrowUpCircle } from "lucide-react";
import { InputForm } from "@/components/InputForm"; import { InputForm } from "@/components/InputForm";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { useState, ReactNode, useEffect } from "react"; import { useState, ReactNode, useEffect, useCallback } from "react";
import ReactMarkdown from "react-markdown"; import ReactMarkdown from "react-markdown";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { Badge } from "@/components/ui/badge"; import { Badge } from "@/components/ui/badge";
@@ -12,6 +12,7 @@ import {
ActivityTimeline, ActivityTimeline,
ProcessedEvent, ProcessedEvent,
} from "@/components/ActivityTimeline"; // Assuming ActivityTimeline is in the same dir or adjust path } from "@/components/ActivityTimeline"; // Assuming ActivityTimeline is in the same dir or adjust path
import { Textarea } from "@/components/ui/textarea";
// Markdown component props type from former ReportView // Markdown component props type from former ReportView
type MdComponentProps = { type MdComponentProps = {
@@ -166,8 +167,6 @@ interface AiMessageBubbleProps {
isLastMessage: boolean; isLastMessage: boolean;
isOverallLoading: boolean; isOverallLoading: boolean;
mdComponents: typeof mdComponents; mdComponents: typeof mdComponents;
handleCopy: (text: string, messageId: string) => void;
copiedMessageId: string | null;
aspectRatio?: string; aspectRatio?: string;
imageSize?: string; imageSize?: string;
} }
@@ -180,17 +179,20 @@ const AiMessageBubble: React.FC<AiMessageBubbleProps> = ({
isLastMessage, isLastMessage,
isOverallLoading, isOverallLoading,
mdComponents, mdComponents,
handleCopy,
copiedMessageId,
aspectRatio, aspectRatio,
imageSize, imageSize,
}) => { }) => {
const [pageImages, setPageImages] = useState< type PageImageState = {
Record< status: "idle" | "pending" | "done" | "error";
string, images: { url: string; id: string }[];
{ status: "pending" | "done" | "error"; url?: string; error?: string } activeIndex: number;
> error?: string;
>({}); draft: string;
isEditing: boolean;
};
const messageKey = message.id ?? "ai";
const [pageStates, setPageStates] = useState<Record<string, PageImageState>>({});
const parsedPages = (() => { const parsedPages = (() => {
const raw = message.content; const raw = message.content;
@@ -225,51 +227,198 @@ const AiMessageBubble: React.FC<AiMessageBubbleProps> = ({
const isLiveActivityForThisBubble = isLastMessage && isOverallLoading; const isLiveActivityForThisBubble = isLastMessage && isOverallLoading;
useEffect(() => { useEffect(() => {
if (!parsedPages || !message.id) return; if (!parsedPages) return;
const backendBase = import.meta.env.DEV setPageStates((prev) => {
? "http://localhost:2024" let changed = false;
: "http://localhost:8123"; const next = { ...prev };
parsedPages.forEach((page) => {
parsedPages.forEach((page) => { const key = `${messageKey}-${page.id}`;
const key = `${message.id}-${page.id}`; if (!next[key]) {
if (pageImages[key]) return; // already requested next[key] = {
status: "idle",
setPageImages((prev) => ({ images: [],
...prev, activeIndex: 0,
[key]: { status: "pending" }, draft: page.detail,
})); isEditing: false,
};
fetch(`${backendBase}/generate_image`, { changed = true;
method: "POST", }
headers: { "Content-Type": "application/json" }, });
body: JSON.stringify({ return changed ? next : prev;
prompt: page.detail,
number_of_images: 1,
aspect_ratio: aspectRatio || "16:9",
image_size: imageSize || "1K",
}),
})
.then(async (res) => {
if (!res.ok) throw new Error(await res.text());
return res.json();
})
.then((data) => {
const url =
data?.images && Array.isArray(data.images) ? data.images[0] : null;
if (!url) throw new Error("No image returned");
setPageImages((prev) => ({
...prev,
[key]: { status: "done", url },
}));
})
.catch((err) => {
setPageImages((prev) => ({
...prev,
[key]: { status: "error", error: String(err) },
}));
});
}); });
}, [parsedPages, message.id, pageImages]); }, [parsedPages, messageKey]);
const requestImage = useCallback(
async (key: string, prompt: string) => {
const backendBase = import.meta.env.DEV
? "http://localhost:2024"
: "http://localhost:8123";
const trimmedPrompt = prompt.trim();
if (!trimmedPrompt) return;
setPageStates((prev) => {
const current =
prev[key] ||
({
status: "idle",
images: [],
activeIndex: 0,
draft: trimmedPrompt,
isEditing: false,
} as PageImageState);
return {
...prev,
[key]: {
...current,
status: "pending",
error: undefined,
draft: trimmedPrompt,
},
};
});
try {
const res = await fetch(`${backendBase}/generate_image`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
prompt: trimmedPrompt,
number_of_images: 1,
aspect_ratio: aspectRatio || "16:9",
image_size: imageSize || "1K",
}),
});
if (!res.ok) {
throw new Error(await res.text());
}
const data = await res.json();
const url =
data?.images && Array.isArray(data.images) ? data.images[0] : null;
if (!url) throw new Error("No image returned");
setPageStates((prev) => {
const current =
prev[key] ||
({
status: "idle",
images: [],
activeIndex: 0,
draft: trimmedPrompt,
isEditing: false,
} as PageImageState);
const newImages = [
...(current.images || []),
{
url,
id: `${Date.now()}-${Math.random().toString(36).slice(2, 8)}`,
},
];
return {
...prev,
[key]: {
...current,
status: "done",
images: newImages,
activeIndex: newImages.length - 1,
error: undefined,
draft: trimmedPrompt,
},
};
});
} catch (err) {
setPageStates((prev) => {
const current =
prev[key] ||
({
status: "idle",
images: [],
activeIndex: 0,
draft: trimmedPrompt,
isEditing: false,
} as PageImageState);
return {
...prev,
[key]: {
...current,
status: "error",
error: String(err),
},
};
});
}
},
[aspectRatio, imageSize]
);
useEffect(() => {
if (!parsedPages) return;
parsedPages.forEach((page) => {
const key = `${messageKey}-${page.id}`;
const state = pageStates[key];
if (!state || (state.status === "idle" && state.images.length === 0)) {
requestImage(key, state?.draft ?? page.detail);
}
});
}, [parsedPages, messageKey, pageStates, requestImage]);
const handlePromptChange = (key: string, value: string) => {
setPageStates((prev) => {
const current = prev[key];
if (!current) return prev;
return {
...prev,
[key]: {
...current,
draft: value,
},
};
});
};
const handleToggleEdit = (key: string) => {
setPageStates((prev) => {
const current = prev[key];
if (!current) return prev;
return {
...prev,
[key]: {
...current,
isEditing: !current.isEditing,
},
};
});
};
const handleSubmitPrompt = (key: string, fallbackPrompt: string) => {
const promptToUse =
(pageStates[key]?.draft || fallbackPrompt || "").trim();
if (!promptToUse) return;
requestImage(key, promptToUse);
setPageStates((prev) => {
const current = prev[key];
if (!current) return prev;
return {
...prev,
[key]: {
...current,
isEditing: false,
},
};
});
};
const handleSetActiveImage = (key: string, idx: number) => {
setPageStates((prev) => {
const current = prev[key];
if (!current) return prev;
return {
...prev,
[key]: {
...current,
activeIndex: idx,
},
};
});
};
return ( return (
<div className={`relative break-words flex flex-col`}> <div className={`relative break-words flex flex-col`}>
@@ -283,47 +432,129 @@ const AiMessageBubble: React.FC<AiMessageBubbleProps> = ({
)} )}
{parsedPages ? ( {parsedPages ? (
<div className="space-y-3"> <div className="space-y-3">
{parsedPages.map((page) => ( {parsedPages.map((page) => {
<div const key = `${messageKey}-${page.id}`;
key={page.id} const pageState = pageStates[key];
className="rounded-xl border border-neutral-700 bg-neutral-800/80 p-3 shadow-sm" const images = pageState?.images || [];
> const activeIndex =
<div className="text-xs uppercase tracking-wide text-neutral-400 mb-1"> images.length > 0
Page {page.id} ? Math.min(pageState?.activeIndex ?? 0, images.length - 1)
</div> : 0;
<ReactMarkdown components={mdComponents}> const activeImage =
{page.detail} images.length > 0 ? images[activeIndex] : undefined;
</ReactMarkdown> const isPending = pageState?.status === "pending";
<div className="mt-2"> return (
{(() => { <div
const key = `${message.id}-${page.id}`; key={page.id}
const img = pageImages[key]; className="rounded-xl border border-neutral-700 bg-neutral-800/80 p-3 shadow-sm space-y-2"
if (!img || img.status === "pending") { >
return ( <div className="flex items-start justify-between">
<div className="flex items-center text-xs text-neutral-400 gap-2"> <div className="text-xs uppercase tracking-wide text-neutral-400">
Page {page.id}
</div>
<div className="flex items-center gap-2">
<Button
variant="ghost"
size="icon"
className="h-8 w-8 text-neutral-300 hover:text-white hover:bg-neutral-700"
onClick={() => handleToggleEdit(key)}
aria-label="Edit prompt"
title="Edit prompt"
>
<Pencil className="h-4 w-4" />
</Button>
<Button
variant="ghost"
size="icon"
disabled={isPending}
className="h-8 w-8 text-blue-300 hover:text-blue-100 hover:bg-blue-500/10 disabled:opacity-60"
onClick={() => handleSubmitPrompt(key, page.detail)}
aria-label="Regenerate image"
title="Regenerate image"
>
{isPending ? (
<Loader2 className="h-4 w-4 animate-spin" /> <Loader2 className="h-4 w-4 animate-spin" />
<span>Generating image...</span> ) : (
</div> <ArrowUpCircle className="h-5 w-5" />
); )}
} </Button>
if (img.status === "error") { </div>
return ( </div>
<div className="text-xs text-red-400">
Image generation failed: {img.error} {pageState?.isEditing ? (
</div> <Textarea
); value={pageState.draft}
} onChange={(e) => handlePromptChange(key, e.target.value)}
return ( className="mt-1 bg-neutral-900 border-neutral-700 text-neutral-100"
<img rows={3}
src={img.url} autoFocus
alt={`Page ${page.id} illustration`} />
className="mt-1 rounded-lg border border-neutral-700" ) : (
/> <ReactMarkdown components={mdComponents}>
); {pageState?.draft ?? page.detail}
})()} </ReactMarkdown>
)}
<div className="mt-1 space-y-2">
{activeImage ? (
<div className="relative">
<img
src={activeImage.url}
alt={`Page ${page.id} illustration`}
className="w-full rounded-lg border border-neutral-700"
/>
{isPending && (
<div className="absolute inset-0 bg-black/40 rounded-lg flex items-center justify-center">
<Loader2 className="h-6 w-6 animate-spin text-white" />
</div>
)}
</div>
) : pageState?.status === "error" ? (
<div className="text-xs text-red-400">
Image generation failed: {pageState.error}
</div>
) : (
<div className="flex items-center text-xs text-neutral-400 gap-2">
<Loader2 className="h-4 w-4 animate-spin" />
<span>Generating image...</span>
</div>
)}
{pageState?.status === "error" && activeImage && (
<div className="text-xs text-red-400">
Image generation failed: {pageState.error}
</div>
)}
{images.length > 1 && (
<div className="flex gap-2 flex-wrap">
{images.map((img, idx) => (
<button
key={img.id}
className={`relative border ${
idx === activeIndex
? "border-blue-400"
: "border-neutral-700"
} rounded-lg overflow-hidden`}
onClick={() => handleSetActiveImage(key, idx)}
aria-label={`Show version ${idx + 1}`}
>
<img
src={img.url}
alt={`Version ${idx + 1}`}
className="h-16 w-24 object-cover"
/>
{idx === activeIndex && (
<div className="absolute inset-0 ring-2 ring-blue-400/60 rounded-lg pointer-events-none" />
)}
</button>
))}
</div>
)}
</div>
</div> </div>
</div> );
))} })}
</div> </div>
) : ( ) : (
<ReactMarkdown components={mdComponents}> <ReactMarkdown components={mdComponents}>
@@ -332,27 +563,6 @@ const AiMessageBubble: React.FC<AiMessageBubbleProps> = ({
: JSON.stringify(message.content)} : JSON.stringify(message.content)}
</ReactMarkdown> </ReactMarkdown>
)} )}
<Button
variant="default"
className={`cursor-pointer bg-neutral-700 border-neutral-600 text-neutral-300 self-end ${
(typeof message.content === "string"
? message.content.length
: JSON.stringify(message.content).length) > 0
? "visible"
: "hidden"
}`}
onClick={() =>
handleCopy(
typeof message.content === "string"
? message.content
: JSON.stringify(message.content),
message.id!
)
}
>
{copiedMessageId === message.id ? "Copied" : "Copy"}
{copiedMessageId === message.id ? <CopyCheck /> : <Copy />}
</Button>
</div> </div>
); );
}; };
@@ -387,17 +597,6 @@ export function ChatMessagesView({
aspectRatio, aspectRatio,
imageSize, imageSize,
}: ChatMessagesViewProps) { }: ChatMessagesViewProps) {
const [copiedMessageId, setCopiedMessageId] = useState<string | null>(null);
const handleCopy = async (text: string, messageId: string) => {
try {
await navigator.clipboard.writeText(text);
setCopiedMessageId(messageId);
setTimeout(() => setCopiedMessageId(null), 2000); // Reset after 2 seconds
} catch (err) {
console.error("Failed to copy text: ", err);
}
};
return ( return (
<div className="flex flex-col h-full"> <div className="flex flex-col h-full">
<ScrollArea className="flex-1 overflow-y-auto" ref={scrollAreaRef}> <ScrollArea className="flex-1 overflow-y-auto" ref={scrollAreaRef}>
@@ -424,8 +623,6 @@ export function ChatMessagesView({
isLastMessage={isLast} isLastMessage={isLast}
isOverallLoading={isLoading} // Pass global loading state isOverallLoading={isLoading} // Pass global loading state
mdComponents={mdComponents} mdComponents={mdComponents}
handleCopy={handleCopy}
copiedMessageId={copiedMessageId}
aspectRatio={aspectRatio} aspectRatio={aspectRatio}
imageSize={imageSize} imageSize={imageSize}
/> />