Skip to content

Commit 158f8c9

Browse files
metal : localized logic in ggml_metal_graph_compute (#4924)
* Metal: Localized logic in `ggml_metal_graph_compute`, minor performance improvement * Whitespace * Collecting command buffer completions on single thread * Whitespace * Reduce diff noise
1 parent 862f5e4 commit 158f8c9

File tree

2 files changed

+17
-21
lines changed

2 files changed

+17
-21
lines changed

ggml-metal.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
// max memory buffers that can be mapped to the device
2929
#define GGML_METAL_MAX_BUFFERS 64
30-
#define GGML_METAL_MAX_COMMAND_BUFFERS 32
3130

3231
struct ggml_tensor;
3332
struct ggml_cgraph;

ggml-metal.m

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,6 @@
170170
id<MTLCommandQueue> queue;
171171
id<MTLLibrary> library;
172172

173-
id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
174-
id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
175-
176173
dispatch_queue_t d_queue;
177174

178175
int n_buffers;
@@ -719,41 +716,39 @@ static bool ggml_metal_graph_compute(
719716
@autoreleasepool {
720717

721718
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
722-
723-
const int n_nodes = gf->n_nodes;
724719
edesc.dispatchType = MTLDispatchTypeSerial;
725720

726721
// create multiple command buffers and enqueue them
727722
// then, we encode the graph into the command buffers in parallel
728723

724+
const int n_nodes = gf->n_nodes;
729725
const int n_cb = ctx->n_cb;
726+
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
730727

731-
for (int i = 0; i < n_cb; ++i) {
732-
ctx->command_buffers[i] = [ctx->queue commandBuffer];
728+
id<MTLCommandBuffer> command_buffer_builder[n_cb];
729+
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
730+
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
731+
command_buffer_builder[cb_idx] = command_buffer;
733732

734733
// enqueue the command buffers in order to specify their execution order
735-
[ctx->command_buffers[i] enqueue];
736-
737-
ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
734+
[command_buffer enqueue];
738735
}
736+
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
739737

740-
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
741738
dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
742739
const int cb_idx = iter;
743740

744741
size_t offs_src0 = 0;
745742
size_t offs_src1 = 0;
746743
size_t offs_dst = 0;
747744

748-
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
749-
id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
745+
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
746+
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
750747

751748
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
752749
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
753750

754-
for (int ind = node_start; ind < node_end; ++ind) {
755-
const int i = ind;
756-
751+
for (int i = node_start; i < node_end; ++i) {
757752
if (i == -1) {
758753
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
759754
continue;
@@ -2249,12 +2244,14 @@ static bool ggml_metal_graph_compute(
22492244
[command_buffer commit];
22502245
});
22512246

2252-
// check status of command buffers
2247+
// Wait for completion and check status of each command buffer
22532248
// needed to detect if the device ran out-of-memory for example (#1881)
2254-
for (int i = 0; i < n_cb; i++) {
2255-
[ctx->command_buffers[i] waitUntilCompleted];
22562249

2257-
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
2250+
for (int i = 0; i < n_cb; ++i) {
2251+
id<MTLCommandBuffer> command_buffer = command_buffers[i];
2252+
[command_buffer waitUntilCompleted];
2253+
2254+
MTLCommandBufferStatus status = [command_buffer status];
22582255
if (status != MTLCommandBufferStatusCompleted) {
22592256
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
22602257
return false;

0 commit comments

Comments
 (0)