Skip to content

add transforms_tree guide#5307

Open
cgarciae wants to merge 1 commit intonnx-basics-treefrom
transforms_tree_guide
Open

add transforms_tree guide#5307
cgarciae wants to merge 1 commit intonnx-basics-treefrom
transforms_tree_guide

Conversation

@cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Mar 4, 2026

No description provided.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request adds a new guide to the NNX documentation that explains how to use transforms with NNX. The guide covers nnx.jit, nnx.grad, nnx.vmap, and nnx.scan, and it also introduces graph mode APIs. The guide uses a simple Linear layer to demonstrate the usage of these transforms.

Highlights

  • Transforms in NNX: This guide introduces NNX transforms (nnx.jit, nnx.grad, nnx.vmap, nnx.scan) as wrappers over JAX transforms with automatic state propagation.
  • Practical Examples: The guide demonstrates the usage of nnx.jit, nnx.grad, nnx.vmap, and nnx.scan with a simple Linear layer, showcasing forward passes, training steps, and vectorization.
  • Graph Mode APIs: It briefly introduces graph mode APIs like StateAxes, DiffState, and StateSharding for advanced control over transforms, contrasting them with the default tree mode.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • docs_nnx/guides/transforms_tree.ipynb
    • Initial commit of transforms_tree.ipynb
  • docs_nnx/guides/transforms_tree.md
    • Initial commit of transforms_tree.md
Activity
  • cgarciae authored the pull request.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a new guide for using NNX transforms in tree mode. The guide is well-structured and covers important transforms like jit, grad, vmap, and scan. However, there are a few issues that need addressing to improve clarity and correctness. The explanation for nnx.grad is misleading as it describes graph-mode features. A code comment in the vmap example has an incorrect shape. The StateSharding example is broken and needs a fix to run. Finally, there's an ambiguous use of nnx.Linear that could confuse readers.

Comment on lines +429 to +432
" weights = Weights(\n",
" kernel=rngs.uniform((16, 16)),\n",
" count=jnp.array(0),\n",
" )\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This code cell fails with a ValueError because the weights.kernel array is not sharded when created, but the in_shardings argument to nnx.graph.jit expects it to be. To make this example runnable, the kernel should be created with the specified sharding before being passed to the Weights constructor.

  weights = Weights(
    kernel=jax.device_put(
        rngs.uniform((16, 16)),
        jax.sharding.NamedSharding(mesh, jax.P(None, 'devices')),
    ),
    count=jnp.array(0),
  )

"source": [
"## jit + grad — training step\n",
"\n",
"`nnx.grad` differentiates with respect to `nnx.Param` variables by default, treating all other state as non-differentiable. The `wrt` argument accepts any [Filter](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to select which Variable types to differentiate. It handles `split`/`merge`/`clone` internally, so you only need to write the loss function.\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The description of nnx.grad appears to describe its graph-mode behavior, which is confusing in a guide focused on tree mode. For instance, it mentions the wrt argument and automatic split/merge capabilities, which are not features of tree-mode nnx.grad and are not used in the accompanying code example. The example correctly demonstrates manual splitting of parameters. Please update the text to accurately describe the tree-mode behavior, where the user is responsible for separating differentiable variables.

" return model(x, rngs=rngs) \n",
"\n",
"y = batched_forward(model, x, rngs)\n",
"print(f'{y.shape = !s}') # (1, 5, 10)\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment for the output shape y.shape is incorrect. Given the model and vmap parameters, the shape should be (10, 1, 3), which is also what the cell's output shows.

print(f'{y.shape = !s}')             # (10, 1, 3)

Comment on lines +369 to +370
"m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0))\n",
"m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(1))\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This example uses nnx.Linear, which appears to be flax.nnx.Linear, not the custom Linear class defined earlier in this notebook. This can be confusing for readers. The custom Linear has a .w attribute, while flax.nnx.Linear has .kernel and .bias attributes, which are used in the loss_fn. Please clarify that flax.nnx.Linear is being used here, for example by aliasing it differently or adding a comment.

@cgarciae cgarciae force-pushed the transforms_tree_guide branch 2 times, most recently from 70e42c0 to a6154ee Compare March 5, 2026 03:01
@cgarciae cgarciae force-pushed the transforms_tree_guide branch from a6154ee to 10fb80c Compare March 13, 2026 18:58
@cgarciae cgarciae force-pushed the transforms_tree_guide branch from 10fb80c to f0f07fd Compare March 14, 2026 01:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant