1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
import os
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from cerebras.cloud.sdk import Cerebras
class BaseLLMProvider(ABC):
@abstractmethod
def chat_completion(self, messages: List[Dict], tools: List[Dict] = None, stream: bool = False, tool_choice: str = "auto") -> Any:
pass
class CerebrasProvider(BaseLLMProvider):
def __init__(self, model: str = "qwen-3-235b-a22b-instruct-2507"):
api_key = os.environ.get("CEREBRAS_API_KEY")
if not api_key:
raise ValueError("CEREBRAS_API_KEY environment variable is required.")
self.client = Cerebras(api_key=api_key)
self.model = model
def chat_completion(self, messages: List[Dict], tools: List[Dict] = None, stream: bool = False, tool_choice: str = "auto"):
kwargs = {
"messages": messages,
"model": self.model,
"stream": stream,
}
if tools:
kwargs["tools"] = tools
kwargs["tool_choice"] = tool_choice
return self.client.chat.completions.create(**kwargs)
|