diff options
Diffstat (limited to 'cerebral/agent.py')
| -rw-r--r-- | cerebral/agent.py | 336 |
1 files changed, 336 insertions, 0 deletions
diff --git a/cerebral/agent.py b/cerebral/agent.py new file mode 100644 index 0000000..20cb210 --- /dev/null +++ b/cerebral/agent.py @@ -0,0 +1,336 @@ +import json +import hashlib +import re +import datetime +import time +import ollama +from typing import List, Dict, Generator, Optional, Callable + +from .config import LOCAL_LLM, LOCAL_EMBED_MODEL, PKM_DIR, MEMORY_FILE, MEMORY_INDEX_PATH, FAISS_INDEX_PATH, HASH_TRACKER_FILE, AGENDA_FILE +from .llm import BaseLLMProvider +from .memory import MemoryManager +from .tools import PKMManager, VisionProcessor, WebSearcher, AgendaManager + +class CerebralAgent: + def __init__(self, provider: BaseLLMProvider, log: Callable[[str], None] = print): + self.provider = provider + self.log = log + + self.log("[bold magenta]Initializing Cerebral Agent Modules...[/bold magenta]") + + # --- FIX: Add cloud_provider=self.provider back to this line --- + self.memory = MemoryManager( + MEMORY_FILE, + MEMORY_INDEX_PATH, + LOCAL_LLM, + LOCAL_EMBED_MODEL, + self.log, + cloud_provider=self.provider + ) + # -------------------------------------------------------------- + + self.pkm = PKMManager(PKM_DIR, FAISS_INDEX_PATH, HASH_TRACKER_FILE, LOCAL_EMBED_MODEL, self.log) + self.vision = VisionProcessor(LOCAL_LLM, self.log) + self.web = WebSearcher(self.log) + self.agenda = AgendaManager(AGENDA_FILE, self.log) + + def generate_session_filename(self, first_prompt: str, first_response: str) -> str: + self.log("[dim italic]Generating descriptive filename based on prompt and response...[/dim italic]") + + hash_input = (first_prompt + first_response).encode('utf-8') + combined_hash = hashlib.sha256(hash_input).hexdigest()[:6] + + sys_prompt = "You are a file naming utility. Read the user's prompt and generate a short, descriptive filename base using ONLY lowercase letters and hyphens. Do NOT add an extension. ONLY output the base filename, absolutely no other text. Example: learning-python-basics" + + try: + response = ollama.chat(model=LOCAL_LLM, messages=[ + {'role': 'system', 'content': sys_prompt}, + {'role': 'user', 'content': first_prompt} + ]) + + raw_content = response['message']['content'].strip() + raw_content = re.sub(r'<think>.*?</think>', '', raw_content, flags=re.DOTALL).strip() + + lines = [line.strip() for line in raw_content.split('\n') if line.strip()] + raw_filename = lines[-1].lower().replace(' ', '-') if lines else "cerebral-session" + clean_base = re.sub(r'[^a-z0-9\-]', '', raw_filename).strip('-') + clean_base = clean_base[:50].strip('-') + + if not clean_base: + clean_base = "cerebral-session" + + final_filename = f"{clean_base}-{combined_hash}.org" + return final_filename + + except Exception as e: + self.log(f"[dim red]Filename generation failed: {e}. Defaulting.[/dim red]") + return f"cerebral-session-{combined_hash}.org" + + def _get_tools(self) -> List[Dict]: + return [ + { + "type": "function", + "function": { + "name": "search_pkm", + "description": "Search the user's personal knowledge base (PKM) for notes, code, or org files.", + "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]} + } + }, + { + "type": "function", + "function": { + "name": "search_web", + "description": "Search the live internet for current events, external documentation, or facts outside your PKM.", + "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]} + } + }, + { + "type": "function", + "function": { + "name": "append_agenda", + "description": "Schedule a new TODO item in the user's private agenda.org file. Use this when the user asks to remember to do something, schedule a task, or add an agenda item.", + "parameters": { + "type": "object", + "properties": { + "task": { + "type": "string", + "description": "The concise title of the task (e.g., 'Buy groceries' or 'Review neuralforge code')" + }, + "scheduled": { + "type": "string", + "description": "Optional. The exact date and time. If a time is requested or implied, you MUST format it strictly as 'YYYY-MM-DD HH:MM' (e.g., '2026-03-27 11:00'). If ONLY a day is requested with absolutely no time, use 'YYYY-MM-DD' (e.g., '2026-03-27'). Do not include the angle brackets (< >), just the raw date/time string." + }, + "description": { + "type": "string", + "description": "Optional. A brief, 1-2 sentence description or note about the task/event." + } + }, + "required": ["task"] + } + } + } + ] + + def chat_stream(self, prompt: str, image_path: Optional[str] = None) -> Generator[str, None, str]: + recent_history = "" + if self.memory.interaction_buffer: + recent_history = "\nRECENT UNCOMPRESSED TURNS:\n" + "\n".join( + [f"User: {i['user']}\nAgent: {i['agent']}" for i in self.memory.interaction_buffer] + ) + + vision_context = "" + if image_path: + self.log("[dim italic]Analyzing image context...[/dim italic]") + vision_summary = self.vision.process(image_path, prompt) + vision_context = f"\n[USER ATTACHED AN IMAGE. Local Vision Summary: {vision_summary}]\n" + + self.log("[dim italic]Querying long-term memory (Ollama Embeddings)...[/dim italic]") + relevant_memories = self.memory.search(prompt) + + current_time = datetime.datetime.now().strftime("%A, %B %d, %Y at %I:%M %p") + + system_prompt = f"""You are a highly capable AI assistant. + + CRITICAL OUTPUT FORMATTING: + You MUST output your responses EXCLUSIVELY in Emacs org-mode format. Use org-mode headings, lists, and LaTeX fragments for math. + + FORMATTING RULES: + 1. NEVER use double asterisks (`**`) for bolding. You MUST use SINGLE asterisks for bold emphasis (e.g., *this is bold*). Double asterisks will break the parser. + 2. Cite your sources inline using proper org-mode link syntax. For web searches, use [[url][Description]]. For PKM files, use [[file:/path/to/file.org][Filename]]. + 3. At the very end of your response, you MUST append a Level 1 heading `* Sources` and neatly list all the search results and PKM documents you referenced using proper org-mode syntax. + + CURRENT TIME AND DATE: + {current_time} + + RESPONSE STYLE GUIDELINES: + - Provide EXTREMELY detailed, exhaustive, and comprehensive answers. + - Write in long-form prose. Do not be brief; expand deeply on concepts. + - Use multiple paragraphs, deep conceptual explanations, and thorough analysis. + + RELEVANT LONG-TERM MEMORIES: + {relevant_memories} + + COMPRESSED SESSION CONTEXT: {self.memory.session_summary} + {recent_history} + """ + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt + vision_context} + ] + + # --- THE REASONING LOOP --- + MAX_ROUNDS = 3 + current_round = 0 + + while current_round < MAX_ROUNDS: + if current_round > 0: + # FIX 1: Pacing buffer to prevent 429 Rate Limits from back-to-back tool calls + time.sleep(1.5) + + remaining = MAX_ROUNDS - current_round + self.log(f"[dim italic]Reasoning cycle ({current_round + 1}/{MAX_ROUNDS})...[/dim italic]") + + budget_reminder = { + "role": "system", + "content": f"REASONING BUDGET: This is reasoning cycle {current_round + 1} of {MAX_ROUNDS}. " + f"You have {remaining - 1} tool-calling cycles left after this one. " + if remaining > 1 else + "FINAL REASONING CYCLE: You MUST finalise now." + "You will not be allowed to call more tools after this turn." + } + + current_messages = messages + [budget_reminder] + + valid_tool_calls = False + response_message = None + allowed_tool_names = [t["function"]["name"] for t in self._get_tools()] + + for attempt in range(3): + pre_flight = self.provider.chat_completion(messages=current_messages, tools=self._get_tools(), stream=False) + response_message = pre_flight.choices[0].message + + if not response_message.tool_calls and response_message.content and "**name**:" in response_message.content: + self.log(f"[dim yellow]Model hallucinated Markdown tool call. Retrying ({attempt+1}/3)...[/dim yellow]") + error_msg = f"ERROR: You attempted to call a tool using Markdown text. You MUST use the native JSON tool calling API. Allowed tools: {allowed_tool_names}" + messages.append({"role": "assistant", "content": response_message.content}) + messages.append({"role": "user", "content": error_msg}) + continue + + if not response_message.tool_calls: + valid_tool_calls = True + break + + has_errors = False + error_feedbacks = [] + + for tool_call in response_message.tool_calls: + func_name = tool_call.function.name + call_error = None + + if func_name not in allowed_tool_names: + has_errors = True + call_error = f"Tool '{func_name}' does not exist. Allowed tools: {allowed_tool_names}" + else: + try: + json.loads(tool_call.function.arguments) + except json.JSONDecodeError: + has_errors = True + call_error = f"Arguments for '{func_name}' are not valid JSON: {tool_call.function.arguments}" + + error_feedbacks.append(call_error) + + if has_errors: + self.log(f"[dim yellow]Malformed tool call detected. Retrying ({attempt+1}/3)...[/dim yellow]") + + assistant_msg = { + "role": "assistant", + "content": response_message.content or "", + "tool_calls": [ + { + "id": t.id, + "type": "function", + "function": {"name": t.function.name, "arguments": t.function.arguments} + } for t in response_message.tool_calls + ] + } + messages.append(assistant_msg) + + for i, tool_call in enumerate(response_message.tool_calls): + err = error_feedbacks[i] + msg_content = f"ERROR: {err}" if err else "Error: Another tool in this batch failed. Please fix the batch and retry." + messages.append({"role": "tool", "tool_call_id": tool_call.id, "content": msg_content}) + continue + + valid_tool_calls = True + break + + if not valid_tool_calls: + self.log("[bold red]Failed to generate valid tool calls. Breaking reasoning loop.[/bold red]") + response_message.tool_calls = None + break + + # --- FIX 2: Zero-Waste Response Generation --- + if not response_message.tool_calls: + if current_round == 0: + self.log("[dim italic]No tools needed. Outputting response...[/dim italic]") + else: + self.log("[dim italic]Reasoning complete. Outputting response...[/dim italic]") + + content = response_message.content or "" + + # Artificially stream the pre-generated block so the UI stays smooth + chunk_size = 30 + for i in range(0, len(content), chunk_size): + yield content[i:i+chunk_size] + time.sleep(0.01) + + self.memory.add_interaction(prompt, content) + return content + + # --- Execute Validated Tools --- + assistant_msg = { + "role": "assistant", + "content": response_message.content or "", + "tool_calls": [ + { + "id": t.id, + "type": "function", + "function": {"name": t.function.name, "arguments": t.function.arguments} + } for t in response_message.tool_calls + ] + } + messages.append(assistant_msg) + + for tool_call in response_message.tool_calls: + func_name = tool_call.function.name + args = json.loads(tool_call.function.arguments) + + if func_name == "search_pkm": + q = args.get("query", prompt) + self.log(f"[cyan]🧠 Tool Call: Searching PKM for '{q}'...[/cyan]") + yield f"\n*(Agent Note: Searched PKM for `{q}`)*\n\n" + result = self.pkm.search(q) + elif func_name == "search_web": + q = args.get("query", prompt) + self.log(f"[cyan]🌐 Tool Call: Searching Web for '{q}'...[/cyan]") + yield f"\n*(Agent Note: Searched Web for `{q}`)*\n\n" + result = self.web.search(q) + elif func_name == "append_agenda": + task = args.get("task", "Untitled Task") + scheduled = args.get("scheduled", "") + description = args.get("description", "") # <-- Extract the new param + self.log(f"[cyan]📅 Tool Call: Appending to Agenda: '{task}'...[/cyan]") + yield f"\n*(Agent Note: Added `{task}` to agenda)*\n\n" + + # Pass the description to the Python tool + result = self.agenda.append_todo(task, scheduled, description) + + messages.append({"role": "tool", "tool_call_id": tool_call.id, "content": result}) + + current_round += 1 + + # --- FALLBACK FINAL STREAMING RESPONSE --- + # Only reached if the agent maxes out all 3 reasoning rounds and still hasn't answered + self.log("[dim italic]Max rounds reached. Forcing final response...[/dim italic]") + time.sleep(1.5) + + messages.append({ + "role": "system", + "content": "You have reached the maximum number of reasoning steps. You must now provide your final, comprehensive answer based on the context gathered so far. Use strict org-mode formatting." + }) + + self.log("[dim italic]Streaming final response...[/dim italic]") + stream = self.provider.chat_completion(messages=messages, tools=self._get_tools(), stream=True, tool_choice="none") + + full_response = "" + for chunk in stream: + content = chunk.choices[0].delta.content or "" + full_response += content + yield content + + self.memory.add_interaction(prompt, full_response) + return full_response + + def shutdown(self): + self.memory.finalize_session() |
