290 lines
11 KiB
Python
290 lines
11 KiB
Python
"""
|
|
Local RAG setup with LangChain, Ollama/OpenAI, and FAISS
|
|
Minimal dependencies, simple code
|
|
"""
|
|
import os
|
|
from pathlib import Path
|
|
|
|
from langchain_community.document_loaders import PyPDFLoader, TextLoader
|
|
from langchain_community.vectorstores import FAISS
|
|
from langchain_huggingface import HuggingFaceEmbeddings
|
|
from langchain_ollama import ChatOllama
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
|
|
|
|
class LocalRAG:
|
|
def __init__(
|
|
self,
|
|
vectorstore_path="./vectorstore",
|
|
llm_provider="ollama",
|
|
ollama_model="gpt-oss:20b",
|
|
openai_model="gpt-5.2",
|
|
ollama_base_url="http://localhost:11434",
|
|
):
|
|
"""Initialize local RAG system. llm_provider: 'ollama' or 'openai'."""
|
|
self.vectorstore_path = vectorstore_path
|
|
self.llm_provider = llm_provider
|
|
|
|
# Embeddings
|
|
print("Loading embeddings model...")
|
|
self.embeddings = HuggingFaceEmbeddings(
|
|
model_name="sentence-transformers/all-MiniLM-L6-v2"
|
|
)
|
|
|
|
# Text splitter
|
|
self.text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=2000,
|
|
chunk_overlap=400
|
|
)
|
|
|
|
# LLM (Ollama or OpenAI)
|
|
if llm_provider == "openai":
|
|
api_key = os.environ.get("OPENAI_API_KEY")
|
|
if not api_key:
|
|
raise ValueError(
|
|
"OPENAI_API_KEY environment variable is required when llm_provider='openai'"
|
|
)
|
|
print(f"Using OpenAI (model: {openai_model})...")
|
|
self.llm = ChatOpenAI(model=openai_model, api_key=api_key)
|
|
else:
|
|
print(f"Using Ollama (model: {ollama_model})...")
|
|
self.llm = ChatOllama(
|
|
model=ollama_model,
|
|
base_url=ollama_base_url
|
|
)
|
|
|
|
# Vector store (load if exists, otherwise None)
|
|
self.vectorstore = None
|
|
self._load_vectorstore()
|
|
|
|
def _load_vectorstore(self):
|
|
"""Load existing vector store if available"""
|
|
index_file = os.path.join(self.vectorstore_path, "index.faiss")
|
|
if os.path.exists(index_file):
|
|
try:
|
|
self.vectorstore = FAISS.load_local(
|
|
self.vectorstore_path,
|
|
self.embeddings,
|
|
allow_dangerous_deserialization=True
|
|
)
|
|
print(f"Loaded existing vector store from {self.vectorstore_path}")
|
|
except Exception as e:
|
|
print(f"Could not load vector store: {e}")
|
|
self.vectorstore = None
|
|
|
|
def add_documents(self, file_paths):
|
|
"""Add documents to the vector store"""
|
|
print(f"\nLoading {len(file_paths)} document(s)...")
|
|
all_docs = []
|
|
|
|
for file_path in file_paths:
|
|
path = Path(file_path)
|
|
if not path.exists():
|
|
print(f"Warning: {file_path} not found, skipping")
|
|
continue
|
|
|
|
# Load document
|
|
if path.suffix.lower() == '.pdf':
|
|
loader = PyPDFLoader(str(path))
|
|
elif path.suffix.lower() in ['.txt', '.md']:
|
|
loader = TextLoader(str(path))
|
|
else:
|
|
print(f"Warning: Unsupported file type {path.suffix}, skipping")
|
|
continue
|
|
|
|
docs = loader.load()
|
|
chunks = self.text_splitter.split_documents(docs)
|
|
all_docs.extend(chunks)
|
|
print(f" - {path.name}: {len(chunks)} chunks")
|
|
|
|
if not all_docs:
|
|
print("No documents loaded!")
|
|
return
|
|
|
|
# Create or update vector store
|
|
print(f"\nCreating embeddings for {len(all_docs)} chunks...")
|
|
if self.vectorstore is None:
|
|
self.vectorstore = FAISS.from_documents(all_docs, self.embeddings)
|
|
else:
|
|
new_store = FAISS.from_documents(all_docs, self.embeddings)
|
|
self.vectorstore.merge_from(new_store)
|
|
|
|
# Save
|
|
os.makedirs(self.vectorstore_path, exist_ok=True)
|
|
self.vectorstore.save_local(self.vectorstore_path)
|
|
print(f"Vector store saved to {self.vectorstore_path}")
|
|
|
|
def list_documents(self):
|
|
"""List all documents in the vector store"""
|
|
if self.vectorstore is None:
|
|
print("No documents in vector store.")
|
|
return []
|
|
|
|
# Get all documents from the vector store
|
|
# We'll retrieve a large number to get all documents
|
|
all_docs = self.vectorstore.similarity_search("", k=10000) # Large k to get all
|
|
|
|
# Extract unique document sources from metadata
|
|
documents = {}
|
|
for doc in all_docs:
|
|
source = doc.metadata.get('source', 'Unknown')
|
|
if source not in documents:
|
|
documents[source] = {
|
|
'source': source,
|
|
'chunks': 0,
|
|
'page': doc.metadata.get('page', None)
|
|
}
|
|
documents[source]['chunks'] += 1
|
|
|
|
# Convert to list and sort
|
|
doc_list = list(documents.values())
|
|
doc_list.sort(key=lambda x: x['source'])
|
|
|
|
print(f"\nDocuments in vector store ({len(doc_list)} unique documents):")
|
|
print("-" * 60)
|
|
for doc_info in doc_list:
|
|
print(f" - {doc_info['source']}")
|
|
print(f" Chunks: {doc_info['chunks']}")
|
|
if doc_info['page'] is not None:
|
|
print(f" Page: {doc_info['page']}")
|
|
|
|
return doc_list
|
|
|
|
def _format_history(self, chat_history):
|
|
"""Format chat history as a string for prompts."""
|
|
lines = []
|
|
for turn in chat_history or []:
|
|
role = (turn.get("role") or "").lower()
|
|
content = (turn.get("content") or "").strip()
|
|
if role == "user":
|
|
lines.append(f"User: {content}")
|
|
elif role == "assistant":
|
|
lines.append(f"Assistant: {content}")
|
|
return "\n".join(lines) if lines else ""
|
|
|
|
def _docs_to_retrieved(self, docs):
|
|
"""Convert document list to retrieved chunks format for API."""
|
|
return [
|
|
{
|
|
"content": doc.page_content,
|
|
"source": doc.metadata.get("source", ""),
|
|
"page": doc.metadata.get("page"),
|
|
}
|
|
for doc in docs
|
|
]
|
|
|
|
def _docs_scores_to_retrieved(self, docs_with_scores):
|
|
"""Convert (Document, score) list to retrieved chunks format with score. FAISS returns L2 distance (lower = more similar)."""
|
|
return [
|
|
{
|
|
"content": doc.page_content,
|
|
"source": doc.metadata.get("source", ""),
|
|
"page": doc.metadata.get("page"),
|
|
"score": float(score),
|
|
}
|
|
for doc, score in docs_with_scores
|
|
]
|
|
|
|
def query(self, question, k=8):
|
|
"""Query the RAG system (no conversation history). Returns dict with 'answer' and 'retrieved'."""
|
|
return self.query_with_history(question, chat_history=[], k=k)
|
|
|
|
def query_with_history(self, question, chat_history=None, k=8):
|
|
"""Query the RAG with conversation history: rephrase question using history for retrieval,
|
|
then answer with full conversation + retrieved context in the prompt.
|
|
Returns dict with 'answer' and 'retrieved' (list of chunks with content, source, page).
|
|
"""
|
|
if self.vectorstore is None:
|
|
return {
|
|
"answer": "Error: No documents loaded. Please add documents first.",
|
|
"retrieved": [],
|
|
}
|
|
|
|
history_str = self._format_history(chat_history)
|
|
search_query = question
|
|
rag_query_instruction = (
|
|
"Do not return a list of references but prioritize meaningful text from abstracts, results and discussion sections."
|
|
)
|
|
|
|
print(f"[RAG] User question: {question!r}")
|
|
|
|
# 1) If we have history, rephrase the question into a standalone query for better retrieval
|
|
if history_str.strip():
|
|
rephrase_prompt = f"""Given this chat history and the latest user question, write a single standalone question that captures what the user is asking.
|
|
Do not answer it; only output the standalone question. If the latest question is already clear on its own, output it unchanged.
|
|
|
|
Chat history:
|
|
{history_str}
|
|
|
|
Latest user question: {question}
|
|
|
|
Standalone question:"""
|
|
rephrase_response = self.llm.invoke(rephrase_prompt)
|
|
search_query = (rephrase_response.content if hasattr(rephrase_response, "content") else str(rephrase_response)).strip() or question
|
|
print(f"[RAG] Standalone search query (rephrased): {search_query!r}")
|
|
|
|
retrieval_query = f"{search_query}\n\n{rag_query_instruction}"
|
|
print(f"[RAG] Search query: {search_query!r}")
|
|
print(f"[RAG] Retrieval query sent to vector store: {retrieval_query!r}")
|
|
|
|
# 2) Retrieve documents with scores (FAISS: L2 distance, lower = more similar)
|
|
docs_with_scores = self.vectorstore.similarity_search_with_score(retrieval_query, k=k)
|
|
docs = [doc for doc, _ in docs_with_scores]
|
|
retrieved = self._docs_scores_to_retrieved(docs_with_scores)
|
|
if docs_with_scores:
|
|
scores = [f"{s:.3f}" for _, s in docs_with_scores]
|
|
print(f"[RAG] Retrieved {len(docs)} chunk(s), scores (L2 dist): [{', '.join(scores)}]")
|
|
context = "\n\n".join([doc.page_content for doc in docs])
|
|
|
|
# 3) Answer using conversation history + retrieved context
|
|
history_block = f"Chat history:\n{history_str}\n\n" if history_str else ""
|
|
answer_prompt = f"""You are an assistant for question-answering. Use the chat history (if any) and the retrieved context below to answer the current question.
|
|
If you don't know the answer, say so. Keep the conversation coherent.
|
|
|
|
{history_block}Relevant context from documents:
|
|
|
|
{context}
|
|
|
|
Current question: {question}
|
|
|
|
Answer:"""
|
|
response = self.llm.invoke(answer_prompt)
|
|
answer = response.content if hasattr(response, "content") else str(response)
|
|
|
|
return {"answer": answer, "retrieved": retrieved}
|
|
|
|
|
|
def main():
|
|
"""Example usage"""
|
|
print("=" * 60)
|
|
print("Local RAG with LangChain, Ollama/OpenAI, and FAISS")
|
|
print("=" * 60)
|
|
|
|
# Initialize
|
|
rag = LocalRAG(ollama_model="gpt-oss:20b")
|
|
|
|
# Add documents (uncomment and add your file paths)
|
|
# rag.add_documents([
|
|
# "data/dok1.pdf",
|
|
# "data/dok2.pdf",
|
|
# "data/dok3.pdf"
|
|
# ])
|
|
|
|
# List documents
|
|
rag.list_documents()
|
|
|
|
# Query
|
|
question = "What do you knowabout modality for perceived message perception?"
|
|
result = rag.query(question)
|
|
print(f"\nQuestion: {question}")
|
|
print(f"Answer: {result['answer']}")
|
|
if result.get("retrieved"):
|
|
print(f"Retrieved {len(result['retrieved'])} chunks")
|
|
|
|
# print("\nSetup complete! Uncomment the code above to add documents and query.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|