Skip to content

Conversation

@vcherepanov-nv
Copy link
Collaborator

Description

Starting with cuBLASMp 0.8.0, they're moving away from using nvshmem for symmetric memory, use NCCL instead.
This change adapts the to changed API.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Get rid of nvshmem dependency (for cuBLASMp)
  • Update cuBLASMp API usage

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vcherepanov-nv and others added 5 commits January 27, 2026 02:36
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Greptile Overview

Greptile Summary

This PR updates the cuBLASMp tensor-parallel GEMM integration to match newer cuBLASMp APIs and removes the direct nvshmem dependency for cuBLASMp workspace/symmetric memory. Build changes include dropping NVSHMEM_DIR injection from setup.py when NVTE_WITH_CUBLASMP is enabled and adjusting the common CMake target to include/link against cuBLASMp + NCCL rather than cuBLASMp + nvshmem_host.

On the runtime side, comm_gemm.cpp switches from nvshmem_malloc/free to cuBLASMp-managed workspace allocation (cublasMpMalloc/Free) and registers that workspace for cuBLASMp matmul (cublasMpBufferRegister/Deregister). The header now documents that callers must synchronize streams before destroying the comm-GEMM context to avoid freeing in-flight workspace.

No additional fix-required issues were found in the new changes beyond the already-discussed review threads.

Confidence Score: 4/5

  • This PR looks safe to merge once the already-noted build discovery and stream-synchronization concerns are addressed.
  • The functional changes are mostly mechanical API renames and a straightforward replacement of nvshmem workspace allocation with cuBLASMp allocation/registration. I did not find additional definite runtime/compile failures introduced by the diff itself beyond the two prior review threads; remaining uncertainty is mainly around external library discovery/configuration and the correctness requirements of the cuBLASMp workspace register/free API pairing (not verifiable from this repo).
  • transformer_engine/common/CMakeLists.txt and transformer_engine/common/comm_gemm/comm_gemm.cpp

Important Files Changed

Filename Overview
setup.py Removes nvshmem-related CMake flag injection when NVTE_WITH_CUBLASMP is enabled; NVTE_ENABLE_NVSHMEM path remains unchanged.
transformer_engine/common/CMakeLists.txt Switches cuBLASMp integration to link against NCCL instead of nvshmem_host and removes NVSHMEM include from NVTE_WITH_CUBLASMP.
transformer_engine/common/comm_gemm/comm_gemm.cpp Updates cuBLASMp API calls and replaces nvshmem symmetric workspace allocation with cublasMpMalloc/Free plus buffer register/deregister; adds caller-responsibility for synchronization in destroy().
transformer_engine/common/include/transformer_engine/comm_gemm.h Documents that callers must synchronize relevant streams before destroying comm-gemm context.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant TE as transformer_engine (comm_gemm)
    participant cuBLASMp
    participant NCCL

    Caller->>TE: nvte_comm_gemm_ctx_create(ncclComm, nranks, rank)
    TE->>cuBLASMp: cublasMpCreate(stream)
    TE->>cuBLASMp: cublasMpGridCreate(col_major, comm)
    TE->>cuBLASMp: cublasMpGridCreate(row_major, comm)
    TE-->>Caller: ctx

    Caller->>TE: nvte_*_gemm(..., main_stream)
    TE->>cuBLASMp: cublasMpSetStream(handle, main_stream)
    TE->>cuBLASMp: cublasMpMatmul_bufferSize(...)
    alt workspace too small
        TE->>cuBLASMp: cublasMpMalloc(grid_col_major, &workspace, bytes)
        TE->>cuBLASMp: cublasMpBufferRegister(grid_row_major, workspace, bytes)
    end
    TE->>cuBLASMp: cublasMpMatmul(..., workspace, host_workspace)
    TE->>TE: record event on main_stream
    TE->>TE: stream waits on event (overlap stream)

    Caller->>TE: nvte_comm_gemm_ctx_destroy(ctx)
    note over Caller,TE: Caller must synchronize involved streams before destroy
    TE->>cuBLASMp: cublasMpBufferDeregister(...)
    TE->>cuBLASMp: cublasMpFree(...)
    TE-->>Caller: ctx freed
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +292 to +296
find_library(NCCL_LIB
NAMES nccl libnccl
PATH_SUFFIXES lib
REQUIRED)
target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB})
target_link_libraries(transformer_engine PUBLIC ${NCCL_LIB} ${CUBLASMP_LIB})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NCCL library not discoverable

find_library(NCCL_LIB ...) is missing a PATHS/hint variable (unlike CUBLASMP_LIB, which uses ${CUBLASMP_DIR}). Unless NCCL is already in the default linker search paths, enabling NVTE_WITH_CUBLASMP will fail at configure time with NCCL_LIB not found. This PR should add a way to point CMake at NCCL (env var / CMake cache var) and pass it from setup.py similarly to CUBLASMP_DIR.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Additional Comments (3)

transformer_engine/common/CMakeLists.txt
Undefined NVSHMEM_DIR include

In the NVTE_WITH_CUBLASMP block, target_include_directories(... ${NVSHMEM_DIR}/include) still references NVSHMEM_DIR, but this PR removed the only place that was setting it (setup.py). With NVTE_WITH_CUBLASMP=ON, CMake will now expand an unset var here, which can break configuration/build (and it also contradicts the goal of removing nvshmem for cuBLASMp).


transformer_engine/common/comm_gemm/comm_gemm.cpp
Uninitialized workspace fields

NVTECommGemmCtx has raw workspace/workspace_size fields, but nvte_comm_gemm_ctx_create()'s designated initializer never sets them. That means cublasmp_gemm() will read uninitialized ctx->workspace_size and may also treat an uninitialized ctx->workspace as non-null (leading to invalid deregister/free calls). These need explicit initialization (e.g., workspace=nullptr, workspace_size=0) in the create path.


transformer_engine/common/comm_gemm/comm_gemm.cpp
Workspace leak + missing cleanup

After switching from nvshmem_malloc/free to cublasMpMalloc/cublasMpFree + cublasMpBufferRegister, nvte_comm_gemm_ctx_destroy() now just deletes the ctx and never deregisters/frees ctx->workspace. Since workspace is a raw pointer (not RAII-managed), this leaks device memory and leaves the cuBLASMp registration state dirty for the lifetime of the process.

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 479 to 485
void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) {
NVTE_API_CALL(nvte_comm_gemm_ctx_destroy);
nvshmemx_sync_all_on_stream(ctx->stream.get());
if (ctx->workspace) {
NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace));
NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace));
}
delete ctx;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsafe workspace free

nvte_comm_gemm_ctx_destroy can deregister/free ctx->workspace while cublasMpMatmul work on the user-provided main_stream is still in flight. cublasmp_gemm sets the cuBLASMp handle stream to main_stream (comm_gemm.cpp:389) and uses ctx->workspace in the enqueue (comm_gemm.cpp:436-439), but destroy() does not synchronize main_stream (or otherwise ensure completion) before calling cublasMpBufferDeregister/cublasMpFree.

This can become a use-after-free if a caller destroys the ctx shortly after launching a comm GEMM. Either synchronize the relevant stream(s) before freeing, or explicitly document (in comm_gemm.h) that callers must synchronize main_stream before calling destroy().

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

*
* \param[in] ctx Context to destroy.
*
* It's the caller's respondibility to synchronize all streams involved before calling this function.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* It's the caller's respondibility to synchronize all streams involved before calling this function.
* It's the caller's responsibility to synchronize all streams involved before calling this function.

f"nvidia-cublasmp-cu{cuda_version()[0]}"
).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}")
cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also explicit nvshmem usage in transformer_engine/common/nvshmem_api - I assume this removal would result in failure of that functionality?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, Nvshmem discovery / CMake knobs for cuBLASMp (which this PR removes) and for nvshmem_api is independent.

@ptrendx
Copy link
Member

ptrendx commented Feb 10, 2026

A general comment as well - if we intend to require cublasMp 0.8.0 then we should make sure we say that in the installation docs etc. I don't think we have that feature (building with cublasmp enabled) currently in the docs - we should change that (between this PR and #2443) @denera FYI.

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@almogsegal
Copy link

@vcherepanov-nv cuBLASMp 0.8.0 added the API you asked for: https://docs.nvidia.com/cuda/cublasmp/usage/functions.html#cublasmpgetstatusstring
I suggest to consider adding this in the NVTE_CHECK_CUBLASMP macro.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants