Skip to content

Commit 50e6565

Browse files
Add the Tokenizer object logic (stanford-crfm#1874)
1 parent 020255d commit 50e6565

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1411
-1221
lines changed

scripts/cache/fix_anthropic_cache.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from helm.common.hierarchical_logger import hlog, htrack
99
from helm.proxy.clients.anthropic_client import AnthropicLegacyClient
1010
from helm.proxy.retry import get_retry_decorator
11+
from helm.proxy.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer
1112

1213

1314
"""
@@ -47,7 +48,9 @@ def add_logprobs(mongo_uri: str, credentials_path: str, dry_run: bool):
4748
api_key: str = credentials["anthropicApiKey"]
4849

4950
cache_config = MongoCacheConfig(mongo_uri, collection_name="anthropic")
50-
client = AnthropicLegacyClient(api_key, cache_config)
51+
client = AnthropicLegacyClient(
52+
api_key=api_key, tokenizer=HuggingFaceTokenizer(cache_config), cache_config=cache_config
53+
)
5154

5255
with create_key_value_store(cache_config) as cache:
5356
for i, (request, response) in enumerate(cache.get_all()):

scripts/compute_request_limits.py

+24-27
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
# TODO #1592: reenable this once the imports are faster
1212
# from helm.proxy.clients.client import Client
13+
from helm.proxy.tokenizers.tokenizer import Tokenizer
1314

1415
import os
1516
import math
@@ -31,17 +32,17 @@ def get_credentials(path: str) -> Dict[str, str]:
3132
return credentials
3233

3334

34-
def get_number_of_tokens(prompt: str, tokenizer_client: Any, tokenizer_name: str) -> int:
35+
def get_number_of_tokens(prompt: str, tokenizer: Tokenizer, tokenizer_name: str) -> int:
3536
tokenization_request = TokenizationRequest(tokenizer=tokenizer_name, text=prompt, encode=True)
36-
tokenization_response = tokenizer_client.tokenize(tokenization_request)
37+
tokenization_response = tokenizer.tokenize(tokenization_request)
3738
return len(tokenization_response.tokens)
3839

3940

4041
def try_request(
4142
client: Any,
4243
model_name: str,
4344
tokenizer_name: str,
44-
tokenizer_client: Any,
45+
tokenizer: Tokenizer,
4546
sequence_length: int,
4647
num_tokens: int,
4748
prefix: str = "",
@@ -51,8 +52,8 @@ def try_request(
5152
Try to make a request with the given sequence_length and num_tokens.
5253
Return True if the request was successful, False otherwise.
5354
"""
54-
num_tokens_prefix = get_number_of_tokens(prefix, tokenizer_client, tokenizer_name)
55-
num_tokens_suffix = get_number_of_tokens(suffix, tokenizer_client, tokenizer_name)
55+
num_tokens_prefix = get_number_of_tokens(prefix, tokenizer, tokenizer_name)
56+
num_tokens_suffix = get_number_of_tokens(suffix, tokenizer, tokenizer_name)
5657

5758
try:
5859
request = Request(
@@ -76,25 +77,25 @@ class RequestLimits:
7677

7778

7879
def figure_out_max_prompt_length(
79-
client: Any, # Client,
80+
client: AutoClient,
8081
model_name: str,
8182
tokenizer_name: str,
8283
upper_bound: int = 9500,
8384
lower_bound: int = 450,
8485
prefix: str = "",
8586
suffix: str = "",
8687
) -> RequestLimits:
87-
tokenizer_client = client._get_tokenizer_client(tokenizer_name)
88-
num_tokens_prefix = get_number_of_tokens(prefix, tokenizer_client, tokenizer_name)
89-
num_tokens_suffix = get_number_of_tokens(suffix, tokenizer_client, tokenizer_name)
88+
tokenizer = client._get_tokenizer(tokenizer_name)
89+
num_tokens_prefix = get_number_of_tokens(prefix, tokenizer, tokenizer_name)
90+
num_tokens_suffix = get_number_of_tokens(suffix, tokenizer, tokenizer_name)
9091

9192
# Perform a binary search to find the max tokens between lower_bound and upper_bound
9293
lower_bound += num_tokens_prefix + num_tokens_suffix
9394
pbar: tqdm
9495
with tqdm(total=int(math.log2(upper_bound - lower_bound))) as pbar:
9596
while lower_bound < upper_bound:
9697
middle = math.ceil((lower_bound + upper_bound) / 2)
97-
if try_request(client, model_name, tokenizer_name, tokenizer_client, middle, 0, prefix, suffix):
98+
if try_request(client, model_name, tokenizer_name, tokenizer, middle, 0, prefix, suffix):
9899
lower_bound = middle
99100
else:
100101
upper_bound = middle - 1
@@ -103,7 +104,7 @@ def figure_out_max_prompt_length(
103104
# Just in case the number of tokens does not match the number of words, check number of tokens with tokenizer
104105
max_prompt_length = get_number_of_tokens(
105106
prefix + " ".join(["hello"] * (lower_bound - num_tokens_prefix - num_tokens_suffix)) + suffix,
106-
tokenizer_client,
107+
tokenizer,
107108
tokenizer_name,
108109
)
109110
return RequestLimits(
@@ -122,7 +123,7 @@ def figure_out_max_prompt_length_plus_tokens(
122123
prefix: str = "",
123124
suffix: str = "",
124125
) -> int:
125-
tokenizer_client = client._get_tokenizer_client(tokenizer_name)
126+
tokenizer = client._get_tokenizer(tokenizer_name)
126127
lower_bound = 1
127128
upper_bound = 2 * max_prompt_length + 1
128129

@@ -131,7 +132,7 @@ def figure_out_max_prompt_length_plus_tokens(
131132
client,
132133
model_name,
133134
tokenizer_name,
134-
tokenizer_client,
135+
tokenizer,
135136
max_prompt_length,
136137
2**31 - 2 - max_prompt_length,
137138
prefix,
@@ -147,9 +148,7 @@ def figure_out_max_prompt_length_plus_tokens(
147148
with tqdm(total=int(math.log2(upper_bound - lower_bound))) as pbar:
148149
while lower_bound < upper_bound:
149150
middle = math.ceil((lower_bound + upper_bound) / 2)
150-
if try_request(
151-
client, model_name, tokenizer_name, tokenizer_client, max_prompt_length, middle, prefix, suffix
152-
):
151+
if try_request(client, model_name, tokenizer_name, tokenizer, max_prompt_length, middle, prefix, suffix):
153152
lower_bound = middle
154153
else:
155154
upper_bound = middle - 1
@@ -159,39 +158,37 @@ def figure_out_max_prompt_length_plus_tokens(
159158

160159

161160
def check_limits(
162-
client: Any, # Client,
161+
client: AutoClient,
163162
model_name: str,
164163
tokenizer_name: str,
165164
limits: RequestLimits,
166165
prefix: str = "",
167166
suffix: str = "",
168167
) -> bool:
169-
tokenizer_client = client._get_tokenizer_client(tokenizer_name)
168+
tokenizer = client._get_tokenizer(tokenizer_name)
170169
result: bool = True
171170

172171
# Check the max_prompt_length
173172
max_prompt_length = limits.max_prompt_length
174173
if max_prompt_length < 0:
175174
print("No limit on the number of tokens")
176-
if not try_request(client, model_name, tokenizer_name, tokenizer_client, 2**32 - 2, 0, prefix, suffix):
175+
if not try_request(client, model_name, tokenizer_name, tokenizer, 2**32 - 2, 0, prefix, suffix):
177176
print(f"There is a limit on the number of tokens. Params: max_prompt_length={2**32 - 2}, max_tokens=1")
178177
result = False
179178
else:
180179
# There is a limit on the number of tokens
181180
# If there is no limit on the number of tokens, max_prompt_length should be -1
182181
# And we should not be here
183182
# Check that max_prompt_length is ok
184-
if not try_request(client, model_name, tokenizer_name, tokenizer_client, max_prompt_length, 0, prefix, suffix):
183+
if not try_request(client, model_name, tokenizer_name, tokenizer, max_prompt_length, 0, prefix, suffix):
185184
print(f"max_prompt_length is too big. Params: max_prompt_length={max_prompt_length}, max_tokens=1")
186185
result = False
187186
# Check that max_prompt_length + 1 is not ok
188-
if try_request(client, model_name, tokenizer_name, tokenizer_client, max_prompt_length + 1, 0, prefix, suffix):
187+
if try_request(client, model_name, tokenizer_name, tokenizer, max_prompt_length + 1, 0, prefix, suffix):
189188
print(f"max_prompt_length could be bigger. Params: max_prompt_length={max_prompt_length+1}, max_tokens=1")
190189
result = False
191190
# Check that max_prompt_length - 1 is ok
192-
if not try_request(
193-
client, model_name, tokenizer_name, tokenizer_client, max_prompt_length - 1, 0, prefix, suffix
194-
):
191+
if not try_request(client, model_name, tokenizer_name, tokenizer, max_prompt_length - 1, 0, prefix, suffix):
195192
print(
196193
f"max_prompt_length ssems to be inconsistent. max_prompt_length={max_prompt_length} "
197194
f"is ok but max_prompt_length={max_prompt_length-1} is not, with max_tokens=0"
@@ -206,7 +203,7 @@ def check_limits(
206203
if max_prompt_length_plus_tokens < 0:
207204
print("No limit on the number of tokens")
208205
if not try_request(
209-
client, model_name, tokenizer_name, tokenizer_client, max(1, max_prompt_length), 2**32 - 2, prefix, suffix
206+
client, model_name, tokenizer_name, tokenizer, max(1, max_prompt_length), 2**32 - 2, prefix, suffix
210207
):
211208
print(
212209
f"There is a limit on the number of tokens. Params: max_prompt_length={max_prompt_length},"
@@ -221,7 +218,7 @@ def check_limits(
221218
client,
222219
model_name,
223220
tokenizer_name,
224-
tokenizer_client,
221+
tokenizer,
225222
max_prompt_length,
226223
max_prompt_length_plus_tokens - max_prompt_length,
227224
prefix,
@@ -236,7 +233,7 @@ def check_limits(
236233
client,
237234
model_name,
238235
tokenizer_name,
239-
tokenizer_client,
236+
tokenizer,
240237
max_prompt_length,
241238
max_prompt_length_plus_tokens - max_prompt_length + 1,
242239
prefix,

src/helm/benchmark/window_services/cohere_window_service.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import List, Optional
22

3-
from helm.proxy.clients.cohere_client import CohereClient
3+
from helm.proxy.tokenizers.cohere_tokenizer import CohereTokenizer
44
from .local_window_service import LocalWindowService
55
from .tokenizer_service import TokenizerService
66
from .window_service import EncodeResult
@@ -62,7 +62,7 @@ def encode(self, text: str, truncation: bool = False, max_length: Optional[int]
6262

6363
response: TokenizationRequestResult
6464
tokens: List[TokenizationToken] = []
65-
if truncation or len(text) <= CohereClient.TOKENIZE_API_MAX_TEXT_LENGTH:
65+
if truncation or len(text) <= CohereTokenizer.TOKENIZE_API_MAX_TEXT_LENGTH:
6666
response = self.service.tokenize(
6767
TokenizationRequest(
6868
text,
@@ -80,7 +80,7 @@ def encode(self, text: str, truncation: bool = False, max_length: Optional[int]
8080
# and make a request for each chunk.
8181
# This can potentially break up valid tokens at the end of the chunk, but the chunk size
8282
# is large enough that this happens infrequently.
83-
chunk_size: int = CohereClient.TOKENIZE_API_MAX_TEXT_LENGTH
83+
chunk_size: int = CohereTokenizer.TOKENIZE_API_MAX_TEXT_LENGTH
8484
for i in range(0, len(text), chunk_size):
8585
chunk: str = text[i : chunk_size + i]
8686
response = self.service.tokenize(
@@ -120,7 +120,7 @@ def fits_within_context_window(self, text: str, expected_completion_token_length
120120
so first check if the text has fewer than 65,536 characters.
121121
"""
122122
return (
123-
len(text) <= CohereClient.TOKENIZE_API_MAX_TEXT_LENGTH
123+
len(text) <= CohereTokenizer.TOKENIZE_API_MAX_TEXT_LENGTH
124124
and self.get_num_tokens(text) + expected_completion_token_length <= self.max_request_length
125125
)
126126

@@ -130,7 +130,7 @@ def truncate_from_right(self, text: str, expected_completion_token_length: int =
130130
minus the expected completion length (defaults to 0).
131131
"""
132132
# First truncate the text so it's within `CohereClient.TOKENIZE_MAX_TEXT_LENGTH` length.
133-
text = text[: CohereClient.TOKENIZE_API_MAX_TEXT_LENGTH]
133+
text = text[: CohereTokenizer.TOKENIZE_API_MAX_TEXT_LENGTH]
134134

135135
max_length: int = self.max_request_length - expected_completion_token_length
136136
result: str = self.decode(self.encode(text, truncation=True, max_length=max_length).tokens)

src/helm/benchmark/window_services/huggingface_window_service.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Optional
2-
from helm.proxy.clients.huggingface_tokenizer import HuggingFaceTokenizers
2+
from helm.proxy.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer
33
from .local_window_service import LocalWindowService
44
from .tokenizer_service import TokenizerService
55

@@ -16,7 +16,7 @@ def __init__(
1616
):
1717
super().__init__(service)
1818
self._tokenizer_name = tokenizer_name
19-
tokenizer = HuggingFaceTokenizers.get_tokenizer(
19+
tokenizer = HuggingFaceTokenizer.get_tokenizer(
2020
helm_tokenizer_name=tokenizer_name,
2121
pretrained_model_name_or_path=pretrained_model_name_or_path or tokenizer_name,
2222
revision=revision,

src/helm/benchmark/window_services/yalm_window_service.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from helm.proxy.clients.yalm_tokenizer.yalm_tokenizer import YaLMTokenizer
1+
from helm.proxy.tokenizers.yalm_tokenizer_data.yalm_tokenizer import YaLMTokenizer
22
from .local_window_service import LocalWindowService
33
from .tokenizer_service import TokenizerService
44

src/helm/common/request.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import time
12
from dataclasses import dataclass, field
2-
from typing import List, Optional, Dict
3+
from typing import Any, Callable, Dict, List, Optional
34

45
from helm.common.media_object import MultimediaObject
56
from helm.proxy.models import Model, get_model
@@ -213,3 +214,17 @@ def render_lines(self) -> List[str]:
213214
completions=[],
214215
embedding=[],
215216
)
217+
218+
219+
def wrap_request_time(compute: Callable[[], Dict[str, Any]]) -> Callable[[], Any]:
220+
"""Return a version of `compute` that puts `request_time` into its output."""
221+
222+
def wrapped_compute():
223+
start_time = time.time()
224+
response = compute()
225+
end_time = time.time()
226+
response["request_time"] = end_time - start_time
227+
response["request_datetime"] = int(start_time)
228+
return response
229+
230+
return wrapped_compute

0 commit comments

Comments
 (0)