-
Notifications
You must be signed in to change notification settings - Fork 14.2k
Add metal count equal op #18314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add metal count equal op #18314
Conversation
| const size_t smem = pipeline.smem; | ||
| int64_t z = 0; | ||
| ggml_backend_tensor_set(op, &z, 0, sizeof(z)); | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not work, you need to call a separate kernel that fills the buffer with zeros
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a new kernel to memset a buffer to a value. Similar to fill but simpler pipeline and only takes the buffer and value.
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Add metal count equal op
This PR extends the CPU implementations of count_equal to Metal.
The current implementation uses a single thread group, but supports multiple if anything changes. This currently matches the CPU / Cuda implementation in which only takes int32 for src0 and src1. This kernel uses the
atomic_fetch_add_explicit, which only supports returning an int32 adds similar to Cuda. This limits the size of the buffers we can take in to 2^31 - 1.The docs have been updated.
codex generated summary:
Summary
This PR introduces a Metal implementation for
COUNT_EQUALonint32tensors that uses SIMD-group reduction to efficiently compute per-threadgroup partial counts and accumulate the result into the destination buffer using atomic operations.The change improves parallel efficiency over a naïve per-element atomic approach by:
simd_sumKey Changes
kernel_count_equal<int32_t>shmem_i32) and SIMD intrinsics (simd_sum) to aggregate countsatomic_fetch_add_explicitper SIMD group