Skip to content

clip.cpp / gguf-py: Support for Qwen2.5 VL - WIP / REVIEW NEEDED (#11483) #12119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2236,6 +2236,29 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
continue
yield name, data

# Additional handling for 2.5 seems not to be needed.
@Model.register("Qwen2_5_VLForConditionalGeneration")
class Qwen2_5_VLModel(Model):
model_arch = gguf.MODEL_ARCH.QWEN2VL

def set_gguf_parameters(self):
super().set_gguf_parameters()
mrope_section = self.hparams["rope_scaling"]["mrope_section"]
mrope_section += [0] * max(0, 4 - len(mrope_section))
self.gguf_writer.add_rope_dimension_sections(mrope_section)

def set_vocab(self):
try:
self._set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_gpt2()

def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
for name, data in super().get_tensors():
if name.startswith("visual."):
continue
yield name, data


@Model.register("WavTokenizerDec")
class WavTokenizerDecModel(Model):
Expand Down
176 changes: 117 additions & 59 deletions examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,25 @@
#include "ggml-backend.h"
#include "gguf.h"

//#ifdef GGML_USE_CUDA
//#include "ggml-cuda.h"
//#endif
//
//#ifdef GGML_USE_SYCL
//#include "ggml-sycl.h"
//#endif
//
//#ifdef GGML_USE_METAL
//#include "ggml-metal.h"
//#endif
//
//#ifdef GGML_USE_CANN
//#include "ggml-cann.h"
//#endif
//
//#ifdef GGML_USE_VULKAN
//#include "ggml-vulkan.h"
//#endif
#ifdef GGML_USE_CUDA
#include "ggml-cuda.h"
#endif

#ifdef GGML_USE_SYCL
#include "ggml-sycl.h"
#endif

#ifdef GGML_USE_METAL
#include "ggml-metal.h"
#endif

#ifdef GGML_USE_CANN
#include "ggml-cann.h"
#endif

#ifdef GGML_USE_VULKAN
#include "ggml-vulkan.h"
#endif

#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
Expand Down Expand Up @@ -106,6 +106,8 @@ static std::string format(const char * fmt, ...) {
#define KEY_HAS_GLM_PROJ "clip.has_glm_projector"
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger"
#define KEY_IS_QWEN2_5 "clip.is_qwen2_5"
#define KEY_RMS_NORM_EPS "clip.%s.attention.rms_norm_epsilon"
#define KEY_USE_GELU "clip.use_gelu"
#define KEY_USE_SILU "clip.use_silu"
#define KEY_N_EMBD "clip.%s.embedding_length"
Expand Down Expand Up @@ -583,6 +585,7 @@ struct clip_ctx {
bool has_minicpmv_projector = false;
bool has_glm_projector = false;
bool has_qwen2vl_merger = false;
bool is_qwen2_5 = false;
int minicpmv_version = 2;

struct clip_vision_model vision_model;
Expand Down Expand Up @@ -734,7 +737,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
if (ctx->has_minicpmv_projector) {
int pos_w = image_size_width/patch_size;
int pos_h = image_size_height/patch_size;
if (ctx->minicpmv_version == 2) {
if (ctx->is_qwen2_5) {
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 2048, pos_w * pos_h, 1);
}
else if (ctx->minicpmv_version == 2) {
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1);
}
else if (ctx->minicpmv_version == 3) {
Expand Down Expand Up @@ -774,8 +780,14 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
cur = ggml_norm(ctx0, cur, eps);

cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w),
model.layers[il].ln_1_b);
if (ctx->is_qwen2_5) {
// RMSNorm for Qwen2.5 (no bias)
cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w);
} else {
// Standard LayerNorm with bias
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w),
model.layers[il].ln_1_b);
}
}

// self-attention
Expand Down Expand Up @@ -834,22 +846,47 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
cur = ggml_norm(ctx0, cur, eps);

cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
if (ctx->is_qwen2_5) {
// RMSNorm for Qwen2.5 (no bias)
cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w);
} else {
// Standard LayerNorm with bias
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w),
model.layers[il].ln_2_b);
}
}

cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
// For Qwen2.5, the MLP uses SiLU gated activation
if (ctx->is_qwen2_5) {
// Qwen2.5 uses SiLU gated activation
// ffn_down is the gate_proj, ffn_up is the up_proj
struct ggml_tensor * gate = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
struct ggml_tensor * up = ggml_mul_mat(ctx0, model.layers[il].ff_i_b, cur); // using ff_i_b as up_proj weight

// Apply SiLU to the gate
gate = ggml_silu_inplace(ctx0, gate);

// Multiply gate and up
cur = ggml_mul(ctx0, gate, up);

if (ctx->use_gelu) {
cur = ggml_gelu_inplace(ctx0, cur);
} else if (ctx->use_silu) {
cur = ggml_silu_inplace(ctx0, cur);
// Apply down projection
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
} else {
cur = ggml_gelu_quick_inplace(ctx0, cur);
}
// Original MLP
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);

if (ctx->use_gelu) {
cur = ggml_gelu_inplace(ctx0, cur);
} else if (ctx->use_silu) {
cur = ggml_silu_inplace(ctx0, cur);
} else {
cur = ggml_gelu_quick_inplace(ctx0, cur);
}

cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
}

// residual 2
cur = ggml_add(ctx0, embeddings, cur);
Expand Down Expand Up @@ -1085,7 +1122,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
const int d_head = 128;
int n_head = hidden_size/d_head;
int num_query = 96;
if (ctx->minicpmv_version == 2) {
if (ctx->is_qwen2_5) {
hidden_size = 2048;
n_head = hidden_size/d_head;
num_query = 64;
}
else if (ctx->minicpmv_version == 2) {
hidden_size = 4096;
n_head = hidden_size/d_head;
num_query = 96;
Expand Down Expand Up @@ -1296,30 +1338,30 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
}
}

//#ifdef GGML_USE_CUDA
// new_clip->backend = ggml_backend_cuda_init(0);
// LOG_INF("%s: CLIP using CUDA backend\n", __func__);
//#endif
//
//#ifdef GGML_USE_METAL
// new_clip->backend = ggml_backend_metal_init();
// LOG_INF("%s: CLIP using Metal backend\n", __func__);
//#endif
//
//#ifdef GGML_USE_CANN
// new_clip->backend = ggml_backend_cann_init(0);
// LOG_INF("%s: CLIP using CANN backend\n", __func__);
//#endif
//
//#ifdef GGML_USE_VULKAN
// new_clip->backend = ggml_backend_vk_init(0);
// LOG_INF("%s: CLIP using Vulkan backend\n", __func__);
//#endif
//
//#ifdef GGML_USE_SYCL
// new_clip->backend = ggml_backend_sycl_init(0);
// LOG_INF("%s: CLIP using SYCL backend\n", __func__);
//#endif
#ifdef GGML_USE_CUDA
new_clip->backend = ggml_backend_cuda_init(0);
LOG_INF("%s: CLIP using CUDA backend\n", __func__);
#endif

#ifdef GGML_USE_METAL
new_clip->backend = ggml_backend_metal_init();
LOG_INF("%s: CLIP using Metal backend\n", __func__);
#endif

#ifdef GGML_USE_CANN
new_clip->backend = ggml_backend_cann_init(0);
LOG_INF("%s: CLIP using CANN backend\n", __func__);
#endif

#ifdef GGML_USE_VULKAN
new_clip->backend = ggml_backend_vk_init(0);
LOG_INF("%s: CLIP using Vulkan backend\n", __func__);
#endif

#ifdef GGML_USE_SYCL
new_clip->backend = ggml_backend_sycl_init(0);
LOG_INF("%s: CLIP using SYCL backend\n", __func__);
#endif

if (!new_clip->backend) {
new_clip->backend = ggml_backend_cpu_init();
Expand Down Expand Up @@ -1360,6 +1402,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
}
// GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search

idx = gguf_find_key(ctx, KEY_IS_QWEN2_5);
if (idx != -1) {
new_clip->is_qwen2_5 = gguf_get_val_bool(ctx, idx);
}

GGML_ASSERT(new_clip->has_vision_encoder);
GGML_ASSERT(!new_clip->has_text_encoder);

Expand Down Expand Up @@ -2942,7 +2989,10 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->vision_model.mm_3_b->ne[0];
}
if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
if (ctx->minicpmv_version == 2) {
if (ctx->is_qwen2_5) {
return 2048;
}
else if (ctx->minicpmv_version == 2) {
return 4096;
}
else if (ctx->minicpmv_version == 3) {
Expand All @@ -2956,6 +3006,11 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->vision_model.mm_model_mlp_3_w->ne[1];
}
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
// For Qwen2.5, the output dimension is 2048 instead of 3584
if (ctx->is_qwen2_5) {
LOG_INF("%s: Qwen2.5 detected, using output dimension 2048\n", __func__);
return 2048;
}
return ctx->vision_model.mm_1_b->ne[0];
}

Expand All @@ -2976,6 +3031,9 @@ bool clip_is_glm(const struct clip_ctx * ctx) {
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
return ctx->has_qwen2vl_merger;
}
bool clip_is_qwen2_5vl(const struct clip_ctx * ctx) {
return ctx->is_qwen2_5;
}

// Determine the number of encoder layers to iterate over
int get_deepest_feature_layer(const struct clip_ctx * ctx) {
Expand Down
1 change: 1 addition & 0 deletions examples/llava/clip.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
CLIP_API bool clip_is_glm(const struct clip_ctx * ctx);
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
CLIP_API bool clip_is_qwen2_5vl(const struct clip_ctx * ctx);

CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx);

Expand Down
Loading