Skip to content

Conversation

@tomlifu
Copy link
Contributor

@tomlifu tomlifu commented Feb 6, 2026

Description

This PR is needed to support vision encoder CUDA Graph.

Related MLM PR: NVIDIA/Megatron-LM#3293, NVIDIA/Megatron-LM#3294

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Lifu Zhang and others added 2 commits February 6, 2026 10:33
Signed-off-by: Lifu Zhang <lifuz@login-lyris02.lyris.clusters.nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Greptile Overview

Greptile Summary

This PR updates transformer_engine/pytorch/graph.py to make CUDA Graph warmup/capture and replay robust when a module’s flattened outputs include None leaves (common for optional/auxiliary outputs in vision encoders).

Concretely, it:

  • Filters torch.autograd.backward(...)’s outputs and grad_tensors to exclude None outputs and only include tensors with requires_grad=True.
  • Makes static grad-output buffer creation (torch.empty_like) None-safe.
  • Avoids calling .detach() on None outputs during graph replay, preserving the original pytree structure.

These changes fit into TE’s make_graphed_callables wrapper around torch.cuda graph capture by ensuring capture/warmup works even when outputs are not all tensors.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk.
  • The change is narrowly scoped to add None guards around existing backward/grad-buffer logic and to avoid detaching None outputs; it prevents concrete crashes when outputs contain None and does not change behavior for normal all-tensor outputs.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/graph.py Adds None-safe handling for outputs/grad outputs during CUDA Graph warmup/capture and avoids calling .detach() on None leaves; no correctness regressions found.

Sequence Diagram

sequenceDiagram
  participant U as User code
  participant MG as make_graphed_callables
  participant WU as Warmup
  participant CG as CUDAGraph
  participant AG as Autograd

  U->>MG: call(modules, sample_args/kwargs)
  MG->>WU: run warmup func(*args, **kwargs)
  WU->>WU: flatten outputs (may include None)
  opt training
    WU->>AG: backward(outputs requiring grad)
    Note over WU,AG: PR skips None outputs
  end

  MG->>CG: capture forward graph
  CG->>CG: store static_outputs
  opt training
    MG->>CG: capture backward graph
    CG->>AG: backward(non-None outputs requiring grad)
  end

  U->>MG: replay graphed callable
  MG->>CG: fwd_graph.replay()
  MG-->>U: return detached tensors + None leaves
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx
Copy link
Member

ptrendx commented Feb 10, 2026

/te-ci pytorch

ptrendx
ptrendx previously approved these changes Feb 10, 2026
Copy link
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than the one small comment LGTM.

Signed-off-by: Lifu Zhang <lifuz@login-lyris02.lyris.clusters.nvidia.com>
@ptrendx ptrendx merged commit 8ebb47e into NVIDIA:main Feb 11, 2026
10 of 12 checks passed
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants