|
170 | 170 | id<MTLCommandQueue> queue;
|
171 | 171 | id<MTLLibrary> library;
|
172 | 172 |
|
173 |
| - id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS]; |
174 |
| - id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS]; |
175 |
| - |
176 | 173 | dispatch_queue_t d_queue;
|
177 | 174 |
|
178 | 175 | int n_buffers;
|
@@ -719,41 +716,39 @@ static bool ggml_metal_graph_compute(
|
719 | 716 | @autoreleasepool {
|
720 | 717 |
|
721 | 718 | MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
722 |
| - |
723 |
| - const int n_nodes = gf->n_nodes; |
724 | 719 | edesc.dispatchType = MTLDispatchTypeSerial;
|
725 | 720 |
|
726 | 721 | // create multiple command buffers and enqueue them
|
727 | 722 | // then, we encode the graph into the command buffers in parallel
|
728 | 723 |
|
| 724 | + const int n_nodes = gf->n_nodes; |
729 | 725 | const int n_cb = ctx->n_cb;
|
| 726 | + const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; |
730 | 727 |
|
731 |
| - for (int i = 0; i < n_cb; ++i) { |
732 |
| - ctx->command_buffers[i] = [ctx->queue commandBuffer]; |
| 728 | + id<MTLCommandBuffer> command_buffer_builder[n_cb]; |
| 729 | + for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { |
| 730 | + id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; |
| 731 | + command_buffer_builder[cb_idx] = command_buffer; |
733 | 732 |
|
734 | 733 | // enqueue the command buffers in order to specify their execution order
|
735 |
| - [ctx->command_buffers[i] enqueue]; |
736 |
| - |
737 |
| - ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc]; |
| 734 | + [command_buffer enqueue]; |
738 | 735 | }
|
| 736 | + const id<MTLCommandBuffer> *command_buffers = command_buffer_builder; |
739 | 737 |
|
740 |
| - const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; |
741 | 738 | dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
|
742 | 739 | const int cb_idx = iter;
|
743 | 740 |
|
744 | 741 | size_t offs_src0 = 0;
|
745 | 742 | size_t offs_src1 = 0;
|
746 | 743 | size_t offs_dst = 0;
|
747 | 744 |
|
748 |
| - id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx]; |
749 |
| - id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx]; |
| 745 | + id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx]; |
| 746 | + id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; |
750 | 747 |
|
751 | 748 | const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
752 | 749 | const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
753 | 750 |
|
754 |
| - for (int ind = node_start; ind < node_end; ++ind) { |
755 |
| - const int i = ind; |
756 |
| - |
| 751 | + for (int i = node_start; i < node_end; ++i) { |
757 | 752 | if (i == -1) {
|
758 | 753 | [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
759 | 754 | continue;
|
@@ -2249,12 +2244,14 @@ static bool ggml_metal_graph_compute(
|
2249 | 2244 | [command_buffer commit];
|
2250 | 2245 | });
|
2251 | 2246 |
|
2252 |
| - // check status of command buffers |
| 2247 | + // Wait for completion and check status of each command buffer |
2253 | 2248 | // needed to detect if the device ran out-of-memory for example (#1881)
|
2254 |
| - for (int i = 0; i < n_cb; i++) { |
2255 |
| - [ctx->command_buffers[i] waitUntilCompleted]; |
2256 | 2249 |
|
2257 |
| - MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status]; |
| 2250 | + for (int i = 0; i < n_cb; ++i) { |
| 2251 | + id<MTLCommandBuffer> command_buffer = command_buffers[i]; |
| 2252 | + [command_buffer waitUntilCompleted]; |
| 2253 | + |
| 2254 | + MTLCommandBufferStatus status = [command_buffer status]; |
2258 | 2255 | if (status != MTLCommandBufferStatusCompleted) {
|
2259 | 2256 | GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
2260 | 2257 | return false;
|
|
0 commit comments