5
5
from transformers .generation .stopping_criteria import (
6
6
StoppingCriteria ,
7
7
StoppingCriteriaList ,
8
- STOPPING_CRITERIA_INPUTS_DOCSTRING ,
9
- add_start_docstrings ,
10
8
)
11
9
from typing import Any , Dict , List , Optional
12
10
@@ -42,18 +40,17 @@ def resolve_alias(model_name: str) -> str:
42
40
43
41
44
42
class StopAtSpecificTokenCriteria (StoppingCriteria ):
45
- def __init__ (self , stop_sequence : List [int ] = None ):
43
+ def __init__ (self , stop_sequence : List [int ]):
46
44
super ().__init__ ()
47
45
self .stop_sequence = stop_sequence
48
46
49
- # @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
50
47
def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor , ** kwargs ) -> bool :
51
48
# Create a tensor from the stop_sequence
52
49
stop_sequence_tensor = torch .tensor (self .stop_sequence , device = input_ids .device , dtype = input_ids .dtype )
53
50
54
51
# Check if the current sequence ends with the stop_sequence
55
52
current_sequence = input_ids [:, - len (self .stop_sequence ) :]
56
- return torch .all (current_sequence == stop_sequence_tensor ).item ()
53
+ return bool ( torch .all (current_sequence == stop_sequence_tensor ).item () )
57
54
58
55
59
56
class HuggingFaceServer :
@@ -88,14 +85,17 @@ def serve_request(self, raw_request: Dict[str, Any]):
88
85
raw_request ["output_scores" ] = True
89
86
top_k_per_token : int = raw_request ["top_k_per_token" ]
90
87
del raw_request ["top_k_per_token" ]
88
+ stopping_criteria : Optional [StoppingCriteriaList ] = None
91
89
if len (raw_request ["stop_sequences" ]) > 0 :
92
90
stop_sequence_ids = self .tokenizer (
93
91
raw_request ["stop_sequences" ], return_token_type_ids = False , add_special_tokens = False
94
92
)
95
- assert len (stop_sequence_ids .input_ids ) == 1 , "Total number of stop words should be 1."
96
- # assert len(stop_sequence_ids.input_ids[0]) == 1, "Total number of tokens in each stop word should be 1."
93
+ assert len (stop_sequence_ids .input_ids ) == 1 , "Total number of stop sequences should be 1."
97
94
if len (stop_sequence_ids .input_ids [0 ]) == 1 :
98
95
raw_request ["eos_token_id" ] = stop_sequence_ids .input_ids [0 ][0 ]
96
+ else :
97
+ stopping_criteria = StoppingCriteriaList ()
98
+ stopping_criteria .append (StopAtSpecificTokenCriteria (stop_sequence = stop_sequence_ids .input_ids [0 ]))
99
99
del raw_request ["stop_sequences" ]
100
100
101
101
# Strip out irrelevant parameters
@@ -105,15 +105,11 @@ def serve_request(self, raw_request: Dict[str, Any]):
105
105
if key not in ["engine" , "prompt" , "echo_prompt" , "stop_sequences" ]
106
106
}
107
107
108
- stopping_criteria = StoppingCriteriaList ()
109
- if stop_sequence_ids != None :
110
- stopping_criteria .append (StopAtSpecificTokenCriteria (stop_sequence = stop_sequence_ids .input_ids [0 ]))
111
-
112
108
# Use HuggingFace's `generate` method.
113
109
output = self .model .generate (
114
110
** encoded_input ,
115
111
** relevant_raw_request ,
116
- stopping_criteria = stopping_criteria if len ( stop_sequence_ids . input_ids [ 0 ]) > 1 else None ,
112
+ stopping_criteria = stopping_criteria ,
117
113
)
118
114
sequences = output .sequences
119
115
scores = output .scores
0 commit comments