diff options
| author | Preston Pan <ret2pop@nullring.xyz> | 2026-03-26 19:46:02 -0700 |
|---|---|---|
| committer | Preston Pan <ret2pop@nullring.xyz> | 2026-03-26 19:46:02 -0700 |
| commit | dcbe7fed3ab74df8dfc8f1c9affc76e2506020f8 (patch) | |
| tree | 8693621be9de583e9efb702da1b0c74c9d00de36 /cerebral/llm.py | |
| parent | f49dfb56c699da817a7deac711bd16d38df80783 (diff) | |
Diffstat (limited to 'cerebral/llm.py')
| -rw-r--r-- | cerebral/llm.py | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/cerebral/llm.py b/cerebral/llm.py new file mode 100644 index 0000000..7eb1597 --- /dev/null +++ b/cerebral/llm.py @@ -0,0 +1,29 @@ +import os +from abc import ABC, abstractmethod +from typing import List, Dict, Any +from cerebras.cloud.sdk import Cerebras + +class BaseLLMProvider(ABC): + @abstractmethod + def chat_completion(self, messages: List[Dict], tools: List[Dict] = None, stream: bool = False, tool_choice: str = "auto") -> Any: + pass + +class CerebrasProvider(BaseLLMProvider): + def __init__(self, model: str = "qwen-3-235b-a22b-instruct-2507"): + api_key = os.environ.get("CEREBRAS_API_KEY") + if not api_key: + raise ValueError("CEREBRAS_API_KEY environment variable is required.") + self.client = Cerebras(api_key=api_key) + self.model = model + + def chat_completion(self, messages: List[Dict], tools: List[Dict] = None, stream: bool = False, tool_choice: str = "auto"): + kwargs = { + "messages": messages, + "model": self.model, + "stream": stream, + } + if tools: + kwargs["tools"] = tools + kwargs["tool_choice"] = tool_choice + + return self.client.chat.completions.create(**kwargs) |
