2
2
import torch
3
3
from dataclasses import asdict
4
4
from transformers import AutoModelForCausalLM , AutoTokenizer
5
+ from transformers .generation .stopping_criteria import (
6
+ StoppingCriteria ,
7
+ StoppingCriteriaList ,
8
+ STOPPING_CRITERIA_INPUTS_DOCSTRING ,
9
+ add_start_docstrings ,
10
+ )
5
11
from typing import Any , Dict , List , Optional
6
12
7
13
from helm .common .cache import Cache , CacheConfig
@@ -35,6 +41,21 @@ def resolve_alias(model_name: str) -> str:
35
41
return _MODEL_NAME_ALIASES .get (model_name , model_name )
36
42
37
43
44
+ class StopAtSpecificTokenCriteria (StoppingCriteria ):
45
+ def __init__ (self , stop_sequence : List [int ] = None ):
46
+ super ().__init__ ()
47
+ self .stop_sequence = stop_sequence
48
+
49
+ # @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
50
+ def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor , ** kwargs ) -> bool :
51
+ # Create a tensor from the stop_sequence
52
+ stop_sequence_tensor = torch .tensor (self .stop_sequence , device = input_ids .device , dtype = input_ids .dtype )
53
+
54
+ # Check if the current sequence ends with the stop_sequence
55
+ current_sequence = input_ids [:, - len (self .stop_sequence ) :]
56
+ return torch .all (current_sequence == stop_sequence_tensor ).item ()
57
+
58
+
38
59
class HuggingFaceServer :
39
60
"""A thin wrapper around a Hugging Face AutoModelForCausalLM for HuggingFaceClient to call."""
40
61
@@ -72,9 +93,10 @@ def serve_request(self, raw_request: Dict[str, Any]):
72
93
raw_request ["stop_sequences" ], return_token_type_ids = False , add_special_tokens = False
73
94
)
74
95
assert len (stop_sequence_ids .input_ids ) == 1 , "Total number of stop words should be 1."
75
- assert len (stop_sequence_ids .input_ids [0 ]) == 1 , "Total number of tokens in each stop word should be 1."
96
+ # assert len(stop_sequence_ids.input_ids[0]) == 1, "Total number of tokens in each stop word should be 1."
97
+ if len (stop_sequence_ids .input_ids [0 ]) == 1 :
98
+ raw_request ["eos_token_id" ] = stop_sequence_ids .input_ids [0 ][0 ]
76
99
del raw_request ["stop_sequences" ]
77
- raw_request ["eos_token_id" ] = stop_sequence_ids .input_ids [0 ][0 ]
78
100
79
101
# Strip out irrelevant parameters
80
102
relevant_raw_request = {
@@ -83,8 +105,16 @@ def serve_request(self, raw_request: Dict[str, Any]):
83
105
if key not in ["engine" , "prompt" , "echo_prompt" , "stop_sequences" ]
84
106
}
85
107
108
+ stopping_criteria = StoppingCriteriaList ()
109
+ if stop_sequence_ids != None :
110
+ stopping_criteria .append (StopAtSpecificTokenCriteria (stop_sequence = stop_sequence_ids .input_ids [0 ]))
111
+
86
112
# Use HuggingFace's `generate` method.
87
- output = self .model .generate (** encoded_input , ** relevant_raw_request )
113
+ output = self .model .generate (
114
+ ** encoded_input ,
115
+ ** relevant_raw_request ,
116
+ stopping_criteria = stopping_criteria if len (stop_sequence_ids .input_ids [0 ]) > 1 else None ,
117
+ )
88
118
sequences = output .sequences
89
119
scores = output .scores
90
120
0 commit comments