import os import re import threading import ollama from typing import List, Dict, Callable from langchain_ollama import OllamaEmbeddings from langchain_community.vectorstores import FAISS from langchain_core.documents import Document from langchain_text_splitters import RecursiveCharacterTextSplitter # <-- NEW: Import BaseLLMProvider for type hinting from .llm import BaseLLMProvider class MemoryManager: # <-- UPDATED: Added cloud_provider argument def __init__(self, memory_file: str, index_path: str, local_model: str, embed_model_name: str, log: Callable[[str], None] = print, cloud_provider: BaseLLMProvider = None): self.memory_file = memory_file self.index_path = index_path self.local_model = local_model self.cloud_provider = cloud_provider # <-- Stored here self.log = log self.session_summary = "Session just started. No prior context." self.interaction_buffer = [] self.COMPRESSION_THRESHOLD = 4 self.embeddings = OllamaEmbeddings(model=embed_model_name) self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50) self.log("[dim italic]Loading persistent memory...[/dim italic]") if os.path.exists(self.memory_file): with open(self.memory_file, "r") as f: self.persistent_memory = f.read().strip() else: self.persistent_memory = "No known user facts or long-term preferences." if os.path.exists(self.index_path): self.vectorstore = FAISS.load_local(self.index_path, self.embeddings, allow_dangerous_deserialization=True) else: self.log("[bold yellow]No memory index found. Building initial database...[/bold yellow]") self.rebuild_index() def get_line_count(self) -> int: if not os.path.exists(self.memory_file): return 0 with open(self.memory_file, "r") as f: return sum(1 for _ in f) def rebuild_index(self): self.log("[dim italic]Reserializing memory log into vector database...[/dim italic]") text = self.persistent_memory if self.persistent_memory else "No known user facts or long-term preferences." chunks = self.text_splitter.split_text(text) docs = [Document(page_content=c) for c in chunks] self.vectorstore = FAISS.from_documents(docs, self.embeddings) self.vectorstore.save_local(self.index_path) self.log("[bold green]Memory database manually rebuilt and saved![/bold green]") def compress_persistent_memory(self): self.log("[bold yellow]Compressing persistent memory (removing duplicates and irrelevant data)...[/bold yellow]") if not os.path.exists(self.memory_file): self.log("[dim]Memory file is empty. Nothing to compress.[/dim]") return sys_prompt = """You are a strictly robotic data deduplication script. Your ONLY job is to compress the provided memory log. RULES: 1. Remove duplicate facts. 2. Remove conversational text, essays, or philosophical analysis. 3. Output ONLY a clean, simple bulleted list of facts. 4. NEVER use headers, bold text, or introductory/closing remarks.""" try: response = ollama.chat(model=self.local_model, messages=[ {'role': 'system', 'content': sys_prompt}, {'role': 'user', 'content': f"MEMORY LOG TO COMPRESS:\n{self.persistent_memory}"} ]) compressed_memory = response['message']['content'].strip() compressed_memory = re.sub(r'.*?', '', compressed_memory, flags=re.DOTALL).strip() with open(self.memory_file, "w") as f: f.write(compressed_memory) self.persistent_memory = compressed_memory self.rebuild_index() self.log("[bold green]Persistent memory successfully compressed and re-indexed![/bold green]") except Exception as e: self.log(f"[bold red]Memory compression failed: {e}[/bold red]") def search(self, query: str) -> str: if not getattr(self, 'vectorstore', None): return "No long-term memories available." docs = self.vectorstore.similarity_search(query, k=3) return "\n".join([f"- {d.page_content}" for d in docs]) def add_interaction(self, user_input: str, bot_response: str): self.interaction_buffer.append({"user": user_input, "agent": bot_response}) if len(self.interaction_buffer) >= self.COMPRESSION_THRESHOLD: buffer_to_compress = list(self.interaction_buffer) self.interaction_buffer = [] threading.Thread(target=self._compress_session, args=(buffer_to_compress,), daemon=True).start() def _compress_session(self, buffer: List[Dict]): buffer_text = "\n".join([f"User: {i['user']}\nAgent: {i['agent']}" for i in buffer]) sys_prompt = """You are a strict summarization script. Merge the recent interactions into the current session summary. RULES: 1. Keep it brief and objective. 2. DO NOT write essays or analyze the user's intent. 3. Output ONLY the raw text of the updated summary. No conversational padding.""" try: response = ollama.chat(model=self.local_model, messages=[ {'role': 'system', 'content': sys_prompt}, {'role': 'user', 'content': f"CURRENT SUMMARY:\n{self.session_summary}\n\nNEW INTERACTIONS:\n{buffer_text}"} ]) self.session_summary = response['message']['content'].strip() self.session_summary = re.sub(r'.*?', '', self.session_summary, flags=re.DOTALL).strip() except Exception as e: self.log(f"[dim red]Background session compression failed: {e}[/dim red]") def finalize_session(self): self.log("[bold yellow]Extracting long-term memories from session...[/bold yellow]") # 1. Sweep up the leftovers: Compress the final un-summarized interactions synchronously if self.interaction_buffer: self.log("[dim italic]Compressing final interactions into session summary...[/dim italic]") self._compress_session(self.interaction_buffer) self.interaction_buffer = [] # Clear it out # Now self.session_summary contains the complete, 100% up-to-date story of the current session. # --- STEP A: Cerebras "Teacher" Fact Extraction --- self.log("[dim italic]Cloud model filtering permanent facts from session summary...[/dim italic]") cloud_sys_prompt = "You are a neutral fact-extractor. Read the provided session summary and extract any NEW long-term facts, topics discussed, and explicit user preferences. Discard all temporary conversational context. Be highly objective, dense, and concise." try: cloud_msgs = [ {"role": "system", "content": cloud_sys_prompt}, {"role": "user", "content": f"SESSION SUMMARY:\n{self.session_summary}"} ] cloud_response = self.cloud_provider.chat_completion(messages=cloud_msgs, stream=False) clean_summary = cloud_response.choices[0].message.content.strip() except Exception as e: self.log(f"[dim red]Cloud extraction failed ({e}). Falling back to raw summary.[/dim red]") clean_summary = self.session_summary # --- STEP B: Local "Student" Formatting --- self.log("[dim italic]Local model formatting permanent facts...[/dim italic]") sys_prompt = """You are a strict data extraction pipeline. Your ONLY job is to take the extracted facts enclosed in the ``` block and format them perfectly for long-term storage. RULES: 1. NEVER write conversational text, greetings, headers, or explanations. 2. ONLY output a raw, bulleted list of concise facts (e.g., "- User uses Emacs org-mode"). 3. Focus ONLY on facts about the USER (their preferences, setup, identity, projects). 4. If there are NO new permanent facts to save, output EXACTLY and ONLY the word: NONE. """ # Enclose the clean summary in markdown blocks to isolate it from the prompt user_prompt = f"Format these facts ONLY:\n\n```\n{clean_summary}\n```" try: response = ollama.chat(model=self.local_model, messages=[ {'role': 'system', 'content': sys_prompt}, {'role': 'user', 'content': user_prompt} ]) new_facts = response['message']['content'].strip() new_facts = re.sub(r'.*?', '', new_facts, flags=re.DOTALL).strip() if new_facts.upper() != "NONE" and new_facts: with open(self.memory_file, "a") as f: f.write(f"\n{new_facts}") self.persistent_memory += f"\n{new_facts}" self.log("[bold green]New facts appended to long-term memory log![/bold green]") self.log("[dim]Note: Run /memory rebuild to index these new facts for next time.[/dim]") else: self.log("[dim]No new long-term facts detected. Skipping memory append.[/dim]") except Exception as e: self.log(f"[bold red]Failed to save long-term memory: {e}[/bold red]")