Perf/graph cache flash attention#1
Open
yzimmermann wants to merge 3 commits into
Open
Conversation
Recent ggml moved ggml-metal.metal into the ggml-metal/ subdirectory. The CMake configure_file() call still pointed at the old top-level location, causing configure to fail with "File does not exist" on any fresh checkout against current ggml.
Refactor of bert_forward_batch / bert_build_graph for substantial speedups and to make bert_forward_batch usable at batch>1. Graph + allocator cache ----------------------- Each call to bert_forward_batch previously rebuilt the entire ggml graph and ran ggml_gallocr_alloc_graph from scratch — on Metal that included a fresh MTLHeap allocation per call. The new code keys a per-(max_len, n_batch) cache on bert_ctx and builds the graph + runs the gallocr planner exactly once per distinct shape. Cached pointers to the input/output tensors avoid graph_get_tensor lookups on the hot path. Per-call scratch buffers (token ids, pad mask, pool weights) are also stored on the cached entry to eliminate heap churn. Flash attention --------------- The unrolled mul_mat -> scale -> add(mask) -> softmax -> mul_mat sequence is replaced by ggml_flash_attn_ext. On Metal this uses the fused SDPA kernel. The attention mask is built as before (outer product of the pad mask) and cast to f16 for the flash_attn API. F16 accumulate is used on accelerated backends (significant speedup, negligible drift for sentence embeddings); CPU is forced to F32 since it has no native fp16 path and would otherwise pay extra conversion work. Batched mean-pool ----------------- The old mean-pool flattened [n_embd, max_len, n_batch] to a 2D matmul that summed across the entire batch, which is wrong for n_batch > 1 and crashed at ggml_reshape_2d. Replaced with a 3D mul_mat that preserves the batch dimension. bert_forward_batch is now bit-exact at bs=N for N identical inputs vs N invocations at bs=1. CPU input-stability workaround ------------------------------ On the CPU backend the gallocr appears to reuse memory of input-flagged tensors between successive graph_compute calls, despite the input flag. Metal keeps them stable. positions/token_types are re-uploaded each forward on CPU only; on accelerated backends they are written once at graph build. Results on M-series Apple Silicon, all-MiniLM-L6-v2 Q4_K_M, bs=1: Metal max-length: 38.2 ms -> 5.7 ms (6.7x) Metal medium: 6.6 ms -> 3.9 ms (1.7x) CPU max-length: 56.6 ms -> ~41 ms (~1.4x) Metal batched throughput at bs=8 reaches ~90k tok/s; previously bs>1 crashed. 24/24 correctness tests pass; CPU/Metal cosine agreement preserved (0.9996); bs>1 outputs bit-exact vs bs=1.
examples/bench loads the model once, warms up, then times N forward passes at varying input lengths and batch sizes. Reports mean / median / p95 / min / max latency and tokens-per-second. Used to validate the perf refactor in the preceding commit: ./build/bin/bench -m MODEL [-c] [-w WARMUP] [-n ITERS] [-t THREADS]
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds a new benchmarking example and improves BERT inference performance by caching compute graphs per (sequence length, batch size) and using fused attention where available.
Changes:
- Add
examples/bench.cppand wire it intoexamples/CMakeLists.txt - Cache ggml graphs/allocations in
bert_ctxkeyed by(max_len, n_batch)to avoid rebuild + gallocr planning per forward - Update Metal shader copy path in top-level
CMakeLists.txt
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| examples/bench.cpp | New CLI benchmark tool for latency/throughput across token lengths and batch sizes |
| examples/CMakeLists.txt | Builds/links the new bench example |
| bert.h | Adds an opaque per-context graph cache pointer to support graph reuse |
| bert.cpp | Implements cached-graph construction, cached forward path, and uses ggml_flash_attn_ext |
| CMakeLists.txt | Fixes Metal shader source path for runtime copy |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+28
to
+40
| static bool parse(int argc, char ** argv, bench_params & p) { | ||
| for (int i = 1; i < argc; i++) { | ||
| std::string a = argv[i]; | ||
| if (a == "-m") p.model = argv[++i]; | ||
| else if (a == "-c") p.use_cpu = true; | ||
| else if (a == "-t") p.n_threads = std::stoi(argv[++i]); | ||
| else if (a == "-w") p.warmup = std::stoi(argv[++i]); | ||
| else if (a == "-n") p.iters = std::stoi(argv[++i]); | ||
| else if (a == "-h" || a == "--help") { usage(argv[0]); return false; } | ||
| else { fprintf(stderr, "unknown arg: %s\n", a.c_str()); usage(argv[0]); return false; } | ||
| } | ||
| return p.model != nullptr; | ||
| } |
Comment on lines
+42
to
+56
| static double median(std::vector<double> v) { | ||
| std::sort(v.begin(), v.end()); | ||
| size_t n = v.size(); | ||
| return n % 2 ? v[n/2] : 0.5*(v[n/2 - 1] + v[n/2]); | ||
| } | ||
|
|
||
| static double percentile(std::vector<double> v, double p) { | ||
| std::sort(v.begin(), v.end()); | ||
| size_t idx = (size_t)((v.size() - 1) * p); | ||
| return v[idx]; | ||
| } | ||
|
|
||
| static double mean(const std::vector<double> & v) { | ||
| double s = 0; for (double x : v) s += x; return s / v.size(); | ||
| } |
| @@ -64,7 +64,9 @@ For these use the llama.cpp interface. | |||
| #include <cmath> | |||
| #include <fstream> | |||
| #include <algorithm> | |||
Comment on lines
+1270
to
+1275
| // No ggml_cont needed; flash_attn handles non-contiguous q/k/v. | ||
| auto proj = [&](ggml_tensor * w, ggml_tensor * b) { | ||
| ggml_tensor * x = ggml_add(ctx0, ggml_mul_mat(ctx0, w, cur), b); | ||
| x = ggml_reshape_4d(ctx0, x, d_head, n_head, max_len, n_batch); | ||
| return ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 2, 1, 3)); | ||
| }; |
Comment on lines
+1173
to
+1176
| using bert_graph_key = std::pair<int,int>; // (max_len, n_batch) | ||
| struct bert_graph_cache_map { | ||
| std::map<bert_graph_key, std::unique_ptr<bert_cached_graph>> entries; | ||
| }; |
Comment on lines
91
to
99
|
|
||
| std::string model_name = ""; | ||
| std::string model_arch = "bert"; | ||
|
|
||
| // Opaque graph cache (bert_graph_cache_map *). Built lazily on first forward | ||
| // at a given (max_len, batch_size); reused for all subsequent forwards at the | ||
| // same shape so we avoid rebuilding the graph and re-running gallocr. | ||
| void * graph_cache = nullptr; | ||
| }; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.