diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 09eb598..660d1fc 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "langgraph-api", "fastapi", "google-genai", + "pillow", ] diff --git a/backend/src/agent/app.py b/backend/src/agent/app.py index f20f6ed..96778d4 100644 --- a/backend/src/agent/app.py +++ b/backend/src/agent/app.py @@ -1,11 +1,62 @@ # mypy: disable - error - code = "no-untyped-def,misc" +import base64 +import io +import os import pathlib -from fastapi import FastAPI, Response +from fastapi import FastAPI, Response, HTTPException from fastapi.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from google import genai +from google.genai import types # Define the FastAPI app app = FastAPI() +# Allow local dev origins (Vite + LangGraph dev) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") +if GEMINI_API_KEY is None: + raise ValueError("GEMINI_API_KEY is not set") + +image_client = genai.Client(api_key=GEMINI_API_KEY) +# Per request: use Gemini 3 image preview model +IMAGE_MODEL = "models/gemini-3-pro-image-preview" + + +class ImageRequest(BaseModel): + prompt: str + number_of_images: int = 1 + + +@app.post("/generate_image") +def generate_image(req: ImageRequest): + """Generate an image for a given prompt and return base64 data URLs.""" + try: + response = image_client.models.generate_images( + model=IMAGE_MODEL, + prompt=req.prompt, + config=types.GenerateImagesConfig(number_of_images=req.number_of_images), + ) + images = [] + for generated_image in response.generated_images: + buffer = io.BytesIO() + generated_image.image.save(buffer, format="PNG") + b64 = base64.b64encode(buffer.getvalue()).decode("ascii") + images.append(f"data:image/png;base64,{b64}") + if not images: + raise RuntimeError("No image generated") + return {"images": images} + except Exception as exc: # pragma: no cover + raise HTTPException(status_code=500, detail=str(exc)) from exc + def create_frontend_router(build_dir="../frontend/dist"): """Creates a router to serve the React frontend. diff --git a/backend/src/agent/graph.py b/backend/src/agent/graph.py index 5e2e10a..5eb39d3 100644 --- a/backend/src/agent/graph.py +++ b/backend/src/agent/graph.py @@ -1,4 +1,6 @@ import os +import json +import re from agent.tools_and_schemas import SearchQueryList, Reflection from dotenv import load_dotenv @@ -267,17 +269,38 @@ def finalize_answer(state: OverallState, config: RunnableConfig): ) result = llm.invoke(formatted_prompt) + # Clean potential markdown fences and parse JSON so we return structured content + content = result.content + if isinstance(content, str): + # Strip markdown fences ```json ... ``` + cleaned = re.sub(r"^```[a-zA-Z]*\s*|\s*```$", "", content.strip()) + try: + parsed = json.loads(cleaned) + except Exception: + parsed = cleaned + content_payload = parsed + else: + content_payload = content + # Replace the short urls with the original urls and add all used urls to the sources_gathered unique_sources = [] for source in state["sources_gathered"]: - if source["short_url"] in result.content: - result.content = result.content.replace( - source["short_url"], source["value"] - ) + if isinstance(content_payload, str) and source["short_url"] in content_payload: + content_payload = content_payload.replace(source["short_url"], source["value"]) + unique_sources.append(source) + elif isinstance(content_payload, list): + # if list of page dicts, replace inside detail strings + updated_pages = [] + for page in content_payload: + if isinstance(page, dict) and isinstance(page.get("detail"), str): + page_detail = page["detail"].replace(source["short_url"], source["value"]) + page = {**page, "detail": page_detail} + updated_pages.append(page) + content_payload = updated_pages unique_sources.append(source) return { - "messages": [AIMessage(content=result.content)], + "messages": [AIMessage(content=content_payload)], "sources_gathered": unique_sources, } diff --git a/frontend/src/components/ChatMessagesView.tsx b/frontend/src/components/ChatMessagesView.tsx index 77c21d5..3edce91 100644 --- a/frontend/src/components/ChatMessagesView.tsx +++ b/frontend/src/components/ChatMessagesView.tsx @@ -4,7 +4,7 @@ import { ScrollArea } from "@/components/ui/scroll-area"; import { Loader2, Copy, CopyCheck } from "lucide-react"; import { InputForm } from "@/components/InputForm"; import { Button } from "@/components/ui/button"; -import { useState, ReactNode } from "react"; +import { useState, ReactNode, useEffect } from "react"; import ReactMarkdown from "react-markdown"; import { cn } from "@/lib/utils"; import { Badge } from "@/components/ui/badge"; @@ -181,26 +181,36 @@ const AiMessageBubble: React.FC = ({ handleCopy, copiedMessageId, }) => { + const [pageImages, setPageImages] = useState< + Record< + string, + { status: "pending" | "done" | "error"; url?: string; error?: string } + > + >({}); + const parsedPages = (() => { - if (typeof message.content !== "string") return null; - try { - const data = JSON.parse(message.content); - if ( - Array.isArray(data) && - data.every( - (p) => - p && - typeof p === "object" && - "id" in p && - "detail" in p && - typeof p.id === "number" && - typeof p.detail === "string" - ) - ) { - return data as { id: number; detail: string }[]; + const raw = message.content; + let data: any = raw; + if (typeof raw === "string") { + try { + data = JSON.parse(raw); + } catch (_e) { + return null; } - } catch (_e) { - return null; + } + if ( + Array.isArray(data) && + data.every( + (p) => + p && + typeof p === "object" && + "id" in p && + "detail" in p && + typeof p.id === "number" && + typeof p.detail === "string" + ) + ) { + return data as { id: number; detail: string }[]; } return null; })(); @@ -210,6 +220,48 @@ const AiMessageBubble: React.FC = ({ isLastMessage && isOverallLoading ? liveActivity : historicalActivity; const isLiveActivityForThisBubble = isLastMessage && isOverallLoading; + useEffect(() => { + if (!parsedPages || !message.id) return; + const backendBase = import.meta.env.DEV + ? "http://localhost:2024" + : "http://localhost:8123"; + + parsedPages.forEach((page) => { + const key = `${message.id}-${page.id}`; + if (pageImages[key]) return; // already requested + + setPageImages((prev) => ({ + ...prev, + [key]: { status: "pending" }, + })); + + fetch(`${backendBase}/generate_image`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ prompt: page.detail, number_of_images: 1 }), + }) + .then(async (res) => { + if (!res.ok) throw new Error(await res.text()); + return res.json(); + }) + .then((data) => { + const url = + data?.images && Array.isArray(data.images) ? data.images[0] : null; + if (!url) throw new Error("No image returned"); + setPageImages((prev) => ({ + ...prev, + [key]: { status: "done", url }, + })); + }) + .catch((err) => { + setPageImages((prev) => ({ + ...prev, + [key]: { status: "error", error: String(err) }, + })); + }); + }); + }, [parsedPages, message.id, pageImages]); + return (
{activityForThisBubble && activityForThisBubble.length > 0 && ( @@ -233,6 +285,34 @@ const AiMessageBubble: React.FC = ({ {page.detail} +
+ {(() => { + const key = `${message.id}-${page.id}`; + const img = pageImages[key]; + if (!img || img.status === "pending") { + return ( +
+ + Generating image... +
+ ); + } + if (img.status === "error") { + return ( +
+ Image generation failed: {img.error} +
+ ); + } + return ( + {`Page + ); + })()} +
))} @@ -246,7 +326,11 @@ const AiMessageBubble: React.FC = ({