Skip to content

Commit c945233

Browse files
authored
Merge pull request stanford-crfm#1143 from stanford-crfm/multiple_completions
Account for multiple completions
2 parents a1aff08 + 586ee38 commit c945233

File tree

2 files changed

+7
-13
lines changed

2 files changed

+7
-13
lines changed

src/benchmark/metrics/basic_metrics.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -539,22 +539,20 @@ def compute_efficiency_metrics(
539539
runtime = request_state.result.batch_request_time
540540
batch_size = request_state.result.batch_size
541541

542-
# Compute total number of prompt and output tokens (in first sequence).
542+
# Compute total number of prompt and output tokens.
543543
# Fetch the right `Tokenizer` depending on the model defined in `AdapterSpec`
544544
# and calculate the number of tokens in the prompt.
545545
tokenizer_service: TokenizerService = metric_service
546546
window_service: WindowService = WindowServiceFactory.get_window_service(adapter_spec.model, tokenizer_service)
547547
prompt: str = request_state.request.prompt
548548
num_prompt_tokens: int = window_service.get_num_tokens(prompt)
549549

550-
# Just take the first completion
551-
# TODO: don't we need to take into account all the completions, since
552-
# the runtime we get (that's used to compute denoised_runtime) is for
553-
# generating all of them?
554-
# TODO: we should unify this into num_completion_tokens
555-
sequence = request_state.result.completions[0]
556-
num_output_tokens: int = len(sequence.tokens)
550+
# Total number of tokens in the completion.
551+
num_completion_tokens: int = sum([len(completion.tokens) for completion in request_state.result.completions])
557552
# Don't include prompt in number of generated tokens (e.g., for language modeling).
553+
# Assume that tokens for different completions are generated sequentially (instead of batched) when
554+
# computing num_output_tokens (for the purpose of runtime estimation).
555+
num_output_tokens: int = num_completion_tokens
558556
if request_state.request.echo_prompt:
559557
# num_prompt_tokens > num_output_tokens can happen if tokenizer doesn't round trip.
560558
if num_prompt_tokens <= num_output_tokens:
@@ -591,10 +589,6 @@ def compute_efficiency_metrics(
591589
else:
592590
training_energy_cost = None
593591

594-
# Total number of tokens in the completion
595-
num_completion_tokens = sum([len(completion.tokens) for completion in request_state.result.completions])
596-
597-
# TODO: unify num_completion_tokens and num_output_tokens
598592
stats = [
599593
Stat(MetricName("num_prompt_tokens")).add(num_prompt_tokens),
600594
Stat(MetricName("num_completion_tokens")).add(num_completion_tokens),

src/benchmark/static/schema.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -913,7 +913,7 @@ metric_groups:
913913
split: ${main_split}
914914
- name: num_prompt_tokens
915915
split: ${main_split}
916-
- name: num_completion_tokens
916+
- name: num_output_tokens
917917
split: ${main_split}
918918
- name: num_train_trials
919919
split: ${main_split}

0 commit comments

Comments
 (0)