diff options
Diffstat (limited to 'cerebral/tools.py')
| -rw-r--r-- | cerebral/tools.py | 194 |
1 files changed, 194 insertions, 0 deletions
diff --git a/cerebral/tools.py b/cerebral/tools.py new file mode 100644 index 0000000..304edaa --- /dev/null +++ b/cerebral/tools.py @@ -0,0 +1,194 @@ +import os +import subprocess +import ollama +from typing import List, Dict, Callable +from abc import ABC, abstractmethod +from ddgs import DDGS +from langchain_ollama import OllamaEmbeddings +from langchain_community.vectorstores import FAISS +from langchain_core.documents import Document +from langchain_text_splitters import RecursiveCharacterTextSplitter + +class PKMManager: + def __init__(self, pkm_dir: str, index_path: str, hash_file: str, embed_model_name: str, log: Callable[[str], None] = print): + self.pkm_dir = pkm_dir + self.index_path = index_path + self.hash_file = hash_file + self.log = log + + self.log(f"[dim italic]Waking up Ollama embeddings ({embed_model_name})...[/dim italic]") + self.embeddings = OllamaEmbeddings(model=embed_model_name) + self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) + self.vectorstore = self._load_or_build() + + def _get_main_commit_hash(self) -> str: + try: + result = subprocess.run( + ["git", "rev-parse", "main"], + cwd=self.pkm_dir, capture_output=True, text=True, check=True + ) + return result.stdout.strip() + except subprocess.CalledProcessError: + return "unknown" + + def _load_or_build(self): + self.log("[dim]Checking Git HEAD hash for PKM changes...[/dim]") + current_hash = self._get_main_commit_hash() + + if os.path.exists(self.index_path) and os.path.exists(self.hash_file): + with open(self.hash_file, "r") as f: + if f.read().strip() == current_hash: + self.log(f"[green]Git hash unchanged ({current_hash[:7]}). Loading cached PKM index...[/green]") + return FAISS.load_local(self.index_path, self.embeddings, allow_dangerous_deserialization=True) + + self.log(f"[bold yellow]New commits detected ({current_hash[:7]}). Rebuilding PKM index...[/bold yellow]") + raw_documents = [] + + self.log(f"[dim]Scanning {self.pkm_dir} for .org files...[/dim]") + for root, dirs, files in os.walk(self.pkm_dir): + if '.git' in dirs: dirs.remove('.git') + if 'nix' in dirs: dirs.remove('nix') + for file in files: + if file.endswith('.org'): + filepath = os.path.join(root, file) + try: + with open(filepath, 'r', encoding='utf-8') as f: + raw_documents.append(Document(page_content=f.read(), metadata={"source": filepath})) + except Exception: + pass + + if not raw_documents: + self.log("[red]No .org files found in PKM directory.[/red]") + return None + + self.log(f"[dim]Chunking {len(raw_documents)} documents...[/dim]") + chunks = self.text_splitter.split_documents(raw_documents) + + self.log(f"[bold cyan]Embedding {len(chunks)} chunks via Ollama (this might take a minute)...[/bold cyan]") + vectorstore = FAISS.from_documents(chunks, self.embeddings) + vectorstore.save_local(self.index_path) + + with open(self.hash_file, "w") as f: + f.write(current_hash) + + self.log("[bold green]PKM Index successfully rebuilt and saved![/bold green]") + return vectorstore + + def search(self, query: str) -> str: + if not self.vectorstore: + return "PKM is empty." + docs = self.vectorstore.similarity_search(query, k=10) + return "PKM Search Results:\n" + "\n\n".join([f"From {d.metadata['source']}:\n{d.page_content}" for d in docs]) + + +class VisionProcessor: + def __init__(self, local_model: str, log: Callable[[str], None] = print): + self.local_model = local_model + self.log = log + self.log("[dim italic]Vision Processor online...[/dim italic]") + + def process(self, image_path: str, user_prompt: str) -> str: + try: + with open(image_path, 'rb') as img_file: + img_bytes = img_file.read() + response = ollama.chat(model=self.local_model, messages=[{ + 'role': 'user', + 'content': f"Describe this image in detail to help another AI answer this prompt: {user_prompt}", + 'images': [img_bytes] + }]) + return response['message']['content'] + except Exception as e: + return f"[Image analysis failed: {e}]" + + +class BaseSearchProvider(ABC): + @abstractmethod + def search(self, query: str, max_results: int = 10) -> List[Dict[str, str]]: + pass + +class GoogleSearchProvider(BaseSearchProvider): + def search(self, query: str, max_results: int = 10) -> List[Dict[str, str]]: + from googlesearch import search + + results = [] + for r in search(query, num_results=max_results, advanced=True): + results.append({ + 'title': getattr(r, 'title', 'No Title'), + 'href': getattr(r, 'url', 'No URL'), + 'body': getattr(r, 'description', 'No Description') + }) + + if not results: + raise Exception("Google returned zero results.") + + return results + +class DDGSSearchProvider(BaseSearchProvider): + def search(self, query: str, max_results: int = 10) -> List[Dict[str, str]]: + results = DDGS().text(query, max_results=max_results) + + if not results: + raise Exception("DuckDuckGo returned zero results.") + + formatted_results = [] + for r in results: + formatted_results.append({ + 'title': r.get('title', 'No Title'), + 'href': r.get('href', 'No URL'), + 'body': r.get('body', 'No Description') + }) + return formatted_results + +class WebSearcher: + def __init__(self, log: Callable[[str], None] = print): + self.log = log + self.providers: List[BaseSearchProvider] = [ + GoogleSearchProvider(), + DDGSSearchProvider() + ] + + def search(self, query: str) -> str: + for provider in self.providers: + provider_name = provider.__class__.__name__ + try: + self.log(f"[dim italic]Trying {provider_name}...[/dim italic]") + results = provider.search(query, max_results=10) + + context = "Web Search Results:\n" + for r in results: + context += f"- Title: {r['title']}\n URL: {r['href']}\n Snippet: {r['body']}\n\n" + return context + + except Exception as e: + self.log(f"[dim yellow]{provider_name} failed ({e}). Falling back...[/dim yellow]") + continue + + return "Web search failed: All search providers were exhausted or rate-limited." + +class AgendaManager: + def __init__(self, agenda_file: str, log: Callable[[str], None] = print): + self.agenda_file = agenda_file + self.log = log + + def append_todo(self, task: str, scheduled: str = "", description: str = "") -> str: + """Appends a TODO to the agenda file. Write-only for privacy.""" + try: + if not os.path.exists(self.agenda_file): + with open(self.agenda_file, "w") as f: + f.write("#+TITLE: Private Agenda\n\n") + + entry = f"\n* TODO {task}\n" + if scheduled: + entry += f" SCHEDULED: <{scheduled}>\n" + if description: + # Add a clean, indented description below the scheduled tag + entry += f" {description}\n" + + with open(self.agenda_file, "a") as f: + f.write(entry) + + self.log(f"[bold green]Successfully appended to agenda: {task}[/bold green]") + return f"Successfully added to agenda: {task}" + except Exception as e: + self.log(f"[bold red]Failed to append to agenda: {e}[/bold red]") + return f"Failed to add to agenda: {e}" |
