-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Fix TrainableBilateralFilter 3D input validation (#7444) #8729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Fix TrainableBilateralFilter 3D input validation (#7444) #8729
Conversation
- Fix dimension comparison to use spatial dims instead of total dims - Add validation for minimum input dimensions - Fix typo in error message (ken_spatial_sigma -> len_spatial_sigma) - Move spatial dimension validation before unsqueeze operations The forward() method was incorrectly comparing self.len_spatial_sigma (number of spatial dimensions) with len(input_tensor.shape) (total dimensions including batch and channel), causing valid 3D inputs to be rejected. Fixes Project-MONAI#7444 Signed-off-by: Abdoulaye Diallo <[email protected]>
📝 WalkthroughWalkthroughThis change fixes dimension handling in TrainableBilateralFilter and TrainableJointBilateralFilter. The fix corrects a variable reference error in error messages (ken_spatial_sigma → len_spatial_sigma), adds validation requiring minimum 3 input dimensions, and refactors dimension-checking logic to properly distinguish between batch, channel, and spatial dimensions. The changes enable proper support for 3D images by aligning dimension comparisons with spatial dimensionality rather than total input dimensions. Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/networks/layers/filtering.py (1)
406-430:⚠️ Potential issue | 🟠 Major
TrainableJointBilateralFilter.forward()not updated with the same fix.This method still uses
len_inputdirectly instead of computingspatial_dims = len_input - 2. It will reject valid 3D inputs just like the original bug inTrainableBilateralFilter. Also missing the minimum dimension validation added to the other class.Proposed fix
def forward(self, input_tensor, guidance_tensor): + if len(input_tensor.shape) < 3: + raise ValueError( + f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}" + ) if input_tensor.shape[1] != 1: raise ValueError( f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. " "Please use multiple parallel filter layers if you want " "to filter multiple channels." ) if input_tensor.shape != guidance_tensor.shape: raise ValueError( "Shape of input image must equal shape of guidance image." f"Got {input_tensor.shape} and {guidance_tensor.shape}." ) len_input = len(input_tensor.shape) + spatial_dims = len_input - 2 # C++ extension so far only supports 5-dim inputs. - if len_input == 3: + if spatial_dims == 1: input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) guidance_tensor = guidance_tensor.unsqueeze(3).unsqueeze(4) - elif len_input == 4: + elif spatial_dims == 2: input_tensor = input_tensor.unsqueeze(4) guidance_tensor = guidance_tensor.unsqueeze(4) - if self.len_spatial_sigma != len_input: - raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).") + if self.len_spatial_sigma != spatial_dims: + raise ValueError(f"Spatial dimension ({spatial_dims}) must match initialized len(spatial_sigma).") prediction = TrainableJointBilateralFilterFunction.apply( input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color ) # Make sure to return tensor of the same shape as the input. - if len_input == 3: + if spatial_dims == 1: prediction = prediction.squeeze(4).squeeze(3) - elif len_input == 4: + elif spatial_dims == 2: prediction = prediction.squeeze(4) return prediction
🤖 Fix all issues with AI agents
In `@monai/networks/layers/filtering.py`:
- Around line 223-225: The error message uses self.len_spatial_sigma which is
not assigned in the branch; fix by referencing the actual expected spatial
dimension attribute or ensuring self.len_spatial_sigma is initialized before
this check: either assign self.len_spatial_sigma = self.spatial_ndim (or the
class's existing spatial-dimension attribute) earlier in the initializer, or
change the ValueError message to use the computed expected dimension (e.g.,
self.spatial_ndim or len(self.spatial_shape)) instead of self.len_spatial_sigma
so the attribute is defined when raising the error in the spatial_sigma
validation.
- Around line 395-398: The else branch references an undefined attribute
self.len_spatial_sigma; fix it by using a defined value (e.g., compute
len_spatial = len(self.spatial_sigma) or use self.spatial_ndim) when building
the error message in the failing branch of the initializer (same place as
TrainableBilateralFilter.__init__). Replace self.len_spatial_sigma with the
actual computed length (len(self.spatial_sigma) or self.spatial_ndim) so the
ValueError message prints a valid expected-dimension value.
| raise ValueError( | ||
| f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}." | ||
| f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: self.len_spatial_sigma is undefined when this branch executes.
The else branch runs when len(spatial_sigma) is not 1, 2, or 3. At that point self.len_spatial_sigma has never been assigned, so this raises AttributeError before the intended ValueError.
Proposed fix
else:
raise ValueError(
- f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}."
+ f"len(spatial_sigma) must be 1, 2, or 3, got {len(spatial_sigma)}."
)🧰 Tools
🪛 Ruff (0.14.14)
[warning] 223-225: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@monai/networks/layers/filtering.py` around lines 223 - 225, The error message
uses self.len_spatial_sigma which is not assigned in the branch; fix by
referencing the actual expected spatial dimension attribute or ensuring
self.len_spatial_sigma is initialized before this check: either assign
self.len_spatial_sigma = self.spatial_ndim (or the class's existing
spatial-dimension attribute) earlier in the initializer, or change the
ValueError message to use the computed expected dimension (e.g.,
self.spatial_ndim or len(self.spatial_shape)) instead of self.len_spatial_sigma
so the attribute is defined when raising the error in the spatial_sigma
validation.
| else: | ||
| raise ValueError( | ||
| f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}." | ||
| f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same bug: self.len_spatial_sigma undefined in else branch.
Identical issue as TrainableBilateralFilter.__init__.
Proposed fix
else:
raise ValueError(
- f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}."
+ f"len(spatial_sigma) must be 1, 2, or 3, got {len(spatial_sigma)}."
)🧰 Tools
🪛 Ruff (0.14.14)
[warning] 396-398: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@monai/networks/layers/filtering.py` around lines 395 - 398, The else branch
references an undefined attribute self.len_spatial_sigma; fix it by using a
defined value (e.g., compute len_spatial = len(self.spatial_sigma) or use
self.spatial_ndim) when building the error message in the failing branch of the
initializer (same place as TrainableBilateralFilter.__init__). Replace
self.len_spatial_sigma with the actual computed length (len(self.spatial_sigma)
or self.spatial_ndim) so the ValueError message prints a valid
expected-dimension value.
The forward() method was incorrectly comparing self.len_spatial_sigma (number of spatial dimensions) with len(input_tensor.shape) (total dimensions including batch and channel), causing valid 3D inputs to be rejected.
Fixes #7444
Description
This PR fixes a validation bug in
TrainableBilateralFilterthat incorrectly rejected valid 3D inputs with shape(B, C, H, W, D).Root Cause: The
forward()method comparedself.len_spatial_sigma(spatial dimensions = 3) withlen(input_tensor.shape)(total dimensions = 5), causing a dimension mismatch error for valid inputs.Solution: Calculate
spatial_dims = len(input_tensor.shape) - 2to exclude batch and channel dimensions, then compare againstself.len_spatial_sigma.Example of fixed behavior:
This fix also improves error messages and adds validation for inputs with insufficient dimensions.
Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.Notes on Testing
The existing unit tests for
TrainableBilateralFilter(24 tests) require the C++ extension and were skipped locally (expected behavior with@skip_if_no_cpp_extensiondecorator). These tests will run automatically in CI.I verified the fix logic with custom local tests for 1D, 2D, and 3D cases (see examples in description above).
Linting and code formatting checks passed:
No new tests were added as the existing 24 unit tests already cover the behavior. No docstring or documentation changes were needed as this is purely a bug fix in validation logic.