Skip to content

Commit 5e2711e

Browse files
authored
Add authentication and improve error handling in TogetherClient (stanford-crfm#1560)
1 parent 9652316 commit 5e2711e

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

src/helm/proxy/clients/together_client.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ def fix_text(x: str, model: str) -> str:
3737
return x
3838

3939

40+
class TogetherClientError(Exception):
41+
pass
42+
43+
4044
class TogetherClient(Client):
4145
"""
4246
Client for the models where we evaluate offline. Since the queries are handled offline, the `TogetherClient` just
@@ -72,22 +76,24 @@ def make_request(self, request: Request) -> RequestResult:
7276
raw_request = TogetherClient.convert_to_raw_request(request)
7377
cache_key: Dict = Client.make_cache_key(raw_request, request)
7478

75-
try:
76-
77-
def do_it():
78-
result = requests.post(TogetherClient.INFERENCE_ENDPOINT, json=raw_request).json()
79-
assert "output" in result, f"Invalid response: {result}"
80-
return result["output"]
81-
82-
def fail():
83-
raise RuntimeError(
84-
f"The result has not been uploaded to the cache for the following request: {cache_key}"
85-
)
86-
87-
response, cached = self.cache.get(cache_key, wrap_request_time(do_it if self.api_key else fail))
88-
except RuntimeError as e:
89-
error: str = f"TogetherClient error: {e}"
90-
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
79+
def do_it():
80+
if not self.api_key:
81+
raise TogetherClientError("togetherApiKey not set in credentials.conf")
82+
headers: Dict[str, str] = {"Authorization": f"Bearer {self.api_key}"}
83+
response = requests.post(TogetherClient.INFERENCE_ENDPOINT, headers=headers, json=raw_request)
84+
try:
85+
response.raise_for_status()
86+
except Exception as e:
87+
raise TogetherClientError(
88+
f"Together request failed with {response.status_code}: {response.text}"
89+
) from e
90+
result = response.json()
91+
return result["output"]
92+
93+
def fail():
94+
raise RuntimeError(f"The result has not been uploaded to the cache for the following request: {cache_key}")
95+
96+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it if self.api_key else fail))
9197

9298
# Expect the result to be structured the same way as a response from OpenAI API.
9399
completions: List[Sequence] = []

0 commit comments

Comments
 (0)