Skip to content

Commit 0620f63

Browse files
authored
Fix pre-commit breakages (stanford-crfm#1902)
1 parent 473acb2 commit 0620f63

File tree

4 files changed

+27
-23
lines changed

4 files changed

+27
-23
lines changed

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ transformers==4.28.1
170170
trio==0.22.0
171171
trio-websocket==0.9.2
172172
typer==0.4.2
173+
types-Pillow==9.3.0.4
173174
types-pytz==2022.4.0.0
174175
types-redis==4.3.21.1
175176
types-requests==2.28.11.2

src/helm/common/images_utils.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
from .general import is_url
1010

1111

12-
def open_image(image_location: str) -> Image:
12+
def open_image(image_location: str) -> Image.Image:
1313
"""
1414
Opens image with the Python Imaging Library.
1515
"""
16-
image: Image
16+
image: Image.Image
1717
if is_url(image_location):
1818
image = Image.open(requests.get(image_location, stream=True).raw)
1919
else:
@@ -24,7 +24,7 @@ def open_image(image_location: str) -> Image:
2424
def encode_base64(image_location: str, format="JPEG") -> str:
2525
"""Returns the base64 representation of an image file."""
2626
image_file = io.BytesIO()
27-
image: Image = open_image(image_location)
27+
image: Image.Image = open_image(image_location)
2828
image.save(image_file, format=format)
2929
return base64.b64encode(image_file.getvalue()).decode("ascii")
3030

@@ -36,7 +36,8 @@ def copy_image(src: str, dest: str, width: Optional[int] = None, height: Optiona
3636
"""
3737
if (width is not None and height is not None) or is_url(src):
3838
image = open_image(src)
39-
resized_image = image.resize((width, height), Image.ANTIALIAS)
40-
resized_image.save(dest)
39+
if width is not None and height is not None:
40+
image = image.resize((width, height), Image.ANTIALIAS)
41+
image.save(dest)
4142
else:
4243
shutil.copy(src, dest)

src/helm/proxy/clients/huggingface_client.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from transformers.generation.stopping_criteria import (
66
StoppingCriteria,
77
StoppingCriteriaList,
8-
STOPPING_CRITERIA_INPUTS_DOCSTRING,
9-
add_start_docstrings,
108
)
119
from typing import Any, Dict, List, Optional
1210

@@ -42,18 +40,17 @@ def resolve_alias(model_name: str) -> str:
4240

4341

4442
class StopAtSpecificTokenCriteria(StoppingCriteria):
45-
def __init__(self, stop_sequence: List[int] = None):
43+
def __init__(self, stop_sequence: List[int]):
4644
super().__init__()
4745
self.stop_sequence = stop_sequence
4846

49-
# @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
5047
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
5148
# Create a tensor from the stop_sequence
5249
stop_sequence_tensor = torch.tensor(self.stop_sequence, device=input_ids.device, dtype=input_ids.dtype)
5350

5451
# Check if the current sequence ends with the stop_sequence
5552
current_sequence = input_ids[:, -len(self.stop_sequence) :]
56-
return torch.all(current_sequence == stop_sequence_tensor).item()
53+
return bool(torch.all(current_sequence == stop_sequence_tensor).item())
5754

5855

5956
class HuggingFaceServer:
@@ -88,14 +85,17 @@ def serve_request(self, raw_request: Dict[str, Any]):
8885
raw_request["output_scores"] = True
8986
top_k_per_token: int = raw_request["top_k_per_token"]
9087
del raw_request["top_k_per_token"]
88+
stopping_criteria: Optional[StoppingCriteriaList] = None
9189
if len(raw_request["stop_sequences"]) > 0:
9290
stop_sequence_ids = self.tokenizer(
9391
raw_request["stop_sequences"], return_token_type_ids=False, add_special_tokens=False
9492
)
95-
assert len(stop_sequence_ids.input_ids) == 1, "Total number of stop words should be 1."
96-
# assert len(stop_sequence_ids.input_ids[0]) == 1, "Total number of tokens in each stop word should be 1."
93+
assert len(stop_sequence_ids.input_ids) == 1, "Total number of stop sequences should be 1."
9794
if len(stop_sequence_ids.input_ids[0]) == 1:
9895
raw_request["eos_token_id"] = stop_sequence_ids.input_ids[0][0]
96+
else:
97+
stopping_criteria = StoppingCriteriaList()
98+
stopping_criteria.append(StopAtSpecificTokenCriteria(stop_sequence=stop_sequence_ids.input_ids[0]))
9999
del raw_request["stop_sequences"]
100100

101101
# Strip out irrelevant parameters
@@ -105,15 +105,11 @@ def serve_request(self, raw_request: Dict[str, Any]):
105105
if key not in ["engine", "prompt", "echo_prompt", "stop_sequences"]
106106
}
107107

108-
stopping_criteria = StoppingCriteriaList()
109-
if stop_sequence_ids != None:
110-
stopping_criteria.append(StopAtSpecificTokenCriteria(stop_sequence=stop_sequence_ids.input_ids[0]))
111-
112108
# Use HuggingFace's `generate` method.
113109
output = self.model.generate(
114110
**encoded_input,
115111
**relevant_raw_request,
116-
stopping_criteria=stopping_criteria if len(stop_sequence_ids.input_ids[0]) > 1 else None,
112+
stopping_criteria=stopping_criteria,
117113
)
118114
sequences = output.sequences
119115
scores = output.scores

src/helm/proxy/clients/vision_language/idefics_client.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from helm.common.images_utils import open_image
1111
from helm.common.gpu_utils import get_torch_device_name
1212
from helm.common.hierarchical_logger import hlog
13+
from helm.common.media_object import TEXT_TYPE
1314
from helm.common.request import Request, RequestResult, Sequence, Token
1415
from helm.common.tokenization_request import (
1516
TokenizationRequest,
@@ -93,12 +94,17 @@ def make_request(self, request: Request) -> RequestResult:
9394
exit_condition = processor.tokenizer(self.END_OF_UTTERANCE_TOKEN, add_special_tokens=False).input_ids
9495
generation_args["eos_token_id"] = exit_condition
9596

96-
multimodal_prompt: List[Union[str, Image]] = [
97-
open_image(media_object.location)
98-
if media_object.is_type("image") and media_object.location
99-
else media_object.text
100-
for media_object in request.multimodal_prompt.media_objects
101-
]
97+
multimodal_prompt: List[Union[str, Image.Image]] = []
98+
for media_object in request.multimodal_prompt.media_objects:
99+
100+
if media_object.is_type("image") and media_object.location:
101+
multimodal_prompt.append(open_image(media_object.location))
102+
elif media_object.is_type(TEXT_TYPE):
103+
if media_object.text is None:
104+
raise ValueError("MediaObject of text type has missing text field value")
105+
multimodal_prompt.append(media_object.text)
106+
else:
107+
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
102108
prompt_text: str = request.multimodal_prompt.text.replace(self.END_OF_UTTERANCE_TOKEN, " ")
103109

104110
try:

0 commit comments

Comments
 (0)