Skip to content

Commit c9b4b5e

Browse files
committed
vulkan: use fewer FA rows for small cache runs
1 parent fd05c51 commit c9b4b5e

File tree

1 file changed

+29
-27
lines changed

1 file changed

+29
-27
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -379,18 +379,18 @@ enum FaCodePath {
379379
};
380380

381381
struct vk_fa_pipeline_state {
382-
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc)
383-
: HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {}
382+
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc)
383+
: HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {}
384384

385385
uint32_t HSK, HSV;
386-
bool small_rows;
386+
bool small_rows, small_cache;
387387
FaCodePath path;
388388
bool aligned;
389389
bool f32acc;
390390

391391
bool operator<(const vk_fa_pipeline_state &b) const {
392-
return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) <
393-
std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc);
392+
return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) <
393+
std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc);
394394
}
395395
};
396396

@@ -2564,10 +2564,10 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
25642564
static constexpr uint32_t flash_attention_num_small_rows = 32;
25652565
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
25662566

2567-
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) {
2567+
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) {
25682568
if (hsv >= 192) {
25692569
return 2;
2570-
} else if ((hsv | hsk) & 8) {
2570+
} else if ((hsv | hsk) & 8 || small_cache) {
25712571
return 4;
25722572
} else {
25732573
return 8;
@@ -2589,9 +2589,8 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
25892589
}
25902590
}
25912591

2592-
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
2592+
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) {
25932593
GGML_UNUSED(clamp);
2594-
GGML_UNUSED(hsv);
25952594

25962595
if (path == FA_SCALAR) {
25972596
if (small_rows) {
@@ -2600,9 +2599,9 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
26002599
if ((hsv | hsk) & 8) {
26012600
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
26022601
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
2603-
return {get_fa_scalar_num_large_rows(hsk, hsv), 64};
2602+
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
26042603
} else {
2605-
return {get_fa_scalar_num_large_rows(hsk, hsv), 32};
2604+
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32};
26062605
}
26072606
}
26082607
}
@@ -2631,8 +2630,8 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
26312630
return {64, 64};
26322631
}
26332632

2634-
static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) {
2635-
return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1];
2633+
static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) {
2634+
return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1];
26362635
}
26372636

26382637
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
@@ -2974,11 +2973,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
29742973
align, disable_robustness, require_full_subgroups, required_subgroup_size);
29752974
};
29762975

2977-
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
2978-
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
2976+
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array<uint32_t, 3> {
2977+
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
29792978
};
29802979

2981-
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
2980+
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector<uint32_t> {
29822981
// For large number of rows, 128 invocations seems to work best.
29832982
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
29842983
// can't use 256 for D==80.
@@ -2988,7 +2987,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
29882987
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
29892988
? scalar_flash_attention_workgroup_size
29902989
: ((small_rows && (D % 32) == 0) ? 256 : 128);
2991-
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
2990+
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
29922991

29932992
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
29942993
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@@ -3003,21 +3002,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
30033002
uint32_t HSK = fa.first.HSK; \
30043003
uint32_t HSV = fa.first.HSV; \
30053004
bool small_rows = fa.first.small_rows; \
3005+
bool small_cache = fa.first.small_cache; \
30063006
FaCodePath path = fa.first.path; \
30073007
bool aligned = fa.first.aligned; \
30083008
bool f32acc = fa.first.f32acc; \
30093009
if (path == FAPATH) { \
30103010
if (aligned) { \
30113011
if (f32acc) { \
3012-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3012+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
30133013
} else { \
3014-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3014+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
30153015
} \
30163016
} else { \
30173017
if (f32acc) { \
3018-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3018+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
30193019
} else { \
3020-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3020+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
30213021
} \
30223022
} \
30233023
} \
@@ -7990,11 +7990,11 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
79907990
}
79917991
}
79927992

7993-
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
7993+
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) {
79947994
// Needs to be kept up to date on shader changes
79957995
GGML_UNUSED(hsv);
79967996
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
7997-
const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv);
7997+
const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache);
79987998
const uint32_t Bc = scalar_flash_attention_Bc;
79997999

80008000
const uint32_t tmpsh = wg_size * sizeof(float);
@@ -8118,14 +8118,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
81188118
uint32_t workgroups_y = (uint32_t)neq2;
81198119
uint32_t workgroups_z = (uint32_t)neq3;
81208120

8121+
const bool small_cache = nek1 < 1024;
8122+
81218123
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
81228124
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
81238125
uint32_t max_gqa;
81248126
switch (path) {
81258127
case FA_SCALAR:
81268128
case FA_COOPMAT1:
81278129
// We may switch from coopmat1 to scalar, so use the scalar limit for both
8128-
max_gqa = get_fa_scalar_num_large_rows(HSK, HSV);
8130+
max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache);
81298131
break;
81308132
case FA_COOPMAT2:
81318133
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
@@ -8159,7 +8161,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
81598161

81608162
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
81618163
if (path == FA_SCALAR &&
8162-
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
8164+
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) {
81638165
small_rows = true;
81648166
}
81658167

@@ -8175,7 +8177,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
81758177
v_stride /= 4;
81768178
}
81778179

8178-
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
8180+
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache);
81798181
bool aligned = (KV % alignment) == 0 &&
81808182
// the "aligned" shader variant will forcibly align strides, for performance
81818183
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
@@ -8187,7 +8189,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
81878189

81888190
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
81898191

8190-
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc);
8192+
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc);
81918193

81928194
vk_pipeline pipeline = nullptr;
81938195

0 commit comments

Comments
 (0)