diff options
| -rw-r--r-- | cerebral/agent.py | 336 | ||||
| -rw-r--r-- | cerebral/config.py | 23 | ||||
| -rw-r--r-- | cerebral/llm.py | 29 | ||||
| -rw-r--r-- | cerebral/memory.py | 176 | ||||
| -rw-r--r-- | cerebral/tools.py | 194 | ||||
| -rw-r--r-- | main.py | 675 |
6 files changed, 776 insertions, 657 deletions
diff --git a/cerebral/agent.py b/cerebral/agent.py new file mode 100644 index 0000000..20cb210 --- /dev/null +++ b/cerebral/agent.py @@ -0,0 +1,336 @@ +import json +import hashlib +import re +import datetime +import time +import ollama +from typing import List, Dict, Generator, Optional, Callable + +from .config import LOCAL_LLM, LOCAL_EMBED_MODEL, PKM_DIR, MEMORY_FILE, MEMORY_INDEX_PATH, FAISS_INDEX_PATH, HASH_TRACKER_FILE, AGENDA_FILE +from .llm import BaseLLMProvider +from .memory import MemoryManager +from .tools import PKMManager, VisionProcessor, WebSearcher, AgendaManager + +class CerebralAgent: + def __init__(self, provider: BaseLLMProvider, log: Callable[[str], None] = print): + self.provider = provider + self.log = log + + self.log("[bold magenta]Initializing Cerebral Agent Modules...[/bold magenta]") + + # --- FIX: Add cloud_provider=self.provider back to this line --- + self.memory = MemoryManager( + MEMORY_FILE, + MEMORY_INDEX_PATH, + LOCAL_LLM, + LOCAL_EMBED_MODEL, + self.log, + cloud_provider=self.provider + ) + # -------------------------------------------------------------- + + self.pkm = PKMManager(PKM_DIR, FAISS_INDEX_PATH, HASH_TRACKER_FILE, LOCAL_EMBED_MODEL, self.log) + self.vision = VisionProcessor(LOCAL_LLM, self.log) + self.web = WebSearcher(self.log) + self.agenda = AgendaManager(AGENDA_FILE, self.log) + + def generate_session_filename(self, first_prompt: str, first_response: str) -> str: + self.log("[dim italic]Generating descriptive filename based on prompt and response...[/dim italic]") + + hash_input = (first_prompt + first_response).encode('utf-8') + combined_hash = hashlib.sha256(hash_input).hexdigest()[:6] + + sys_prompt = "You are a file naming utility. Read the user's prompt and generate a short, descriptive filename base using ONLY lowercase letters and hyphens. Do NOT add an extension. ONLY output the base filename, absolutely no other text. Example: learning-python-basics" + + try: + response = ollama.chat(model=LOCAL_LLM, messages=[ + {'role': 'system', 'content': sys_prompt}, + {'role': 'user', 'content': first_prompt} + ]) + + raw_content = response['message']['content'].strip() + raw_content = re.sub(r'<think>.*?</think>', '', raw_content, flags=re.DOTALL).strip() + + lines = [line.strip() for line in raw_content.split('\n') if line.strip()] + raw_filename = lines[-1].lower().replace(' ', '-') if lines else "cerebral-session" + clean_base = re.sub(r'[^a-z0-9\-]', '', raw_filename).strip('-') + clean_base = clean_base[:50].strip('-') + + if not clean_base: + clean_base = "cerebral-session" + + final_filename = f"{clean_base}-{combined_hash}.org" + return final_filename + + except Exception as e: + self.log(f"[dim red]Filename generation failed: {e}. Defaulting.[/dim red]") + return f"cerebral-session-{combined_hash}.org" + + def _get_tools(self) -> List[Dict]: + return [ + { + "type": "function", + "function": { + "name": "search_pkm", + "description": "Search the user's personal knowledge base (PKM) for notes, code, or org files.", + "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]} + } + }, + { + "type": "function", + "function": { + "name": "search_web", + "description": "Search the live internet for current events, external documentation, or facts outside your PKM.", + "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]} + } + }, + { + "type": "function", + "function": { + "name": "append_agenda", + "description": "Schedule a new TODO item in the user's private agenda.org file. Use this when the user asks to remember to do something, schedule a task, or add an agenda item.", + "parameters": { + "type": "object", + "properties": { + "task": { + "type": "string", + "description": "The concise title of the task (e.g., 'Buy groceries' or 'Review neuralforge code')" + }, + "scheduled": { + "type": "string", + "description": "Optional. The exact date and time. If a time is requested or implied, you MUST format it strictly as 'YYYY-MM-DD HH:MM' (e.g., '2026-03-27 11:00'). If ONLY a day is requested with absolutely no time, use 'YYYY-MM-DD' (e.g., '2026-03-27'). Do not include the angle brackets (< >), just the raw date/time string." + }, + "description": { + "type": "string", + "description": "Optional. A brief, 1-2 sentence description or note about the task/event." + } + }, + "required": ["task"] + } + } + } + ] + + def chat_stream(self, prompt: str, image_path: Optional[str] = None) -> Generator[str, None, str]: + recent_history = "" + if self.memory.interaction_buffer: + recent_history = "\nRECENT UNCOMPRESSED TURNS:\n" + "\n".join( + [f"User: {i['user']}\nAgent: {i['agent']}" for i in self.memory.interaction_buffer] + ) + + vision_context = "" + if image_path: + self.log("[dim italic]Analyzing image context...[/dim italic]") + vision_summary = self.vision.process(image_path, prompt) + vision_context = f"\n[USER ATTACHED AN IMAGE. Local Vision Summary: {vision_summary}]\n" + + self.log("[dim italic]Querying long-term memory (Ollama Embeddings)...[/dim italic]") + relevant_memories = self.memory.search(prompt) + + current_time = datetime.datetime.now().strftime("%A, %B %d, %Y at %I:%M %p") + + system_prompt = f"""You are a highly capable AI assistant. + + CRITICAL OUTPUT FORMATTING: + You MUST output your responses EXCLUSIVELY in Emacs org-mode format. Use org-mode headings, lists, and LaTeX fragments for math. + + FORMATTING RULES: + 1. NEVER use double asterisks (`**`) for bolding. You MUST use SINGLE asterisks for bold emphasis (e.g., *this is bold*). Double asterisks will break the parser. + 2. Cite your sources inline using proper org-mode link syntax. For web searches, use [[url][Description]]. For PKM files, use [[file:/path/to/file.org][Filename]]. + 3. At the very end of your response, you MUST append a Level 1 heading `* Sources` and neatly list all the search results and PKM documents you referenced using proper org-mode syntax. + + CURRENT TIME AND DATE: + {current_time} + + RESPONSE STYLE GUIDELINES: + - Provide EXTREMELY detailed, exhaustive, and comprehensive answers. + - Write in long-form prose. Do not be brief; expand deeply on concepts. + - Use multiple paragraphs, deep conceptual explanations, and thorough analysis. + + RELEVANT LONG-TERM MEMORIES: + {relevant_memories} + + COMPRESSED SESSION CONTEXT: {self.memory.session_summary} + {recent_history} + """ + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt + vision_context} + ] + + # --- THE REASONING LOOP --- + MAX_ROUNDS = 3 + current_round = 0 + + while current_round < MAX_ROUNDS: + if current_round > 0: + # FIX 1: Pacing buffer to prevent 429 Rate Limits from back-to-back tool calls + time.sleep(1.5) + + remaining = MAX_ROUNDS - current_round + self.log(f"[dim italic]Reasoning cycle ({current_round + 1}/{MAX_ROUNDS})...[/dim italic]") + + budget_reminder = { + "role": "system", + "content": f"REASONING BUDGET: This is reasoning cycle {current_round + 1} of {MAX_ROUNDS}. " + f"You have {remaining - 1} tool-calling cycles left after this one. " + if remaining > 1 else + "FINAL REASONING CYCLE: You MUST finalise now." + "You will not be allowed to call more tools after this turn." + } + + current_messages = messages + [budget_reminder] + + valid_tool_calls = False + response_message = None + allowed_tool_names = [t["function"]["name"] for t in self._get_tools()] + + for attempt in range(3): + pre_flight = self.provider.chat_completion(messages=current_messages, tools=self._get_tools(), stream=False) + response_message = pre_flight.choices[0].message + + if not response_message.tool_calls and response_message.content and "**name**:" in response_message.content: + self.log(f"[dim yellow]Model hallucinated Markdown tool call. Retrying ({attempt+1}/3)...[/dim yellow]") + error_msg = f"ERROR: You attempted to call a tool using Markdown text. You MUST use the native JSON tool calling API. Allowed tools: {allowed_tool_names}" + messages.append({"role": "assistant", "content": response_message.content}) + messages.append({"role": "user", "content": error_msg}) + continue + + if not response_message.tool_calls: + valid_tool_calls = True + break + + has_errors = False + error_feedbacks = [] + + for tool_call in response_message.tool_calls: + func_name = tool_call.function.name + call_error = None + + if func_name not in allowed_tool_names: + has_errors = True + call_error = f"Tool '{func_name}' does not exist. Allowed tools: {allowed_tool_names}" + else: + try: + json.loads(tool_call.function.arguments) + except json.JSONDecodeError: + has_errors = True + call_error = f"Arguments for '{func_name}' are not valid JSON: {tool_call.function.arguments}" + + error_feedbacks.append(call_error) + + if has_errors: + self.log(f"[dim yellow]Malformed tool call detected. Retrying ({attempt+1}/3)...[/dim yellow]") + + assistant_msg = { + "role": "assistant", + "content": response_message.content or "", + "tool_calls": [ + { + "id": t.id, + "type": "function", + "function": {"name": t.function.name, "arguments": t.function.arguments} + } for t in response_message.tool_calls + ] + } + messages.append(assistant_msg) + + for i, tool_call in enumerate(response_message.tool_calls): + err = error_feedbacks[i] + msg_content = f"ERROR: {err}" if err else "Error: Another tool in this batch failed. Please fix the batch and retry." + messages.append({"role": "tool", "tool_call_id": tool_call.id, "content": msg_content}) + continue + + valid_tool_calls = True + break + + if not valid_tool_calls: + self.log("[bold red]Failed to generate valid tool calls. Breaking reasoning loop.[/bold red]") + response_message.tool_calls = None + break + + # --- FIX 2: Zero-Waste Response Generation --- + if not response_message.tool_calls: + if current_round == 0: + self.log("[dim italic]No tools needed. Outputting response...[/dim italic]") + else: + self.log("[dim italic]Reasoning complete. Outputting response...[/dim italic]") + + content = response_message.content or "" + + # Artificially stream the pre-generated block so the UI stays smooth + chunk_size = 30 + for i in range(0, len(content), chunk_size): + yield content[i:i+chunk_size] + time.sleep(0.01) + + self.memory.add_interaction(prompt, content) + return content + + # --- Execute Validated Tools --- + assistant_msg = { + "role": "assistant", + "content": response_message.content or "", + "tool_calls": [ + { + "id": t.id, + "type": "function", + "function": {"name": t.function.name, "arguments": t.function.arguments} + } for t in response_message.tool_calls + ] + } + messages.append(assistant_msg) + + for tool_call in response_message.tool_calls: + func_name = tool_call.function.name + args = json.loads(tool_call.function.arguments) + + if func_name == "search_pkm": + q = args.get("query", prompt) + self.log(f"[cyan]🧠 Tool Call: Searching PKM for '{q}'...[/cyan]") + yield f"\n*(Agent Note: Searched PKM for `{q}`)*\n\n" + result = self.pkm.search(q) + elif func_name == "search_web": + q = args.get("query", prompt) + self.log(f"[cyan]🌐 Tool Call: Searching Web for '{q}'...[/cyan]") + yield f"\n*(Agent Note: Searched Web for `{q}`)*\n\n" + result = self.web.search(q) + elif func_name == "append_agenda": + task = args.get("task", "Untitled Task") + scheduled = args.get("scheduled", "") + description = args.get("description", "") # <-- Extract the new param + self.log(f"[cyan]📅 Tool Call: Appending to Agenda: '{task}'...[/cyan]") + yield f"\n*(Agent Note: Added `{task}` to agenda)*\n\n" + + # Pass the description to the Python tool + result = self.agenda.append_todo(task, scheduled, description) + + messages.append({"role": "tool", "tool_call_id": tool_call.id, "content": result}) + + current_round += 1 + + # --- FALLBACK FINAL STREAMING RESPONSE --- + # Only reached if the agent maxes out all 3 reasoning rounds and still hasn't answered + self.log("[dim italic]Max rounds reached. Forcing final response...[/dim italic]") + time.sleep(1.5) + + messages.append({ + "role": "system", + "content": "You have reached the maximum number of reasoning steps. You must now provide your final, comprehensive answer based on the context gathered so far. Use strict org-mode formatting." + }) + + self.log("[dim italic]Streaming final response...[/dim italic]") + stream = self.provider.chat_completion(messages=messages, tools=self._get_tools(), stream=True, tool_choice="none") + + full_response = "" + for chunk in stream: + content = chunk.choices[0].delta.content or "" + full_response += content + yield content + + self.memory.add_interaction(prompt, full_response) + return full_response + + def shutdown(self): + self.memory.finalize_session() diff --git a/cerebral/config.py b/cerebral/config.py new file mode 100644 index 0000000..7acb1d7 --- /dev/null +++ b/cerebral/config.py @@ -0,0 +1,23 @@ +import os +from dotenv import load_dotenv + +load_dotenv() + +LOCAL_LLM = "qwen3-vl:8b" +LOCAL_EMBED_MODEL = "nomic-embed-text-v2-moe:latest" +PKM_DIR = os.path.expanduser("~/monorepo") + +XDG_CONFIG_HOME = os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")) +APP_CONFIG_DIR = os.path.join(XDG_CONFIG_HOME, "cerebral") +APP_CACHE_DIR = os.path.expanduser("~/.cache/cerebral") +ORG_OUTPUT_DIR = os.path.expanduser("~/org/cerebral") +AGENDA_FILE = os.path.expanduser("~/org/agenda.org") + +os.makedirs(APP_CONFIG_DIR, exist_ok=True) +os.makedirs(APP_CACHE_DIR, exist_ok=True) +os.makedirs(ORG_OUTPUT_DIR, exist_ok=True) + +MEMORY_FILE = os.path.join(APP_CACHE_DIR, "memory_summary.txt") +MEMORY_INDEX_PATH = os.path.join(APP_CACHE_DIR, "memory_index") +FAISS_INDEX_PATH = os.path.join(APP_CONFIG_DIR, "pkm_index") +HASH_TRACKER_FILE = os.path.join(APP_CONFIG_DIR, "latest_commit.txt") diff --git a/cerebral/llm.py b/cerebral/llm.py new file mode 100644 index 0000000..7eb1597 --- /dev/null +++ b/cerebral/llm.py @@ -0,0 +1,29 @@ +import os +from abc import ABC, abstractmethod +from typing import List, Dict, Any +from cerebras.cloud.sdk import Cerebras + +class BaseLLMProvider(ABC): + @abstractmethod + def chat_completion(self, messages: List[Dict], tools: List[Dict] = None, stream: bool = False, tool_choice: str = "auto") -> Any: + pass + +class CerebrasProvider(BaseLLMProvider): + def __init__(self, model: str = "qwen-3-235b-a22b-instruct-2507"): + api_key = os.environ.get("CEREBRAS_API_KEY") + if not api_key: + raise ValueError("CEREBRAS_API_KEY environment variable is required.") + self.client = Cerebras(api_key=api_key) + self.model = model + + def chat_completion(self, messages: List[Dict], tools: List[Dict] = None, stream: bool = False, tool_choice: str = "auto"): + kwargs = { + "messages": messages, + "model": self.model, + "stream": stream, + } + if tools: + kwargs["tools"] = tools + kwargs["tool_choice"] = tool_choice + + return self.client.chat.completions.create(**kwargs) diff --git a/cerebral/memory.py b/cerebral/memory.py new file mode 100644 index 0000000..cbb6bb9 --- /dev/null +++ b/cerebral/memory.py @@ -0,0 +1,176 @@ +import os +import re +import threading +import ollama +from typing import List, Dict, Callable +from langchain_ollama import OllamaEmbeddings +from langchain_community.vectorstores import FAISS +from langchain_core.documents import Document +from langchain_text_splitters import RecursiveCharacterTextSplitter + +# <-- NEW: Import BaseLLMProvider for type hinting +from .llm import BaseLLMProvider + +class MemoryManager: + # <-- UPDATED: Added cloud_provider argument + def __init__(self, memory_file: str, index_path: str, local_model: str, embed_model_name: str, log: Callable[[str], None] = print, cloud_provider: BaseLLMProvider = None): + self.memory_file = memory_file + self.index_path = index_path + self.local_model = local_model + self.cloud_provider = cloud_provider # <-- Stored here + self.log = log + + self.session_summary = "Session just started. No prior context." + self.interaction_buffer = [] + self.COMPRESSION_THRESHOLD = 4 + + self.embeddings = OllamaEmbeddings(model=embed_model_name) + self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50) + + self.log("[dim italic]Loading persistent memory...[/dim italic]") + if os.path.exists(self.memory_file): + with open(self.memory_file, "r") as f: + self.persistent_memory = f.read().strip() + else: + self.persistent_memory = "No known user facts or long-term preferences." + + if os.path.exists(self.index_path): + self.vectorstore = FAISS.load_local(self.index_path, self.embeddings, allow_dangerous_deserialization=True) + else: + self.log("[bold yellow]No memory index found. Building initial database...[/bold yellow]") + self.rebuild_index() + + def get_line_count(self) -> int: + if not os.path.exists(self.memory_file): + return 0 + with open(self.memory_file, "r") as f: + return sum(1 for _ in f) + + def rebuild_index(self): + self.log("[dim italic]Reserializing memory log into vector database...[/dim italic]") + text = self.persistent_memory if self.persistent_memory else "No known user facts or long-term preferences." + + chunks = self.text_splitter.split_text(text) + docs = [Document(page_content=c) for c in chunks] + self.vectorstore = FAISS.from_documents(docs, self.embeddings) + self.vectorstore.save_local(self.index_path) + self.log("[bold green]Memory database manually rebuilt and saved![/bold green]") + + def compress_persistent_memory(self): + self.log("[bold yellow]Compressing persistent memory (removing duplicates and irrelevant data)...[/bold yellow]") + if not os.path.exists(self.memory_file): + self.log("[dim]Memory file is empty. Nothing to compress.[/dim]") + return + + sys_prompt = """You are a strictly robotic data deduplication script. Your ONLY job is to compress the provided memory log. + RULES: + 1. Remove duplicate facts. + 2. Remove conversational text, essays, or philosophical analysis. + 3. Output ONLY a clean, simple bulleted list of facts. + 4. NEVER use headers, bold text, or introductory/closing remarks.""" + + try: + response = ollama.chat(model=self.local_model, messages=[ + {'role': 'system', 'content': sys_prompt}, + {'role': 'user', 'content': f"MEMORY LOG TO COMPRESS:\n{self.persistent_memory}"} + ]) + compressed_memory = response['message']['content'].strip() + compressed_memory = re.sub(r'<think>.*?</think>', '', compressed_memory, flags=re.DOTALL).strip() + + with open(self.memory_file, "w") as f: + f.write(compressed_memory) + self.persistent_memory = compressed_memory + self.rebuild_index() + self.log("[bold green]Persistent memory successfully compressed and re-indexed![/bold green]") + except Exception as e: + self.log(f"[bold red]Memory compression failed: {e}[/bold red]") + + def search(self, query: str) -> str: + if not getattr(self, 'vectorstore', None): + return "No long-term memories available." + docs = self.vectorstore.similarity_search(query, k=3) + return "\n".join([f"- {d.page_content}" for d in docs]) + + def add_interaction(self, user_input: str, bot_response: str): + self.interaction_buffer.append({"user": user_input, "agent": bot_response}) + if len(self.interaction_buffer) >= self.COMPRESSION_THRESHOLD: + buffer_to_compress = list(self.interaction_buffer) + self.interaction_buffer = [] + threading.Thread(target=self._compress_session, args=(buffer_to_compress,), daemon=True).start() + + def _compress_session(self, buffer: List[Dict]): + buffer_text = "\n".join([f"User: {i['user']}\nAgent: {i['agent']}" for i in buffer]) + sys_prompt = """You are a strict summarization script. Merge the recent interactions into the current session summary. + RULES: + 1. Keep it brief and objective. + 2. DO NOT write essays or analyze the user's intent. + 3. Output ONLY the raw text of the updated summary. No conversational padding.""" + + try: + response = ollama.chat(model=self.local_model, messages=[ + {'role': 'system', 'content': sys_prompt}, + {'role': 'user', 'content': f"CURRENT SUMMARY:\n{self.session_summary}\n\nNEW INTERACTIONS:\n{buffer_text}"} + ]) + self.session_summary = response['message']['content'].strip() + self.session_summary = re.sub(r'<think>.*?</think>', '', self.session_summary, flags=re.DOTALL).strip() + except Exception as e: + self.log(f"[dim red]Background session compression failed: {e}[/dim red]") + + def finalize_session(self): + self.log("[bold yellow]Extracting long-term memories from session...[/bold yellow]") + + # 1. Sweep up the leftovers: Compress the final un-summarized interactions synchronously + if self.interaction_buffer: + self.log("[dim italic]Compressing final interactions into session summary...[/dim italic]") + self._compress_session(self.interaction_buffer) + self.interaction_buffer = [] # Clear it out + + # Now self.session_summary contains the complete, 100% up-to-date story of the current session. + + # --- STEP A: Cerebras "Teacher" Fact Extraction --- + self.log("[dim italic]Cloud model filtering permanent facts from session summary...[/dim italic]") + cloud_sys_prompt = "You are a neutral fact-extractor. Read the provided session summary and extract any NEW long-term facts, topics discussed, and explicit user preferences. Discard all temporary conversational context. Be highly objective, dense, and concise." + try: + cloud_msgs = [ + {"role": "system", "content": cloud_sys_prompt}, + {"role": "user", "content": f"SESSION SUMMARY:\n{self.session_summary}"} + ] + cloud_response = self.cloud_provider.chat_completion(messages=cloud_msgs, stream=False) + clean_summary = cloud_response.choices[0].message.content.strip() + except Exception as e: + self.log(f"[dim red]Cloud extraction failed ({e}). Falling back to raw summary.[/dim red]") + clean_summary = self.session_summary + + # --- STEP B: Local "Student" Formatting --- + self.log("[dim italic]Local model formatting permanent facts...[/dim italic]") + sys_prompt = """You are a strict data extraction pipeline. Your ONLY job is to take the extracted facts enclosed in the ``` block and format them perfectly for long-term storage. + + RULES: + 1. NEVER write conversational text, greetings, headers, or explanations. + 2. ONLY output a raw, bulleted list of concise facts (e.g., "- User uses Emacs org-mode"). + 3. Focus ONLY on facts about the USER (their preferences, setup, identity, projects). + 4. If there are NO new permanent facts to save, output EXACTLY and ONLY the word: NONE. + """ + + # Enclose the clean summary in markdown blocks to isolate it from the prompt + user_prompt = f"Format these facts ONLY:\n\n```\n{clean_summary}\n```" + + try: + response = ollama.chat(model=self.local_model, messages=[ + {'role': 'system', 'content': sys_prompt}, + {'role': 'user', 'content': user_prompt} + ]) + + new_facts = response['message']['content'].strip() + new_facts = re.sub(r'<think>.*?</think>', '', new_facts, flags=re.DOTALL).strip() + + if new_facts.upper() != "NONE" and new_facts: + with open(self.memory_file, "a") as f: + f.write(f"\n{new_facts}") + self.persistent_memory += f"\n{new_facts}" + self.log("[bold green]New facts appended to long-term memory log![/bold green]") + self.log("[dim]Note: Run /memory rebuild to index these new facts for next time.[/dim]") + else: + self.log("[dim]No new long-term facts detected. Skipping memory append.[/dim]") + except Exception as e: + self.log(f"[bold red]Failed to save long-term memory: {e}[/bold red]") diff --git a/cerebral/tools.py b/cerebral/tools.py new file mode 100644 index 0000000..304edaa --- /dev/null +++ b/cerebral/tools.py @@ -0,0 +1,194 @@ +import os +import subprocess +import ollama +from typing import List, Dict, Callable +from abc import ABC, abstractmethod +from ddgs import DDGS +from langchain_ollama import OllamaEmbeddings +from langchain_community.vectorstores import FAISS +from langchain_core.documents import Document +from langchain_text_splitters import RecursiveCharacterTextSplitter + +class PKMManager: + def __init__(self, pkm_dir: str, index_path: str, hash_file: str, embed_model_name: str, log: Callable[[str], None] = print): + self.pkm_dir = pkm_dir + self.index_path = index_path + self.hash_file = hash_file + self.log = log + + self.log(f"[dim italic]Waking up Ollama embeddings ({embed_model_name})...[/dim italic]") + self.embeddings = OllamaEmbeddings(model=embed_model_name) + self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) + self.vectorstore = self._load_or_build() + + def _get_main_commit_hash(self) -> str: + try: + result = subprocess.run( + ["git", "rev-parse", "main"], + cwd=self.pkm_dir, capture_output=True, text=True, check=True + ) + return result.stdout.strip() + except subprocess.CalledProcessError: + return "unknown" + + def _load_or_build(self): + self.log("[dim]Checking Git HEAD hash for PKM changes...[/dim]") + current_hash = self._get_main_commit_hash() + + if os.path.exists(self.index_path) and os.path.exists(self.hash_file): + with open(self.hash_file, "r") as f: + if f.read().strip() == current_hash: + self.log(f"[green]Git hash unchanged ({current_hash[:7]}). Loading cached PKM index...[/green]") + return FAISS.load_local(self.index_path, self.embeddings, allow_dangerous_deserialization=True) + + self.log(f"[bold yellow]New commits detected ({current_hash[:7]}). Rebuilding PKM index...[/bold yellow]") + raw_documents = [] + + self.log(f"[dim]Scanning {self.pkm_dir} for .org files...[/dim]") + for root, dirs, files in os.walk(self.pkm_dir): + if '.git' in dirs: dirs.remove('.git') + if 'nix' in dirs: dirs.remove('nix') + for file in files: + if file.endswith('.org'): + filepath = os.path.join(root, file) + try: + with open(filepath, 'r', encoding='utf-8') as f: + raw_documents.append(Document(page_content=f.read(), metadata={"source": filepath})) + except Exception: + pass + + if not raw_documents: + self.log("[red]No .org files found in PKM directory.[/red]") + return None + + self.log(f"[dim]Chunking {len(raw_documents)} documents...[/dim]") + chunks = self.text_splitter.split_documents(raw_documents) + + self.log(f"[bold cyan]Embedding {len(chunks)} chunks via Ollama (this might take a minute)...[/bold cyan]") + vectorstore = FAISS.from_documents(chunks, self.embeddings) + vectorstore.save_local(self.index_path) + + with open(self.hash_file, "w") as f: + f.write(current_hash) + + self.log("[bold green]PKM Index successfully rebuilt and saved![/bold green]") + return vectorstore + + def search(self, query: str) -> str: + if not self.vectorstore: + return "PKM is empty." + docs = self.vectorstore.similarity_search(query, k=10) + return "PKM Search Results:\n" + "\n\n".join([f"From {d.metadata['source']}:\n{d.page_content}" for d in docs]) + + +class VisionProcessor: + def __init__(self, local_model: str, log: Callable[[str], None] = print): + self.local_model = local_model + self.log = log + self.log("[dim italic]Vision Processor online...[/dim italic]") + + def process(self, image_path: str, user_prompt: str) -> str: + try: + with open(image_path, 'rb') as img_file: + img_bytes = img_file.read() + response = ollama.chat(model=self.local_model, messages=[{ + 'role': 'user', + 'content': f"Describe this image in detail to help another AI answer this prompt: {user_prompt}", + 'images': [img_bytes] + }]) + return response['message']['content'] + except Exception as e: + return f"[Image analysis failed: {e}]" + + +class BaseSearchProvider(ABC): + @abstractmethod + def search(self, query: str, max_results: int = 10) -> List[Dict[str, str]]: + pass + +class GoogleSearchProvider(BaseSearchProvider): + def search(self, query: str, max_results: int = 10) -> List[Dict[str, str]]: + from googlesearch import search + + results = [] + for r in search(query, num_results=max_results, advanced=True): + results.append({ + 'title': getattr(r, 'title', 'No Title'), + 'href': getattr(r, 'url', 'No URL'), + 'body': getattr(r, 'description', 'No Description') + }) + + if not results: + raise Exception("Google returned zero results.") + + return results + +class DDGSSearchProvider(BaseSearchProvider): + def search(self, query: str, max_results: int = 10) -> List[Dict[str, str]]: + results = DDGS().text(query, max_results=max_results) + + if not results: + raise Exception("DuckDuckGo returned zero results.") + + formatted_results = [] + for r in results: + formatted_results.append({ + 'title': r.get('title', 'No Title'), + 'href': r.get('href', 'No URL'), + 'body': r.get('body', 'No Description') + }) + return formatted_results + +class WebSearcher: + def __init__(self, log: Callable[[str], None] = print): + self.log = log + self.providers: List[BaseSearchProvider] = [ + GoogleSearchProvider(), + DDGSSearchProvider() + ] + + def search(self, query: str) -> str: + for provider in self.providers: + provider_name = provider.__class__.__name__ + try: + self.log(f"[dim italic]Trying {provider_name}...[/dim italic]") + results = provider.search(query, max_results=10) + + context = "Web Search Results:\n" + for r in results: + context += f"- Title: {r['title']}\n URL: {r['href']}\n Snippet: {r['body']}\n\n" + return context + + except Exception as e: + self.log(f"[dim yellow]{provider_name} failed ({e}). Falling back...[/dim yellow]") + continue + + return "Web search failed: All search providers were exhausted or rate-limited." + +class AgendaManager: + def __init__(self, agenda_file: str, log: Callable[[str], None] = print): + self.agenda_file = agenda_file + self.log = log + + def append_todo(self, task: str, scheduled: str = "", description: str = "") -> str: + """Appends a TODO to the agenda file. Write-only for privacy.""" + try: + if not os.path.exists(self.agenda_file): + with open(self.agenda_file, "w") as f: + f.write("#+TITLE: Private Agenda\n\n") + + entry = f"\n* TODO {task}\n" + if scheduled: + entry += f" SCHEDULED: <{scheduled}>\n" + if description: + # Add a clean, indented description below the scheduled tag + entry += f" {description}\n" + + with open(self.agenda_file, "a") as f: + f.write(entry) + + self.log(f"[bold green]Successfully appended to agenda: {task}[/bold green]") + return f"Successfully added to agenda: {task}" + except Exception as e: + self.log(f"[bold red]Failed to append to agenda: {e}[/bold red]") + return f"Failed to add to agenda: {e}" @@ -1,651 +1,14 @@ import os import sys -import json import subprocess -import hashlib -import re -import threading -import datetime # <-- NEW -from abc import ABC, abstractmethod -from typing import List, Dict, Any, Generator, Optional, Callable -from dotenv import load_dotenv - -import ollama -from cerebras.cloud.sdk import Cerebras -from ddgs import DDGS -from langchain_ollama import OllamaEmbeddings -from langchain_community.vectorstores import FAISS -from langchain_core.documents import Document -from langchain_text_splitters import RecursiveCharacterTextSplitter - from rich.console import Console from rich.prompt import Prompt from rich.panel import Panel -load_dotenv() - -# ========================================== -# 1. Configuration & Constants -# ========================================== -LOCAL_LLM = "qwen3-vl:8b" -LOCAL_EMBED_MODEL = "nomic-embed-text-v2-moe:latest" -PKM_DIR = os.path.expanduser("~/monorepo") - -XDG_CONFIG_HOME = os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")) -APP_CONFIG_DIR = os.path.join(XDG_CONFIG_HOME, "cerebral") -APP_CACHE_DIR = os.path.expanduser("~/.cache/cerebral") -ORG_OUTPUT_DIR = os.path.expanduser("~/org/cerebral") - -os.makedirs(APP_CONFIG_DIR, exist_ok=True) -os.makedirs(APP_CACHE_DIR, exist_ok=True) -os.makedirs(ORG_OUTPUT_DIR, exist_ok=True) - -MEMORY_FILE = os.path.join(APP_CACHE_DIR, "memory_summary.txt") -MEMORY_INDEX_PATH = os.path.join(APP_CACHE_DIR, "memory_index") -FAISS_INDEX_PATH = os.path.join(APP_CONFIG_DIR, "pkm_index") -HASH_TRACKER_FILE = os.path.join(APP_CONFIG_DIR, "latest_commit.txt") - -# ========================================== -# 2. Abstract LLM Provider -# ========================================== -class BaseLLMProvider(ABC): - """Abstract interface for LLM providers to ensure easy swapping.""" - @abstractmethod - # <-- UPDATED: Added tool_choice parameter - def chat_completion(self, messages: List[Dict], tools: List[Dict] = None, stream: bool = False, tool_choice: str = "auto") -> Any: - pass - -class CerebrasProvider(BaseLLMProvider): - def __init__(self, model: str = "qwen-3-235b-a22b-instruct-2507"): - api_key = os.environ.get("CEREBRAS_API_KEY") - if not api_key: - raise ValueError("CEREBRAS_API_KEY environment variable is required.") - self.client = Cerebras(api_key=api_key) - self.model = model - - def chat_completion(self, messages: List[Dict], tools: List[Dict] = None, stream: bool = False, tool_choice: str = "auto"): - kwargs = { - "messages": messages, - "model": self.model, - "stream": stream, - } - if tools: - kwargs["tools"] = tools - kwargs["tool_choice"] = tool_choice # <-- UPDATED - - return self.client.chat.completions.create(**kwargs) - -# ========================================== -# 3. Core Modules -# ========================================== -class MemoryManager: - def __init__(self, memory_file: str, index_path: str, local_model: str, embed_model_name: str, log: Callable[[str], None] = print): - self.memory_file = memory_file - self.index_path = index_path - self.local_model = local_model - self.log = log - - self.session_summary = "Session just started. No prior context." - self.interaction_buffer = [] - self.COMPRESSION_THRESHOLD = 4 - - self.embeddings = OllamaEmbeddings(model=embed_model_name) - self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50) - - self.log("[dim italic]Loading persistent memory...[/dim italic]") - if os.path.exists(self.memory_file): - with open(self.memory_file, "r") as f: - self.persistent_memory = f.read().strip() - else: - self.persistent_memory = "No known user facts or long-term preferences." - - if os.path.exists(self.index_path): - self.vectorstore = FAISS.load_local(self.index_path, self.embeddings, allow_dangerous_deserialization=True) - else: - self.log("[bold yellow]No memory index found. Building initial database...[/bold yellow]") - self.rebuild_index() - - def get_line_count(self) -> int: - if not os.path.exists(self.memory_file): - return 0 - with open(self.memory_file, "r") as f: - return sum(1 for _ in f) - - def rebuild_index(self): - self.log("[dim italic]Reserializing memory log into vector database...[/dim italic]") - text = self.persistent_memory if self.persistent_memory else "No known user facts or long-term preferences." - - chunks = self.text_splitter.split_text(text) - docs = [Document(page_content=c) for c in chunks] - self.vectorstore = FAISS.from_documents(docs, self.embeddings) - self.vectorstore.save_local(self.index_path) - self.log("[bold green]Memory database manually rebuilt and saved![/bold green]") - - def compress_persistent_memory(self): - self.log("[bold yellow]Compressing persistent memory (removing duplicates and irrelevant data)...[/bold yellow]") - if not os.path.exists(self.memory_file): - self.log("[dim]Memory file is empty. Nothing to compress.[/dim]") - return - - # STRICT PROMPT FOR COMPRESSION - sys_prompt = """You are a strictly robotic data deduplication script. Your ONLY job is to compress the provided memory log. - RULES: - 1. Remove duplicate facts. - 2. Remove conversational text, essays, or philosophical analysis. - 3. Output ONLY a clean, simple bulleted list of facts. - 4. NEVER use headers, bold text, or introductory/closing remarks.""" - - try: - response = ollama.chat(model=self.local_model, messages=[ - {'role': 'system', 'content': sys_prompt}, - {'role': 'user', 'content': f"MEMORY LOG TO COMPRESS:\n{self.persistent_memory}"} - ]) - compressed_memory = response['message']['content'].strip() - compressed_memory = re.sub(r'<think>.*?</think>', '', compressed_memory, flags=re.DOTALL).strip() - - with open(self.memory_file, "w") as f: - f.write(compressed_memory) - self.persistent_memory = compressed_memory - self.rebuild_index() - self.log("[bold green]Persistent memory successfully compressed and re-indexed![/bold green]") - except Exception as e: - self.log(f"[bold red]Memory compression failed: {e}[/bold red]") - - def search(self, query: str) -> str: - if not getattr(self, 'vectorstore', None): - return "No long-term memories available." - docs = self.vectorstore.similarity_search(query, k=3) - return "\n".join([f"- {d.page_content}" for d in docs]) - - def add_interaction(self, user_input: str, bot_response: str): - self.interaction_buffer.append({"user": user_input, "agent": bot_response}) - if len(self.interaction_buffer) >= self.COMPRESSION_THRESHOLD: - buffer_to_compress = list(self.interaction_buffer) - self.interaction_buffer = [] - threading.Thread(target=self._compress_session, args=(buffer_to_compress,), daemon=True).start() - - def _compress_session(self, buffer: List[Dict]): - buffer_text = "\n".join([f"User: {i['user']}\nAgent: {i['agent']}" for i in buffer]) - - # STRICT PROMPT FOR SESSION COMPRESSION - sys_prompt = """You are a strict summarization script. Merge the recent interactions into the current session summary. - RULES: - 1. Keep it brief and objective. - 2. DO NOT write essays or analyze the user's intent. - 3. Output ONLY the raw text of the updated summary. No conversational padding.""" - - try: - response = ollama.chat(model=self.local_model, messages=[ - {'role': 'system', 'content': sys_prompt}, - {'role': 'user', 'content': f"CURRENT SUMMARY:\n{self.session_summary}\n\nNEW INTERACTIONS:\n{buffer_text}"} - ]) - self.session_summary = response['message']['content'].strip() - self.session_summary = re.sub(r'<think>.*?</think>', '', self.session_summary, flags=re.DOTALL).strip() - except Exception as e: - self.log(f"[dim red]Background session compression failed: {e}[/dim red]") - - def finalize_session(self): - self.log("[bold yellow]Extracting long-term memories from session...[/bold yellow]") - final_context = self.session_summary - if self.interaction_buffer: - final_context += "\n" + "\n".join([f"User: {i['user']}\nAgent: {i['agent']}" for i in self.interaction_buffer]) - - # STRICT PROMPT FOR EXTRACTION - sys_prompt = """You are a strict data extraction pipeline. Your ONLY job is to extract permanent, long-term facts about the user from the provided session text. - - RULES: - 1. NEVER write conversational text, greetings, headers, or explanations. - 2. NEVER write essays, evaluate, or analyze the meaning of the facts. - 3. ONLY output a raw, bulleted list of concise facts (e.g., "- User uses Emacs org-mode"). - 4. If there are NO new permanent facts to save, output EXACTLY and ONLY the word: NONE. - """ - - try: - response = ollama.chat(model=self.local_model, messages=[ - {'role': 'system', 'content': sys_prompt}, - {'role': 'user', 'content': f"SESSION TEXT TO EXTRACT FROM:\n{final_context}"} - ]) - - new_facts = response['message']['content'].strip() - new_facts = re.sub(r'<think>.*?</think>', '', new_facts, flags=re.DOTALL).strip() - - if new_facts.upper() != "NONE" and new_facts: - # Failsafe: If the model hallucinates an essay anyway, block it from saving. - if len(new_facts.split('\n')) > 15 or "###" in new_facts: - self.log("[dim red]Model hallucinated an essay instead of facts. Discarding to protect memory database.[/dim red]") - return - - with open(self.memory_file, "a") as f: - f.write(f"\n{new_facts}") - self.persistent_memory += f"\n{new_facts}" - self.log("[bold green]New facts appended to long-term memory log![/bold green]") - self.log("[dim]Note: Run /memory rebuild to index these new facts for next time.[/dim]") - else: - self.log("[dim]No new long-term facts detected. Skipping memory append.[/dim]") - except Exception as e: - self.log(f"[bold red]Failed to save long-term memory: {e}[/bold red]") - -class PKMManager: - def __init__(self, pkm_dir: str, index_path: str, hash_file: str, embed_model_name: str, log: Callable[[str], None] = print): - self.pkm_dir = pkm_dir - self.index_path = index_path - self.hash_file = hash_file - self.log = log - - self.log(f"[dim italic]Waking up Ollama embeddings ({embed_model_name})...[/dim italic]") - self.embeddings = OllamaEmbeddings(model=embed_model_name) - self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) - self.vectorstore = self._load_or_build() - - def _get_main_commit_hash(self) -> str: - try: - result = subprocess.run( - ["git", "rev-parse", "main"], - cwd=self.pkm_dir, capture_output=True, text=True, check=True - ) - return result.stdout.strip() - except subprocess.CalledProcessError: - return "unknown" - - def _load_or_build(self): - self.log("[dim]Checking Git HEAD hash for PKM changes...[/dim]") - current_hash = self._get_main_commit_hash() - - if os.path.exists(self.index_path) and os.path.exists(self.hash_file): - with open(self.hash_file, "r") as f: - if f.read().strip() == current_hash: - self.log(f"[green]Git hash unchanged ({current_hash[:7]}). Loading cached PKM index...[/green]") - return FAISS.load_local(self.index_path, self.embeddings, allow_dangerous_deserialization=True) - - self.log(f"[bold yellow]New commits detected ({current_hash[:7]}). Rebuilding PKM index...[/bold yellow]") - raw_documents = [] - - self.log(f"[dim]Scanning {self.pkm_dir} for .org files...[/dim]") - for root, dirs, files in os.walk(self.pkm_dir): - if '.git' in dirs: dirs.remove('.git') - if 'nix' in dirs: dirs.remove('nix') - for file in files: - if file.endswith('.org'): - filepath = os.path.join(root, file) - try: - with open(filepath, 'r', encoding='utf-8') as f: - raw_documents.append(Document(page_content=f.read(), metadata={"source": filepath})) - except Exception: - pass - - if not raw_documents: - self.log("[red]No .org files found in PKM directory.[/red]") - return None - - self.log(f"[dim]Chunking {len(raw_documents)} documents...[/dim]") - chunks = self.text_splitter.split_documents(raw_documents) - - self.log(f"[bold cyan]Embedding {len(chunks)} chunks via Ollama (this might take a minute)...[/bold cyan]") - vectorstore = FAISS.from_documents(chunks, self.embeddings) - vectorstore.save_local(self.index_path) - - with open(self.hash_file, "w") as f: - f.write(current_hash) - - self.log("[bold green]PKM Index successfully rebuilt and saved![/bold green]") - return vectorstore - - def search(self, query: str) -> str: - if not self.vectorstore: - return "PKM is empty." - docs = self.vectorstore.similarity_search(query, k=10) - return "PKM Search Results:\n" + "\n\n".join([f"From {d.metadata['source']}:\n{d.page_content}" for d in docs]) - - -class VisionProcessor: - def __init__(self, local_model: str, log: Callable[[str], None] = print): - self.local_model = local_model - self.log = log - self.log("[dim italic]Vision Processor online...[/dim italic]") - - def process(self, image_path: str, user_prompt: str) -> str: - try: - with open(image_path, 'rb') as img_file: - img_bytes = img_file.read() - response = ollama.chat(model=self.local_model, messages=[{ - 'role': 'user', - 'content': f"Describe this image in detail to help another AI answer this prompt: {user_prompt}", - 'images': [img_bytes] - }]) - return response['message']['content'] - except Exception as e: - return f"[Image analysis failed: {e}]" - - -# ========================================== -# Web Search Providers -# ========================================== -class BaseSearchProvider(ABC): - """Abstract interface for web search engines to ensure easy swapping and fallbacks.""" - @abstractmethod - def search(self, query: str, max_results: int = 10) -> List[Dict[str, str]]: - pass - -class GoogleSearchProvider(BaseSearchProvider): - def search(self, query: str, max_results: int = 10) -> List[Dict[str, str]]: - # Imported locally so it doesn't crash the app if the package is missing - from googlesearch import search - - results = [] - # advanced=True forces it to return objects with title, url, and description - for r in search(query, num_results=max_results, advanced=True): - results.append({ - 'title': getattr(r, 'title', 'No Title'), - 'href': getattr(r, 'url', 'No URL'), - 'body': getattr(r, 'description', 'No Description') - }) - - if not results: - raise Exception("Google returned zero results.") - - return results - -class DDGSSearchProvider(BaseSearchProvider): - def search(self, query: str, max_results: int = 10) -> List[Dict[str, str]]: - results = DDGS().text(query, max_results=max_results) - - if not results: - raise Exception("DuckDuckGo returned zero results.") - - formatted_results = [] - for r in results: - formatted_results.append({ - 'title': r.get('title', 'No Title'), - 'href': r.get('href', 'No URL'), - 'body': r.get('body', 'No Description') - }) - return formatted_results - -class WebSearcher: - def __init__(self, log: Callable[[str], None] = print): - self.log = log - # The order of this list dictates the fallback priority - self.providers: List[BaseSearchProvider] = [ - GoogleSearchProvider(), - DDGSSearchProvider() - ] - - def search(self, query: str) -> str: - for provider in self.providers: - provider_name = provider.__class__.__name__ - try: - self.log(f"[dim italic]Trying {provider_name}...[/dim italic]") - results = provider.search(query, max_results=10) - - context = "Web Search Results:\n" - for r in results: - context += f"- Title: {r['title']}\n URL: {r['href']}\n Snippet: {r['body']}\n\n" - return context - - except Exception as e: - # Catch 429 Rate Limits, connection errors, or empty results and seamlessly fall back - self.log(f"[dim yellow]{provider_name} failed ({e}). Falling back...[/dim yellow]") - continue - - return "Web search failed: All search providers were exhausted or rate-limited." - -# ========================================== -# 4. The Orchestrator (Agnostic Agent) -# ========================================== -class CerebralAgent: - def __init__(self, provider: BaseLLMProvider, log: Callable[[str], None] = print): - self.provider = provider - self.log = log - - self.log("[bold magenta]Initializing Cerebral Agent Modules...[/bold magenta]") - self.memory = MemoryManager(MEMORY_FILE, MEMORY_INDEX_PATH, LOCAL_LLM, LOCAL_EMBED_MODEL, self.log) - self.pkm = PKMManager(PKM_DIR, FAISS_INDEX_PATH, HASH_TRACKER_FILE, LOCAL_EMBED_MODEL, self.log) - self.vision = VisionProcessor(LOCAL_LLM, self.log) - self.web = WebSearcher(self.log) - - def generate_session_filename(self, first_prompt: str, first_response: str) -> str: - self.log("[dim italic]Generating descriptive filename based on prompt and response...[/dim italic]") - - hash_input = (first_prompt + first_response).encode('utf-8') - combined_hash = hashlib.sha256(hash_input).hexdigest()[:6] - - sys_prompt = "You are a file naming utility. Read the user's prompt and generate a short, descriptive filename base using ONLY lowercase letters and hyphens. Do NOT add an extension. ONLY output the base filename, absolutely no other text. Example: learning-python-basics" - - try: - response = ollama.chat(model=LOCAL_LLM, messages=[ - {'role': 'system', 'content': sys_prompt}, - {'role': 'user', 'content': first_prompt} - ]) - - raw_content = response['message']['content'].strip() - raw_content = re.sub(r'<think>.*?</think>', '', raw_content, flags=re.DOTALL).strip() - - lines = [line.strip() for line in raw_content.split('\n') if line.strip()] - raw_filename = lines[-1].lower().replace(' ', '-') if lines else "cerebral-session" - clean_base = re.sub(r'[^a-z0-9\-]', '', raw_filename).strip('-') - clean_base = clean_base[:50].strip('-') - - if not clean_base: - clean_base = "cerebral-session" - - final_filename = f"{clean_base}-{combined_hash}.org" - return final_filename - - except Exception as e: - self.log(f"[dim red]Filename generation failed: {e}. Defaulting.[/dim red]") - return f"cerebral-session-{combined_hash}.org" - - def _get_tools(self) -> List[Dict]: - return [ - { - "type": "function", - "function": { - "name": "search_pkm", - "description": "Search the user's personal knowledge base (PKM) for notes, code, or org files.", - "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]} - } - }, - { - "type": "function", - "function": { - "name": "search_web", - "description": "Search the live internet for current events, external documentation, or facts outside your PKM.", - "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]} - } - } - ] - - def chat_stream(self, prompt: str, image_path: Optional[str] = None) -> Generator[str, None, str]: - """Core interaction loop. Yields text chunks. Returns full text when done.""" - - recent_history = "" - if self.memory.interaction_buffer: - recent_history = "\nRECENT UNCOMPRESSED TURNS:\n" + "\n".join( - [f"User: {i['user']}\nAgent: {i['agent']}" for i in self.memory.interaction_buffer] - ) - - vision_context = "" - if image_path: - self.log("[dim italic]Analyzing image context...[/dim italic]") - vision_summary = self.vision.process(image_path, prompt) - vision_context = f"\n[USER ATTACHED AN IMAGE. Local Vision Summary: {vision_summary}]\n" - - self.log("[dim italic]Querying long-term memory (Ollama Embeddings)...[/dim italic]") - relevant_memories = self.memory.search(prompt) - - current_time = datetime.datetime.now().strftime("%A, %B %d, %Y at %I:%M %p") - - system_prompt = f"""You are a highly capable AI assistant. - - CRITICAL OUTPUT FORMATTING: - You MUST output your responses EXCLUSIVELY in Emacs org-mode format. Use org-mode headings, lists, and LaTeX fragments for math. - - FORMATTING RULES: - 1. NEVER use double asterisks (`**`) for bolding. You MUST use SINGLE asterisks for bold emphasis (e.g., *this is bold*). Double asterisks will break the parser. - 2. Cite your sources inline using proper org-mode link syntax. For web searches, use [[url][Description]]. For PKM files, use [[file:/path/to/file.org][Filename]]. - 3. At the very end of your response, you MUST append a Level 1 heading `* Sources` and neatly list all the search results and PKM documents you referenced using proper org-mode syntax. - - CURRENT TIME AND DATE: - {current_time} - - RESPONSE STYLE GUIDELINES: - - Provide EXTREMELY detailed, exhaustive, and comprehensive answers. - - Write in long-form prose. Do not be brief; expand deeply on concepts. - - Use multiple paragraphs, deep conceptual explanations, and thorough analysis. - - RELEVANT LONG-TERM MEMORIES: - {relevant_memories} - - COMPRESSED SESSION CONTEXT: {self.memory.session_summary} - {recent_history} - """ - - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt + vision_context} - ] - - self.log("[dim italic]Analyzing intent & tool requirements (Cerebras)...[/dim italic]") - - # --- NEW: Self-Healing Tool Call Loop --- - MAX_RETRIES = 3 - valid_tool_calls = False - response_message = None - allowed_tool_names = [t["function"]["name"] for t in self._get_tools()] - - for attempt in range(MAX_RETRIES): - pre_flight = self.provider.chat_completion(messages=messages, tools=self._get_tools(), stream=False) - response_message = pre_flight.choices[0].message - - # Scenario A: Hallucinated Markdown Tool Call - if not response_message.tool_calls and response_message.content and "**name**:" in response_message.content: - self.log(f"[dim yellow]Model hallucinated Markdown tool call. Retrying ({attempt+1}/{MAX_RETRIES})...[/dim yellow]") - error_msg = f"ERROR: You attempted to call a tool using Markdown text. You MUST use the native JSON tool calling API. Allowed tools: {allowed_tool_names}" - messages.append({"role": "assistant", "content": response_message.content}) - messages.append({"role": "user", "content": error_msg}) - continue - - # Scenario B: Legitimate text response (No tools needed) - if not response_message.tool_calls: - valid_tool_calls = True - break - - # Scenario C: Native API Tool Calls (Needs Validation) - has_errors = False - error_feedbacks = [] - - for tool_call in response_message.tool_calls: - func_name = tool_call.function.name - call_error = None - - if func_name not in allowed_tool_names: - has_errors = True - call_error = f"Tool '{func_name}' does not exist. Allowed tools: {allowed_tool_names}" - else: - try: - json.loads(tool_call.function.arguments) - except json.JSONDecodeError: - has_errors = True - call_error = f"Arguments for '{func_name}' are not valid JSON: {tool_call.function.arguments}" - - error_feedbacks.append(call_error) - - if has_errors: - self.log(f"[dim yellow]Malformed tool call detected. Retrying ({attempt+1}/{MAX_RETRIES})...[/dim yellow]") - - # Append the bad tool call to history so it learns what it did wrong - assistant_msg = { - "role": "assistant", - "content": response_message.content or "", - "tool_calls": [ - { - "id": t.id, - "type": "function", - "function": {"name": t.function.name, "arguments": t.function.arguments} - } for t in response_message.tool_calls - ] - } - messages.append(assistant_msg) - - # Append the specific errors as API tool responses - for i, tool_call in enumerate(response_message.tool_calls): - err = error_feedbacks[i] - msg_content = f"ERROR: {err}" if err else "Error: Another tool in this batch failed. Please fix the batch and retry." - messages.append({"role": "tool", "tool_call_id": tool_call.id, "content": msg_content}) - continue - - # Scenario D: Valid Tool Calls - valid_tool_calls = True - break - - # Failsafe: If it fails 3 times, wipe the tool calls to force a graceful text degradation - if not valid_tool_calls: - self.log("[bold red]Failed to generate valid tool calls. Proceeding without tools.[/bold red]") - response_message.tool_calls = None - # ---------------------------------------- - - if not response_message.tool_calls: - self.log("[dim italic]No tools needed. Outputting response...[/dim italic]") - content = response_message.content or "" - yield content - self.memory.add_interaction(prompt, content) - return content - - # --- Execute Validated Tools --- - assistant_msg = { - "role": "assistant", - "content": response_message.content or "", - "tool_calls": [ - { - "id": t.id, - "type": "function", - "function": {"name": t.function.name, "arguments": t.function.arguments} - } for t in response_message.tool_calls - ] - } - messages.append(assistant_msg) - - for tool_call in response_message.tool_calls: - func_name = tool_call.function.name - args = json.loads(tool_call.function.arguments) # Guaranteed to be safe now - - if func_name == "search_pkm": - q = args.get("query", prompt) - self.log(f"[cyan]🧠 Tool Call: Searching PKM for '{q}'...[/cyan]") - yield f"\n*(Agent Note: Searched PKM for `{q}`)*\n\n" - result = self.pkm.search(q) - elif func_name == "search_web": - q = args.get("query", prompt) - self.log(f"[cyan]🌐 Tool Call: Searching Web for '{q}'...[/cyan]") - yield f"\n*(Agent Note: Searched Web for `{q}`)*\n\n" - result = self.web.search(q) - - messages.append({"role": "tool", "tool_call_id": tool_call.id, "content": result}) - - messages.append({ - "role": "system", - "content": "Tool results received. Now provide your final, comprehensive answer in strict org-mode. REMEMBER: Use *single asterisks* for bold, NEVER double asterisks." - }) +from cerebral.config import ORG_OUTPUT_DIR +from cerebral.llm import CerebrasProvider +from cerebral.agent import CerebralAgent - self.log("[dim italic]Streaming final response...[/dim italic]") - stream = self.provider.chat_completion(messages=messages, tools=self._get_tools(), stream=True, tool_choice="none") - - full_response = "" - for chunk in stream: - content = chunk.choices[0].delta.content or "" - full_response += content - yield content - - self.memory.add_interaction(prompt, full_response) - return full_response - - def shutdown(self): - self.memory.finalize_session() - - -# ========================================== -# 5. The CLI Presentation Layer -# ========================================== class CLIApp: def __init__(self, agent: CerebralAgent, console: Console): self.agent = agent @@ -653,35 +16,36 @@ class CLIApp: self.current_session_file = None def run(self): - self.console.print(Panel.fit("🤖 [bold blue]Modular Cerebral Agent[/bold blue] initialized.\n- Type [bold]/image /path/to/img.png <prompt>[/bold] to attach images.\n- Type [bold]/exit[/bold] to quit.", border_style="blue")) + self.console.print(Panel.fit("🤖 [bold blue]Modular Cerebral Agent[/bold blue] initialized.\n- Type [bold]/image /path/to/img.png <prompt>[/bold] to attach images.\n- Type [bold]/exit[/bold] to quit.\n- Type [bold]/memory[/bold] for DB tools.", border_style="blue")) while True: try: user_input = Prompt.ask("\n[bold magenta]You[/bold magenta]") + clean_input = user_input.strip().lower() - if user_input.lower() == '/memory count': + if clean_input == '/memory': + help_text = ( + "[bold cyan]/memory count[/bold cyan] : Print the number of lines in persistent memory.\n" + "[bold cyan]/memory rebuild[/bold cyan] : Manually reserialize the FAISS database from the log.\n" + "[bold cyan]/memory compress[/bold cyan] : Use the local LLM to scrub duplicates and compress the log." + ) + self.console.print(Panel.fit(help_text, title="🧠 Memory Commands", border_style="cyan")) + continue + + if clean_input == '/memory count': count = self.agent.memory.get_line_count() self.console.print(f"[bold cyan]Persistent Memory Lines:[/bold cyan] {count}") continue - if user_input.lower() == '/memory rebuild': + if clean_input == '/memory rebuild': self.agent.memory.rebuild_index() continue - if user_input.lower() == '/memory compress': + if clean_input == '/memory compress': self.agent.memory.compress_persistent_memory() continue - if clean_input == '/memory': - help_text = ( - "[bold cyan]/memory count[/bold cyan] : Print the number of lines in persistent memory.\n" - "[bold cyan]/memory rebuild[/bold cyan] : Manually reserialize the FAISS database from the log.\n" - "[bold cyan]/memory compress[/bold cyan] : Use the local LLM to scrub duplicates and compress the log." - ) - self.console.print(Panel.fit(help_text, title="🧠 Memory Commands", border_style="cyan")) - continue - - if user_input.lower() in ['/exit', '/quit']: + if clean_input in ['/exit', '/quit']: self.console.print("\n[dim italic]Initiating shutdown sequence...[/dim italic]") self.agent.shutdown() self.console.print("[bold red]Exiting...[/bold red]") @@ -745,9 +109,6 @@ class CLIApp: except Exception as e: self.console.print(f"[bold red]An error occurred: {e}[/bold red]") -# ========================================== -# 6. Entry Point -# ========================================== if __name__ == "__main__": console = Console() try: |
