-
Notifications
You must be signed in to change notification settings - Fork 78
Add kernel based alltoallv and cuda backend for MoE dispatch and combine #5863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
!test |
Greptile OverviewGreptile SummaryAdds GPU-initiated Key changes:
Minor issues:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant App as Application
participant Dispatch as doMoeDispatch
participant PrepMeta as prepareAlltoallvMetadata
participant TCPStore as TCP Store
participant SymMem as SymmetricTensor
participant AlltoallvKernel as alltoallv_kernel (GPU)
participant Combine as doMoeCombine
App->>Dispatch: x, topk_idx, topk_weights, num_experts
Note over Dispatch: CUDA Backend Path
Dispatch->>Dispatch: Compute rank_for_token, n_tokens_to_rank
Dispatch->>PrepMeta: n_tokens_to_rank, tag="moe_dispatch_counts"
PrepMeta->>PrepMeta: Copy send_counts to CPU
PrepMeta->>TCPStore: set(alltoallvCountsKey(tag, my_rank), send_counts)
loop For each rank
PrepMeta->>TCPStore: get(alltoallvCountsKey(tag, rank))
end
PrepMeta->>PrepMeta: Build counts_matrix, compute offsets
PrepMeta->>PrepMeta: barrier()
PrepMeta->>TCPStore: deleteKey for all ranks
PrepMeta-->>Dispatch: AlltoallvMetadata (recv_counts, offsets, max values)
Dispatch->>SymMem: allocate symmetric buffers (send_x, recv_x, etc.)
Dispatch->>SymMem: setupRemoteHandles("moe_dispatch_recv_x", ...)
Note over SymMem: Exchange IPC handles via TCPStore
loop For each payload (x, topk_idx, topk_weights, src_idx)
Dispatch->>AlltoallvKernel: send, recv, metadata, recv_ptrs
Note over AlltoallvKernel: GPU directly writes to remote symmetric memory
AlltoallvKernel->>AlltoallvKernel: Copy send[offset] to recv_ptrs[peer][offset]
end
Dispatch->>Dispatch: alltoallvBarrier("moe_dispatch_counts")
Dispatch->>Dispatch: Narrow recv buffers to total_recv
Dispatch-->>App: DispatchResult (recv_x, recv_topk_idx, recv_src_idx, ...)
App->>Combine: x, src_idx, n_tokens_to_rank, n_tokens_from_rank
Note over Combine: Similar alltoallv flow in reverse
Combine->>PrepMeta: n_tokens_from_rank, tag="moe_combine_counts"
PrepMeta-->>Combine: AlltoallvMetadata
Combine->>SymMem: allocate and setup symmetric buffers
Combine->>AlltoallvKernel: alltoallv payloads back
Combine->>Combine: index_copy to restore original order
Combine-->>App: CombineResult (combined_x)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 1 comment
| if (!communicator_->isBackendAvailable(CommunicatorBackend::kNccl)) { | ||
| GTEST_SKIP() << "Backend " << backend << " not available."; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: checking wrong backend constant - should check backend parameter, not hardcoded kNccl
| if (!communicator_->isBackendAvailable(CommunicatorBackend::kNccl)) { | |
| GTEST_SKIP() << "Backend " << backend << " not available."; | |
| } | |
| if (!communicator_->isBackendAvailable(backend)) { | |
| GTEST_SKIP() << "Backend " << backend << " not available."; | |
| } |
|
Review updated until commit 3828247 Description
|
| Relevant files | |||||||||
|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||
| Tests |
| ||||||||
| Documentation |
| ||||||||
| Configuration changes |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| 🔒 No security concerns identified |
| ⚡ Recommended focus areas for review |
Performance Concern
|
Test failures
-
(Medium, 1)
Thunder nvFuser CUDA produces zero output for nanoGPT autograd test (test_networks.py)Test Name A100 Source thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32 ❌ -
(Low, 1)
Minor numerical mismatch in Thunder vs Torch instance_norm nvFuser CUDA tests (float32, H100)Test Name H100 Source thunder.tests.test_ops.test_core_vs_torch_consistency_instance_norm_nvfuser_cuda_thunder.dtypes.float32 ❌
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 2 comments
| opts.push_back("-I/usr/local/cuda/include"); | ||
| opts.push_back("-I/usr/local/cuda/include/cccl"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hardcoded CUDA include paths may break on non-standard installations
| opts.push_back("-I/usr/local/cuda/include"); | |
| opts.push_back("-I/usr/local/cuda/include/cccl"); | |
| // Use CUDA_HOME environment variable or CMake-detected paths | |
| std::string cuda_home = std::getenv("CUDA_HOME") ? std::getenv("CUDA_HOME") : "/usr/local/cuda"; | |
| opts.push_back(("-I" + cuda_home + "/include").c_str()); | |
| opts.push_back(("-I" + cuda_home + "/include/cccl").c_str()); |
| std::string alltoallvBarrierKey(const std::string& tag, int64_t rank) { | ||
| return "nvfuser_alltoallv_barrier_" + tag + "_" + std::to_string(rank); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused function - alltoallvBarrierKey is defined but never called
wujingyue
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM otherwise
| NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram( | ||
| &prog, | ||
| nvfuser_resources::alltoallv_cu, | ||
| "alltoallv.cu", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why nvrtc? Can't we simply alltoallv_kernel<<<...>>>?
| backend, | ||
| CommunicatorBackend::kNccl, | ||
| "Only NCCL backend is supported for MoeDispatch."); | ||
| NVF_CHECK( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NVF_CHECK_EQ
alltoallvimplementation using GPU-initiated comms (SM-driven NVLink), taking only GPU buffers, even for the alltoallv "metadate" such as splitSize. Available throughkCudabackend. Requires recv buffer to be allocated as symmetric memoryCudabackend for dispatch and combine which avoids gpu->cpu sync (compared to nccl backed version)