summaryrefslogtreecommitdiff
path: root/cerebral/tools.py
diff options
context:
space:
mode:
Diffstat (limited to 'cerebral/tools.py')
-rw-r--r--cerebral/tools.py194
1 files changed, 194 insertions, 0 deletions
diff --git a/cerebral/tools.py b/cerebral/tools.py
new file mode 100644
index 0000000..304edaa
--- /dev/null
+++ b/cerebral/tools.py
@@ -0,0 +1,194 @@
+import os
+import subprocess
+import ollama
+from typing import List, Dict, Callable
+from abc import ABC, abstractmethod
+from ddgs import DDGS
+from langchain_ollama import OllamaEmbeddings
+from langchain_community.vectorstores import FAISS
+from langchain_core.documents import Document
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+
+class PKMManager:
+ def __init__(self, pkm_dir: str, index_path: str, hash_file: str, embed_model_name: str, log: Callable[[str], None] = print):
+ self.pkm_dir = pkm_dir
+ self.index_path = index_path
+ self.hash_file = hash_file
+ self.log = log
+
+ self.log(f"[dim italic]Waking up Ollama embeddings ({embed_model_name})...[/dim italic]")
+ self.embeddings = OllamaEmbeddings(model=embed_model_name)
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
+ self.vectorstore = self._load_or_build()
+
+ def _get_main_commit_hash(self) -> str:
+ try:
+ result = subprocess.run(
+ ["git", "rev-parse", "main"],
+ cwd=self.pkm_dir, capture_output=True, text=True, check=True
+ )
+ return result.stdout.strip()
+ except subprocess.CalledProcessError:
+ return "unknown"
+
+ def _load_or_build(self):
+ self.log("[dim]Checking Git HEAD hash for PKM changes...[/dim]")
+ current_hash = self._get_main_commit_hash()
+
+ if os.path.exists(self.index_path) and os.path.exists(self.hash_file):
+ with open(self.hash_file, "r") as f:
+ if f.read().strip() == current_hash:
+ self.log(f"[green]Git hash unchanged ({current_hash[:7]}). Loading cached PKM index...[/green]")
+ return FAISS.load_local(self.index_path, self.embeddings, allow_dangerous_deserialization=True)
+
+ self.log(f"[bold yellow]New commits detected ({current_hash[:7]}). Rebuilding PKM index...[/bold yellow]")
+ raw_documents = []
+
+ self.log(f"[dim]Scanning {self.pkm_dir} for .org files...[/dim]")
+ for root, dirs, files in os.walk(self.pkm_dir):
+ if '.git' in dirs: dirs.remove('.git')
+ if 'nix' in dirs: dirs.remove('nix')
+ for file in files:
+ if file.endswith('.org'):
+ filepath = os.path.join(root, file)
+ try:
+ with open(filepath, 'r', encoding='utf-8') as f:
+ raw_documents.append(Document(page_content=f.read(), metadata={"source": filepath}))
+ except Exception:
+ pass
+
+ if not raw_documents:
+ self.log("[red]No .org files found in PKM directory.[/red]")
+ return None
+
+ self.log(f"[dim]Chunking {len(raw_documents)} documents...[/dim]")
+ chunks = self.text_splitter.split_documents(raw_documents)
+
+ self.log(f"[bold cyan]Embedding {len(chunks)} chunks via Ollama (this might take a minute)...[/bold cyan]")
+ vectorstore = FAISS.from_documents(chunks, self.embeddings)
+ vectorstore.save_local(self.index_path)
+
+ with open(self.hash_file, "w") as f:
+ f.write(current_hash)
+
+ self.log("[bold green]PKM Index successfully rebuilt and saved![/bold green]")
+ return vectorstore
+
+ def search(self, query: str) -> str:
+ if not self.vectorstore:
+ return "PKM is empty."
+ docs = self.vectorstore.similarity_search(query, k=10)
+ return "PKM Search Results:\n" + "\n\n".join([f"From {d.metadata['source']}:\n{d.page_content}" for d in docs])
+
+
+class VisionProcessor:
+ def __init__(self, local_model: str, log: Callable[[str], None] = print):
+ self.local_model = local_model
+ self.log = log
+ self.log("[dim italic]Vision Processor online...[/dim italic]")
+
+ def process(self, image_path: str, user_prompt: str) -> str:
+ try:
+ with open(image_path, 'rb') as img_file:
+ img_bytes = img_file.read()
+ response = ollama.chat(model=self.local_model, messages=[{
+ 'role': 'user',
+ 'content': f"Describe this image in detail to help another AI answer this prompt: {user_prompt}",
+ 'images': [img_bytes]
+ }])
+ return response['message']['content']
+ except Exception as e:
+ return f"[Image analysis failed: {e}]"
+
+
+class BaseSearchProvider(ABC):
+ @abstractmethod
+ def search(self, query: str, max_results: int = 10) -> List[Dict[str, str]]:
+ pass
+
+class GoogleSearchProvider(BaseSearchProvider):
+ def search(self, query: str, max_results: int = 10) -> List[Dict[str, str]]:
+ from googlesearch import search
+
+ results = []
+ for r in search(query, num_results=max_results, advanced=True):
+ results.append({
+ 'title': getattr(r, 'title', 'No Title'),
+ 'href': getattr(r, 'url', 'No URL'),
+ 'body': getattr(r, 'description', 'No Description')
+ })
+
+ if not results:
+ raise Exception("Google returned zero results.")
+
+ return results
+
+class DDGSSearchProvider(BaseSearchProvider):
+ def search(self, query: str, max_results: int = 10) -> List[Dict[str, str]]:
+ results = DDGS().text(query, max_results=max_results)
+
+ if not results:
+ raise Exception("DuckDuckGo returned zero results.")
+
+ formatted_results = []
+ for r in results:
+ formatted_results.append({
+ 'title': r.get('title', 'No Title'),
+ 'href': r.get('href', 'No URL'),
+ 'body': r.get('body', 'No Description')
+ })
+ return formatted_results
+
+class WebSearcher:
+ def __init__(self, log: Callable[[str], None] = print):
+ self.log = log
+ self.providers: List[BaseSearchProvider] = [
+ GoogleSearchProvider(),
+ DDGSSearchProvider()
+ ]
+
+ def search(self, query: str) -> str:
+ for provider in self.providers:
+ provider_name = provider.__class__.__name__
+ try:
+ self.log(f"[dim italic]Trying {provider_name}...[/dim italic]")
+ results = provider.search(query, max_results=10)
+
+ context = "Web Search Results:\n"
+ for r in results:
+ context += f"- Title: {r['title']}\n URL: {r['href']}\n Snippet: {r['body']}\n\n"
+ return context
+
+ except Exception as e:
+ self.log(f"[dim yellow]{provider_name} failed ({e}). Falling back...[/dim yellow]")
+ continue
+
+ return "Web search failed: All search providers were exhausted or rate-limited."
+
+class AgendaManager:
+ def __init__(self, agenda_file: str, log: Callable[[str], None] = print):
+ self.agenda_file = agenda_file
+ self.log = log
+
+ def append_todo(self, task: str, scheduled: str = "", description: str = "") -> str:
+ """Appends a TODO to the agenda file. Write-only for privacy."""
+ try:
+ if not os.path.exists(self.agenda_file):
+ with open(self.agenda_file, "w") as f:
+ f.write("#+TITLE: Private Agenda\n\n")
+
+ entry = f"\n* TODO {task}\n"
+ if scheduled:
+ entry += f" SCHEDULED: <{scheduled}>\n"
+ if description:
+ # Add a clean, indented description below the scheduled tag
+ entry += f" {description}\n"
+
+ with open(self.agenda_file, "a") as f:
+ f.write(entry)
+
+ self.log(f"[bold green]Successfully appended to agenda: {task}[/bold green]")
+ return f"Successfully added to agenda: {task}"
+ except Exception as e:
+ self.log(f"[bold red]Failed to append to agenda: {e}[/bold red]")
+ return f"Failed to add to agenda: {e}"