-
Notifications
You must be signed in to change notification settings - Fork 12k
ggml: improve ggml_backend_cuda_cpy_tensor_async #13818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The existing row split implementation does the mul mat and then immediately gathers the tensors and is per backend implementation. |
I also started working on tensor parallelism, see #13776 . I would be happy to leave the implementation to you if you're interested in working on it. |
ggml/src/ggml-backend.cpp
Outdated
if (input_backend->iface.synchronize) { | ||
// async copy succeeded, need to synchronize the input backend to ensure the copy is done before the split backend uses it | ||
input_backend->iface.synchronize(input_backend); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A synchronization after an async copy is not necessary. The way async copy is intended to work is roughly explained here:
llama.cpp/ggml/include/ggml-backend.h
Lines 108 to 112 in 7fe03e7
// asynchronous copy | |
// the copy is performed after all the currently queued operations in backend_src | |
// backend_dst will wait for the copy to complete before performing other operations | |
// automatic fallback to sync copy if async is not supported | |
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, however i was receiving garbage output without this sync when using layer parallel. I suspect it’s because the stream being used for the async copy is the source and not the dest, as specified in the code comments.
how should I proceed here? I was wary of changing the existing stream behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a problem with the current implementation. The copy is performed on the source stream so that it happens at the end of all queued operations in the source backend. Then the destination stream waits on an event until the copy is complete, which ensures that any operations added later to the destination backend are executed after the copy has completed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is that I need to manage the dest stream syncing manually, as I am queueing multiple asynchronous memcpy and then performing one synchronize after the gathers, allowing for concurrent transfers. Otherwise copying to the dest context is serialized. Without this change, I am unable to get tensor parallel working faster than 50% utilization per GPU (with 2 GPU) since each GPU ends up waiting for the other. I had a workaround for this by using different thread to avoid the blocking, but that seems like a lot of unnecessary overhead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should note that I am also not using ggml_backend_tensor_copy_async
and calling the device memcpy and sync directly since similar sync behavior exists there. Maybe the ggml-backend.cpp should call that instead?
llama.cpp/ggml/src/ggml-backend.cpp
Line 408 in a8ea03d
ggml_backend_synchronize(backend_src); |
If this is intended behavior, this may just be a gap in the API where there's no way to start multiple asynchronous memcpy without blocking on the destination per copy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose you could remove the event wait on the dst stream at the end of the async copy, and transfer the responsibility of synchronizing the dst backend to the application.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the change to make the async memcpy happen on the dst stream, as that's where further ops will presumably occur with that data. It's the responsibility of the caller to ensure the src is safe to use for memcpy until it is complete. Ie, calling synchronize on the src backend if necessary. In some cases no synchronize is needed at all. I saw CANN is the only other backend that implements asynchronous memcpy, so I updated that as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I took a further look at the pipeline parallel batching implementation and I think my last change would have negatively affected performance. I've updated the change to leave the src stream copy intact, and then issue a final sync once all inputs are sent. This way the input sync doesn't serialize the memcpy.
Make device to device actually async; right now it syncs on dst. Implement host to/from device async.
I'm definitely interested, but don't let me stop you. |
@slaren @JohannesGaessler Using this model to test because it's dense enough to saturate GPU cores, the initial results on Before on layer split: 9 tokens/sec, GPU usage at 50/50 in nvidia-smi The new implementation also performs better than the existing implementation on smaller models as well: it has no GPU communication other than after the RMS Norm calls that require tensor gather (existing implementation gathers after every mul mat). That frequent gpu-gpu overhead kills performance. On a smaller Qwen 3 32B dense model: Since the new GPU implementation is a wrapping backend, it should also work with heterogenous devices (if their respective backends implement the new backend requirements). The very much work in progress is here: https://github.com/koush/llama.cpp/tree/parallel Should I use this pull request to consolidate all my backend api changes before opening a pull request for tensor parallelism backend? |
I would prefer if everything is in the same PR so that it is clear what is the motivation for the changes to the backend interface. |
I'm working on a general tensor parallel backend. It leverages asynchronous tensor copies. I found that the pipeline was stalling here when the async call was actually sync.