Skip to content

Commit 75193da

Browse files
Simplify call structure when updating Stats member variables. (stanford-crfm#1927)
1 parent 1aa727b commit 75193da

File tree

2 files changed

+49
-14
lines changed

2 files changed

+49
-14
lines changed

src/helm/benchmark/metrics/statistic.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,18 @@ class Stat:
1717
max: Optional[float] = None
1818
mean: Optional[float] = None
1919
variance: Optional[float] = None
20+
"""This is the population variance, not the sample variance.
21+
22+
See https://towardsdatascience.com/variance-sample-vs-population-3ddbd29e498a
23+
for details.
24+
"""
25+
2026
stddev: Optional[float] = None
27+
"""This is the population standard deviation, not the sample standard deviation.
28+
29+
See https://towardsdatascience.com/variance-sample-vs-population-3ddbd29e498a
30+
for details.
31+
"""
2132

2233
def add(self, x) -> "Stat":
2334
# Skip Nones for statistic aggregation.
@@ -69,22 +80,17 @@ def process(x: Optional[float]) -> str:
6980
else:
7081
return "(0)"
7182

72-
def _update_mean(self):
73-
self.mean = self.sum / self.count if self.count else None
74-
75-
def _update_variance(self):
76-
self._update_mean()
77-
if self.mean is None:
78-
return None
83+
def _update_mean_variance_stddev(self):
84+
if self.count == 0:
85+
# No stats with no elements.
86+
return
87+
# Update mean
88+
self.mean = self.sum / self.count
89+
# Update variance
7990
pvariance = self.sum_squared / self.count - self.mean**2
8091
self.variance = 0 if pvariance < 0 else pvariance
81-
82-
def _update_stddev(self):
83-
self._update_variance()
84-
self.stddev = math.sqrt(self.variance) if self.variance is not None else None
85-
86-
def _update_mean_variance_stddev(self):
87-
self._update_stddev()
92+
# Update stddev
93+
self.stddev = math.sqrt(self.variance)
8894

8995
def take_mean(self):
9096
"""Return a version of the stat that only has the mean."""

src/helm/benchmark/metrics/test_statistic.py

+29
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,26 @@
11
from typing import Dict
22

3+
import pytest
4+
import statistics
5+
36
from .metric_name import MetricName
47
from .statistic import Stat, merge_stat
58

69

10+
def test_stat_add():
11+
stat = Stat(MetricName("some_metric"))
12+
population = list(range(10))
13+
for i in population:
14+
stat.add(i)
15+
assert stat.sum == sum(population)
16+
assert stat.count == 10
17+
assert stat.min == 0
18+
assert stat.max == 9
19+
assert stat.mean == sum(population) / 10
20+
assert stat.variance == pytest.approx(statistics.pvariance(population))
21+
assert stat.stddev == pytest.approx(statistics.pstdev(population))
22+
23+
724
def test_merge_stat():
825
# Ensure that `MetricName`s are hashable
926
metric_name = MetricName("some_metric")
@@ -12,3 +29,15 @@ def test_merge_stat():
1229

1330
assert len(stats) == 1
1431
assert stats[metric_name].sum == 2
32+
assert stats[metric_name].mean == 1
33+
34+
35+
def test_merge_empty_stat():
36+
# This test ensures we guard against division by zero.
37+
metric_name = MetricName("some_metric")
38+
empty_1 = Stat(metric_name)
39+
empty_2 = Stat(metric_name)
40+
merged = empty_1.merge(empty_2)
41+
42+
assert merged.count == 0
43+
assert merged.stddev is None

0 commit comments

Comments
 (0)