diff options
Diffstat (limited to 'cerebral/llm.py')
| -rw-r--r-- | cerebral/llm.py | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/cerebral/llm.py b/cerebral/llm.py new file mode 100644 index 0000000..7eb1597 --- /dev/null +++ b/cerebral/llm.py @@ -0,0 +1,29 @@ +import os +from abc import ABC, abstractmethod +from typing import List, Dict, Any +from cerebras.cloud.sdk import Cerebras + +class BaseLLMProvider(ABC): + @abstractmethod + def chat_completion(self, messages: List[Dict], tools: List[Dict] = None, stream: bool = False, tool_choice: str = "auto") -> Any: + pass + +class CerebrasProvider(BaseLLMProvider): + def __init__(self, model: str = "qwen-3-235b-a22b-instruct-2507"): + api_key = os.environ.get("CEREBRAS_API_KEY") + if not api_key: + raise ValueError("CEREBRAS_API_KEY environment variable is required.") + self.client = Cerebras(api_key=api_key) + self.model = model + + def chat_completion(self, messages: List[Dict], tools: List[Dict] = None, stream: bool = False, tool_choice: str = "auto"): + kwargs = { + "messages": messages, + "model": self.model, + "stream": stream, + } + if tools: + kwargs["tools"] = tools + kwargs["tool_choice"] = tool_choice + + return self.client.chat.completions.create(**kwargs) |
