mirror of
https://github.com/Zippland/NanoComic.git
synced 2026-01-19 01:21:08 +08:00
feat: add image generation support and Chinese language default
This commit is contained in:
@@ -18,6 +18,7 @@ dependencies = [
|
||||
"langgraph-api",
|
||||
"fastapi",
|
||||
"google-genai",
|
||||
"pillow",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,62 @@
|
||||
# mypy: disable - error - code = "no-untyped-def,misc"
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import pathlib
|
||||
from fastapi import FastAPI, Response
|
||||
from fastapi import FastAPI, Response, HTTPException
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
# Define the FastAPI app
|
||||
app = FastAPI()
|
||||
|
||||
# Allow local dev origins (Vite + LangGraph dev)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
||||
if GEMINI_API_KEY is None:
|
||||
raise ValueError("GEMINI_API_KEY is not set")
|
||||
|
||||
image_client = genai.Client(api_key=GEMINI_API_KEY)
|
||||
# Per request: use Gemini 3 image preview model
|
||||
IMAGE_MODEL = "models/gemini-3-pro-image-preview"
|
||||
|
||||
|
||||
class ImageRequest(BaseModel):
|
||||
prompt: str
|
||||
number_of_images: int = 1
|
||||
|
||||
|
||||
@app.post("/generate_image")
|
||||
def generate_image(req: ImageRequest):
|
||||
"""Generate an image for a given prompt and return base64 data URLs."""
|
||||
try:
|
||||
response = image_client.models.generate_images(
|
||||
model=IMAGE_MODEL,
|
||||
prompt=req.prompt,
|
||||
config=types.GenerateImagesConfig(number_of_images=req.number_of_images),
|
||||
)
|
||||
images = []
|
||||
for generated_image in response.generated_images:
|
||||
buffer = io.BytesIO()
|
||||
generated_image.image.save(buffer, format="PNG")
|
||||
b64 = base64.b64encode(buffer.getvalue()).decode("ascii")
|
||||
images.append(f"data:image/png;base64,{b64}")
|
||||
if not images:
|
||||
raise RuntimeError("No image generated")
|
||||
return {"images": images}
|
||||
except Exception as exc: # pragma: no cover
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
def create_frontend_router(build_dir="../frontend/dist"):
|
||||
"""Creates a router to serve the React frontend.
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
|
||||
from agent.tools_and_schemas import SearchQueryList, Reflection
|
||||
from dotenv import load_dotenv
|
||||
@@ -267,17 +269,38 @@ def finalize_answer(state: OverallState, config: RunnableConfig):
|
||||
)
|
||||
result = llm.invoke(formatted_prompt)
|
||||
|
||||
# Clean potential markdown fences and parse JSON so we return structured content
|
||||
content = result.content
|
||||
if isinstance(content, str):
|
||||
# Strip markdown fences ```json ... ```
|
||||
cleaned = re.sub(r"^```[a-zA-Z]*\s*|\s*```$", "", content.strip())
|
||||
try:
|
||||
parsed = json.loads(cleaned)
|
||||
except Exception:
|
||||
parsed = cleaned
|
||||
content_payload = parsed
|
||||
else:
|
||||
content_payload = content
|
||||
|
||||
# Replace the short urls with the original urls and add all used urls to the sources_gathered
|
||||
unique_sources = []
|
||||
for source in state["sources_gathered"]:
|
||||
if source["short_url"] in result.content:
|
||||
result.content = result.content.replace(
|
||||
source["short_url"], source["value"]
|
||||
)
|
||||
if isinstance(content_payload, str) and source["short_url"] in content_payload:
|
||||
content_payload = content_payload.replace(source["short_url"], source["value"])
|
||||
unique_sources.append(source)
|
||||
elif isinstance(content_payload, list):
|
||||
# if list of page dicts, replace inside detail strings
|
||||
updated_pages = []
|
||||
for page in content_payload:
|
||||
if isinstance(page, dict) and isinstance(page.get("detail"), str):
|
||||
page_detail = page["detail"].replace(source["short_url"], source["value"])
|
||||
page = {**page, "detail": page_detail}
|
||||
updated_pages.append(page)
|
||||
content_payload = updated_pages
|
||||
unique_sources.append(source)
|
||||
|
||||
return {
|
||||
"messages": [AIMessage(content=result.content)],
|
||||
"messages": [AIMessage(content=content_payload)],
|
||||
"sources_gathered": unique_sources,
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { Loader2, Copy, CopyCheck } from "lucide-react";
|
||||
import { InputForm } from "@/components/InputForm";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useState, ReactNode } from "react";
|
||||
import { useState, ReactNode, useEffect } from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
@@ -181,26 +181,36 @@ const AiMessageBubble: React.FC<AiMessageBubbleProps> = ({
|
||||
handleCopy,
|
||||
copiedMessageId,
|
||||
}) => {
|
||||
const [pageImages, setPageImages] = useState<
|
||||
Record<
|
||||
string,
|
||||
{ status: "pending" | "done" | "error"; url?: string; error?: string }
|
||||
>
|
||||
>({});
|
||||
|
||||
const parsedPages = (() => {
|
||||
if (typeof message.content !== "string") return null;
|
||||
try {
|
||||
const data = JSON.parse(message.content);
|
||||
if (
|
||||
Array.isArray(data) &&
|
||||
data.every(
|
||||
(p) =>
|
||||
p &&
|
||||
typeof p === "object" &&
|
||||
"id" in p &&
|
||||
"detail" in p &&
|
||||
typeof p.id === "number" &&
|
||||
typeof p.detail === "string"
|
||||
)
|
||||
) {
|
||||
return data as { id: number; detail: string }[];
|
||||
const raw = message.content;
|
||||
let data: any = raw;
|
||||
if (typeof raw === "string") {
|
||||
try {
|
||||
data = JSON.parse(raw);
|
||||
} catch (_e) {
|
||||
return null;
|
||||
}
|
||||
} catch (_e) {
|
||||
return null;
|
||||
}
|
||||
if (
|
||||
Array.isArray(data) &&
|
||||
data.every(
|
||||
(p) =>
|
||||
p &&
|
||||
typeof p === "object" &&
|
||||
"id" in p &&
|
||||
"detail" in p &&
|
||||
typeof p.id === "number" &&
|
||||
typeof p.detail === "string"
|
||||
)
|
||||
) {
|
||||
return data as { id: number; detail: string }[];
|
||||
}
|
||||
return null;
|
||||
})();
|
||||
@@ -210,6 +220,48 @@ const AiMessageBubble: React.FC<AiMessageBubbleProps> = ({
|
||||
isLastMessage && isOverallLoading ? liveActivity : historicalActivity;
|
||||
const isLiveActivityForThisBubble = isLastMessage && isOverallLoading;
|
||||
|
||||
useEffect(() => {
|
||||
if (!parsedPages || !message.id) return;
|
||||
const backendBase = import.meta.env.DEV
|
||||
? "http://localhost:2024"
|
||||
: "http://localhost:8123";
|
||||
|
||||
parsedPages.forEach((page) => {
|
||||
const key = `${message.id}-${page.id}`;
|
||||
if (pageImages[key]) return; // already requested
|
||||
|
||||
setPageImages((prev) => ({
|
||||
...prev,
|
||||
[key]: { status: "pending" },
|
||||
}));
|
||||
|
||||
fetch(`${backendBase}/generate_image`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ prompt: page.detail, number_of_images: 1 }),
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw new Error(await res.text());
|
||||
return res.json();
|
||||
})
|
||||
.then((data) => {
|
||||
const url =
|
||||
data?.images && Array.isArray(data.images) ? data.images[0] : null;
|
||||
if (!url) throw new Error("No image returned");
|
||||
setPageImages((prev) => ({
|
||||
...prev,
|
||||
[key]: { status: "done", url },
|
||||
}));
|
||||
})
|
||||
.catch((err) => {
|
||||
setPageImages((prev) => ({
|
||||
...prev,
|
||||
[key]: { status: "error", error: String(err) },
|
||||
}));
|
||||
});
|
||||
});
|
||||
}, [parsedPages, message.id, pageImages]);
|
||||
|
||||
return (
|
||||
<div className={`relative break-words flex flex-col`}>
|
||||
{activityForThisBubble && activityForThisBubble.length > 0 && (
|
||||
@@ -233,6 +285,34 @@ const AiMessageBubble: React.FC<AiMessageBubbleProps> = ({
|
||||
<ReactMarkdown components={mdComponents}>
|
||||
{page.detail}
|
||||
</ReactMarkdown>
|
||||
<div className="mt-2">
|
||||
{(() => {
|
||||
const key = `${message.id}-${page.id}`;
|
||||
const img = pageImages[key];
|
||||
if (!img || img.status === "pending") {
|
||||
return (
|
||||
<div className="flex items-center text-xs text-neutral-400 gap-2">
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
<span>Generating image...</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
if (img.status === "error") {
|
||||
return (
|
||||
<div className="text-xs text-red-400">
|
||||
Image generation failed: {img.error}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<img
|
||||
src={img.url}
|
||||
alt={`Page ${page.id} illustration`}
|
||||
className="mt-1 rounded-lg border border-neutral-700"
|
||||
/>
|
||||
);
|
||||
})()}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
@@ -246,7 +326,11 @@ const AiMessageBubble: React.FC<AiMessageBubbleProps> = ({
|
||||
<Button
|
||||
variant="default"
|
||||
className={`cursor-pointer bg-neutral-700 border-neutral-600 text-neutral-300 self-end ${
|
||||
message.content.length > 0 ? "visible" : "hidden"
|
||||
(typeof message.content === "string"
|
||||
? message.content.length
|
||||
: JSON.stringify(message.content).length) > 0
|
||||
? "visible"
|
||||
: "hidden"
|
||||
}`}
|
||||
onClick={() =>
|
||||
handleCopy(
|
||||
|
||||
@@ -41,7 +41,7 @@ export const InputForm: React.FC<InputFormProps> = ({
|
||||
const [effort, setEffort] = useState("medium");
|
||||
// Default to a current, broadly capable model
|
||||
const [model, setModel] = useState("gemini-2.5-flash");
|
||||
const [language, setLanguage] = useState("English");
|
||||
const [language, setLanguage] = useState("简体中文");
|
||||
|
||||
const handleInternalSubmit = (e?: React.FormEvent) => {
|
||||
if (e) e.preventDefault();
|
||||
@@ -187,18 +187,18 @@ export const InputForm: React.FC<InputFormProps> = ({
|
||||
<SelectValue placeholder="Language" />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="bg-neutral-700 border-neutral-600 text-neutral-300 cursor-pointer">
|
||||
<SelectItem
|
||||
value="English"
|
||||
className="hover:bg-neutral-600 focus:bg-neutral-600 cursor-pointer"
|
||||
>
|
||||
English
|
||||
</SelectItem>
|
||||
<SelectItem
|
||||
value="简体中文"
|
||||
className="hover:bg-neutral-600 focus:bg-neutral-600 cursor-pointer"
|
||||
>
|
||||
简体中文
|
||||
</SelectItem>
|
||||
<SelectItem
|
||||
value="English"
|
||||
className="hover:bg-neutral-600 focus:bg-neutral-600 cursor-pointer"
|
||||
>
|
||||
English
|
||||
</SelectItem>
|
||||
<SelectItem
|
||||
value="日本語"
|
||||
className="hover:bg-neutral-600 focus:bg-neutral-600 cursor-pointer"
|
||||
|
||||
Reference in New Issue
Block a user