Skip to content

Commit 2b06784

Browse files
authored
Fix: Total number of tokens in each stop word should be 1 (stanford-crfm#1892)
1 parent 9133b2d commit 2b06784

File tree

1 file changed

+33
-3
lines changed

1 file changed

+33
-3
lines changed

src/helm/proxy/clients/huggingface_client.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
import torch
33
from dataclasses import asdict
44
from transformers import AutoModelForCausalLM, AutoTokenizer
5+
from transformers.generation.stopping_criteria import (
6+
StoppingCriteria,
7+
StoppingCriteriaList,
8+
STOPPING_CRITERIA_INPUTS_DOCSTRING,
9+
add_start_docstrings,
10+
)
511
from typing import Any, Dict, List, Optional
612

713
from helm.common.cache import Cache, CacheConfig
@@ -35,6 +41,21 @@ def resolve_alias(model_name: str) -> str:
3541
return _MODEL_NAME_ALIASES.get(model_name, model_name)
3642

3743

44+
class StopAtSpecificTokenCriteria(StoppingCriteria):
45+
def __init__(self, stop_sequence: List[int] = None):
46+
super().__init__()
47+
self.stop_sequence = stop_sequence
48+
49+
# @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
50+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
51+
# Create a tensor from the stop_sequence
52+
stop_sequence_tensor = torch.tensor(self.stop_sequence, device=input_ids.device, dtype=input_ids.dtype)
53+
54+
# Check if the current sequence ends with the stop_sequence
55+
current_sequence = input_ids[:, -len(self.stop_sequence) :]
56+
return torch.all(current_sequence == stop_sequence_tensor).item()
57+
58+
3859
class HuggingFaceServer:
3960
"""A thin wrapper around a Hugging Face AutoModelForCausalLM for HuggingFaceClient to call."""
4061

@@ -72,9 +93,10 @@ def serve_request(self, raw_request: Dict[str, Any]):
7293
raw_request["stop_sequences"], return_token_type_ids=False, add_special_tokens=False
7394
)
7495
assert len(stop_sequence_ids.input_ids) == 1, "Total number of stop words should be 1."
75-
assert len(stop_sequence_ids.input_ids[0]) == 1, "Total number of tokens in each stop word should be 1."
96+
# assert len(stop_sequence_ids.input_ids[0]) == 1, "Total number of tokens in each stop word should be 1."
97+
if len(stop_sequence_ids.input_ids[0]) == 1:
98+
raw_request["eos_token_id"] = stop_sequence_ids.input_ids[0][0]
7699
del raw_request["stop_sequences"]
77-
raw_request["eos_token_id"] = stop_sequence_ids.input_ids[0][0]
78100

79101
# Strip out irrelevant parameters
80102
relevant_raw_request = {
@@ -83,8 +105,16 @@ def serve_request(self, raw_request: Dict[str, Any]):
83105
if key not in ["engine", "prompt", "echo_prompt", "stop_sequences"]
84106
}
85107

108+
stopping_criteria = StoppingCriteriaList()
109+
if stop_sequence_ids != None:
110+
stopping_criteria.append(StopAtSpecificTokenCriteria(stop_sequence=stop_sequence_ids.input_ids[0]))
111+
86112
# Use HuggingFace's `generate` method.
87-
output = self.model.generate(**encoded_input, **relevant_raw_request)
113+
output = self.model.generate(
114+
**encoded_input,
115+
**relevant_raw_request,
116+
stopping_criteria=stopping_criteria if len(stop_sequence_ids.input_ids[0]) > 1 else None,
117+
)
88118
sequences = output.sequences
89119
scores = output.scores
90120

0 commit comments

Comments
 (0)