feat: add image generation support and Chinese language default

This commit is contained in:
zihanjian
2025-12-01 21:12:35 +08:00
parent 69327f63c5
commit 13815cf3ac
5 changed files with 192 additions and 33 deletions

View File

@@ -18,6 +18,7 @@ dependencies = [
"langgraph-api",
"fastapi",
"google-genai",
"pillow",
]

View File

@@ -1,11 +1,62 @@
# mypy: disable - error - code = "no-untyped-def,misc"
import base64
import io
import os
import pathlib
from fastapi import FastAPI, Response
from fastapi import FastAPI, Response, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from google import genai
from google.genai import types
# Define the FastAPI app
app = FastAPI()
# Allow local dev origins (Vite + LangGraph dev)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if GEMINI_API_KEY is None:
raise ValueError("GEMINI_API_KEY is not set")
image_client = genai.Client(api_key=GEMINI_API_KEY)
# Per request: use Gemini 3 image preview model
IMAGE_MODEL = "models/gemini-3-pro-image-preview"
class ImageRequest(BaseModel):
prompt: str
number_of_images: int = 1
@app.post("/generate_image")
def generate_image(req: ImageRequest):
"""Generate an image for a given prompt and return base64 data URLs."""
try:
response = image_client.models.generate_images(
model=IMAGE_MODEL,
prompt=req.prompt,
config=types.GenerateImagesConfig(number_of_images=req.number_of_images),
)
images = []
for generated_image in response.generated_images:
buffer = io.BytesIO()
generated_image.image.save(buffer, format="PNG")
b64 = base64.b64encode(buffer.getvalue()).decode("ascii")
images.append(f"data:image/png;base64,{b64}")
if not images:
raise RuntimeError("No image generated")
return {"images": images}
except Exception as exc: # pragma: no cover
raise HTTPException(status_code=500, detail=str(exc)) from exc
def create_frontend_router(build_dir="../frontend/dist"):
"""Creates a router to serve the React frontend.

View File

@@ -1,4 +1,6 @@
import os
import json
import re
from agent.tools_and_schemas import SearchQueryList, Reflection
from dotenv import load_dotenv
@@ -267,17 +269,38 @@ def finalize_answer(state: OverallState, config: RunnableConfig):
)
result = llm.invoke(formatted_prompt)
# Clean potential markdown fences and parse JSON so we return structured content
content = result.content
if isinstance(content, str):
# Strip markdown fences ```json ... ```
cleaned = re.sub(r"^```[a-zA-Z]*\s*|\s*```$", "", content.strip())
try:
parsed = json.loads(cleaned)
except Exception:
parsed = cleaned
content_payload = parsed
else:
content_payload = content
# Replace the short urls with the original urls and add all used urls to the sources_gathered
unique_sources = []
for source in state["sources_gathered"]:
if source["short_url"] in result.content:
result.content = result.content.replace(
source["short_url"], source["value"]
)
if isinstance(content_payload, str) and source["short_url"] in content_payload:
content_payload = content_payload.replace(source["short_url"], source["value"])
unique_sources.append(source)
elif isinstance(content_payload, list):
# if list of page dicts, replace inside detail strings
updated_pages = []
for page in content_payload:
if isinstance(page, dict) and isinstance(page.get("detail"), str):
page_detail = page["detail"].replace(source["short_url"], source["value"])
page = {**page, "detail": page_detail}
updated_pages.append(page)
content_payload = updated_pages
unique_sources.append(source)
return {
"messages": [AIMessage(content=result.content)],
"messages": [AIMessage(content=content_payload)],
"sources_gathered": unique_sources,
}

View File

@@ -4,7 +4,7 @@ import { ScrollArea } from "@/components/ui/scroll-area";
import { Loader2, Copy, CopyCheck } from "lucide-react";
import { InputForm } from "@/components/InputForm";
import { Button } from "@/components/ui/button";
import { useState, ReactNode } from "react";
import { useState, ReactNode, useEffect } from "react";
import ReactMarkdown from "react-markdown";
import { cn } from "@/lib/utils";
import { Badge } from "@/components/ui/badge";
@@ -181,26 +181,36 @@ const AiMessageBubble: React.FC<AiMessageBubbleProps> = ({
handleCopy,
copiedMessageId,
}) => {
const [pageImages, setPageImages] = useState<
Record<
string,
{ status: "pending" | "done" | "error"; url?: string; error?: string }
>
>({});
const parsedPages = (() => {
if (typeof message.content !== "string") return null;
try {
const data = JSON.parse(message.content);
if (
Array.isArray(data) &&
data.every(
(p) =>
p &&
typeof p === "object" &&
"id" in p &&
"detail" in p &&
typeof p.id === "number" &&
typeof p.detail === "string"
)
) {
return data as { id: number; detail: string }[];
const raw = message.content;
let data: any = raw;
if (typeof raw === "string") {
try {
data = JSON.parse(raw);
} catch (_e) {
return null;
}
} catch (_e) {
return null;
}
if (
Array.isArray(data) &&
data.every(
(p) =>
p &&
typeof p === "object" &&
"id" in p &&
"detail" in p &&
typeof p.id === "number" &&
typeof p.detail === "string"
)
) {
return data as { id: number; detail: string }[];
}
return null;
})();
@@ -210,6 +220,48 @@ const AiMessageBubble: React.FC<AiMessageBubbleProps> = ({
isLastMessage && isOverallLoading ? liveActivity : historicalActivity;
const isLiveActivityForThisBubble = isLastMessage && isOverallLoading;
useEffect(() => {
if (!parsedPages || !message.id) return;
const backendBase = import.meta.env.DEV
? "http://localhost:2024"
: "http://localhost:8123";
parsedPages.forEach((page) => {
const key = `${message.id}-${page.id}`;
if (pageImages[key]) return; // already requested
setPageImages((prev) => ({
...prev,
[key]: { status: "pending" },
}));
fetch(`${backendBase}/generate_image`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ prompt: page.detail, number_of_images: 1 }),
})
.then(async (res) => {
if (!res.ok) throw new Error(await res.text());
return res.json();
})
.then((data) => {
const url =
data?.images && Array.isArray(data.images) ? data.images[0] : null;
if (!url) throw new Error("No image returned");
setPageImages((prev) => ({
...prev,
[key]: { status: "done", url },
}));
})
.catch((err) => {
setPageImages((prev) => ({
...prev,
[key]: { status: "error", error: String(err) },
}));
});
});
}, [parsedPages, message.id, pageImages]);
return (
<div className={`relative break-words flex flex-col`}>
{activityForThisBubble && activityForThisBubble.length > 0 && (
@@ -233,6 +285,34 @@ const AiMessageBubble: React.FC<AiMessageBubbleProps> = ({
<ReactMarkdown components={mdComponents}>
{page.detail}
</ReactMarkdown>
<div className="mt-2">
{(() => {
const key = `${message.id}-${page.id}`;
const img = pageImages[key];
if (!img || img.status === "pending") {
return (
<div className="flex items-center text-xs text-neutral-400 gap-2">
<Loader2 className="h-4 w-4 animate-spin" />
<span>Generating image...</span>
</div>
);
}
if (img.status === "error") {
return (
<div className="text-xs text-red-400">
Image generation failed: {img.error}
</div>
);
}
return (
<img
src={img.url}
alt={`Page ${page.id} illustration`}
className="mt-1 rounded-lg border border-neutral-700"
/>
);
})()}
</div>
</div>
))}
</div>
@@ -246,7 +326,11 @@ const AiMessageBubble: React.FC<AiMessageBubbleProps> = ({
<Button
variant="default"
className={`cursor-pointer bg-neutral-700 border-neutral-600 text-neutral-300 self-end ${
message.content.length > 0 ? "visible" : "hidden"
(typeof message.content === "string"
? message.content.length
: JSON.stringify(message.content).length) > 0
? "visible"
: "hidden"
}`}
onClick={() =>
handleCopy(

View File

@@ -41,7 +41,7 @@ export const InputForm: React.FC<InputFormProps> = ({
const [effort, setEffort] = useState("medium");
// Default to a current, broadly capable model
const [model, setModel] = useState("gemini-2.5-flash");
const [language, setLanguage] = useState("English");
const [language, setLanguage] = useState("简体中文");
const handleInternalSubmit = (e?: React.FormEvent) => {
if (e) e.preventDefault();
@@ -187,18 +187,18 @@ export const InputForm: React.FC<InputFormProps> = ({
<SelectValue placeholder="Language" />
</SelectTrigger>
<SelectContent className="bg-neutral-700 border-neutral-600 text-neutral-300 cursor-pointer">
<SelectItem
value="English"
className="hover:bg-neutral-600 focus:bg-neutral-600 cursor-pointer"
>
English
</SelectItem>
<SelectItem
value="简体中文"
className="hover:bg-neutral-600 focus:bg-neutral-600 cursor-pointer"
>
</SelectItem>
<SelectItem
value="English"
className="hover:bg-neutral-600 focus:bg-neutral-600 cursor-pointer"
>
English
</SelectItem>
<SelectItem
value="日本語"
className="hover:bg-neutral-600 focus:bg-neutral-600 cursor-pointer"