summaryrefslogtreecommitdiff
path: root/cerebral/llm.py
diff options
context:
space:
mode:
Diffstat (limited to 'cerebral/llm.py')
-rw-r--r--cerebral/llm.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/cerebral/llm.py b/cerebral/llm.py
new file mode 100644
index 0000000..7eb1597
--- /dev/null
+++ b/cerebral/llm.py
@@ -0,0 +1,29 @@
+import os
+from abc import ABC, abstractmethod
+from typing import List, Dict, Any
+from cerebras.cloud.sdk import Cerebras
+
+class BaseLLMProvider(ABC):
+ @abstractmethod
+ def chat_completion(self, messages: List[Dict], tools: List[Dict] = None, stream: bool = False, tool_choice: str = "auto") -> Any:
+ pass
+
+class CerebrasProvider(BaseLLMProvider):
+ def __init__(self, model: str = "qwen-3-235b-a22b-instruct-2507"):
+ api_key = os.environ.get("CEREBRAS_API_KEY")
+ if not api_key:
+ raise ValueError("CEREBRAS_API_KEY environment variable is required.")
+ self.client = Cerebras(api_key=api_key)
+ self.model = model
+
+ def chat_completion(self, messages: List[Dict], tools: List[Dict] = None, stream: bool = False, tool_choice: str = "auto"):
+ kwargs = {
+ "messages": messages,
+ "model": self.model,
+ "stream": stream,
+ }
+ if tools:
+ kwargs["tools"] = tools
+ kwargs["tool_choice"] = tool_choice
+
+ return self.client.chat.completions.create(**kwargs)