|
9 | 9 |
|
10 | 10 | from helm.common.cache import Cache, CacheConfig
|
11 | 11 | from helm.common.hierarchical_logger import htrack_block, hlog
|
12 |
| -from helm.common.request import EMBEDDING_UNAVAILABLE_REQUEST_RESULT, Request, RequestResult, Sequence, Token |
| 12 | +from helm.common.request import ( |
| 13 | + EMBEDDING_UNAVAILABLE_REQUEST_RESULT, |
| 14 | + Request, |
| 15 | + RequestResult, |
| 16 | + Sequence, |
| 17 | + Token, |
| 18 | + ErrorFlags, |
| 19 | +) |
13 | 20 | from helm.common.tokenization_request import (
|
14 | 21 | TokenizationRequest,
|
15 | 22 | TokenizationRequestResult,
|
|
22 | 29 | from dataclasses import asdict
|
23 | 30 |
|
24 | 31 |
|
25 |
| -class AnthropicPromptTooLongError(Exception): |
26 |
| - pass |
27 |
| - |
28 |
| - |
29 |
| -class AnthropicPromptPlusMaxTokensTooLongError(Exception): |
30 |
| - pass |
31 |
| - |
32 |
| - |
33 | 32 | class AnthropicClient(Client):
|
34 | 33 | """
|
35 | 34 | Client for the Anthropic models (https://arxiv.org/abs/2204.05862).
|
@@ -138,9 +137,23 @@ def do_it():
|
138 | 137 | response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
139 | 138 | except Exception as error:
|
140 | 139 | if "Prompt must contain anthropic.AI_PROMPT" in str(error):
|
141 |
| - raise AnthropicPromptTooLongError(f"Prompt too long: {request.prompt}") |
| 140 | + return RequestResult( |
| 141 | + success=False, |
| 142 | + cached=False, |
| 143 | + error=response["error"], |
| 144 | + completions=[], |
| 145 | + embedding=[], |
| 146 | + error_flags=ErrorFlags(is_retriable=False, is_fatal=False), |
| 147 | + ) |
142 | 148 | if "exceeds max (" in str(error):
|
143 |
| - raise AnthropicPromptPlusMaxTokensTooLongError(f"Prompt + max_tokens too long: {request.prompt}") |
| 149 | + return RequestResult( |
| 150 | + success=False, |
| 151 | + cached=False, |
| 152 | + error=response["error"], |
| 153 | + completions=[], |
| 154 | + embedding=[], |
| 155 | + error_flags=ErrorFlags(is_retriable=False, is_fatal=False), |
| 156 | + ) |
144 | 157 | return RequestResult(success=False, cached=False, error=str(error), completions=[], embedding=[])
|
145 | 158 |
|
146 | 159 | # Post process the completion.
|
|
0 commit comments