Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Qwen2.5-VL Optimization #13155

Merged
merged 6 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 25 additions & 36 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand Down Expand Up @@ -271,8 +272,13 @@ def forward(
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
for x in (q, k, v))
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
use_flash_attn = self.attn_backend == _Backend.FLASH_ATTN
q = apply_rotary_pos_emb_vision(q,
rotary_pos_emb,
use_flash_attn=use_flash_attn)
k = apply_rotary_pos_emb_vision(k,
rotary_pos_emb,
use_flash_attn=use_flash_attn)

if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import (
Expand All @@ -296,20 +302,23 @@ def forward(
"(b s) ... -> b s ...",
b=batch_size)
elif self.attn_backend == _Backend.TORCH_SDPA:
seq_length = q.size(1)
q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
attention_mask = torch.zeros([1, seq_length, seq_length],
device=q.device,
dtype=torch.bool)
# Execute attention entry by entry for speed & less VRAM.
outputs = []
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
cu_seqlens[i - 1]:cu_seqlens[i]] = True
output = F.scaled_dot_product_attention(q,
k,
v,
attention_mask,
dropout_p=0.0)
context_layer = rearrange(output, "b h s d -> b s h d ")
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
for x in [q_i, k_i, v_i])
output_i = F.scaled_dot_product_attention(q_i,
k_i,
v_i,
dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
Expand All @@ -327,25 +336,6 @@ def forward(
return output


class Qwen2RMSNorm(nn.Module):

def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class Qwen2_5_VisionBlock(nn.Module):

def __init__(
Expand Down Expand Up @@ -516,8 +506,7 @@ def __init__(
hidden_size=self.hidden_size,
)

# NOTE: We use torch native RMSNorm here for precision purposes.
norm_layer = partial(Qwen2RMSNorm, eps=norm_eps)
norm_layer = partial(RMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)

Expand Down
37 changes: 22 additions & 15 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,15 @@ def apply_rotary_emb_torch(x: torch.Tensor,


def apply_rotary_pos_emb_vision(t: torch.Tensor,
freqs: torch.Tensor) -> torch.Tensor:
freqs: torch.Tensor,
use_flash_attn=False) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
apply_rotary_emb = apply_rotary_emb_torch
if use_flash_attn:
from flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output


Expand Down Expand Up @@ -336,20 +340,23 @@ def forward(
"(b s) ... -> b s ...",
b=batch_size)
elif self.attn_backend == _Backend.TORCH_SDPA:
seq_length = q.size(1)
q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
attention_mask = torch.zeros([1, seq_length, seq_length],
device=q.device,
dtype=torch.bool)
# Execute attention entry by entry for speed & less VRAM.
outputs = []
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
cu_seqlens[i - 1]:cu_seqlens[i]] = True
output = F.scaled_dot_product_attention(q,
k,
v,
attention_mask,
dropout_p=0.0)
context_layer = rearrange(output, "b h s d -> b s h d ")
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
for x in [q_i, k_i, v_i])
output_i = F.scaled_dot_product_attention(q_i,
k_i,
v_i,
dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
Expand Down