diff options
Diffstat (limited to 'cerebral/memory.py')
| -rw-r--r-- | cerebral/memory.py | 176 |
1 files changed, 176 insertions, 0 deletions
diff --git a/cerebral/memory.py b/cerebral/memory.py new file mode 100644 index 0000000..cbb6bb9 --- /dev/null +++ b/cerebral/memory.py @@ -0,0 +1,176 @@ +import os +import re +import threading +import ollama +from typing import List, Dict, Callable +from langchain_ollama import OllamaEmbeddings +from langchain_community.vectorstores import FAISS +from langchain_core.documents import Document +from langchain_text_splitters import RecursiveCharacterTextSplitter + +# <-- NEW: Import BaseLLMProvider for type hinting +from .llm import BaseLLMProvider + +class MemoryManager: + # <-- UPDATED: Added cloud_provider argument + def __init__(self, memory_file: str, index_path: str, local_model: str, embed_model_name: str, log: Callable[[str], None] = print, cloud_provider: BaseLLMProvider = None): + self.memory_file = memory_file + self.index_path = index_path + self.local_model = local_model + self.cloud_provider = cloud_provider # <-- Stored here + self.log = log + + self.session_summary = "Session just started. No prior context." + self.interaction_buffer = [] + self.COMPRESSION_THRESHOLD = 4 + + self.embeddings = OllamaEmbeddings(model=embed_model_name) + self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50) + + self.log("[dim italic]Loading persistent memory...[/dim italic]") + if os.path.exists(self.memory_file): + with open(self.memory_file, "r") as f: + self.persistent_memory = f.read().strip() + else: + self.persistent_memory = "No known user facts or long-term preferences." + + if os.path.exists(self.index_path): + self.vectorstore = FAISS.load_local(self.index_path, self.embeddings, allow_dangerous_deserialization=True) + else: + self.log("[bold yellow]No memory index found. Building initial database...[/bold yellow]") + self.rebuild_index() + + def get_line_count(self) -> int: + if not os.path.exists(self.memory_file): + return 0 + with open(self.memory_file, "r") as f: + return sum(1 for _ in f) + + def rebuild_index(self): + self.log("[dim italic]Reserializing memory log into vector database...[/dim italic]") + text = self.persistent_memory if self.persistent_memory else "No known user facts or long-term preferences." + + chunks = self.text_splitter.split_text(text) + docs = [Document(page_content=c) for c in chunks] + self.vectorstore = FAISS.from_documents(docs, self.embeddings) + self.vectorstore.save_local(self.index_path) + self.log("[bold green]Memory database manually rebuilt and saved![/bold green]") + + def compress_persistent_memory(self): + self.log("[bold yellow]Compressing persistent memory (removing duplicates and irrelevant data)...[/bold yellow]") + if not os.path.exists(self.memory_file): + self.log("[dim]Memory file is empty. Nothing to compress.[/dim]") + return + + sys_prompt = """You are a strictly robotic data deduplication script. Your ONLY job is to compress the provided memory log. + RULES: + 1. Remove duplicate facts. + 2. Remove conversational text, essays, or philosophical analysis. + 3. Output ONLY a clean, simple bulleted list of facts. + 4. NEVER use headers, bold text, or introductory/closing remarks.""" + + try: + response = ollama.chat(model=self.local_model, messages=[ + {'role': 'system', 'content': sys_prompt}, + {'role': 'user', 'content': f"MEMORY LOG TO COMPRESS:\n{self.persistent_memory}"} + ]) + compressed_memory = response['message']['content'].strip() + compressed_memory = re.sub(r'<think>.*?</think>', '', compressed_memory, flags=re.DOTALL).strip() + + with open(self.memory_file, "w") as f: + f.write(compressed_memory) + self.persistent_memory = compressed_memory + self.rebuild_index() + self.log("[bold green]Persistent memory successfully compressed and re-indexed![/bold green]") + except Exception as e: + self.log(f"[bold red]Memory compression failed: {e}[/bold red]") + + def search(self, query: str) -> str: + if not getattr(self, 'vectorstore', None): + return "No long-term memories available." + docs = self.vectorstore.similarity_search(query, k=3) + return "\n".join([f"- {d.page_content}" for d in docs]) + + def add_interaction(self, user_input: str, bot_response: str): + self.interaction_buffer.append({"user": user_input, "agent": bot_response}) + if len(self.interaction_buffer) >= self.COMPRESSION_THRESHOLD: + buffer_to_compress = list(self.interaction_buffer) + self.interaction_buffer = [] + threading.Thread(target=self._compress_session, args=(buffer_to_compress,), daemon=True).start() + + def _compress_session(self, buffer: List[Dict]): + buffer_text = "\n".join([f"User: {i['user']}\nAgent: {i['agent']}" for i in buffer]) + sys_prompt = """You are a strict summarization script. Merge the recent interactions into the current session summary. + RULES: + 1. Keep it brief and objective. + 2. DO NOT write essays or analyze the user's intent. + 3. Output ONLY the raw text of the updated summary. No conversational padding.""" + + try: + response = ollama.chat(model=self.local_model, messages=[ + {'role': 'system', 'content': sys_prompt}, + {'role': 'user', 'content': f"CURRENT SUMMARY:\n{self.session_summary}\n\nNEW INTERACTIONS:\n{buffer_text}"} + ]) + self.session_summary = response['message']['content'].strip() + self.session_summary = re.sub(r'<think>.*?</think>', '', self.session_summary, flags=re.DOTALL).strip() + except Exception as e: + self.log(f"[dim red]Background session compression failed: {e}[/dim red]") + + def finalize_session(self): + self.log("[bold yellow]Extracting long-term memories from session...[/bold yellow]") + + # 1. Sweep up the leftovers: Compress the final un-summarized interactions synchronously + if self.interaction_buffer: + self.log("[dim italic]Compressing final interactions into session summary...[/dim italic]") + self._compress_session(self.interaction_buffer) + self.interaction_buffer = [] # Clear it out + + # Now self.session_summary contains the complete, 100% up-to-date story of the current session. + + # --- STEP A: Cerebras "Teacher" Fact Extraction --- + self.log("[dim italic]Cloud model filtering permanent facts from session summary...[/dim italic]") + cloud_sys_prompt = "You are a neutral fact-extractor. Read the provided session summary and extract any NEW long-term facts, topics discussed, and explicit user preferences. Discard all temporary conversational context. Be highly objective, dense, and concise." + try: + cloud_msgs = [ + {"role": "system", "content": cloud_sys_prompt}, + {"role": "user", "content": f"SESSION SUMMARY:\n{self.session_summary}"} + ] + cloud_response = self.cloud_provider.chat_completion(messages=cloud_msgs, stream=False) + clean_summary = cloud_response.choices[0].message.content.strip() + except Exception as e: + self.log(f"[dim red]Cloud extraction failed ({e}). Falling back to raw summary.[/dim red]") + clean_summary = self.session_summary + + # --- STEP B: Local "Student" Formatting --- + self.log("[dim italic]Local model formatting permanent facts...[/dim italic]") + sys_prompt = """You are a strict data extraction pipeline. Your ONLY job is to take the extracted facts enclosed in the ``` block and format them perfectly for long-term storage. + + RULES: + 1. NEVER write conversational text, greetings, headers, or explanations. + 2. ONLY output a raw, bulleted list of concise facts (e.g., "- User uses Emacs org-mode"). + 3. Focus ONLY on facts about the USER (their preferences, setup, identity, projects). + 4. If there are NO new permanent facts to save, output EXACTLY and ONLY the word: NONE. + """ + + # Enclose the clean summary in markdown blocks to isolate it from the prompt + user_prompt = f"Format these facts ONLY:\n\n```\n{clean_summary}\n```" + + try: + response = ollama.chat(model=self.local_model, messages=[ + {'role': 'system', 'content': sys_prompt}, + {'role': 'user', 'content': user_prompt} + ]) + + new_facts = response['message']['content'].strip() + new_facts = re.sub(r'<think>.*?</think>', '', new_facts, flags=re.DOTALL).strip() + + if new_facts.upper() != "NONE" and new_facts: + with open(self.memory_file, "a") as f: + f.write(f"\n{new_facts}") + self.persistent_memory += f"\n{new_facts}" + self.log("[bold green]New facts appended to long-term memory log![/bold green]") + self.log("[dim]Note: Run /memory rebuild to index these new facts for next time.[/dim]") + else: + self.log("[dim]No new long-term facts detected. Skipping memory append.[/dim]") + except Exception as e: + self.log(f"[bold red]Failed to save long-term memory: {e}[/bold red]") |
