1
- # mypy: check_untyped_defs = False
2
1
###############################################
3
2
# Source: https://github.com/tingofurro/summac
4
3
###############################################
5
4
5
+ from typing import Dict , List
6
6
from transformers import AutoTokenizer , AutoModelForSequenceClassification
7
7
import nltk
8
8
import numpy as np
9
+ import numpy .typing as npt
9
10
import torch
10
11
import os
11
12
import json
@@ -145,6 +146,7 @@ def build_image(self, original, generated):
145
146
146
147
if self .model is None :
147
148
self .load_nli ()
149
+ assert self .model
148
150
149
151
dataset = [
150
152
{"premise" : original_chunks [i ], "hypothesis" : generated_chunks [j ], "doc_i" : i , "gen_i" : j }
@@ -303,7 +305,7 @@ def compute_histogram(self, original=None, generated=None, image=None):
303
305
304
306
full_histogram = []
305
307
for i_gen in range (N_gen ):
306
- histos = []
308
+ histos : List [ npt . NDArray ] = []
307
309
308
310
for i_depth in range (N_depth ):
309
311
if (
@@ -317,32 +319,31 @@ def compute_histogram(self, original=None, generated=None, image=None):
317
319
histos .append (histo )
318
320
319
321
if self .norm_histo :
320
- histos = [[N_ori , N_gen ]] + histos
322
+ histos = [np . array ( [N_ori , N_gen ]) ] + histos
321
323
histogram_row = np .concatenate (histos )
322
324
full_histogram .append (histogram_row )
323
325
324
326
n_rows_missing = self .n_rows - len (full_histogram )
325
327
full_histogram += [[0.0 ] * self .full_size ] * n_rows_missing
326
328
full_histogram = full_histogram [: self .n_rows ]
327
- full_histogram = np .array (full_histogram )
328
- return image , full_histogram
329
+ return image , np .array (full_histogram )
329
330
330
331
def forward (self , originals , generateds , images = None ):
331
332
if images is not None :
332
333
# In case they've been pre-computed.
333
- histograms = []
334
+ histogram_list = []
334
335
for image in images :
335
336
_ , histogram = self .compute_histogram (image = image )
336
- histograms .append (histogram )
337
+ histogram_list .append (histogram )
337
338
else :
338
- images , histograms = [], []
339
+ images , histogram_list = [], []
339
340
for original , generated in zip (originals , generateds ):
340
341
image , histogram = self .compute_histogram (original = original , generated = generated )
341
342
images .append (image )
342
- histograms .append (histogram )
343
+ histogram_list .append (histogram )
343
344
344
- N = len (histograms )
345
- histograms = torch .FloatTensor (histograms ).to (self .device )
345
+ N = len (histogram_list )
346
+ histograms = torch .FloatTensor (histogram_list ).to (self .device )
346
347
347
348
non_zeros = (torch .sum (histograms , dim = - 1 ) != 0.0 ).long ()
348
349
seq_lengths = non_zeros .sum (dim = - 1 ).tolist ()
@@ -379,8 +380,8 @@ def forward(self, originals, generateds, images=None):
379
380
)
380
381
else :
381
382
features .append (torch .FloatTensor ([0.0 , 0.0 , 0.0 ]).unsqueeze (0 )) # .cuda()
382
- features = torch .cat (features )
383
- logits = self .layer_final (features )
383
+ features_tensor = torch .cat (features )
384
+ logits = self .layer_final (features_tensor )
384
385
histograms_out = [histogram .cpu ().numpy () for histogram in histograms ]
385
386
return logits , histograms_out , images
386
387
@@ -451,7 +452,7 @@ def score_one(self, original, generated):
451
452
return {"score" : final_score , "image" : image }
452
453
453
454
def score (self , sources , generateds , ** kwargs ):
454
- output = {"scores" : [], "images" : []}
455
+ output : Dict [ str , List ] = {"scores" : [], "images" : []}
455
456
for source , gen in zip (sources , generateds ):
456
457
score = self .score_one (source , gen )
457
458
output ["scores" ].append (score ["score" ])
0 commit comments