From c9b4b5eabf50713d36736335014808748b4202f4 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Mon, 22 Dec 2025 08:27:22 +0100 Subject: [PATCH] vulkan: use fewer FA rows for small cache runs --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 56 ++++++++++++++-------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a871f85afb..59b76003c7 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -379,18 +379,18 @@ enum FaCodePath { }; struct vk_fa_pipeline_state { - vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc) - : HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {} + vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc) + : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {} uint32_t HSK, HSV; - bool small_rows; + bool small_rows, small_cache; FaCodePath path; bool aligned; bool f32acc; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) < - std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc); + return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) < + std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc); } }; @@ -2564,10 +2564,10 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; -static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) { +static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) { if (hsv >= 192) { return 2; - } else if ((hsv | hsk) & 8) { + } else if ((hsv | hsk) & 8 || small_cache) { return 4; } else { return 8; @@ -2589,9 +2589,8 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) { } } -static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) { +static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) { GGML_UNUSED(clamp); - GGML_UNUSED(hsv); if (path == FA_SCALAR) { if (small_rows) { @@ -2600,9 +2599,9 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 if ((hsv | hsk) & 8) { // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. - return {get_fa_scalar_num_large_rows(hsk, hsv), 64}; + return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64}; } else { - return {get_fa_scalar_num_large_rows(hsk, hsv), 32}; + return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32}; } } } @@ -2631,8 +2630,8 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 return {64, 64}; } -static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) { - return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1]; +static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) { + return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1]; } static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -2974,11 +2973,11 @@ static void ggml_vk_load_shaders(vk_device& device) { align, disable_robustness, require_full_subgroups, required_subgroup_size); }; - auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { - return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1}; + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array { + return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1}; }; - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // can't use 256 for D==80. @@ -2988,7 +2987,7 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) ? scalar_flash_attention_workgroup_size : ((small_rows && (D % 32) == 0) ? 256 : 128); - auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows); + auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache); // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. @@ -3003,21 +3002,22 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t HSK = fa.first.HSK; \ uint32_t HSV = fa.first.HSV; \ bool small_rows = fa.first.small_rows; \ + bool small_cache = fa.first.small_cache; \ FaCodePath path = fa.first.path; \ bool aligned = fa.first.aligned; \ bool f32acc = fa.first.f32acc; \ if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } \ } \ @@ -7990,11 +7990,11 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) { // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv); + const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache); const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t tmpsh = wg_size * sizeof(float); @@ -8118,6 +8118,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; + const bool small_cache = nek1 < 1024; + // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). uint32_t max_gqa; @@ -8125,7 +8127,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx case FA_SCALAR: case FA_COOPMAT1: // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_large_rows(HSK, HSV); + max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache); break; case FA_COOPMAT2: max_gqa = get_fa_num_small_rows(FA_COOPMAT2); @@ -8159,7 +8161,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory if (path == FA_SCALAR && - !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) { + !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) { small_rows = true; } @@ -8175,7 +8177,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx v_stride /= 4; } - uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows); + uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache); bool aligned = (KV % alignment) == 0 && // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; @@ -8187,7 +8189,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; - vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc); + vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc); vk_pipeline pipeline = nullptr;