Skip to content

Commit 9be35a3

Browse files
Fix typing issues in metrics/ and remove check_untyped_defs. (stanford-crfm#1942)
1 parent 50e6565 commit 9be35a3

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

src/helm/benchmark/metrics/summac/model_summac.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
# mypy: check_untyped_defs = False
21
###############################################
32
# Source: https://github.com/tingofurro/summac
43
###############################################
54

5+
from typing import Dict, List
66
from transformers import AutoTokenizer, AutoModelForSequenceClassification
77
import nltk
88
import numpy as np
9+
import numpy.typing as npt
910
import torch
1011
import os
1112
import json
@@ -145,6 +146,7 @@ def build_image(self, original, generated):
145146

146147
if self.model is None:
147148
self.load_nli()
149+
assert self.model
148150

149151
dataset = [
150152
{"premise": original_chunks[i], "hypothesis": generated_chunks[j], "doc_i": i, "gen_i": j}
@@ -303,7 +305,7 @@ def compute_histogram(self, original=None, generated=None, image=None):
303305

304306
full_histogram = []
305307
for i_gen in range(N_gen):
306-
histos = []
308+
histos: List[npt.NDArray] = []
307309

308310
for i_depth in range(N_depth):
309311
if (
@@ -317,32 +319,31 @@ def compute_histogram(self, original=None, generated=None, image=None):
317319
histos.append(histo)
318320

319321
if self.norm_histo:
320-
histos = [[N_ori, N_gen]] + histos
322+
histos = [np.array([N_ori, N_gen])] + histos
321323
histogram_row = np.concatenate(histos)
322324
full_histogram.append(histogram_row)
323325

324326
n_rows_missing = self.n_rows - len(full_histogram)
325327
full_histogram += [[0.0] * self.full_size] * n_rows_missing
326328
full_histogram = full_histogram[: self.n_rows]
327-
full_histogram = np.array(full_histogram)
328-
return image, full_histogram
329+
return image, np.array(full_histogram)
329330

330331
def forward(self, originals, generateds, images=None):
331332
if images is not None:
332333
# In case they've been pre-computed.
333-
histograms = []
334+
histogram_list = []
334335
for image in images:
335336
_, histogram = self.compute_histogram(image=image)
336-
histograms.append(histogram)
337+
histogram_list.append(histogram)
337338
else:
338-
images, histograms = [], []
339+
images, histogram_list = [], []
339340
for original, generated in zip(originals, generateds):
340341
image, histogram = self.compute_histogram(original=original, generated=generated)
341342
images.append(image)
342-
histograms.append(histogram)
343+
histogram_list.append(histogram)
343344

344-
N = len(histograms)
345-
histograms = torch.FloatTensor(histograms).to(self.device)
345+
N = len(histogram_list)
346+
histograms = torch.FloatTensor(histogram_list).to(self.device)
346347

347348
non_zeros = (torch.sum(histograms, dim=-1) != 0.0).long()
348349
seq_lengths = non_zeros.sum(dim=-1).tolist()
@@ -379,8 +380,8 @@ def forward(self, originals, generateds, images=None):
379380
)
380381
else:
381382
features.append(torch.FloatTensor([0.0, 0.0, 0.0]).unsqueeze(0)) # .cuda()
382-
features = torch.cat(features)
383-
logits = self.layer_final(features)
383+
features_tensor = torch.cat(features)
384+
logits = self.layer_final(features_tensor)
384385
histograms_out = [histogram.cpu().numpy() for histogram in histograms]
385386
return logits, histograms_out, images
386387

@@ -451,7 +452,7 @@ def score_one(self, original, generated):
451452
return {"score": final_score, "image": image}
452453

453454
def score(self, sources, generateds, **kwargs):
454-
output = {"scores": [], "images": []}
455+
output: Dict[str, List] = {"scores": [], "images": []}
455456
for source, gen in zip(sources, generateds):
456457
score = self.score_one(source, gen)
457458
output["scores"].append(score["score"])

src/helm/benchmark/metrics/test_bias_metrics.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: check_untyped_defs = False
21
from dataclasses import dataclass
32
from typing import Callable, List, Optional
43

@@ -12,7 +11,7 @@ class TestCase:
1211
rel_tol: float = 0.01
1312

1413

15-
def check_test_cases(test_cases: List[TestCase], bias_func: Callable[[List[str]], float]):
14+
def check_test_cases(test_cases: List[TestCase], bias_func: Callable[[List[str]], Optional[float]]):
1615
for test_case in test_cases:
1716
bias_score = bias_func(test_case.texts)
1817
error_msg = f"Expected: {test_case.bias_score}, Actual:{bias_score}"

0 commit comments

Comments
 (0)