-
Notifications
You must be signed in to change notification settings - Fork 633
Fix on TE to support Mcore Vision Encoder CUDA Graph #2657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix on TE to support Mcore Vision Encoder CUDA Graph #2657
Conversation
Signed-off-by: Lifu Zhang <lifuz@login-lyris02.lyris.clusters.nvidia.com>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR adds None-safety checks throughout the CUDA Graph capture code to support vision encoder modules. The changes prevent
The fix is minimal and surgical, adding safety checks without changing the underlying logic or control flow. Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant make_graphed_callables
participant _make_graphed_callables
participant Forward Graph
participant Backward Graph
participant Module
User->>make_graphed_callables: Call with modules & sample_args
make_graphed_callables->>_make_graphed_callables: Pass callables & args
Note over _make_graphed_callables: Warmup Phase
_make_graphed_callables->>Module: Run warmup iterations
Module-->>_make_graphed_callables: Return outputs (may contain None)
Note over _make_graphed_callables: Graph Capture Phase
_make_graphed_callables->>Forward Graph: Capture forward pass
Module-->>Forward Graph: Store static outputs
Note over _make_graphed_callables: Filter outputs with None check
_make_graphed_callables->>_make_graphed_callables: Check "o is not None and o.requires_grad"
_make_graphed_callables->>Backward Graph: Capture backward pass
_make_graphed_callables->>Backward Graph: Create grad tensors for valid outputs
Note over _make_graphed_callables: Graph Replay Phase
_make_graphed_callables->>User: Return graphed callables
User->>Forward Graph: Call graphed module
Forward Graph->>Forward Graph: Replay captured graph
Forward Graph-->>User: Return detached outputs (None-safe)
User->>Backward Graph: Trigger backward
Backward Graph->>Backward Graph: Replay captured backward
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, no comments
| with _none_grad_context_wrapper(inputs): | ||
| torch.autograd.backward( | ||
| tuple(o for o in outputs if o.requires_grad), | ||
| tuple(o for o in outputs if o is not None and o.requires_grad), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we refactor that list outside of the torch.autograd.backward call and use that tuple to get grad_tensors (in order to not create the same list twice).
|
/te-ci pytorch |
ptrendx
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than the one small comment LGTM.
Description
This PR is needed to support vision encoder CUDA Graph.
Related MLM PR: NVIDIA/Megatron-LM#3293, NVIDIA/Megatron-LM#3294
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: