Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Pixtral-HF] Add custom slice_encoder_output for Pixtral #13080

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
678c291
Patch multi_modal_placeholders to RequestOutput
Nov 16, 2024
a1cdcb3
pipe multi_modal_placeholders from intput to final output
Nov 17, 2024
f60964a
[V1] Add code owners for V1 (#10397)
WoosukKwon Nov 16, 2024
578e482
[2/N][torch.compile] make compilation cfg part of vllm cfg (#10383)
youkaichao Nov 17, 2024
7fa97cf
[V1] Refactor model executable interface for all text-only language m…
ywang96 Nov 17, 2024
629f512
[CI/Build] Fix IDC hpu [Device not found] issue (#10384)
xuechendi Nov 17, 2024
7539ab8
[Bugfix][CPU] Fix CPU embedding runner with tensor parallel (#10394)
Isotr0py Nov 17, 2024
bce660d
[platforms] refactor cpu code (#10402)
youkaichao Nov 17, 2024
305708b
[Hardware] [HPU]add `mark_step` for hpu (#10239)
jikunshang Nov 17, 2024
871a773
[Bugfix] Fix mrope_position_delta in non-last prefill chunk (#10403)
imkero Nov 17, 2024
242bb53
[Misc] Enhance offline_inference to support user-configurable paramet…
wchen61 Nov 17, 2024
f5312d3
Fix initialization
DarkLight1337 Nov 18, 2024
439e324
Run isort
DarkLight1337 Nov 18, 2024
60815f2
isort
DarkLight1337 Nov 18, 2024
ec46755
isort
DarkLight1337 Nov 18, 2024
ce3ae6f
[Misc] Add uninitialized params tracking for `AutoWeightsLoader` (#10…
Isotr0py Nov 18, 2024
466b2cf
[Bugfix] Ignore ray reinit error when current platform is ROCm or XPU…
HollowMan6 Nov 18, 2024
3f092ce
update RequestOutput.__init__() to take `multi_modal_placeholders` as…
Nov 18, 2024
dd8427e
update RequestOutput.__init__() to take `multi_modal_placeholders` as…
Nov 18, 2024
76ac8b0
update RequestOutput.__init__() to take `multi_modal_placeholders` as…
Nov 18, 2024
904e925
Merge branch 'vllm-project:main' into main
lk-chen Nov 18, 2024
c963a25
disable mypy type check
Nov 18, 2024
470fbd3
disable mypy type check
Nov 18, 2024
550be23
remove unnecessary debug code
Nov 18, 2024
1eb4d96
Merge branch 'vllm-project:main' into main
lk-chen Nov 18, 2024
9b002b0
Merge branch 'vllm-project:main' into main
lk-chen Nov 20, 2024
436beb2
Merge branch 'vllm-project:main' into main
lk-chen Nov 25, 2024
bbc6420
Merge branch 'vllm-project:main' into main
lk-chen Dec 12, 2024
5254415
Merge branch 'vllm-project:main' into main
lk-chen Feb 5, 2025
6077919
Merge branch 'vllm-project:main' into main
lk-chen Feb 6, 2025
d8f785a
[V1] Enhance check when clicing encoder output
lk-chen Feb 5, 2025
89f243b
[V1][Pixtral-HF] Add custom `slice_encoder_output` for Pixtral
lk-chen Feb 11, 2025
1d37090
Merge branch 'vllm-project:main' into main
lk-chen Feb 11, 2025
098b444
Merge branch 'main' into pixtral_hf
lk-chen Feb 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs,
NestedTensors)
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
Expand Down Expand Up @@ -507,6 +507,43 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
if (config.text_config.architectures is None
and config.text_config.model_type == "mistral"):
config.text_config.architectures = ["MistralForCausalLM"]

def _slice_encoder_output(
mm_input: MultiModalKwargs,
encoder_output: torch.Tensor,
mm_pos: PlaceholderRange,
num_computed_tokens: int,
num_scheduled_tokens: int,
) -> torch.Tensor:
assert "pixel_values" in mm_input
image_input = mm_input["pixel_values"]
ncols, nrows = get_pixtral_hf_image_feature_grid_size(
self.config.vision_config,
image_width=image_input.shape[-1],
image_height=image_input.shape[-2],
)
placeholder_start = mm_pos["offset"]

# Turn placeholder position into encoder output position
def placeholder_pos_to_encoder_output_pos(
placeholder_pos: int) -> int:
return placeholder_pos % (ncols + 1) + placeholder_pos // (
ncols + 1) * ncols

start_idx = max(
placeholder_pos_to_encoder_output_pos(num_computed_tokens -
placeholder_start),
0)
end_idx = min(
placeholder_pos_to_encoder_output_pos(
num_computed_tokens + num_scheduled_tokens -
placeholder_start), len(encoder_output))
assert start_idx <= end_idx, (
f"{start_idx=} should be no greater than {end_idx=}")
return encoder_output[start_idx:end_idx]

self.slice_encoder_output = _slice_encoder_output

if (config.projector_hidden_act is None
and config.vision_config.hidden_act == "gelu"):
config.projector_hidden_act = "gelu"
Expand Down
27 changes: 21 additions & 6 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,15 +755,30 @@ def _gather_encoder_outputs(
# in the decoder's KV cache.
continue

start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens)
assert start_idx < end_idx
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
encoder_outputs.append(encoder_output[start_idx:end_idx])
if hasattr(self.model, "slice_encoder_output"):
# Per-model custom logic to slice the encoder output. Some
# models (e.g. Pixtral) have dynamic number of special
# tokens (e.g. image_break) in the middle of placeholder
# positions. This allows the model to calculate
# encoder_output slices taking into account the special
# tokens.
encoder_outputs.append(
self.model.slice_encoder_output(
mm_input=req_state.mm_inputs[i],
encoder_output=encoder_output,
mm_pos=pos_info,
num_computed_tokens=num_computed_tokens,
num_scheduled_tokens=num_scheduled_tokens))
else:
start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens)
assert start_idx < end_idx
encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs

def get_model(self) -> nn.Module:
Expand Down