Skip to content

Commit 7fe63cd

Browse files
Addition of Error Flags: retriable and fatal (stanford-crfm#1533)
1 parent 356a6b2 commit 7fe63cd

File tree

5 files changed

+68
-17
lines changed

5 files changed

+68
-17
lines changed

src/helm/benchmark/executor.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from helm.common.general import parallel_map
55
from helm.common.hierarchical_logger import htrack, hlog
6-
from helm.common.request import RequestResult
6+
from helm.common.request import RequestResult, Sequence
77
from helm.common.authentication import Authentication
88
from helm.proxy.services.remote_service import RemoteService
99
from helm.proxy.services.server_service import ServerService
@@ -85,5 +85,9 @@ def execute(self, scenario_state: ScenarioState) -> ScenarioState:
8585
def process(self, state: RequestState) -> RequestState:
8686
result: RequestResult = self.service.make_request(self.execution_spec.auth, state.request)
8787
if not result.success:
88-
raise ExecutorError(f"{str(result.error)} Request: {state.request}")
88+
if result.error_flags and not result.error_flags.is_fatal:
89+
hlog(f"WARNING: Non-fatal error treated as empty completion: {result.error}")
90+
result.completions = [Sequence(text="", logprob=0, tokens=[])]
91+
else:
92+
raise ExecutorError(f"{str(result.error)} Request: {state.request}")
8993
return replace(state, result=result)

src/helm/common/request.py

+16
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,19 @@ def render_lines(self) -> List[str]:
129129

130130

131131
@dataclass(frozen=True)
132+
class ErrorFlags:
133+
"""Describes how to treat errors in the request."""
134+
135+
is_retriable: Optional[bool] = None
136+
"""Whether the request is retriable or whether the error is permanent.
137+
If None, the error is treated as retriable."""
138+
139+
is_fatal: Optional[bool] = None
140+
"""Whether the error is fatal, i.e. the run should be discarded.
141+
If None, the error is treated as fatal."""
142+
143+
144+
@dataclass(frozen=False)
132145
class RequestResult:
133146
"""What comes back due to a `Request`."""
134147

@@ -155,6 +168,9 @@ class RequestResult:
155168
error: Optional[str] = None
156169
"""If `success` is false, what was the error?"""
157170

171+
error_flags: Optional[ErrorFlags] = None
172+
"""Describes how to treat errors in the request."""
173+
158174
batch_size: Optional[int] = None
159175
"""Batch size (`TogetherClient` only)"""
160176

src/helm/proxy/clients/anthropic_client.py

+24-11
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@
99

1010
from helm.common.cache import Cache, CacheConfig
1111
from helm.common.hierarchical_logger import htrack_block, hlog
12-
from helm.common.request import EMBEDDING_UNAVAILABLE_REQUEST_RESULT, Request, RequestResult, Sequence, Token
12+
from helm.common.request import (
13+
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
14+
Request,
15+
RequestResult,
16+
Sequence,
17+
Token,
18+
ErrorFlags,
19+
)
1320
from helm.common.tokenization_request import (
1421
TokenizationRequest,
1522
TokenizationRequestResult,
@@ -22,14 +29,6 @@
2229
from dataclasses import asdict
2330

2431

25-
class AnthropicPromptTooLongError(Exception):
26-
pass
27-
28-
29-
class AnthropicPromptPlusMaxTokensTooLongError(Exception):
30-
pass
31-
32-
3332
class AnthropicClient(Client):
3433
"""
3534
Client for the Anthropic models (https://arxiv.org/abs/2204.05862).
@@ -138,9 +137,23 @@ def do_it():
138137
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
139138
except Exception as error:
140139
if "Prompt must contain anthropic.AI_PROMPT" in str(error):
141-
raise AnthropicPromptTooLongError(f"Prompt too long: {request.prompt}")
140+
return RequestResult(
141+
success=False,
142+
cached=False,
143+
error=response["error"],
144+
completions=[],
145+
embedding=[],
146+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
147+
)
142148
if "exceeds max (" in str(error):
143-
raise AnthropicPromptPlusMaxTokensTooLongError(f"Prompt + max_tokens too long: {request.prompt}")
149+
return RequestResult(
150+
success=False,
151+
cached=False,
152+
error=response["error"],
153+
completions=[],
154+
embedding=[],
155+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
156+
)
144157
return RequestResult(success=False, cached=False, error=str(error), completions=[], embedding=[])
145158

146159
# Post process the completion.

src/helm/proxy/clients/palmyra_client.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from helm.common.cache import Cache, CacheConfig
66
from helm.common.hierarchical_logger import hlog
7-
from helm.common.request import Request, RequestResult, Sequence, Token
7+
from helm.common.request import Request, RequestResult, Sequence, Token, ErrorFlags
88
from helm.common.tokenization_request import (
99
DecodeRequest,
1010
DecodeRequestResult,
@@ -75,8 +75,6 @@ def make_request(self, request: Request) -> RequestResult:
7575

7676
def do_it():
7777
result = self._send_request(model_name, raw_request)
78-
if "choices" not in result:
79-
raise ValueError(f"Invalid response: {result}")
8078
return result
8179

8280
# We need to include the engine's name to differentiate among requests made for different model
@@ -99,6 +97,21 @@ def do_it():
9997
error: str = f"PalmyraClient error: {e}"
10098
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
10199

100+
if "choices" not in response:
101+
if "errors" in response and response["errors"][0]["key"] == "fail.content.moderation.failed":
102+
return RequestResult(
103+
success=False,
104+
cached=False,
105+
error=response["errors"][0]["description"],
106+
completions=[],
107+
embedding=[],
108+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
109+
request_time=response["request_time"],
110+
request_datetime=response["request_datetime"],
111+
)
112+
else:
113+
raise ValueError(f"Invalid response: {response}")
114+
102115
response_text: str = response["choices"][0]["text"]
103116

104117
# The Writer API doesn't support echo. If `echo_prompt` is true, combine the prompt and completion.

src/helm/proxy/retry.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,12 @@ def retry_if_request_failed(result: Union[RequestResult, TokenizationRequestResu
4848
"""Fails if `success` of `RequestResult` or `TokenizationRequestResult` is false."""
4949
if not result.success:
5050
hlog(result.error)
51-
return not result.success
51+
retry_if_fail: bool = True
52+
if isinstance(result, RequestResult):
53+
retry_if_fail = (
54+
result.error_flags is None or result.error_flags.is_retriable is None or result.error_flags.is_retriable
55+
)
56+
return not result.success and retry_if_fail
5257

5358

5459
retry_request: Callable = get_retry_decorator(

0 commit comments

Comments
 (0)