131 lines
4.2 KiB
Python
131 lines
4.2 KiB
Python
"""
|
|
FastAPI server for Local RAG with chat GUI.
|
|
Run with: uvicorn server:app --reload
|
|
"""
|
|
import os
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import HTMLResponse
|
|
from pydantic import BaseModel
|
|
|
|
from local_rag import LocalRAG
|
|
|
|
OLLAMA_MODEL = "gpt-oss:20b"
|
|
OPENAI_MODEL = "gpt-5-mini"
|
|
VECTORSTORE_PATH = "./vectorstore"
|
|
|
|
# Dual RAG instances for on-the-fly provider switching
|
|
rag_ollama = LocalRAG(
|
|
vectorstore_path=VECTORSTORE_PATH,
|
|
llm_provider="ollama",
|
|
ollama_model=OLLAMA_MODEL,
|
|
openai_model=OPENAI_MODEL,
|
|
)
|
|
rag_openai = None
|
|
if os.environ.get("OPENAI_API_KEY"):
|
|
try:
|
|
rag_openai = LocalRAG(
|
|
vectorstore_path=VECTORSTORE_PATH,
|
|
llm_provider="openai",
|
|
ollama_model=OLLAMA_MODEL,
|
|
openai_model=OPENAI_MODEL,
|
|
)
|
|
except Exception as e:
|
|
print(f"OpenAI RAG not available: {e}")
|
|
|
|
app = FastAPI(title="Local RAG Chat", version="1.0.0")
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
role: str # "user" | "assistant"
|
|
content: str
|
|
|
|
|
|
class ChatRequest(BaseModel):
|
|
message: str
|
|
history: list[ChatMessage] = [] # previous turns for conversation context
|
|
llm_provider: str = "ollama" # "ollama" | "openai"
|
|
|
|
|
|
class RetrievedChunk(BaseModel):
|
|
content: str
|
|
source: str
|
|
page: int | None
|
|
score: float | None = None # L2 distance from FAISS (lower = more similar)
|
|
|
|
|
|
class ChatResponse(BaseModel):
|
|
answer: str
|
|
error: str | None = None
|
|
retrieved: list[RetrievedChunk] | None = None
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
def chat_view():
|
|
"""Serve the chat GUI."""
|
|
html_path = Path(__file__).parent / "templates" / "chat.html"
|
|
if not html_path.exists():
|
|
raise HTTPException(status_code=500, detail="Chat template not found")
|
|
return HTMLResponse(content=html_path.read_text(encoding="utf-8"))
|
|
|
|
|
|
def _get_rag(provider: str):
|
|
"""Return the RAG instance for the given provider. Fall back to Ollama if OpenAI unavailable."""
|
|
if provider == "openai" and rag_openai is not None:
|
|
return rag_openai
|
|
return rag_ollama
|
|
|
|
|
|
@app.post("/api/chat", response_model=ChatResponse)
|
|
def chat(request: ChatRequest):
|
|
"""Handle a chat message and return the RAG answer."""
|
|
if not request.message or not request.message.strip():
|
|
return ChatResponse(answer="", error="Message cannot be empty")
|
|
if request.llm_provider == "openai" and rag_openai is None:
|
|
return ChatResponse(answer="", error="OpenAI not configured. Set OPENAI_API_KEY.")
|
|
rag = _get_rag(request.llm_provider)
|
|
try:
|
|
chat_history = [{"role": m.role, "content": m.content} for m in request.history]
|
|
result = rag.query_with_history(
|
|
request.message.strip(),
|
|
chat_history=chat_history,
|
|
)
|
|
answer = result["answer"]
|
|
retrieved = result.get("retrieved", [])
|
|
|
|
# Server-side console trace: shorter chunk logs + raw LLM response
|
|
if retrieved:
|
|
print(f"\n[RAG] Retrieved {len(retrieved)} chunk(s)")
|
|
for i, chunk in enumerate(retrieved):
|
|
content = chunk.get("content", "")
|
|
preview = (content[:150] + "...") if len(content) > 150 else content
|
|
print(f" [{i + 1}] {chunk.get('source', '')} p.{chunk.get('page', '?')} s={chunk.get('score')} | {preview!r}")
|
|
else:
|
|
print(f"\n[RAG] Retrieved 0 chunks")
|
|
provider_label = "OpenAI" if request.llm_provider == "openai" else "Ollama"
|
|
model_name = OPENAI_MODEL if request.llm_provider == "openai" else OLLAMA_MODEL
|
|
print(f"[RAG] LLM response ({provider_label} / {model_name}):\n{answer}")
|
|
|
|
return ChatResponse(answer=answer, retrieved=retrieved)
|
|
except Exception as e:
|
|
return ChatResponse(answer="", error=str(e))
|
|
|
|
|
|
@app.get("/api/health")
|
|
def health():
|
|
"""Health check and vector store status."""
|
|
has_docs = rag_ollama.vectorstore is not None
|
|
return {"status": "ok", "vectorstore_loaded": has_docs}
|
|
|
|
|
|
@app.get("/api/providers")
|
|
def providers():
|
|
"""Return which LLM providers are available."""
|
|
return {"ollama": True, "openai": rag_openai is not None}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|