Skip to content

Conversation

@samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented Jan 22, 2026

  • Add alltoallv implementation using GPU-initiated comms (SM-driven NVLink), taking only GPU buffers, even for the alltoallv "metadate" such as splitSize. Available throughkCuda backend. Requires recv buffer to be allocated as symmetric memory
  • Add Cuda backend for dispatch and combine which avoids gpu->cpu sync (compared to nccl backed version)

@samnordmann
Copy link
Collaborator Author

!test

@samnordmann samnordmann changed the title add kernel based a2av and cuda backend for d/c Add kernel based alltoallv and cuda backend for MoE dispatch and combine Jan 22, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 22, 2026

Greptile Overview

Greptile Summary

Adds GPU-initiated alltoallv implementation using SM-driven NVLink for the new kCuda backend, avoiding CPU synchronization compared to the existing NCCL backend. The implementation uses symmetric memory for direct GPU-to-GPU writes and runtime NVRTC compilation of the CUDA kernel.

Key changes:

  • New alltoallv.cu CUDA kernel performing byte-level data transfer across ranks using remote symmetric memory pointers
  • prepareAlltoallvMetadata exchanges send/recv counts via TCPStore and computes offsets for variable-sized alltoallv operations
  • doMoeDispatch and doMoeCombine now support both NCCL and CUDA backends with conditional branching
  • CUDA backend path allocates symmetric tensors, exchanges IPC handles, and launches GPU kernels for direct peer-to-peer transfer
  • Tests parameterized to validate both backends

Minor issues:

  • Hardcoded CUDA include paths in cuda_p2p.cpp:74-75 may cause portability issues
  • Unused helper function alltoallvBarrierKey defined but never called

Confidence Score: 4/5

  • Safe to merge with minor portability considerations
  • The implementation is well-structured with proper error checking, comprehensive tests, and clear separation between NCCL and CUDA backend paths. The core logic for metadata exchange, symmetric memory allocation, and GPU kernel launch appears correct. Score reduced from 5 to 4 due to hardcoded CUDA paths that could cause build failures on non-standard CUDA installations, though this is easily addressable.
  • Pay attention to csrc/multidevice/cuda_p2p.cpp for the hardcoded CUDA include paths

Important Files Changed

Filename Overview
csrc/multidevice/alltoallv.cu New CUDA kernel for GPU-initiated alltoallv using byte-level copying across symmetric memory
csrc/multidevice/cuda_p2p.cpp Implemented alltoallv with runtime compilation, metadata exchange via TCPStore, and symmetric memory coordination - contains hardcoded CUDA paths
csrc/multidevice/dispatch_combine.cpp Added CUDA backend path for MoE dispatch/combine using symmetric tensors and alltoallv, refactored NCCL path into conditional branches

Sequence Diagram

