10
10
11
11
# TODO #1592: reenable this once the imports are faster
12
12
# from helm.proxy.clients.client import Client
13
+ from helm .proxy .tokenizers .tokenizer import Tokenizer
13
14
14
15
import os
15
16
import math
@@ -31,17 +32,17 @@ def get_credentials(path: str) -> Dict[str, str]:
31
32
return credentials
32
33
33
34
34
- def get_number_of_tokens (prompt : str , tokenizer_client : Any , tokenizer_name : str ) -> int :
35
+ def get_number_of_tokens (prompt : str , tokenizer : Tokenizer , tokenizer_name : str ) -> int :
35
36
tokenization_request = TokenizationRequest (tokenizer = tokenizer_name , text = prompt , encode = True )
36
- tokenization_response = tokenizer_client .tokenize (tokenization_request )
37
+ tokenization_response = tokenizer .tokenize (tokenization_request )
37
38
return len (tokenization_response .tokens )
38
39
39
40
40
41
def try_request (
41
42
client : Any ,
42
43
model_name : str ,
43
44
tokenizer_name : str ,
44
- tokenizer_client : Any ,
45
+ tokenizer : Tokenizer ,
45
46
sequence_length : int ,
46
47
num_tokens : int ,
47
48
prefix : str = "" ,
@@ -51,8 +52,8 @@ def try_request(
51
52
Try to make a request with the given sequence_length and num_tokens.
52
53
Return True if the request was successful, False otherwise.
53
54
"""
54
- num_tokens_prefix = get_number_of_tokens (prefix , tokenizer_client , tokenizer_name )
55
- num_tokens_suffix = get_number_of_tokens (suffix , tokenizer_client , tokenizer_name )
55
+ num_tokens_prefix = get_number_of_tokens (prefix , tokenizer , tokenizer_name )
56
+ num_tokens_suffix = get_number_of_tokens (suffix , tokenizer , tokenizer_name )
56
57
57
58
try :
58
59
request = Request (
@@ -76,25 +77,25 @@ class RequestLimits:
76
77
77
78
78
79
def figure_out_max_prompt_length (
79
- client : Any , # Client ,
80
+ client : AutoClient ,
80
81
model_name : str ,
81
82
tokenizer_name : str ,
82
83
upper_bound : int = 9500 ,
83
84
lower_bound : int = 450 ,
84
85
prefix : str = "" ,
85
86
suffix : str = "" ,
86
87
) -> RequestLimits :
87
- tokenizer_client = client ._get_tokenizer_client (tokenizer_name )
88
- num_tokens_prefix = get_number_of_tokens (prefix , tokenizer_client , tokenizer_name )
89
- num_tokens_suffix = get_number_of_tokens (suffix , tokenizer_client , tokenizer_name )
88
+ tokenizer = client ._get_tokenizer (tokenizer_name )
89
+ num_tokens_prefix = get_number_of_tokens (prefix , tokenizer , tokenizer_name )
90
+ num_tokens_suffix = get_number_of_tokens (suffix , tokenizer , tokenizer_name )
90
91
91
92
# Perform a binary search to find the max tokens between lower_bound and upper_bound
92
93
lower_bound += num_tokens_prefix + num_tokens_suffix
93
94
pbar : tqdm
94
95
with tqdm (total = int (math .log2 (upper_bound - lower_bound ))) as pbar :
95
96
while lower_bound < upper_bound :
96
97
middle = math .ceil ((lower_bound + upper_bound ) / 2 )
97
- if try_request (client , model_name , tokenizer_name , tokenizer_client , middle , 0 , prefix , suffix ):
98
+ if try_request (client , model_name , tokenizer_name , tokenizer , middle , 0 , prefix , suffix ):
98
99
lower_bound = middle
99
100
else :
100
101
upper_bound = middle - 1
@@ -103,7 +104,7 @@ def figure_out_max_prompt_length(
103
104
# Just in case the number of tokens does not match the number of words, check number of tokens with tokenizer
104
105
max_prompt_length = get_number_of_tokens (
105
106
prefix + " " .join (["hello" ] * (lower_bound - num_tokens_prefix - num_tokens_suffix )) + suffix ,
106
- tokenizer_client ,
107
+ tokenizer ,
107
108
tokenizer_name ,
108
109
)
109
110
return RequestLimits (
@@ -122,7 +123,7 @@ def figure_out_max_prompt_length_plus_tokens(
122
123
prefix : str = "" ,
123
124
suffix : str = "" ,
124
125
) -> int :
125
- tokenizer_client = client ._get_tokenizer_client (tokenizer_name )
126
+ tokenizer = client ._get_tokenizer (tokenizer_name )
126
127
lower_bound = 1
127
128
upper_bound = 2 * max_prompt_length + 1
128
129
@@ -131,7 +132,7 @@ def figure_out_max_prompt_length_plus_tokens(
131
132
client ,
132
133
model_name ,
133
134
tokenizer_name ,
134
- tokenizer_client ,
135
+ tokenizer ,
135
136
max_prompt_length ,
136
137
2 ** 31 - 2 - max_prompt_length ,
137
138
prefix ,
@@ -147,9 +148,7 @@ def figure_out_max_prompt_length_plus_tokens(
147
148
with tqdm (total = int (math .log2 (upper_bound - lower_bound ))) as pbar :
148
149
while lower_bound < upper_bound :
149
150
middle = math .ceil ((lower_bound + upper_bound ) / 2 )
150
- if try_request (
151
- client , model_name , tokenizer_name , tokenizer_client , max_prompt_length , middle , prefix , suffix
152
- ):
151
+ if try_request (client , model_name , tokenizer_name , tokenizer , max_prompt_length , middle , prefix , suffix ):
153
152
lower_bound = middle
154
153
else :
155
154
upper_bound = middle - 1
@@ -159,39 +158,37 @@ def figure_out_max_prompt_length_plus_tokens(
159
158
160
159
161
160
def check_limits (
162
- client : Any , # Client ,
161
+ client : AutoClient ,
163
162
model_name : str ,
164
163
tokenizer_name : str ,
165
164
limits : RequestLimits ,
166
165
prefix : str = "" ,
167
166
suffix : str = "" ,
168
167
) -> bool :
169
- tokenizer_client = client ._get_tokenizer_client (tokenizer_name )
168
+ tokenizer = client ._get_tokenizer (tokenizer_name )
170
169
result : bool = True
171
170
172
171
# Check the max_prompt_length
173
172
max_prompt_length = limits .max_prompt_length
174
173
if max_prompt_length < 0 :
175
174
print ("No limit on the number of tokens" )
176
- if not try_request (client , model_name , tokenizer_name , tokenizer_client , 2 ** 32 - 2 , 0 , prefix , suffix ):
175
+ if not try_request (client , model_name , tokenizer_name , tokenizer , 2 ** 32 - 2 , 0 , prefix , suffix ):
177
176
print (f"There is a limit on the number of tokens. Params: max_prompt_length={ 2 ** 32 - 2 } , max_tokens=1" )
178
177
result = False
179
178
else :
180
179
# There is a limit on the number of tokens
181
180
# If there is no limit on the number of tokens, max_prompt_length should be -1
182
181
# And we should not be here
183
182
# Check that max_prompt_length is ok
184
- if not try_request (client , model_name , tokenizer_name , tokenizer_client , max_prompt_length , 0 , prefix , suffix ):
183
+ if not try_request (client , model_name , tokenizer_name , tokenizer , max_prompt_length , 0 , prefix , suffix ):
185
184
print (f"max_prompt_length is too big. Params: max_prompt_length={ max_prompt_length } , max_tokens=1" )
186
185
result = False
187
186
# Check that max_prompt_length + 1 is not ok
188
- if try_request (client , model_name , tokenizer_name , tokenizer_client , max_prompt_length + 1 , 0 , prefix , suffix ):
187
+ if try_request (client , model_name , tokenizer_name , tokenizer , max_prompt_length + 1 , 0 , prefix , suffix ):
189
188
print (f"max_prompt_length could be bigger. Params: max_prompt_length={ max_prompt_length + 1 } , max_tokens=1" )
190
189
result = False
191
190
# Check that max_prompt_length - 1 is ok
192
- if not try_request (
193
- client , model_name , tokenizer_name , tokenizer_client , max_prompt_length - 1 , 0 , prefix , suffix
194
- ):
191
+ if not try_request (client , model_name , tokenizer_name , tokenizer , max_prompt_length - 1 , 0 , prefix , suffix ):
195
192
print (
196
193
f"max_prompt_length ssems to be inconsistent. max_prompt_length={ max_prompt_length } "
197
194
f"is ok but max_prompt_length={ max_prompt_length - 1 } is not, with max_tokens=0"
@@ -206,7 +203,7 @@ def check_limits(
206
203
if max_prompt_length_plus_tokens < 0 :
207
204
print ("No limit on the number of tokens" )
208
205
if not try_request (
209
- client , model_name , tokenizer_name , tokenizer_client , max (1 , max_prompt_length ), 2 ** 32 - 2 , prefix , suffix
206
+ client , model_name , tokenizer_name , tokenizer , max (1 , max_prompt_length ), 2 ** 32 - 2 , prefix , suffix
210
207
):
211
208
print (
212
209
f"There is a limit on the number of tokens. Params: max_prompt_length={ max_prompt_length } ,"
@@ -221,7 +218,7 @@ def check_limits(
221
218
client ,
222
219
model_name ,
223
220
tokenizer_name ,
224
- tokenizer_client ,
221
+ tokenizer ,
225
222
max_prompt_length ,
226
223
max_prompt_length_plus_tokens - max_prompt_length ,
227
224
prefix ,
@@ -236,7 +233,7 @@ def check_limits(
236
233
client ,
237
234
model_name ,
238
235
tokenizer_name ,
239
- tokenizer_client ,
236
+ tokenizer ,
240
237
max_prompt_length ,
241
238
max_prompt_length_plus_tokens - max_prompt_length + 1 ,
242
239
prefix ,
0 commit comments