summaryrefslogtreecommitdiff
path: root/cerebral/memory.py
diff options
context:
space:
mode:
Diffstat (limited to 'cerebral/memory.py')
-rw-r--r--cerebral/memory.py176
1 files changed, 176 insertions, 0 deletions
diff --git a/cerebral/memory.py b/cerebral/memory.py
new file mode 100644
index 0000000..cbb6bb9
--- /dev/null
+++ b/cerebral/memory.py
@@ -0,0 +1,176 @@
+import os
+import re
+import threading
+import ollama
+from typing import List, Dict, Callable
+from langchain_ollama import OllamaEmbeddings
+from langchain_community.vectorstores import FAISS
+from langchain_core.documents import Document
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+
+# <-- NEW: Import BaseLLMProvider for type hinting
+from .llm import BaseLLMProvider
+
+class MemoryManager:
+ # <-- UPDATED: Added cloud_provider argument
+ def __init__(self, memory_file: str, index_path: str, local_model: str, embed_model_name: str, log: Callable[[str], None] = print, cloud_provider: BaseLLMProvider = None):
+ self.memory_file = memory_file
+ self.index_path = index_path
+ self.local_model = local_model
+ self.cloud_provider = cloud_provider # <-- Stored here
+ self.log = log
+
+ self.session_summary = "Session just started. No prior context."
+ self.interaction_buffer = []
+ self.COMPRESSION_THRESHOLD = 4
+
+ self.embeddings = OllamaEmbeddings(model=embed_model_name)
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
+
+ self.log("[dim italic]Loading persistent memory...[/dim italic]")
+ if os.path.exists(self.memory_file):
+ with open(self.memory_file, "r") as f:
+ self.persistent_memory = f.read().strip()
+ else:
+ self.persistent_memory = "No known user facts or long-term preferences."
+
+ if os.path.exists(self.index_path):
+ self.vectorstore = FAISS.load_local(self.index_path, self.embeddings, allow_dangerous_deserialization=True)
+ else:
+ self.log("[bold yellow]No memory index found. Building initial database...[/bold yellow]")
+ self.rebuild_index()
+
+ def get_line_count(self) -> int:
+ if not os.path.exists(self.memory_file):
+ return 0
+ with open(self.memory_file, "r") as f:
+ return sum(1 for _ in f)
+
+ def rebuild_index(self):
+ self.log("[dim italic]Reserializing memory log into vector database...[/dim italic]")
+ text = self.persistent_memory if self.persistent_memory else "No known user facts or long-term preferences."
+
+ chunks = self.text_splitter.split_text(text)
+ docs = [Document(page_content=c) for c in chunks]
+ self.vectorstore = FAISS.from_documents(docs, self.embeddings)
+ self.vectorstore.save_local(self.index_path)
+ self.log("[bold green]Memory database manually rebuilt and saved![/bold green]")
+
+ def compress_persistent_memory(self):
+ self.log("[bold yellow]Compressing persistent memory (removing duplicates and irrelevant data)...[/bold yellow]")
+ if not os.path.exists(self.memory_file):
+ self.log("[dim]Memory file is empty. Nothing to compress.[/dim]")
+ return
+
+ sys_prompt = """You are a strictly robotic data deduplication script. Your ONLY job is to compress the provided memory log.
+ RULES:
+ 1. Remove duplicate facts.
+ 2. Remove conversational text, essays, or philosophical analysis.
+ 3. Output ONLY a clean, simple bulleted list of facts.
+ 4. NEVER use headers, bold text, or introductory/closing remarks."""
+
+ try:
+ response = ollama.chat(model=self.local_model, messages=[
+ {'role': 'system', 'content': sys_prompt},
+ {'role': 'user', 'content': f"MEMORY LOG TO COMPRESS:\n{self.persistent_memory}"}
+ ])
+ compressed_memory = response['message']['content'].strip()
+ compressed_memory = re.sub(r'<think>.*?</think>', '', compressed_memory, flags=re.DOTALL).strip()
+
+ with open(self.memory_file, "w") as f:
+ f.write(compressed_memory)
+ self.persistent_memory = compressed_memory
+ self.rebuild_index()
+ self.log("[bold green]Persistent memory successfully compressed and re-indexed![/bold green]")
+ except Exception as e:
+ self.log(f"[bold red]Memory compression failed: {e}[/bold red]")
+
+ def search(self, query: str) -> str:
+ if not getattr(self, 'vectorstore', None):
+ return "No long-term memories available."
+ docs = self.vectorstore.similarity_search(query, k=3)
+ return "\n".join([f"- {d.page_content}" for d in docs])
+
+ def add_interaction(self, user_input: str, bot_response: str):
+ self.interaction_buffer.append({"user": user_input, "agent": bot_response})
+ if len(self.interaction_buffer) >= self.COMPRESSION_THRESHOLD:
+ buffer_to_compress = list(self.interaction_buffer)
+ self.interaction_buffer = []
+ threading.Thread(target=self._compress_session, args=(buffer_to_compress,), daemon=True).start()
+
+ def _compress_session(self, buffer: List[Dict]):
+ buffer_text = "\n".join([f"User: {i['user']}\nAgent: {i['agent']}" for i in buffer])
+ sys_prompt = """You are a strict summarization script. Merge the recent interactions into the current session summary.
+ RULES:
+ 1. Keep it brief and objective.
+ 2. DO NOT write essays or analyze the user's intent.
+ 3. Output ONLY the raw text of the updated summary. No conversational padding."""
+
+ try:
+ response = ollama.chat(model=self.local_model, messages=[
+ {'role': 'system', 'content': sys_prompt},
+ {'role': 'user', 'content': f"CURRENT SUMMARY:\n{self.session_summary}\n\nNEW INTERACTIONS:\n{buffer_text}"}
+ ])
+ self.session_summary = response['message']['content'].strip()
+ self.session_summary = re.sub(r'<think>.*?</think>', '', self.session_summary, flags=re.DOTALL).strip()
+ except Exception as e:
+ self.log(f"[dim red]Background session compression failed: {e}[/dim red]")
+
+ def finalize_session(self):
+ self.log("[bold yellow]Extracting long-term memories from session...[/bold yellow]")
+
+ # 1. Sweep up the leftovers: Compress the final un-summarized interactions synchronously
+ if self.interaction_buffer:
+ self.log("[dim italic]Compressing final interactions into session summary...[/dim italic]")
+ self._compress_session(self.interaction_buffer)
+ self.interaction_buffer = [] # Clear it out
+
+ # Now self.session_summary contains the complete, 100% up-to-date story of the current session.
+
+ # --- STEP A: Cerebras "Teacher" Fact Extraction ---
+ self.log("[dim italic]Cloud model filtering permanent facts from session summary...[/dim italic]")
+ cloud_sys_prompt = "You are a neutral fact-extractor. Read the provided session summary and extract any NEW long-term facts, topics discussed, and explicit user preferences. Discard all temporary conversational context. Be highly objective, dense, and concise."
+ try:
+ cloud_msgs = [
+ {"role": "system", "content": cloud_sys_prompt},
+ {"role": "user", "content": f"SESSION SUMMARY:\n{self.session_summary}"}
+ ]
+ cloud_response = self.cloud_provider.chat_completion(messages=cloud_msgs, stream=False)
+ clean_summary = cloud_response.choices[0].message.content.strip()
+ except Exception as e:
+ self.log(f"[dim red]Cloud extraction failed ({e}). Falling back to raw summary.[/dim red]")
+ clean_summary = self.session_summary
+
+ # --- STEP B: Local "Student" Formatting ---
+ self.log("[dim italic]Local model formatting permanent facts...[/dim italic]")
+ sys_prompt = """You are a strict data extraction pipeline. Your ONLY job is to take the extracted facts enclosed in the ``` block and format them perfectly for long-term storage.
+
+ RULES:
+ 1. NEVER write conversational text, greetings, headers, or explanations.
+ 2. ONLY output a raw, bulleted list of concise facts (e.g., "- User uses Emacs org-mode").
+ 3. Focus ONLY on facts about the USER (their preferences, setup, identity, projects).
+ 4. If there are NO new permanent facts to save, output EXACTLY and ONLY the word: NONE.
+ """
+
+ # Enclose the clean summary in markdown blocks to isolate it from the prompt
+ user_prompt = f"Format these facts ONLY:\n\n```\n{clean_summary}\n```"
+
+ try:
+ response = ollama.chat(model=self.local_model, messages=[
+ {'role': 'system', 'content': sys_prompt},
+ {'role': 'user', 'content': user_prompt}
+ ])
+
+ new_facts = response['message']['content'].strip()
+ new_facts = re.sub(r'<think>.*?</think>', '', new_facts, flags=re.DOTALL).strip()
+
+ if new_facts.upper() != "NONE" and new_facts:
+ with open(self.memory_file, "a") as f:
+ f.write(f"\n{new_facts}")
+ self.persistent_memory += f"\n{new_facts}"
+ self.log("[bold green]New facts appended to long-term memory log![/bold green]")
+ self.log("[dim]Note: Run /memory rebuild to index these new facts for next time.[/dim]")
+ else:
+ self.log("[dim]No new long-term facts detected. Skipping memory append.[/dim]")
+ except Exception as e:
+ self.log(f"[bold red]Failed to save long-term memory: {e}[/bold red]")