-
Notifications
You must be signed in to change notification settings - Fork 633
Get rid of nvshmem dependency for cuBLASMp integration #2661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR updates the cuBLASMp tensor-parallel GEMM integration to match newer cuBLASMp APIs and removes the direct nvshmem dependency for cuBLASMp workspace/symmetric memory. Build changes include dropping On the runtime side, No additional fix-required issues were found in the new changes beyond the already-discussed review threads. Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant TE as transformer_engine (comm_gemm)
participant cuBLASMp
participant NCCL
Caller->>TE: nvte_comm_gemm_ctx_create(ncclComm, nranks, rank)
TE->>cuBLASMp: cublasMpCreate(stream)
TE->>cuBLASMp: cublasMpGridCreate(col_major, comm)
TE->>cuBLASMp: cublasMpGridCreate(row_major, comm)
TE-->>Caller: ctx
Caller->>TE: nvte_*_gemm(..., main_stream)
TE->>cuBLASMp: cublasMpSetStream(handle, main_stream)
TE->>cuBLASMp: cublasMpMatmul_bufferSize(...)
alt workspace too small
TE->>cuBLASMp: cublasMpMalloc(grid_col_major, &workspace, bytes)
TE->>cuBLASMp: cublasMpBufferRegister(grid_row_major, workspace, bytes)
end
TE->>cuBLASMp: cublasMpMatmul(..., workspace, host_workspace)
TE->>TE: record event on main_stream
TE->>TE: stream waits on event (overlap stream)
Caller->>TE: nvte_comm_gemm_ctx_destroy(ctx)
note over Caller,TE: Caller must synchronize involved streams before destroy
TE->>cuBLASMp: cublasMpBufferDeregister(...)
TE->>cuBLASMp: cublasMpFree(...)
TE-->>Caller: ctx freed
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 4 comments
| find_library(NCCL_LIB | ||
| NAMES nccl libnccl | ||
| PATH_SUFFIXES lib | ||
| REQUIRED) | ||
| target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB}) | ||
| target_link_libraries(transformer_engine PUBLIC ${NCCL_LIB} ${CUBLASMP_LIB}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NCCL library not discoverable
find_library(NCCL_LIB ...) is missing a PATHS/hint variable (unlike CUBLASMP_LIB, which uses ${CUBLASMP_DIR}). Unless NCCL is already in the default linker search paths, enabling NVTE_WITH_CUBLASMP will fail at configure time with NCCL_LIB not found. This PR should add a way to point CMake at NCCL (env var / CMake cache var) and pass it from setup.py similarly to CUBLASMP_DIR.
Additional Comments (3)
In the
After switching from |
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
| void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) { | ||
| NVTE_API_CALL(nvte_comm_gemm_ctx_destroy); | ||
| nvshmemx_sync_all_on_stream(ctx->stream.get()); | ||
| if (ctx->workspace) { | ||
| NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace)); | ||
| NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace)); | ||
| } | ||
| delete ctx; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unsafe workspace free
nvte_comm_gemm_ctx_destroy can deregister/free ctx->workspace while cublasMpMatmul work on the user-provided main_stream is still in flight. cublasmp_gemm sets the cuBLASMp handle stream to main_stream (comm_gemm.cpp:389) and uses ctx->workspace in the enqueue (comm_gemm.cpp:436-439), but destroy() does not synchronize main_stream (or otherwise ensure completion) before calling cublasMpBufferDeregister/cublasMpFree.
This can become a use-after-free if a caller destroys the ctx shortly after launching a comm GEMM. Either synchronize the relevant stream(s) before freeing, or explicitly document (in comm_gemm.h) that callers must synchronize main_stream before calling destroy().
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
| * | ||
| * \param[in] ctx Context to destroy. | ||
| * | ||
| * It's the caller's respondibility to synchronize all streams involved before calling this function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| * It's the caller's respondibility to synchronize all streams involved before calling this function. | |
| * It's the caller's responsibility to synchronize all streams involved before calling this function. |
| f"nvidia-cublasmp-cu{cuda_version()[0]}" | ||
| ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}") | ||
| cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}") | ||
| nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is also explicit nvshmem usage in transformer_engine/common/nvshmem_api - I assume this removal would result in failure of that functionality?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, Nvshmem discovery / CMake knobs for cuBLASMp (which this PR removes) and for nvshmem_api is independent.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
|
@vcherepanov-nv cuBLASMp 0.8.0 added the API you asked for: https://docs.nvidia.com/cuda/cublasmp/usage/functions.html#cublasmpgetstatusstring |
Description
Starting with cuBLASMp 0.8.0, they're moving away from using nvshmem for symmetric memory, use NCCL instead.
This change adapts the to changed API.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: