Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,7 @@ struct vk_op_diag_mask_push_constants {
struct vk_op_rope_push_constants {
uint32_t rope_mode;
uint32_t ncols;
uint32_t nrows;
uint32_t n_dims;
float freq_scale;
uint32_t p_delta_rows;
Expand Down Expand Up @@ -9083,10 +9084,20 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
} break;
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
break;
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
{
uint32_t nrows = (uint32_t)ggml_nrows(src0);
uint32_t z = 1;
if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) {
z = CEIL_DIV(nrows, 32768);
nrows = 32768;
}
elements = { nrows, (uint32_t)ne00, z };

} break;
case GGML_OP_GET_ROWS:
elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
Expand Down Expand Up @@ -10014,7 +10025,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);

vk_op_rope_push_constants rope {
(uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
(uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
has_ff, (uint32_t)src0->ne[2], nb01, nb02,
{ sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
Expand Down
5 changes: 4 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i1 = gl_GlobalInvocationID.x;
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (i1 >= pc.nrows) {
return;
}
rope_multi(i0, i1, pc);
}
5 changes: 4 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i1 = gl_GlobalInvocationID.x;
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (i1 >= pc.nrows) {
return;
}
rope_neox(i0, i1, pc);
}
5 changes: 4 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i1 = gl_GlobalInvocationID.x;
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (i1 >= pc.nrows) {
return;
}
rope_norm(i0, i1, pc);
}
1 change: 1 addition & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
struct rope_params {
uint rope_mode;
uint ncols;
uint nrows;
uint n_dims;
float freq_scale;
uint p_delta_rows;
Expand Down
5 changes: 4 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i1 = gl_GlobalInvocationID.x;
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (i1 >= pc.nrows) {
return;
}
rope_vision(i0, i1, pc);
}
3 changes: 3 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7732,6 +7732,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 13B
test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 30B
test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 65B
test_cases.emplace_back(new test_rope(type, {16, 16, 8192, 1}, 16, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
}

if (all) {
Expand All @@ -7746,6 +7747,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
test_cases.emplace_back(new test_rope(type, { 16, 16, 8192, 1}, 16, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw));
}

if (all) {
Expand All @@ -7759,6 +7761,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
test_cases.emplace_back(new test_rope(type, {128, 16, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen3vl)
test_cases.emplace_back(new test_rope(type, {16, 16, 8192, 1}, 16, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
}

test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
Expand Down
Loading