import os
import re
import threading
import ollama
from typing import List, Dict, Callable
from langchain_ollama import OllamaEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
# <-- NEW: Import BaseLLMProvider for type hinting
from .llm import BaseLLMProvider
class MemoryManager:
# <-- UPDATED: Added cloud_provider argument
def __init__(self, memory_file: str, index_path: str, local_model: str, embed_model_name: str, log: Callable[[str], None] = print, cloud_provider: BaseLLMProvider = None):
self.memory_file = memory_file
self.index_path = index_path
self.local_model = local_model
self.cloud_provider = cloud_provider # <-- Stored here
self.log = log
self.session_summary = "Session just started. No prior context."
self.interaction_buffer = []
self.COMPRESSION_THRESHOLD = 4
self.embeddings = OllamaEmbeddings(model=embed_model_name)
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
self.log("[dim italic]Loading persistent memory...[/dim italic]")
if os.path.exists(self.memory_file):
with open(self.memory_file, "r") as f:
self.persistent_memory = f.read().strip()
else:
self.persistent_memory = "No known user facts or long-term preferences."
if os.path.exists(self.index_path):
self.vectorstore = FAISS.load_local(self.index_path, self.embeddings, allow_dangerous_deserialization=True)
else:
self.log("[bold yellow]No memory index found. Building initial database...[/bold yellow]")
self.rebuild_index()
def get_line_count(self) -> int:
if not os.path.exists(self.memory_file):
return 0
with open(self.memory_file, "r") as f:
return sum(1 for _ in f)
def rebuild_index(self):
self.log("[dim italic]Reserializing memory log into vector database...[/dim italic]")
text = self.persistent_memory if self.persistent_memory else "No known user facts or long-term preferences."
chunks = self.text_splitter.split_text(text)
docs = [Document(page_content=c) for c in chunks]
self.vectorstore = FAISS.from_documents(docs, self.embeddings)
self.vectorstore.save_local(self.index_path)
self.log("[bold green]Memory database manually rebuilt and saved![/bold green]")
def compress_persistent_memory(self):
self.log("[bold yellow]Compressing persistent memory (removing duplicates and irrelevant data)...[/bold yellow]")
if not os.path.exists(self.memory_file):
self.log("[dim]Memory file is empty. Nothing to compress.[/dim]")
return
sys_prompt = """You are a strictly robotic data deduplication script. Your ONLY job is to compress the provided memory log.
RULES:
1. Remove duplicate facts.
2. Remove conversational text, essays, or philosophical analysis.
3. Output ONLY a clean, simple bulleted list of facts.
4. NEVER use headers, bold text, or introductory/closing remarks."""
try:
response = ollama.chat(model=self.local_model, messages=[
{'role': 'system', 'content': sys_prompt},
{'role': 'user', 'content': f"MEMORY LOG TO COMPRESS:\n{self.persistent_memory}"}
])
compressed_memory = response['message']['content'].strip()
compressed_memory = re.sub(r'.*?', '', compressed_memory, flags=re.DOTALL).strip()
with open(self.memory_file, "w") as f:
f.write(compressed_memory)
self.persistent_memory = compressed_memory
self.rebuild_index()
self.log("[bold green]Persistent memory successfully compressed and re-indexed![/bold green]")
except Exception as e:
self.log(f"[bold red]Memory compression failed: {e}[/bold red]")
def search(self, query: str) -> str:
if not getattr(self, 'vectorstore', None):
return "No long-term memories available."
docs = self.vectorstore.similarity_search(query, k=3)
return "\n".join([f"- {d.page_content}" for d in docs])
def add_interaction(self, user_input: str, bot_response: str):
self.interaction_buffer.append({"user": user_input, "agent": bot_response})
if len(self.interaction_buffer) >= self.COMPRESSION_THRESHOLD:
buffer_to_compress = list(self.interaction_buffer)
self.interaction_buffer = []
threading.Thread(target=self._compress_session, args=(buffer_to_compress,), daemon=True).start()
def _compress_session(self, buffer: List[Dict]):
buffer_text = "\n".join([f"User: {i['user']}\nAgent: {i['agent']}" for i in buffer])
sys_prompt = """You are a strict summarization script. Merge the recent interactions into the current session summary.
RULES:
1. Keep it brief and objective.
2. DO NOT write essays or analyze the user's intent.
3. Output ONLY the raw text of the updated summary. No conversational padding."""
try:
response = ollama.chat(model=self.local_model, messages=[
{'role': 'system', 'content': sys_prompt},
{'role': 'user', 'content': f"CURRENT SUMMARY:\n{self.session_summary}\n\nNEW INTERACTIONS:\n{buffer_text}"}
])
self.session_summary = response['message']['content'].strip()
self.session_summary = re.sub(r'.*?', '', self.session_summary, flags=re.DOTALL).strip()
except Exception as e:
self.log(f"[dim red]Background session compression failed: {e}[/dim red]")
def finalize_session(self):
self.log("[bold yellow]Extracting long-term memories from session...[/bold yellow]")
# 1. Sweep up the leftovers: Compress the final un-summarized interactions synchronously
if self.interaction_buffer:
self.log("[dim italic]Compressing final interactions into session summary...[/dim italic]")
self._compress_session(self.interaction_buffer)
self.interaction_buffer = [] # Clear it out
# Now self.session_summary contains the complete, 100% up-to-date story of the current session.
# --- STEP A: Cerebras "Teacher" Fact Extraction ---
self.log("[dim italic]Cloud model filtering permanent facts from session summary...[/dim italic]")
cloud_sys_prompt = "You are a neutral fact-extractor. Read the provided session summary and extract any NEW long-term facts, topics discussed, and explicit user preferences. Discard all temporary conversational context. Be highly objective, dense, and concise."
try:
cloud_msgs = [
{"role": "system", "content": cloud_sys_prompt},
{"role": "user", "content": f"SESSION SUMMARY:\n{self.session_summary}"}
]
cloud_response = self.cloud_provider.chat_completion(messages=cloud_msgs, stream=False)
clean_summary = cloud_response.choices[0].message.content.strip()
except Exception as e:
self.log(f"[dim red]Cloud extraction failed ({e}). Falling back to raw summary.[/dim red]")
clean_summary = self.session_summary
# --- STEP B: Local "Student" Formatting ---
self.log("[dim italic]Local model formatting permanent facts...[/dim italic]")
sys_prompt = """You are a strict data extraction pipeline. Your ONLY job is to take the extracted facts enclosed in the ``` block and format them perfectly for long-term storage.
RULES:
1. NEVER write conversational text, greetings, headers, or explanations.
2. ONLY output a raw, bulleted list of concise facts (e.g., "- User uses Emacs org-mode").
3. Focus ONLY on facts about the USER (their preferences, setup, identity, projects).
4. If there are NO new permanent facts to save, output EXACTLY and ONLY the word: NONE.
"""
# Enclose the clean summary in markdown blocks to isolate it from the prompt
user_prompt = f"Format these facts ONLY:\n\n```\n{clean_summary}\n```"
try:
response = ollama.chat(model=self.local_model, messages=[
{'role': 'system', 'content': sys_prompt},
{'role': 'user', 'content': user_prompt}
])
new_facts = response['message']['content'].strip()
new_facts = re.sub(r'.*?', '', new_facts, flags=re.DOTALL).strip()
if new_facts.upper() != "NONE" and new_facts:
with open(self.memory_file, "a") as f:
f.write(f"\n{new_facts}")
self.persistent_memory += f"\n{new_facts}"
self.log("[bold green]New facts appended to long-term memory log![/bold green]")
self.log("[dim]Note: Run /memory rebuild to index these new facts for next time.[/dim]")
else:
self.log("[dim]No new long-term facts detected. Skipping memory append.[/dim]")
except Exception as e:
self.log(f"[bold red]Failed to save long-term memory: {e}[/bold red]")