@@ -17,6 +17,7 @@ class TestICETokenizerClient:
17
17
def setup_method (self , method ):
18
18
cache_file = tempfile .NamedTemporaryFile (delete = False )
19
19
self .cache_path : str = cache_file .name
20
+ self .tokenizer_name = "TsinghuaKEG/ice"
20
21
self .client = ICETokenizerClient (SqliteCacheConfig (self .cache_path ))
21
22
22
23
# The test cases were created using the examples from https://github.com/THUDM/icetk#tokenization
@@ -27,26 +28,28 @@ def teardown_method(self, method):
27
28
os .remove (self .cache_path )
28
29
29
30
def test_tokenize (self ):
30
- request = TokenizationRequest (text = self .test_prompt )
31
+ request = TokenizationRequest (text = self .test_prompt , tokenizer = self . tokenizer_name )
31
32
result : TokenizationRequestResult = self .client .tokenize (request )
32
33
assert not result .cached , "First time making the tokenize request. Result should not be cached"
33
34
result : TokenizationRequestResult = self .client .tokenize (request )
34
35
assert result .cached , "Result should be cached"
35
36
assert result .raw_tokens == [" Hello" , " World" , "!" , " I" , " am" , " ice" , "tk" , "." ]
36
37
37
38
def test_encode (self ):
38
- request = TokenizationRequest (text = self .test_prompt , encode = True )
39
+ request = TokenizationRequest (text = self .test_prompt , tokenizer = self . tokenizer_name , encode = True )
39
40
result : TokenizationRequestResult = self .client .tokenize (request )
40
41
assert result .raw_tokens == self .encoded_test_prompt
41
42
42
43
def test_encode_with_truncation (self ):
43
44
max_length : int = 3
44
- request = TokenizationRequest (text = self .test_prompt , encode = True , truncation = True , max_length = max_length )
45
+ request = TokenizationRequest (
46
+ text = self .test_prompt , tokenizer = self .tokenizer_name , encode = True , truncation = True , max_length = max_length
47
+ )
45
48
result : TokenizationRequestResult = self .client .tokenize (request )
46
49
assert result .raw_tokens == self .encoded_test_prompt [:max_length ]
47
50
48
51
def test_decode (self ):
49
- request = DecodeRequest (tokens = self .encoded_test_prompt )
52
+ request = DecodeRequest (tokens = self .encoded_test_prompt , tokenizer = self . tokenizer_name )
50
53
result : DecodeRequestResult = self .client .decode (request )
51
54
assert not result .cached , "First time making the decode request. Result should not be cached"
52
55
result : DecodeRequestResult = self .client .decode (request )
0 commit comments