Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request adds a new guide to the NNX documentation that explains how to use transforms with NNX. The guide covers Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds a new guide for using NNX transforms in tree mode. The guide is well-structured and covers important transforms like jit, grad, vmap, and scan. However, there are a few issues that need addressing to improve clarity and correctness. The explanation for nnx.grad is misleading as it describes graph-mode features. A code comment in the vmap example has an incorrect shape. The StateSharding example is broken and needs a fix to run. Finally, there's an ambiguous use of nnx.Linear that could confuse readers.
| " weights = Weights(\n", | ||
| " kernel=rngs.uniform((16, 16)),\n", | ||
| " count=jnp.array(0),\n", | ||
| " )\n", |
There was a problem hiding this comment.
This code cell fails with a ValueError because the weights.kernel array is not sharded when created, but the in_shardings argument to nnx.graph.jit expects it to be. To make this example runnable, the kernel should be created with the specified sharding before being passed to the Weights constructor.
weights = Weights(
kernel=jax.device_put(
rngs.uniform((16, 16)),
jax.sharding.NamedSharding(mesh, jax.P(None, 'devices')),
),
count=jnp.array(0),
)
| "source": [ | ||
| "## jit + grad — training step\n", | ||
| "\n", | ||
| "`nnx.grad` differentiates with respect to `nnx.Param` variables by default, treating all other state as non-differentiable. The `wrt` argument accepts any [Filter](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to select which Variable types to differentiate. It handles `split`/`merge`/`clone` internally, so you only need to write the loss function.\n", |
There was a problem hiding this comment.
The description of nnx.grad appears to describe its graph-mode behavior, which is confusing in a guide focused on tree mode. For instance, it mentions the wrt argument and automatic split/merge capabilities, which are not features of tree-mode nnx.grad and are not used in the accompanying code example. The example correctly demonstrates manual splitting of parameters. Please update the text to accurately describe the tree-mode behavior, where the user is responsible for separating differentiable variables.
| " return model(x, rngs=rngs) \n", | ||
| "\n", | ||
| "y = batched_forward(model, x, rngs)\n", | ||
| "print(f'{y.shape = !s}') # (1, 5, 10)\n", |
| "m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0))\n", | ||
| "m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(1))\n", |
There was a problem hiding this comment.
This example uses nnx.Linear, which appears to be flax.nnx.Linear, not the custom Linear class defined earlier in this notebook. This can be confusing for readers. The custom Linear has a .w attribute, while flax.nnx.Linear has .kernel and .bias attributes, which are used in the loss_fn. Please clarify that flax.nnx.Linear is being used here, for example by aliasing it differently or adding a comment.
c666b70 to
328920a
Compare
70e42c0 to
a6154ee
Compare
a6154ee to
10fb80c
Compare
328920a to
39c5a7e
Compare
39c5a7e to
6a7380b
Compare
10fb80c to
f0f07fd
Compare
No description provided.