sequenceDiagram
    participant App as Application
    participant Dispatch as doMoeDispatch
    participant PrepMeta as prepareAlltoallvMetadata
    participant TCPStore as TCP Store
    participant SymMem as SymmetricTensor
    participant AlltoallvKernel as alltoallv_kernel (GPU)
    participant Combine as doMoeCombine

    App->>Dispatch: x, topk_idx, topk_weights, num_experts
    
    Note over Dispatch: CUDA Backend Path
    Dispatch->>Dispatch: Compute rank_for_token, n_tokens_to_rank
    Dispatch->>PrepMeta: n_tokens_to_rank, tag="moe_dispatch_counts"
    
    PrepMeta->>PrepMeta: Copy send_counts to CPU
    PrepMeta->>TCPStore: set(alltoallvCountsKey(tag, my_rank), send_counts)
    loop For each rank
        PrepMeta->>TCPStore: get(alltoallvCountsKey(tag, rank))
    end
    PrepMeta->>PrepMeta: Build counts_matrix, compute offsets
    PrepMeta->>PrepMeta: barrier()
    PrepMeta->>TCPStore: deleteKey for all ranks
    PrepMeta-->>Dispatch: AlltoallvMetadata (recv_counts, offsets, max values)
    
    Dispatch->>SymMem: allocate symmetric buffers (send_x, recv_x, etc.)
    Dispatch->>SymMem: setupRemoteHandles("moe_dispatch_recv_x", ...)
    Note over SymMem: Exchange IPC handles via TCPStore
    
    loop For each payload (x, topk_idx, topk_weights, src_idx)
        Dispatch->>AlltoallvKernel: send, recv, metadata, recv_ptrs
        Note over AlltoallvKernel: GPU directly writes to remote symmetric memory
        AlltoallvKernel->>AlltoallvKernel: Copy send[offset] to recv_ptrs[peer][offset]
    end
    
    Dispatch->>Dispatch: alltoallvBarrier("moe_dispatch_counts")
    Dispatch->>Dispatch: Narrow recv buffers to total_recv
    Dispatch-->>App: DispatchResult (recv_x, recv_topk_idx, recv_src_idx, ...)
    
    App->>Combine: x, src_idx, n_tokens_to_rank, n_tokens_from_rank
    Note over Combine: Similar alltoallv flow in reverse
    Combine->>PrepMeta: n_tokens_from_rank, tag="moe_combine_counts"
    PrepMeta-->>Combine: AlltoallvMetadata
    Combine->>SymMem: allocate and setup symmetric buffers
    Combine->>AlltoallvKernel: alltoallv payloads back
    Combine->>Combine: index_copy to restore original order
    Combine-->>App: CombineResult (combined_x)
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 36 to 38
if (!communicator_->isBackendAvailable(CommunicatorBackend::kNccl)) {
GTEST_SKIP() << "Backend " << backend << " not available.";
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: checking wrong backend constant - should check backend parameter, not hardcoded kNccl

Suggested change
if (!communicator_->isBackendAvailable(CommunicatorBackend::kNccl)) {
GTEST_SKIP() << "Backend " << backend << " not available.";
}
if (!communicator_->isBackendAvailable(backend)) {
GTEST_SKIP() << "Backend " << backend << " not available.";
}

Base automatically changed from dispatch_combine/stub to main February 9, 2026 14:43
@github-actions
Copy link

github-actions bot commented Feb 10, 2026

Review updated until commit 3828247

Description

  • Add kernel-based alltoallv implementation using GPU-initiated communications (SM-driven NVLink)

  • Implement CUDA backend for MoE dispatch and combine operations to avoid GPU-CPU synchronization

  • Add support for both NCCL and CUDA backends in dispatch/combine with runtime backend selection

  • Create new test suite for alltoallv functionality and extend existing tests to cover both backends

Changes walkthrough

Relevant files
Enhancement
cuda_p2p.cpp
Implement kernel-based alltoallv with CUDA backend             

csrc/multidevice/cuda_p2p.cpp

  • Add launchAlltoallvKernel function with NVRTC compilation and kernel
    launching
  • Add prepareAlltoallvMetadata function for alltoallv metadata
    preparation
  • Add alltoallvWithCudaBackend function for GPU-initiated alltoallv
    operations
  • Add serialization/deserialization helpers and barrier functions
  • +315/-0 
    dispatch_combine.cpp
    Add CUDA backend support for MoE dispatch/combine               

    csrc/multidevice/dispatch_combine.cpp

  • Add dual backend support (NCCL and CUDA) for dispatch and combine
    operations
  • Implement CUDA backend path using symmetric tensors and alltoallv
    kernel
  • Maintain existing NCCL implementation as fallback option
  • Add CUDA stream handling and symmetric memory management
  • +201/-59
    alltoallv.cu
    Implement CUDA kernel for alltoallv operations                     

    csrc/multidevice/alltoallv.cu

  • Implement alltoallv_kernel CUDA kernel for peer-to-peer data transfer
  • Handle multi-dimensional block launching with peer rank dimension
  • Perform byte-level memory operations with proper offset calculations
  • +36/-0   
    cuda_p2p.h
    Add alltoallv function declarations and metadata struct   

    csrc/multidevice/cuda_p2p.h

  • Add AlltoallvMetadata struct definition for alltoallv parameters
  • Declare alltoallv preparation and execution functions
  • Add barrier function declaration for synchronization
  • +26/-0   
    Tests
    test_multidevice_alltoallv.cpp
    Add comprehensive alltoallv test suite                                     

    tests/cpp/test_multidevice_alltoallv.cpp

  • Create new test file for alltoallv CUDA backend functionality
  • Test asymmetric alltoallv patterns with multiple ranks
  • Verify data integrity across distributed operations
  • +80/-0   
    test_multidevice_dispatch_combine.cpp
    Parameterize dispatch/combine tests for multiple backends

    tests/cpp/test_multidevice_dispatch_combine.cpp

  • Convert tests to parameterized format supporting both NCCL and CUDA
    backends
  • Add backend availability checks and skip logic
  • Extend test coverage to include dispatch-only and combine-only
    scenarios
  • +32/-20 
    Documentation
    dispatch_combine.h
    Update dispatch/combine API documentation                               

    csrc/multidevice/dispatch_combine.h

  • Update documentation to reflect CUDA backend support
  • Clarify tensor shape requirements for topk operations
  • +4/-4     
    Configuration changes
    CMakeLists.txt
    Add alltoallv files to build system                                           

    CMakeLists.txt

  • Add new alltoallv test file to test compilation
  • Include alltoallv.cu in runtime files for NVRTC compilation
  • +2/-0     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review
    Performance Concern

    The NVRTC compilation and module loading happens on every first call (static module/kernel pointers). This could introduce significant startup overhead. Consider pre-compiling the kernel or implementing a compilation cache to avoid repeated compilation overhead.

    static CUmodule module = nullptr;
    static CUfunction kernel = nullptr;
    
    if (module == nullptr) {
      nvrtcProgram prog;
      NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram(
          &prog,
          nvfuser_resources::alltoallv_cu,
          "alltoallv.cu",
          0,
          nullptr,
          nullptr));
    
      int major = 0;
      int minor = 0;
      int device = 0;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device));
      cudaDeviceProp prop;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDeviceProperties(&prop, device));
      major = prop.major;
      minor = prop.minor;
    
      std::string arch_arg = "--gpu-architecture=compute_" +
          std::to_string(major) + std::to_string(minor);
      std::vector<const char*> opts = {arch_arg.c_str(), "--std=c++17"};
      // NVRTC needs CUDA headers to compile alltoallv.cu.
      opts.push_back("-I/usr/local/cuda/include");
      opts.push_back("-I/usr/local/cuda/include/cccl");
    
      nvrtcResult res = nvrtcCompileProgram(prog, (int)opts.size(), opts.data());
      if (res != NVRTC_SUCCESS) {
        size_t logSize;
        NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLogSize(prog, &logSize));
        std::vector<char> log(logSize);
        NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLog(prog, log.data()));
        NVF_ERROR(false, "Alltoallv kernel compilation failed:\n", log.data());
      }
    
      size_t ptxSize;
      NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize));
      std::vector<char> ptx(ptxSize);
      NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data()));
      NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog));
    
      CUresult load_result = cuModuleLoadData(&module, ptx.data());
      if (load_result != CUDA_SUCCESS) {
        constexpr size_t kLogSize = 8192;
        char error_log[kLogSize];
        char info_log[kLogSize];
        CUjit_option options[] = {
            CU_JIT_ERROR_LOG_BUFFER,
            CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
            CU_JIT_INFO_LOG_BUFFER,
            CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
            CU_JIT_LOG_VERBOSE};
        void* option_values[] = {
            (void*)error_log,
            (void*)kLogSize,
            (void*)info_log,
            (void*)kLogSize,
            (void*)1};
        cuModuleLoadDataEx(&module, ptx.data(), 5, options, option_values);
        NVF_ERROR(
            false,
            "Alltoallv kernel module load failed with error: ",
            load_result,
            "\nInfo Log:\n",
            info_log,
            "\nError Log:\n",
            error_log);
      }
    
      NVFUSER_CUDA_SAFE_CALL(
          cuModuleGetFunction(&kernel, module, "alltoallv_kernel"));
    }
    Documentation Gap

    The PR mentions "Requires recv buffer to be allocated as symmetric memory" but this constraint is not clearly documented in the API. Add documentation about this requirement and potential limitations.

    void alltoallvWithCudaBackend(
        const at::Tensor& send,
        const at::Tensor& recv,
        const AlltoallvMetadata& metadata,
        const std::vector<void*>& recv_ptrs,
        CUstream stream) {
      NVF_CHECK(send.is_cuda(), "alltoallv send must be CUDA.");
      NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA.");
      NVF_CHECK(
          (int64_t)recv_ptrs.size() == metadata.world_size,
          "recv_ptrs size must match world size.");
    
      auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
      auto recv_ptrs_cpu = at::empty({metadata.world_size}, cpu_options);
      auto* ptrs = recv_ptrs_cpu.data_ptr<int64_t>();
      for (int64_t rank = 0; rank < metadata.world_size; ++rank) {
        ptrs[rank] =
            static_cast<int64_t>(reinterpret_cast<uintptr_t>(recv_ptrs[rank]));
      }
      auto recv_ptrs_cuda = recv_ptrs_cpu.to(send.device());
    
      const int64_t elem_stride =
          metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1;
      NVF_CHECK(
          metadata.max_send_total == 0 ||
              send.numel() % metadata.max_send_total == 0,
          "alltoallv send numel must be divisible by max_send_total.");
      NVF_CHECK(
          metadata.max_recv == 0 || recv.numel() % metadata.max_recv == 0,
          "alltoallv recv numel must be divisible by max_recv.");
    
      auto send_offsets = metadata.send_offsets;
      auto send_counts = metadata.send_counts;
      auto recv_offsets = metadata.recv_offsets;
      int64_t max_send_bytes = metadata.max_send_bytes;
      if (elem_stride > 1) {
        send_offsets = metadata.send_offsets * elem_stride;
        send_counts = metadata.send_counts * elem_stride;
        recv_offsets = metadata.recv_offsets * elem_stride;
        max_send_bytes = metadata.max_send_bytes * elem_stride;
      }
    
      launchAlltoallvKernel(
          send.data_ptr(),
          reinterpret_cast<const uint64_t*>(recv_ptrs_cuda.data_ptr<int64_t>()),
          send_offsets.data_ptr<int64_t>(),
          send_counts.data_ptr<int64_t>(),
          recv_offsets.data_ptr<int64_t>(),
          metadata.world_size,
          send.element_size(),
          max_send_bytes * send.element_size(),
          stream);
    }
    Performance Validation

    The PR claims the CUDA backend "avoids gpu->cpu sync" compared to NCCL, but no performance data or benchmarks are provided. Include performance metrics to validate this claim and demonstrate the actual benefits.

    NVF_CHECK(
        backend == CommunicatorBackend::kCuda,
        "Only CUDA and NCCL backends are supported for MoeDispatch.");
    
    auto metadata =
        prepareAlltoallvMetadata(n_tokens_to_rank, "moe_dispatch_counts");
    auto n_tokens_from_rank = metadata.recv_counts;
    const int64_t total_recv = metadata.total_recv;
    const int64_t max_recv = metadata.max_recv;
    
    // Allocate symmetric buffers for send/recv payloads.
    auto send_x_sym = SymmetricTensor::allocate(
        {metadata.max_send_total, hidden}, x.scalar_type(), x.device());
    send_x_sym.narrow(0, 0, num_tokens).copy_(send_x);
    auto send_topk_idx_sym = SymmetricTensor::allocate(
        {metadata.max_send_total, topk_idx.size(1)},
        topk_idx.scalar_type(),
        x.device());
    send_topk_idx_sym.narrow(0, 0, num_tokens).copy_(send_topk_idx);
    auto send_topk_weights_sym = SymmetricTensor::allocate(
        {metadata.max_send_total, topk_weights.size(1)},
        topk_weights.scalar_type(),
        x.device());
    send_topk_weights_sym.narrow(0, 0, num_tokens).copy_(send_topk_weights);
    auto send_src_idx_sym = SymmetricTensor::allocate(
        {metadata.max_send_total}, send_src_idx.scalar_type(), x.device());
    send_src_idx_sym.narrow(0, 0, num_tokens).copy_(send_src_idx);
    
    auto recv_x_sym = SymmetricTensor::allocate(
        {max_recv, hidden}, x.scalar_type(), x.device());
    auto recv_topk_idx_sym = SymmetricTensor::allocate(
        {max_recv, topk_idx.size(1)}, topk_idx.scalar_type(), x.device());
    auto recv_topk_weights_sym = SymmetricTensor::allocate(
        {max_recv, topk_weights.size(1)}, topk_weights.scalar_type(), x.device());
    auto recv_src_idx_sym = SymmetricTensor::allocate(
        {max_recv}, send_src_idx.scalar_type(), x.device());
    
    SymmetricTensor recv_x_handle(recv_x_sym);
    SymmetricTensor recv_topk_idx_handle(recv_topk_idx_sym);
    SymmetricTensor recv_topk_weights_handle(recv_topk_weights_sym);
    SymmetricTensor recv_src_idx_handle(recv_src_idx_sym);
    recv_x_handle.setupRemoteHandles("moe_dispatch_recv_x");
    recv_topk_idx_handle.setupRemoteHandles("moe_dispatch_recv_topk_idx");
    recv_topk_weights_handle.setupRemoteHandles("moe_dispatch_recv_topk_weights");
    recv_src_idx_handle.setupRemoteHandles("moe_dispatch_recv_src_idx");
    
    std::vector<void*> recv_x_ptrs(world_size);
    std::vector<void*> recv_topk_idx_ptrs(world_size);
    std::vector<void*> recv_topk_weights_ptrs(world_size);
    std::vector<void*> recv_src_idx_ptrs(world_size);
    for (int64_t rank = 0; rank < world_size; ++rank) {
      recv_x_ptrs[rank] = recv_x_handle.remoteTensor(rank).data_ptr();
      recv_topk_idx_ptrs[rank] =
          recv_topk_idx_handle.remoteTensor(rank).data_ptr();
      recv_topk_weights_ptrs[rank] =
          recv_topk_weights_handle.remoteTensor(rank).data_ptr();
      recv_src_idx_ptrs[rank] = recv_src_idx_handle.remoteTensor(rank).data_ptr();
    }
    
    auto stream =
        static_cast<CUstream>(at::cuda::getDefaultCUDAStream().stream());
    alltoallvWithCudaBackend(
        send_x_sym, recv_x_sym, metadata, recv_x_ptrs, stream);
    alltoallvWithCudaBackend(
        send_topk_idx_sym,
        recv_topk_idx_sym,
        metadata,
        recv_topk_idx_ptrs,
        stream);
    alltoallvWithCudaBackend(
        send_topk_weights_sym,
        recv_topk_weights_sym,
        metadata,
        recv_topk_weights_ptrs,
        stream);
    alltoallvWithCudaBackend(
        send_src_idx_sym, recv_src_idx_sym, metadata, recv_src_idx_ptrs, stream);
    alltoallvBarrier("moe_dispatch_counts");
    
    auto recv_x = recv_x_sym.narrow(0, 0, total_recv);
    auto recv_topk_idx = recv_topk_idx_sym.narrow(0, 0, total_recv);
    auto recv_topk_weights = recv_topk_weights_sym.narrow(0, 0, total_recv);
    auto recv_src_idx = recv_src_idx_sym.narrow(0, 0, total_recv);
    

    Test failures

    • (Medium, 1) Thunder nvFuser CUDA produces zero output for nanoGPT autograd test (test_networks.py)

      Test Name A100 Source
      thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32
    • (Low, 1) Minor numerical mismatch in Thunder vs Torch instance_norm nvFuser CUDA tests (float32, H100)

      Test Name H100 Source
      thunder.tests.test_ops.test_core_vs_torch_consistency_instance_norm_nvfuser_cuda_thunder.dtypes.float32

    @samnordmann
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines +74 to +75
    opts.push_back("-I/usr/local/cuda/include");
    opts.push_back("-I/usr/local/cuda/include/cccl");
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    hardcoded CUDA include paths may break on non-standard installations

    Suggested change
    opts.push_back("-I/usr/local/cuda/include");
    opts.push_back("-I/usr/local/cuda/include/cccl");
    // Use CUDA_HOME environment variable or CMake-detected paths
    std::string cuda_home = std::getenv("CUDA_HOME") ? std::getenv("CUDA_HOME") : "/usr/local/cuda";
    opts.push_back(("-I" + cuda_home + "/include").c_str());
    opts.push_back(("-I" + cuda_home + "/include/cccl").c_str());

    Comment on lines +172 to +174
    std::string alltoallvBarrierKey(const std::string& tag, int64_t rank) {
    return "nvfuser_alltoallv_barrier_" + tag + "_" + std::to_string(rank);
    }
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    unused function - alltoallvBarrierKey is defined but never called

    @samnordmann samnordmann requested a review from nsarka February 10, 2026 19:21
    Copy link
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM otherwise

    NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram(
    &prog,
    nvfuser_resources::alltoallv_cu,
    "alltoallv.cu",
    Copy link
    Collaborator

    @wujingyue wujingyue Feb 11, 2026

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Why nvrtc? Can't we simply alltoallv_kernel<<<...>>>?

    backend,
    CommunicatorBackend::kNccl,
    "Only NCCL backend is supported for MoeDispatch.");
    NVF_CHECK(
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    NVF_CHECK_EQ

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants