Skip to content

ggml: improve ggml_backend_cuda_cpy_tensor_async #13818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

koush
Copy link

@koush koush commented May 27, 2025

I'm working on a general tensor parallel backend. It leverages asynchronous tensor copies. I found that the pipeline was stalling here when the async call was actually sync.

@koush koush force-pushed the cuda-memcpy-async branch from f76085e to 5225eaa Compare May 27, 2025 03:53
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels May 27, 2025
@koush koush changed the title ggml: fix ggml_backend_cuda_cpy_tensor_async device to device to actually be async ggml: improve ggml_backend_cuda_cpy_tensor_async May 27, 2025
@koush
Copy link
Author

koush commented May 27, 2025

The existing row split implementation does the mul mat and then immediately gathers the tensors and is per backend implementation. My current approach leaves the tensors on the GPU for further unary and binary ops, and eventually needs to be gathered for the ROPE and RMS ops. It's around 15% faster than single GPU (and much faster than row splitting, which is slower than single GPU on CUDA), but graph execution is currently disabled and once I get that sorted it should be significantly improved.

@koush koush force-pushed the cuda-memcpy-async branch from 5225eaa to 1984136 Compare May 27, 2025 04:04
@JohannesGaessler
Copy link
Collaborator

I also started working on tensor parallelism, see #13776 . I would be happy to leave the implementation to you if you're interested in working on it.

Comment on lines 1398 to 1401
if (input_backend->iface.synchronize) {
// async copy succeeded, need to synchronize the input backend to ensure the copy is done before the split backend uses it
input_backend->iface.synchronize(input_backend);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A synchronization after an async copy is not necessary. The way async copy is intended to work is roughly explained here:

// asynchronous copy
// the copy is performed after all the currently queued operations in backend_src
// backend_dst will wait for the copy to complete before performing other operations
// automatic fallback to sync copy if async is not supported
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);

Copy link
Author

@koush koush May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, however i was receiving garbage output without this sync when using layer parallel. I suspect it’s because the stream being used for the async copy is the source and not the dest, as specified in the code comments.

how should I proceed here? I was wary of changing the existing stream behavior.

Copy link
Member

@slaren slaren May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a problem with the current implementation. The copy is performed on the source stream so that it happens at the end of all queued operations in the source backend. Then the destination stream waits on an event until the copy is complete, which ensures that any operations added later to the destination backend are executed after the copy has completed.

Copy link
Author

@koush koush May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that I need to manage the dest stream syncing manually, as I am queueing multiple asynchronous memcpy and then performing one synchronize after the gathers, allowing for concurrent transfers. Otherwise copying to the dest context is serialized. Without this change, I am unable to get tensor parallel working faster than 50% utilization per GPU (with 2 GPU) since each GPU ends up waiting for the other. I had a workaround for this by using different thread to avoid the blocking, but that seems like a lot of unnecessary overhead.

Copy link
Author

@koush koush May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should note that I am also not using ggml_backend_tensor_copy_async and calling the device memcpy and sync directly since similar sync behavior exists there. Maybe the ggml-backend.cpp should call that instead?

ggml_backend_synchronize(backend_src);

If this is intended behavior, this may just be a gap in the API where there's no way to start multiple asynchronous memcpy without blocking on the destination per copy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose you could remove the event wait on the dst stream at the end of the async copy, and transfer the responsibility of synchronizing the dst backend to the application.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the change to make the async memcpy happen on the dst stream, as that's where further ops will presumably occur with that data. It's the responsibility of the caller to ensure the src is safe to use for memcpy until it is complete. Ie, calling synchronize on the src backend if necessary. In some cases no synchronize is needed at all. I saw CANN is the only other backend that implements asynchronous memcpy, so I updated that as well.

Copy link
Author

@koush koush May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I took a further look at the pipeline parallel batching implementation and I think my last change would have negatively affected performance. I've updated the change to leave the src stream copy intact, and then issue a final sync once all inputs are sent. This way the input sync doesn't serialize the memcpy.

@koush koush force-pushed the cuda-memcpy-async branch from daa3f79 to f23582b Compare May 27, 2025 22:19
Make device to device actually async; right now it syncs on dst.
Implement host to/from device async.
@koush koush force-pushed the cuda-memcpy-async branch from f23582b to 7966d05 Compare May 28, 2025 01:28
@koush
Copy link
Author

koush commented May 29, 2025

I also started working on tensor parallelism, see #13776 . I would be happy to leave the implementation to you if you're interested in working on it.

I'm definitely interested, but don't let me stop you.

@koush
Copy link
Author

koush commented Jun 1, 2025

@slaren @JohannesGaessler
I've got a decent starting implementation for tensor parallelism now. It is dependent on this change and some other minor ones in the common backend code.

Using this model to test because it's dense enough to saturate GPU cores, the initial results on bartowski/nvidia_Llama-3_1-Nemotron-Ultra-253B-v1-GGUF:IQ4_NL on 2x RTX Pro 6000:

Before on layer split: 9 tokens/sec, GPU usage at 50/50 in nvidia-smi
Existing tensor split: 13 tokens/sec, GPU usage at 65/65 in nvidia-smi
New tensor split backend: 16 tokens/sec, both GPU usage at 90/90, 25% improvement

The new implementation also performs better than the existing implementation on smaller models as well: it has no GPU communication other than after the RMS Norm calls that require tensor gather (existing implementation gathers after every mul mat). That frequent gpu-gpu overhead kills performance.

On a smaller Qwen 3 32B dense model:
Before on layer split: 26 tokens/sec
Existing tensor split: 19 tokens/sec
New tensor split backend: 36 tokens/sec

Since the new GPU implementation is a wrapping backend, it should also work with heterogenous devices (if their respective backends implement the new backend requirements).

The very much work in progress is here: https://github.com/koush/llama.cpp/tree/parallel

Should I use this pull request to consolidate all my backend api changes before opening a pull request for tensor parallelism backend?

@slaren
Copy link
Member

slaren commented Jun 1, 2025

Should I use this pull request to consolidate all my backend api changes before opening a pull request for tensor parallelism backend?

I would prefer if everything is in the same PR so that it is clear what is the motivation for the changes to the backend interface.

@koush koush closed this Jun 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants