Skip to content

Perf/graph cache flash attention#1

Open
yzimmermann wants to merge 3 commits into
re-Isearch:mainfrom
yzimmermann:perf/graph-cache-flash-attention
Open

Perf/graph cache flash attention#1
yzimmermann wants to merge 3 commits into
re-Isearch:mainfrom
yzimmermann:perf/graph-cache-flash-attention

Conversation

@yzimmermann

Copy link
Copy Markdown

No description provided.

Recent ggml moved ggml-metal.metal into the ggml-metal/ subdirectory.
The CMake configure_file() call still pointed at the old top-level
location, causing configure to fail with "File does not exist" on any
fresh checkout against current ggml.
Refactor of bert_forward_batch / bert_build_graph for substantial
speedups and to make bert_forward_batch usable at batch>1.

Graph + allocator cache
-----------------------
Each call to bert_forward_batch previously rebuilt the entire ggml
graph and ran ggml_gallocr_alloc_graph from scratch — on Metal that
included a fresh MTLHeap allocation per call. The new code keys a
per-(max_len, n_batch) cache on bert_ctx and builds the graph + runs
the gallocr planner exactly once per distinct shape. Cached pointers
to the input/output tensors avoid graph_get_tensor lookups on the
hot path. Per-call scratch buffers (token ids, pad mask, pool weights)
are also stored on the cached entry to eliminate heap churn.

Flash attention
---------------
The unrolled mul_mat -> scale -> add(mask) -> softmax -> mul_mat
sequence is replaced by ggml_flash_attn_ext. On Metal this uses the
fused SDPA kernel. The attention mask is built as before (outer product
of the pad mask) and cast to f16 for the flash_attn API. F16 accumulate
is used on accelerated backends (significant speedup, negligible drift
for sentence embeddings); CPU is forced to F32 since it has no native
fp16 path and would otherwise pay extra conversion work.

Batched mean-pool
-----------------
The old mean-pool flattened [n_embd, max_len, n_batch] to a 2D matmul
that summed across the entire batch, which is wrong for n_batch > 1
and crashed at ggml_reshape_2d. Replaced with a 3D mul_mat that
preserves the batch dimension. bert_forward_batch is now bit-exact at
bs=N for N identical inputs vs N invocations at bs=1.

CPU input-stability workaround
------------------------------
On the CPU backend the gallocr appears to reuse memory of input-flagged
tensors between successive graph_compute calls, despite the input flag.
Metal keeps them stable. positions/token_types are re-uploaded each
forward on CPU only; on accelerated backends they are written once at
graph build.

Results on M-series Apple Silicon, all-MiniLM-L6-v2 Q4_K_M, bs=1:
  Metal max-length: 38.2 ms -> 5.7 ms  (6.7x)
  Metal medium:      6.6 ms -> 3.9 ms  (1.7x)
  CPU  max-length:  56.6 ms -> ~41 ms  (~1.4x)
Metal batched throughput at bs=8 reaches ~90k tok/s; previously bs>1
crashed.

24/24 correctness tests pass; CPU/Metal cosine agreement preserved
(0.9996); bs>1 outputs bit-exact vs bs=1.
examples/bench loads the model once, warms up, then times N forward
passes at varying input lengths and batch sizes. Reports mean / median
/ p95 / min / max latency and tokens-per-second.

Used to validate the perf refactor in the preceding commit:
  ./build/bin/bench -m MODEL [-c] [-w WARMUP] [-n ITERS] [-t THREADS]
Copilot AI review requested due to automatic review settings May 26, 2026 02:01
@yzimmermann yzimmermann reopened this May 26, 2026

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds a new benchmarking example and improves BERT inference performance by caching compute graphs per (sequence length, batch size) and using fused attention where available.

Changes:

  • Add examples/bench.cpp and wire it into examples/CMakeLists.txt
  • Cache ggml graphs/allocations in bert_ctx keyed by (max_len, n_batch) to avoid rebuild + gallocr planning per forward
  • Update Metal shader copy path in top-level CMakeLists.txt

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
examples/bench.cpp New CLI benchmark tool for latency/throughput across token lengths and batch sizes
examples/CMakeLists.txt Builds/links the new bench example
bert.h Adds an opaque per-context graph cache pointer to support graph reuse
bert.cpp Implements cached-graph construction, cached forward path, and uses ggml_flash_attn_ext
CMakeLists.txt Fixes Metal shader source path for runtime copy

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/bench.cpp
Comment on lines +28 to +40
static bool parse(int argc, char ** argv, bench_params & p) {
for (int i = 1; i < argc; i++) {
std::string a = argv[i];
if (a == "-m") p.model = argv[++i];
else if (a == "-c") p.use_cpu = true;
else if (a == "-t") p.n_threads = std::stoi(argv[++i]);
else if (a == "-w") p.warmup = std::stoi(argv[++i]);
else if (a == "-n") p.iters = std::stoi(argv[++i]);
else if (a == "-h" || a == "--help") { usage(argv[0]); return false; }
else { fprintf(stderr, "unknown arg: %s\n", a.c_str()); usage(argv[0]); return false; }
}
return p.model != nullptr;
}
Comment thread examples/bench.cpp
Comment on lines +42 to +56
static double median(std::vector<double> v) {
std::sort(v.begin(), v.end());
size_t n = v.size();
return n % 2 ? v[n/2] : 0.5*(v[n/2 - 1] + v[n/2]);
}

static double percentile(std::vector<double> v, double p) {
std::sort(v.begin(), v.end());
size_t idx = (size_t)((v.size() - 1) * p);
return v[idx];
}

static double mean(const std::vector<double> & v) {
double s = 0; for (double x : v) s += x; return s / v.size();
}
Comment thread bert.cpp
@@ -64,7 +64,9 @@ For these use the llama.cpp interface.
#include <cmath>
#include <fstream>
#include <algorithm>
Comment thread bert.cpp
Comment on lines +1270 to +1275
// No ggml_cont needed; flash_attn handles non-contiguous q/k/v.
auto proj = [&](ggml_tensor * w, ggml_tensor * b) {
ggml_tensor * x = ggml_add(ctx0, ggml_mul_mat(ctx0, w, cur), b);
x = ggml_reshape_4d(ctx0, x, d_head, n_head, max_len, n_batch);
return ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 2, 1, 3));
};
Comment thread bert.cpp
Comment on lines +1173 to +1176
using bert_graph_key = std::pair<int,int>; // (max_len, n_batch)
struct bert_graph_cache_map {
std::map<bert_graph_key, std::unique_ptr<bert_cached_graph>> entries;
};
Comment thread bert.h
Comment on lines 91 to 99

std::string model_name = "";
std::string model_arch = "bert";

// Opaque graph cache (bert_graph_cache_map *). Built lazily on first forward
// at a given (max_len, batch_size); reused for all subsequent forwards at the
// same shape so we avoid rebuilding the graph and re-running gallocr.
void * graph_cache = nullptr;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants