@@ -539,22 +539,20 @@ def compute_efficiency_metrics(
539
539
runtime = request_state .result .batch_request_time
540
540
batch_size = request_state .result .batch_size
541
541
542
- # Compute total number of prompt and output tokens (in first sequence) .
542
+ # Compute total number of prompt and output tokens.
543
543
# Fetch the right `Tokenizer` depending on the model defined in `AdapterSpec`
544
544
# and calculate the number of tokens in the prompt.
545
545
tokenizer_service : TokenizerService = metric_service
546
546
window_service : WindowService = WindowServiceFactory .get_window_service (adapter_spec .model , tokenizer_service )
547
547
prompt : str = request_state .request .prompt
548
548
num_prompt_tokens : int = window_service .get_num_tokens (prompt )
549
549
550
- # Just take the first completion
551
- # TODO: don't we need to take into account all the completions, since
552
- # the runtime we get (that's used to compute denoised_runtime) is for
553
- # generating all of them?
554
- # TODO: we should unify this into num_completion_tokens
555
- sequence = request_state .result .completions [0 ]
556
- num_output_tokens : int = len (sequence .tokens )
550
+ # Total number of tokens in the completion.
551
+ num_completion_tokens : int = sum ([len (completion .tokens ) for completion in request_state .result .completions ])
557
552
# Don't include prompt in number of generated tokens (e.g., for language modeling).
553
+ # Assume that tokens for different completions are generated sequentially (instead of batched) when
554
+ # computing num_output_tokens (for the purpose of runtime estimation).
555
+ num_output_tokens : int = num_completion_tokens
558
556
if request_state .request .echo_prompt :
559
557
# num_prompt_tokens > num_output_tokens can happen if tokenizer doesn't round trip.
560
558
if num_prompt_tokens <= num_output_tokens :
@@ -591,10 +589,6 @@ def compute_efficiency_metrics(
591
589
else :
592
590
training_energy_cost = None
593
591
594
- # Total number of tokens in the completion
595
- num_completion_tokens = sum ([len (completion .tokens ) for completion in request_state .result .completions ])
596
-
597
- # TODO: unify num_completion_tokens and num_output_tokens
598
592
stats = [
599
593
Stat (MetricName ("num_prompt_tokens" )).add (num_prompt_tokens ),
600
594
Stat (MetricName ("num_completion_tokens" )).add (num_completion_tokens ),
0 commit comments