Skip to content

Commit ae0b179

Browse files
authored
Fix typing errors detected by a newer version of mypy (stanford-crfm#1841)
1 parent 7ad2997 commit ae0b179

File tree

9 files changed

+20
-50
lines changed

9 files changed

+20
-50
lines changed

scripts/data_overlap/load_documents.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,4 @@ def get_raw_document_iterator(file_path: str) -> Iterator[str]:
136136

137137
def get_custom_document_iterator(file_path: str) -> Iterator[str]:
138138
"""Define your own document reading method"""
139-
pass
139+
raise NotImplementedError()

setup.cfg

+4
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@ ignore = E203,E231,E731,W503,W605
174174
# Settings for Mypy: static type checker for Python 3
175175
[mypy]
176176
ignore_missing_imports = True
177+
# TODO(#1831): Change this to True
178+
check_untyped_defs = False
179+
# TODO(#1831): Change this to True
180+
disallow_untyped_defs = False
177181

178182
[tool:pytest]
179183
addopts =

src/helm/benchmark/augmentations/contrast_sets_perturbation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,4 @@ def apply(self, instance: Instance, seed: Optional[int] = None) -> Instance:
8383
)
8484

8585
def perturb(self, text: str, rng: Random) -> str: # we need this since parent method is abstract
86-
pass
86+
raise NotImplementedError("Should never be called since apply() was overridden")

src/helm/benchmark/augmentations/gender_perturbation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(
9999
source_class: str,
100100
target_class: str,
101101
mapping_file_path: Optional[str] = None,
102-
mapping_file_genders: List[str] = None,
102+
mapping_file_genders: Optional[List[str]] = None,
103103
bidirectional: bool = False,
104104
):
105105
"""Initialize the gender perturbation.

src/helm/benchmark/run.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,12 @@ def run_entries_to_run_specs(
4949
adapter_spec: AdapterSpec = run_spec.adapter_spec
5050
if max_eval_instances is not None and adapter_spec.max_eval_instances is None:
5151
adapter_spec = replace(adapter_spec, max_eval_instances=max_eval_instances)
52-
if num_train_trials is not None or adapter_spec.max_train_instances == 0:
53-
adapter_spec = replace(
54-
adapter_spec, num_train_trials=1 if adapter_spec.max_train_instances == 0 else num_train_trials
55-
)
52+
53+
if adapter_spec.max_train_instances == 0:
54+
adapter_spec = replace(adapter_spec, num_train_trials=1)
55+
elif num_train_trials is not None:
56+
adapter_spec = replace(adapter_spec, num_train_trials=num_train_trials)
57+
5658
run_spec = replace(run_spec, adapter_spec=adapter_spec)
5759

5860
# Append groups

src/helm/benchmark/run_expander.py

+1-30
Original file line numberDiff line numberDiff line change
@@ -67,35 +67,6 @@ def sanitize(value):
6767
]
6868

6969

70-
class ReplaceRunSpecValueRunExpander(RunExpander):
71-
"""
72-
Replace a single field (e.g., max_train_instances) with a list of values (e.g., 0, 1, 2).
73-
"""
74-
75-
def __init__(self, value):
76-
"""
77-
`value` is either the actual value to use or a lookup into the values dict.
78-
"""
79-
self.name = type(self).name
80-
if value in type(self).values_dict:
81-
self.values = type(self).values_dict[value]
82-
else:
83-
self.values = [value]
84-
85-
def expand(self, run_spec: RunSpec) -> List[RunSpec]:
86-
def sanitize(value):
87-
return str(value).replace("/", "_")
88-
89-
return [
90-
replace(
91-
run_spec,
92-
name=f"{run_spec.name},{self.name}={sanitize(value)}",
93-
metrics=value,
94-
)
95-
for value in self.values
96-
]
97-
98-
9970
class InstructionsRunExpander(RunExpander):
10071
"""
10172
Set the instructions of the prompt.
@@ -530,7 +501,7 @@ def gender(
530501
source_class: str,
531502
target_class: str,
532503
mapping_file_path: Optional[str] = None,
533-
mapping_file_genders: Tuple[str] = None,
504+
mapping_file_genders: Optional[Tuple[str]] = None,
534505
bidirectional: bool = False,
535506
) -> PerturbationSpec:
536507
return PerturbationSpec(

src/helm/common/cache.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from abc import ABC, abstractmethod
1+
from abc import abstractmethod
2+
import contextlib
23
from dataclasses import dataclass
34
import json
45
from typing import Dict, Callable, Generator, Iterable, Optional, Tuple
@@ -103,19 +104,13 @@ def cache_stats_key(self) -> str:
103104
return self.main.cache_stats_key
104105

105106

106-
class KeyValueStore(ABC):
107+
class KeyValueStore(contextlib.AbstractContextManager):
107108
"""Key value store that persists writes."""
108109

109110
@property
110111
def path(self):
111112
return self._path
112113

113-
def __enter__(self) -> "KeyValueStore":
114-
pass
115-
116-
def __exit__(self, exc_type, exc_value, traceback) -> None:
117-
pass
118-
119114
@abstractmethod
120115
def contains(self, key: Dict) -> bool:
121116
pass
@@ -150,11 +145,9 @@ def __init__(self, path: str):
150145

151146
def __enter__(self) -> "_SqliteKeyValueStore":
152147
self._sqlite_dict.__enter__()
153-
super().__enter__()
154148
return self
155149

156150
def __exit__(self, exc_type, exc_value, traceback) -> None:
157-
super().__exit__(exc_type, exc_value, traceback)
158151
self._sqlite_dict.__exit__(exc_type, exc_value, traceback)
159152

160153
def contains(self, key: Dict) -> bool:
@@ -204,11 +197,10 @@ def __init__(self, uri: str, collection_name: str):
204197
super().__init__()
205198

206199
def __enter__(self) -> "_MongoKeyValueStore":
207-
super().__enter__()
208200
return self
209201

210202
def __exit__(self, exc_type, exc_value, traceback) -> None:
211-
super().__exit__(exc_type, exc_value, traceback)
203+
return
212204

213205
def _canonicalize_key(self, key: Dict) -> SON:
214206
serialized = json.dumps(key, sort_keys=True)

src/helm/proxy/clients/ai21_client.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def parse_token(raw: Dict, first: bool) -> Token:
129129
top_logprobs=top_logprobs,
130130
)
131131

132-
def parse_sequence(raw: Dict, first: bool, finish_reason: Dict = None) -> Sequence:
132+
def parse_sequence(raw: Dict, first: bool, finish_reason: Optional[Dict] = None) -> Sequence:
133133
text = raw["text"]
134134
tokens = [parse_token(token, first and i == 0) for i, token in enumerate(raw["tokens"])]
135135
logprob = sum(token.logprob for token in tokens)

src/helm/proxy/services/service.py

+1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def get_toxicity_scores(self, auth: Authentication, request: PerspectiveAPIReque
108108
"""Get toxicity scores for a batch of text."""
109109
pass
110110

111+
@abstractmethod
111112
def make_critique_request(self, auth: Authentication, request: CritiqueRequest) -> CritiqueRequestResult:
112113
"""Get responses to a critique request."""
113114
pass

0 commit comments

Comments
 (0)