summaryrefslogtreecommitdiff
path: root/cerebral/agent.py
blob: 20cb210a48d8a295c47709008a6eb3584ee8eccc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
import json
import hashlib
import re
import datetime
import time
import ollama
from typing import List, Dict, Generator, Optional, Callable

from .config import LOCAL_LLM, LOCAL_EMBED_MODEL, PKM_DIR, MEMORY_FILE, MEMORY_INDEX_PATH, FAISS_INDEX_PATH, HASH_TRACKER_FILE, AGENDA_FILE
from .llm import BaseLLMProvider
from .memory import MemoryManager
from .tools import PKMManager, VisionProcessor, WebSearcher, AgendaManager

class CerebralAgent:
    def __init__(self, provider: BaseLLMProvider, log: Callable[[str], None] = print):
        self.provider = provider
        self.log = log
        
        self.log("[bold magenta]Initializing Cerebral Agent Modules...[/bold magenta]")
        
        # --- FIX: Add cloud_provider=self.provider back to this line ---
        self.memory = MemoryManager(
            MEMORY_FILE, 
            MEMORY_INDEX_PATH, 
            LOCAL_LLM, 
            LOCAL_EMBED_MODEL, 
            self.log, 
            cloud_provider=self.provider
        )
        # --------------------------------------------------------------
        
        self.pkm = PKMManager(PKM_DIR, FAISS_INDEX_PATH, HASH_TRACKER_FILE, LOCAL_EMBED_MODEL, self.log)
        self.vision = VisionProcessor(LOCAL_LLM, self.log)
        self.web = WebSearcher(self.log)
        self.agenda = AgendaManager(AGENDA_FILE, self.log)

    def generate_session_filename(self, first_prompt: str, first_response: str) -> str:
        self.log("[dim italic]Generating descriptive filename based on prompt and response...[/dim italic]")
        
        hash_input = (first_prompt + first_response).encode('utf-8')
        combined_hash = hashlib.sha256(hash_input).hexdigest()[:6]
        
        sys_prompt = "You are a file naming utility. Read the user's prompt and generate a short, descriptive filename base using ONLY lowercase letters and hyphens. Do NOT add an extension. ONLY output the base filename, absolutely no other text. Example: learning-python-basics"
        
        try:
            response = ollama.chat(model=LOCAL_LLM, messages=[
                {'role': 'system', 'content': sys_prompt},
                {'role': 'user', 'content': first_prompt}
            ])
            
            raw_content = response['message']['content'].strip()
            raw_content = re.sub(r'<think>.*?</think>', '', raw_content, flags=re.DOTALL).strip()
            
            lines = [line.strip() for line in raw_content.split('\n') if line.strip()]
            raw_filename = lines[-1].lower().replace(' ', '-') if lines else "cerebral-session"
            clean_base = re.sub(r'[^a-z0-9\-]', '', raw_filename).strip('-')
            clean_base = clean_base[:50].strip('-')
            
            if not clean_base:
                clean_base = "cerebral-session"
                
            final_filename = f"{clean_base}-{combined_hash}.org"
            return final_filename
            
        except Exception as e:
            self.log(f"[dim red]Filename generation failed: {e}. Defaulting.[/dim red]")
            return f"cerebral-session-{combined_hash}.org"

    def _get_tools(self) -> List[Dict]:
        return [
            {
                "type": "function",
                "function": {
                    "name": "search_pkm",
                    "description": "Search the user's personal knowledge base (PKM) for notes, code, or org files.",
                    "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}
                }
            },
            {
                "type": "function",
                "function": {
                    "name": "search_web",
                    "description": "Search the live internet for current events, external documentation, or facts outside your PKM.",
                    "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}
                }
            },
            {
                "type": "function",
                "function": {
                    "name": "append_agenda",
                    "description": "Schedule a new TODO item in the user's private agenda.org file. Use this when the user asks to remember to do something, schedule a task, or add an agenda item.",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "task": {
                                "type": "string", 
                                "description": "The concise title of the task (e.g., 'Buy groceries' or 'Review neuralforge code')"
                            },
                            "scheduled": {
                                "type": "string", 
                                "description": "Optional. The exact date and time. If a time is requested or implied, you MUST format it strictly as 'YYYY-MM-DD HH:MM' (e.g., '2026-03-27 11:00'). If ONLY a day is requested with absolutely no time, use 'YYYY-MM-DD' (e.g., '2026-03-27'). Do not include the angle brackets (< >), just the raw date/time string."
                            },
                            "description": {
                                "type": "string",
                                "description": "Optional. A brief, 1-2 sentence description or note about the task/event."
                            }
                        },
                        "required": ["task"]
                    }
                }
            }
        ]

    def chat_stream(self, prompt: str, image_path: Optional[str] = None) -> Generator[str, None, str]:
        recent_history = ""
        if self.memory.interaction_buffer:
            recent_history = "\nRECENT UNCOMPRESSED TURNS:\n" + "\n".join(
                [f"User: {i['user']}\nAgent: {i['agent']}" for i in self.memory.interaction_buffer]
            )

        vision_context = ""
        if image_path:
            self.log("[dim italic]Analyzing image context...[/dim italic]")
            vision_summary = self.vision.process(image_path, prompt)
            vision_context = f"\n[USER ATTACHED AN IMAGE. Local Vision Summary: {vision_summary}]\n"
            
        self.log("[dim italic]Querying long-term memory (Ollama Embeddings)...[/dim italic]")
        relevant_memories = self.memory.search(prompt)
        
        current_time = datetime.datetime.now().strftime("%A, %B %d, %Y at %I:%M %p")

        system_prompt = f"""You are a highly capable AI assistant.
        
        CRITICAL OUTPUT FORMATTING: 
        You MUST output your responses EXCLUSIVELY in Emacs org-mode format. Use org-mode headings, lists, and LaTeX fragments for math.
        
        FORMATTING RULES:
        1. NEVER use double asterisks (`**`) for bolding. You MUST use SINGLE asterisks for bold emphasis (e.g., *this is bold*). Double asterisks will break the parser.
        2. Cite your sources inline using proper org-mode link syntax. For web searches, use [[url][Description]]. For PKM files, use [[file:/path/to/file.org][Filename]].
        3. At the very end of your response, you MUST append a Level 1 heading `* Sources` and neatly list all the search results and PKM documents you referenced using proper org-mode syntax.
        
        CURRENT TIME AND DATE: 
        {current_time}
        
        RESPONSE STYLE GUIDELINES:
        - Provide EXTREMELY detailed, exhaustive, and comprehensive answers.
        - Write in long-form prose. Do not be brief; expand deeply on concepts.
        - Use multiple paragraphs, deep conceptual explanations, and thorough analysis.
        
        RELEVANT LONG-TERM MEMORIES: 
        {relevant_memories}
        
        COMPRESSED SESSION CONTEXT: {self.memory.session_summary}
        {recent_history}
        """

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt + vision_context}
        ]

        # --- THE REASONING LOOP ---
        MAX_ROUNDS = 3
        current_round = 0
        
        while current_round < MAX_ROUNDS:
            if current_round > 0:
                # FIX 1: Pacing buffer to prevent 429 Rate Limits from back-to-back tool calls
                time.sleep(1.5) 
                
            remaining = MAX_ROUNDS - current_round
            self.log(f"[dim italic]Reasoning cycle ({current_round + 1}/{MAX_ROUNDS})...[/dim italic]")

            budget_reminder = {
                "role": "system", 
                "content": f"REASONING BUDGET: This is reasoning cycle {current_round + 1} of {MAX_ROUNDS}. "
                           f"You have {remaining - 1} tool-calling cycles left after this one. "
                           if remaining > 1 else 
                           "FINAL REASONING CYCLE: You MUST finalise now."
                           "You will not be allowed to call more tools after this turn."
            }

            current_messages = messages + [budget_reminder]

            valid_tool_calls = False
            response_message = None
            allowed_tool_names = [t["function"]["name"] for t in self._get_tools()]
            
            for attempt in range(3):
                pre_flight = self.provider.chat_completion(messages=current_messages, tools=self._get_tools(), stream=False)
                response_message = pre_flight.choices[0].message
                
                if not response_message.tool_calls and response_message.content and "**name**:" in response_message.content:
                    self.log(f"[dim yellow]Model hallucinated Markdown tool call. Retrying ({attempt+1}/3)...[/dim yellow]")
                    error_msg = f"ERROR: You attempted to call a tool using Markdown text. You MUST use the native JSON tool calling API. Allowed tools: {allowed_tool_names}"
                    messages.append({"role": "assistant", "content": response_message.content})
                    messages.append({"role": "user", "content": error_msg})
                    continue
                    
                if not response_message.tool_calls:
                    valid_tool_calls = True
                    break
                    
                has_errors = False
                error_feedbacks = []
                
                for tool_call in response_message.tool_calls:
                    func_name = tool_call.function.name
                    call_error = None
                    
                    if func_name not in allowed_tool_names:
                        has_errors = True
                        call_error = f"Tool '{func_name}' does not exist. Allowed tools: {allowed_tool_names}"
                    else:
                        try:
                            json.loads(tool_call.function.arguments)
                        except json.JSONDecodeError:
                            has_errors = True
                            call_error = f"Arguments for '{func_name}' are not valid JSON: {tool_call.function.arguments}"
                    
                    error_feedbacks.append(call_error)
                    
                if has_errors:
                    self.log(f"[dim yellow]Malformed tool call detected. Retrying ({attempt+1}/3)...[/dim yellow]")
                    
                    assistant_msg = {
                        "role": "assistant",
                        "content": response_message.content or "",
                        "tool_calls": [
                            {
                                "id": t.id,
                                "type": "function",
                                "function": {"name": t.function.name, "arguments": t.function.arguments}
                            } for t in response_message.tool_calls
                        ]
                    }
                    messages.append(assistant_msg)
                    
                    for i, tool_call in enumerate(response_message.tool_calls):
                        err = error_feedbacks[i]
                        msg_content = f"ERROR: {err}" if err else "Error: Another tool in this batch failed. Please fix the batch and retry."
                        messages.append({"role": "tool", "tool_call_id": tool_call.id, "content": msg_content})
                    continue
                    
                valid_tool_calls = True
                break
                
            if not valid_tool_calls:
                self.log("[bold red]Failed to generate valid tool calls. Breaking reasoning loop.[/bold red]")
                response_message.tool_calls = None
                break

            # --- FIX 2: Zero-Waste Response Generation ---
            if not response_message.tool_calls:
                if current_round == 0:
                    self.log("[dim italic]No tools needed. Outputting response...[/dim italic]")
                else:
                    self.log("[dim italic]Reasoning complete. Outputting response...[/dim italic]")
                
                content = response_message.content or ""
                
                # Artificially stream the pre-generated block so the UI stays smooth
                chunk_size = 30
                for i in range(0, len(content), chunk_size):
                    yield content[i:i+chunk_size]
                    time.sleep(0.01)
                    
                self.memory.add_interaction(prompt, content)
                return content
            
            # --- Execute Validated Tools ---
            assistant_msg = {
                "role": "assistant",
                "content": response_message.content or "",
                "tool_calls": [
                    {
                        "id": t.id,
                        "type": "function",
                        "function": {"name": t.function.name, "arguments": t.function.arguments}
                    } for t in response_message.tool_calls
                ]
            }
            messages.append(assistant_msg)
            
            for tool_call in response_message.tool_calls:
                func_name = tool_call.function.name
                args = json.loads(tool_call.function.arguments) 
                
                if func_name == "search_pkm":
                    q = args.get("query", prompt)
                    self.log(f"[cyan]🧠 Tool Call: Searching PKM for '{q}'...[/cyan]")
                    yield f"\n*(Agent Note: Searched PKM for `{q}`)*\n\n"
                    result = self.pkm.search(q)
                elif func_name == "search_web":
                    q = args.get("query", prompt)
                    self.log(f"[cyan]🌐 Tool Call: Searching Web for '{q}'...[/cyan]")
                    yield f"\n*(Agent Note: Searched Web for `{q}`)*\n\n"
                    result = self.web.search(q)
                elif func_name == "append_agenda":
                    task = args.get("task", "Untitled Task")
                    scheduled = args.get("scheduled", "")
                    description = args.get("description", "") # <-- Extract the new param
                    self.log(f"[cyan]📅 Tool Call: Appending to Agenda: '{task}'...[/cyan]")
                    yield f"\n*(Agent Note: Added `{task}` to agenda)*\n\n"
                    
                    # Pass the description to the Python tool
                    result = self.agenda.append_todo(task, scheduled, description)
                    
                messages.append({"role": "tool", "tool_call_id": tool_call.id, "content": result})
            
            current_round += 1

        # --- FALLBACK FINAL STREAMING RESPONSE ---
        # Only reached if the agent maxes out all 3 reasoning rounds and still hasn't answered
        self.log("[dim italic]Max rounds reached. Forcing final response...[/dim italic]")
        time.sleep(1.5)
        
        messages.append({
            "role": "system",
            "content": "You have reached the maximum number of reasoning steps. You must now provide your final, comprehensive answer based on the context gathered so far. Use strict org-mode formatting."
        })

        self.log("[dim italic]Streaming final response...[/dim italic]")
        stream = self.provider.chat_completion(messages=messages, tools=self._get_tools(), stream=True, tool_choice="none")
        
        full_response = ""
        for chunk in stream:
            content = chunk.choices[0].delta.content or ""
            full_response += content
            yield content
            
        self.memory.add_interaction(prompt, full_response)
        return full_response

    def shutdown(self):
        self.memory.finalize_session()