summaryrefslogtreecommitdiff
path: root/cerebral/llm.py
blob: 7eb159752194ea5ec181df6b3d0274948e9b6c44 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from cerebras.cloud.sdk import Cerebras

class BaseLLMProvider(ABC):
    @abstractmethod
    def chat_completion(self, messages: List[Dict], tools: List[Dict] = None, stream: bool = False, tool_choice: str = "auto") -> Any:
        pass

class CerebrasProvider(BaseLLMProvider):
    def __init__(self, model: str = "qwen-3-235b-a22b-instruct-2507"):
        api_key = os.environ.get("CEREBRAS_API_KEY")
        if not api_key:
            raise ValueError("CEREBRAS_API_KEY environment variable is required.")
        self.client = Cerebras(api_key=api_key)
        self.model = model

    def chat_completion(self, messages: List[Dict], tools: List[Dict] = None, stream: bool = False, tool_choice: str = "auto"):
        kwargs = {
            "messages": messages,
            "model": self.model,
            "stream": stream,
        }
        if tools:
            kwargs["tools"] = tools
            kwargs["tool_choice"] = tool_choice
            
        return self.client.chat.completions.create(**kwargs)