Skip to content

Commit be6f08d

Browse files
authored
Optional dependencies (stanford-crfm#1798)
1 parent 73bf60d commit be6f08d

17 files changed

+122
-37
lines changed

install-dev.sh

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,6 @@ pip install --no-binary=protobuf protobuf==3.20.2
1414
# Install all pinned dependencies
1515
pip install -r requirements-freeze.txt
1616
# Install HELM in edit mode
17-
pip install -e .
17+
pip install -e .[all]
18+
# Check dependencies
19+
pip check

pre-commit.sh

-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ if [ "$valid_version" == "False" ]; then
1111
exit 1
1212
fi
1313

14-
pip check
15-
1614
# Python style checks and linting
1715
black --check --diff src scripts || (
1816
echo ""

setup.cfg

+24-14
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ install_requires=
3737
# sqlitedict==2.0.0 is slow! https://github.com/RaRe-Technologies/sqlitedict/issues/152
3838
# Keep sqlitedict version at 1.7.0.
3939
sqlitedict~=1.7.0
40+
bottle~=0.12.23
4041
# TODO: Remove these from common
4142
protobuf~=3.20.2 # Can't use 4.21.0 due to backward incompatibility
4243
pymongo~=4.2.0
@@ -56,15 +57,6 @@ install_requires=
5657
# TODO: Remove after this issue is resolved
5758
scikit-learn~=1.1.2
5859

59-
# Server Extras
60-
bottle~=0.12.23
61-
gunicorn~=20.1.0
62-
63-
# Scenario Extras
64-
gdown~=4.4.0 # For opinions_qa_scenario
65-
sympy~=1.11.1 # For numeracy_scenario
66-
xlrd~=2.0.1 # For ice_scenario: used by pandas.read_excel
67-
6860
# Model Extras
6961
aleph-alpha-client~=2.14.0
7062
anthropic~=0.2.5
@@ -84,20 +76,38 @@ install_requires=
8476

8577
# Metrics Extras
8678
google-api-python-client~=2.64.0 # For perspective_api_client via toxicity_metrics
79+
80+
[options.extras_require]
81+
proxy-server =
82+
gunicorn~=20.1.0
83+
84+
human-evaluation =
85+
scaleapi~=2.13.0
86+
surge-api~=1.1.0
87+
88+
scenarios =
89+
gdown~=4.4.0 # For disinformation_scenario, med_mcqa_scenario, med_qa_scenario: used by ensure_file_downloaded()
90+
sympy~=1.11.1 # For numeracy_scenario
91+
xlrd~=2.0.1 # For ice_scenario: used by pandas.read_excel()
92+
93+
metrics =
8794
numba~=0.56.4 # For copyright_metrics
8895
pytrec_eval==0.5 # For ranking_metrics
8996
sacrebleu~=2.2.1 # For disinformation_metrics, machine_translation_metrics
9097
summ-eval~=0.892 # For summarization_metrics
9198

92-
# Human Evaluation Extras
93-
scaleapi~=2.13.0
94-
surge-api~=1.1.0
95-
96-
# Plots Extras
99+
plots =
97100
colorcet~=3.0.1
98101
matplotlib~=3.6.0
99102
seaborn~=0.11.0
100103

104+
all =
105+
crfm-helm[server]
106+
crfm-helm[human-evaluation]
107+
crfm-helm[scenarios]
108+
crfm-helm[metrics]
109+
crfm-helm[plots]
110+
101111
[options.entry_points]
102112
console_scripts =
103113
helm-run = helm.benchmark.run:main

src/helm/benchmark/metrics/copyright_metrics.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
import re
22
from typing import List, Optional
33

4-
import numba
54
import numpy as np
65
from nltk.tokenize.treebank import TreebankWordTokenizer
76

87
from helm.benchmark.adaptation.request_state import RequestState
98
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
109
from helm.benchmark.scenarios.scenario import Reference
10+
from helm.common.optional_dependencies import handle_module_not_found_error
1111
from helm.common.request import RequestResult
1212
from .metric import Metric
1313
from .metric_name import MetricName
1414
from .metric_service import MetricService
1515
from .statistic import Stat
1616

17+
try:
18+
import numba
19+
except ModuleNotFoundError as e:
20+
handle_module_not_found_error(e)
21+
1722

1823
def _longest_common_prefix_length(s1: np.ndarray, s2: np.ndarray, previous_best: Optional[float] = None) -> float:
1924
"""Compute the length of the longest common prefix."""

src/helm/benchmark/metrics/disinformation_metrics.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from typing import Dict, List, Optional
66

77
import numpy as np
8-
from sacrebleu.metrics import BLEU
98

109
from helm.common.general import ensure_file_downloaded
10+
from helm.common.optional_dependencies import handle_module_not_found_error
1111
from helm.common.request import RequestResult, Sequence
1212
from helm.benchmark.adaptation.request_state import RequestState
1313
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
@@ -16,6 +16,11 @@
1616
from .metric_service import MetricService
1717
from .statistic import Stat
1818

19+
try:
20+
from sacrebleu.metrics import BLEU
21+
except ModuleNotFoundError as e:
22+
handle_module_not_found_error(e)
23+
1924

2025
HUMAN_EVAL_CODALAB_LINK: str = (
2126
"https://worksheets.codalab.org/rest/bundles/0xd8c577022f584f27aead3f00aa771da5/contents/blob/{file_name}"

src/helm/benchmark/metrics/machine_translation_metrics.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from typing import List
2-
from sacrebleu import BLEU
32

43
from helm.benchmark.adaptation.request_state import RequestState
4+
from helm.common.optional_dependencies import handle_module_not_found_error
55
from .metric import Metric
66
from .metric_name import MetricName
77
from .statistic import Stat
88

9+
try:
10+
from sacrebleu.metrics import BLEU
11+
except ModuleNotFoundError as e:
12+
handle_module_not_found_error(e)
13+
914

1015
class MachineTranslationMetric(Metric):
1116
"""

src/helm/benchmark/metrics/ranking_metrics.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from dataclasses import dataclass
22
from typing import Callable, Dict, List, Tuple, Optional
33

4-
import pytrec_eval
5-
64
from helm.benchmark.adaptation.adapters.adapter_factory import ADAPT_RANKING_BINARY
75
from helm.benchmark.adaptation.request_state import RequestState
86
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
7+
from helm.common.optional_dependencies import handle_module_not_found_error
98
from helm.benchmark.scenarios.scenario import unpack_tag, CORRECT_TAG, Reference
109
from helm.common.request import RequestResult
1110
from helm.common.general import binarize_dict
@@ -14,6 +13,11 @@
1413
from .metric_service import MetricService
1514
from .statistic import Stat
1615

16+
try:
17+
import pytrec_eval
18+
except ModuleNotFoundError as e:
19+
handle_module_not_found_error(e)
20+
1721

1822
@dataclass
1923
class RankingObject:

src/helm/benchmark/metrics/summarization_metrics.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
1414
from helm.common.hierarchical_logger import hlog
1515
from helm.common.general import ensure_file_downloaded
16+
from helm.common.optional_dependencies import handle_module_not_found_error
1617
from .metric import Metric, MetricResult
1718
from .metric_name import MetricName
1819
from .metric_service import MetricService
@@ -21,6 +22,7 @@
2122
from .summac.model_summac import SummaCZS
2223
from bert_score import BERTScorer
2324

25+
2426
QAFACTEVAL_CODALAB_LINK: str = (
2527
"https://worksheets.codalab.org/rest/bundles/0xf4de83c1f0d34d7999480223e8f5ab87/contents/blob/"
2628
)
@@ -52,7 +54,11 @@ def __init__(self, task: str, device: str = "cpu"):
5254
# `NameError: name 'stderr' is not defined`
5355
if not spacy.util.is_package("en_core_web_sm"):
5456
spacy.cli.download("en_core_web_sm") # type: ignore
55-
from summ_eval.data_stats_metric import DataStatsMetric
57+
58+
try:
59+
from summ_eval.data_stats_metric import DataStatsMetric
60+
except ModuleNotFoundError as e:
61+
handle_module_not_found_error(e)
5662

5763
self.data_stats_metric = DataStatsMetric()
5864
self.task: str = task

src/helm/benchmark/presentation/create_plots.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,23 @@
66
import os
77
from typing import List, Dict, Optional, Any, Callable, Union, Mapping, Tuple, Set
88

9-
import colorcet
10-
import matplotlib
11-
import matplotlib.pyplot as plt
129
import numpy as np
1310
from scipy.stats import pearsonr
14-
import seaborn as sns
1511

1612
from helm.common.hierarchical_logger import hlog
13+
from helm.common.optional_dependencies import handle_module_not_found_error
1714
from helm.benchmark.presentation.schema import read_schema
1815
from helm.benchmark.presentation.summarize import AGGREGATE_WIN_RATE_COLUMN
1916

17+
try:
18+
import colorcet
19+
import matplotlib
20+
import matplotlib.pyplot as plt
21+
import seaborn as sns
22+
except ModuleNotFoundError as e:
23+
handle_module_not_found_error(e)
24+
25+
2026
sns.set_style("whitegrid")
2127

2228
DOWN_ARROW = "\u2193"

src/helm/benchmark/run_specs.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from .scenarios.scenario import ScenarioSpec
3636
from .scenarios.big_bench_scenario import BIGBenchScenario
3737
from .scenarios.msmarco_scenario import MSMARCOScenario
38-
from .scenarios.numeracy_scenario import get_numeracy_adapter_spec, RELTYPE_INFO
3938
from .scenarios.copyright_scenario import datatag2hash_code
4039
from .scenarios.raft_scenario import get_raft_instructions
4140
from .scenarios.lextreme_scenario import (
@@ -1001,6 +1000,8 @@ def get_raft_spec(subset: str) -> RunSpec:
10011000
def get_numeracy_spec(
10021001
relation_type: str = "linear", mode: str = "function", seed: str = "0", run_solver: str = "False"
10031002
) -> RunSpec:
1003+
from .scenarios.numeracy_scenario import get_numeracy_adapter_spec, RELTYPE_INFO
1004+
10041005
run_solver: bool = True if run_solver == "True" else False # type: ignore
10051006
random_seed = int(seed)
10061007
scenario_spec = ScenarioSpec(

src/helm/benchmark/scenarios/ice_scenario.py

+7
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,16 @@
44
from enum import Enum
55
import pandas as pd
66

7+
from helm.common.optional_dependencies import handle_module_not_found_error
78
from .ice_scenario_pinned_file_order import listdir_with_pinned_file_order
89
from .scenario import Scenario, Instance, TEST_SPLIT, Input
910

11+
try:
12+
# pd.read_excel() uses xlrd
13+
import xlrd # noqa
14+
except ModuleNotFoundError as e:
15+
handle_module_not_found_error(e)
16+
1017

1118
class ICESubset(Enum):
1219
CANADA = "can"

src/helm/benchmark/scenarios/numeracy_scenario.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,23 @@
77
import numpy as np
88
import numpy.typing as npt
99
import random
10-
import sympy
11-
from sympy import Symbol, Poly, diff
12-
from sympy.parsing.sympy_parser import standard_transformations, implicit_multiplication_application
1310
from typing import List, Optional, Tuple, Dict
1411

1512
from helm.benchmark.adaptation.adapters.adapter_factory import ADAPT_GENERATION
1613
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
1714
from helm.benchmark.window_services.tokenizer_service import TokenizerService
1815
from helm.common.authentication import Authentication
16+
from helm.common.optional_dependencies import handle_module_not_found_error
1917
from helm.proxy.services.server_service import ServerService
2018
from .scenario import Scenario, Instance, Reference, TRAIN_SPLIT, TEST_SPLIT, CORRECT_TAG, Input, Output
2119

20+
try:
21+
import sympy
22+
from sympy import Symbol, Poly, diff
23+
from sympy.parsing.sympy_parser import standard_transformations, implicit_multiplication_application
24+
except ModuleNotFoundError as e:
25+
handle_module_not_found_error(e)
26+
2227

2328
# TODO: we shouldn't create an Adapter and TokenizerService in a scenario
2429
# The Adapter and Scenarios should be completely decoupled.

src/helm/common/general.py

+5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from dataclasses import asdict, is_dataclass
1414

1515
from helm.common.hierarchical_logger import hlog, htrack, htrack_block
16+
from helm.common.optional_dependencies import handle_module_not_found_error
1617

1718

1819
_CREDENTIALS_FILE_NAME = "credentials.conf"
@@ -82,6 +83,10 @@ def ensure_file_downloaded(
8283
# gdown is used to download large files/zip folders from Google Drive.
8384
# It bypasses security warnings which wget cannot handle.
8485
if source_url.startswith("https://drive.google.com"):
86+
try:
87+
import gdown # noqa
88+
except ModuleNotFoundError as e:
89+
handle_module_not_found_error(e)
8590
downloader_executable = "gdown"
8691
tmp_path: str = f"{target_path}.tmp"
8792
shell([downloader_executable, source_url, "-O", tmp_path])
+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class OptionalDependencyNotInstalled(Exception):
2+
pass
3+
4+
5+
def handle_module_not_found_error(e: ModuleNotFoundError):
6+
# TODO: Ask user to install more specific optional dependencies
7+
# e.g. crfm-helm[plots] or crfm-helm[server]
8+
raise OptionalDependencyNotInstalled(
9+
f"Optional dependency {e.name} is not installed. " "Please run `pip install helm-crfm[all]` to install it."
10+
) from e

src/helm/proxy/clients/scale_critique_client.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
from typing import Dict, List, Union, Set, Any
55

66
from cattrs import unstructure
7-
import scaleapi
8-
from scaleapi.tasks import TaskType, TaskStatus
9-
from scaleapi.exceptions import ScaleDuplicateResource
107

118
from helm.common.hierarchical_logger import hlog
129
from helm.common.cache import Cache, CacheConfig
@@ -17,8 +14,16 @@
1714
CritiqueTaskTemplate,
1815
CritiqueResponse,
1916
)
17+
from helm.common.optional_dependencies import handle_module_not_found_error
2018
from helm.proxy.clients.critique_client import CritiqueClient
2119

20+
try:
21+
import scaleapi
22+
from scaleapi.tasks import TaskType, TaskStatus
23+
from scaleapi.exceptions import ScaleDuplicateResource
24+
except ModuleNotFoundError as e:
25+
handle_module_not_found_error(e)
26+
2227

2328
class ScaleCritiqueClientError(Exception):
2429
pass

src/helm/proxy/clients/surge_ai_critique_client.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
import threading
33
from typing import Dict, List
44

5-
import surge
6-
from surge import questions as surge_questions
7-
85
from helm.common.cache import Cache, CacheConfig
96
from helm.common.critique_request import (
107
CritiqueQuestionTemplate,
@@ -14,8 +11,15 @@
1411
CritiqueTaskTemplate,
1512
)
1613
from helm.common.hierarchical_logger import hlog
14+
from helm.common.optional_dependencies import handle_module_not_found_error
1715
from helm.proxy.clients.critique_client import CritiqueClient
1816

17+
try:
18+
import surge
19+
from surge import questions as surge_questions
20+
except ModuleNotFoundError as e:
21+
handle_module_not_found_error(e)
22+
1923

2024
_surge_cache_lock = threading.Lock()
2125

src/helm/proxy/server.py

+7
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,20 @@
1818

1919
from helm.common.authentication import Authentication
2020
from helm.common.hierarchical_logger import hlog
21+
from helm.common.optional_dependencies import handle_module_not_found_error
2122
from helm.common.request import Request
2223
from helm.common.perspective_api_request import PerspectiveAPIRequest
2324
from helm.common.tokenization_request import TokenizationRequest, DecodeRequest
2425
from .accounts import Account
2526
from .services.server_service import ServerService
2627
from .query import Query
2728

29+
try:
30+
import gunicorn # noqa
31+
except ModuleNotFoundError as e:
32+
handle_module_not_found_error(e)
33+
34+
2835
bottle.BaseRequest.MEMFILE_MAX = 1024 * 1024
2936

3037
app = bottle.default_app()

0 commit comments

Comments
 (0)