summaryrefslogtreecommitdiff
path: root/cerebral/llm.py
diff options
context:
space:
mode:
authorPreston Pan <ret2pop@nullring.xyz>2026-03-26 19:46:02 -0700
committerPreston Pan <ret2pop@nullring.xyz>2026-03-26 19:46:02 -0700
commitdcbe7fed3ab74df8dfc8f1c9affc76e2506020f8 (patch)
tree8693621be9de583e9efb702da1b0c74c9d00de36 /cerebral/llm.py
parentf49dfb56c699da817a7deac711bd16d38df80783 (diff)
good versionHEADmain
Diffstat (limited to 'cerebral/llm.py')
-rw-r--r--cerebral/llm.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/cerebral/llm.py b/cerebral/llm.py
new file mode 100644
index 0000000..7eb1597
--- /dev/null
+++ b/cerebral/llm.py
@@ -0,0 +1,29 @@
+import os
+from abc import ABC, abstractmethod
+from typing import List, Dict, Any
+from cerebras.cloud.sdk import Cerebras
+
+class BaseLLMProvider(ABC):
+ @abstractmethod
+ def chat_completion(self, messages: List[Dict], tools: List[Dict] = None, stream: bool = False, tool_choice: str = "auto") -> Any:
+ pass
+
+class CerebrasProvider(BaseLLMProvider):
+ def __init__(self, model: str = "qwen-3-235b-a22b-instruct-2507"):
+ api_key = os.environ.get("CEREBRAS_API_KEY")
+ if not api_key:
+ raise ValueError("CEREBRAS_API_KEY environment variable is required.")
+ self.client = Cerebras(api_key=api_key)
+ self.model = model
+
+ def chat_completion(self, messages: List[Dict], tools: List[Dict] = None, stream: bool = False, tool_choice: str = "auto"):
+ kwargs = {
+ "messages": messages,
+ "model": self.model,
+ "stream": stream,
+ }
+ if tools:
+ kwargs["tools"] = tools
+ kwargs["tool_choice"] = tool_choice
+
+ return self.client.chat.completions.create(**kwargs)