-
Notifications
You must be signed in to change notification settings - Fork 790
Added PSNR metric for image quality evaluation #4377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
❌ Your project check has failed because the head coverage (68.89%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #4377 +/- ##
==========================================
+ Coverage 68.86% 68.89% +0.03%
==========================================
Files 1412 1413 +1
Lines 168243 168462 +219
==========================================
+ Hits 115856 116068 +212
- Misses 52387 52394 +7 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements the Peak Signal-to-Noise Ratio (PSNR) metric for evaluating image quality, addressing issue #4312 which tracks the implementation of various image quality metrics for the Burn training framework.
Changes:
- Added PSNRMetric implementation with configurable max_val and epsilon parameters for numerical stability
- Supports tensors with 2 or more dimensions (default 4D) with batch processing
- Includes 12 comprehensive test cases covering various scenarios including edge cases
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| crates/burn-train/src/metric/vision/psnr.rs | Implements PSNRMetric, PSNRInput, and PSNRConfig with the formula 10 * log10(MAX² / MSE), including comprehensive tests |
| crates/burn-train/src/metric/vision/mod.rs | Exports the new PSNR module alongside existing dice module |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if item.outputs.dims() != item.targets.dims() { | ||
| panic!( | ||
| "Outputs and targets must have the same dimensions. Got {:?} and {:?}", | ||
| item.outputs.dims(), | ||
| item.targets.dims() | ||
| ); | ||
| } | ||
|
|
Copilot
AI
Jan 25, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shape validation in the update method is redundant because the same check is already performed in PSNRInput::new() at lines 42-46. Since PSNRInput can only be constructed with matching shapes, this check will never fail in practice. Consider removing this duplicate validation to avoid unnecessary panic branches.
| if item.outputs.dims() != item.targets.dims() { | |
| panic!( | |
| "Outputs and targets must have the same dimensions. Got {:?} and {:?}", | |
| item.outputs.dims(), | |
| item.targets.dims() | |
| ); | |
| } |
|
|
||
| /// Creates a new PSNR metric with a custom config. | ||
| pub fn with_config(config: PSNRConfig) -> Self { | ||
| let name = MetricName::new(format!("PSNR ({}D)", D)); |
Copilot
AI
Jan 25, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The naming convention for metrics with dimensions varies in the codebase. The DiceMetric uses the format "{D}D Dice Metric" (e.g., "4D Dice Metric"), while PSNRMetric uses "PSNR ({}D)" (e.g., "PSNR (4D)"). For consistency with the established convention in dice.rs (line 106), consider using the format "{}D PSNR" or "{D}D PSNR Metric" instead of "PSNR ({}D)".
| let ln_10 = 10.0_f64.ln(); | ||
|
|
||
| // Add epsilon to MSE to avoid division by zero and log(0) when MSE is exactly 0 | ||
| let mse_safe = mse_flat.clone() + self.config.epsilon; |
Copilot
AI
Jan 25, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The .clone() call on mse_flat at line 142 is unnecessary since mse_flat is only used once after this line. The addition operation consumes the tensor anyway, so the clone adds unnecessary overhead. Consider using mse_flat directly: let mse_safe = mse_flat + self.config.epsilon;
| let mse_safe = mse_flat.clone() + self.config.epsilon; | |
| let mse_safe = mse_flat + self.config.epsilon; |
| /// Maximum pixel value (1.0 for normalized images, 255.0 for 8-bit images). | ||
| pub max_val: f64, | ||
| /// Small value added to MSE to avoid numerical instability in log calculation. | ||
| /// Should be small enough to allow high PSNR values but representable in f32 (default: 1e-12). |
Copilot
AI
Jan 25, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation comment mentions "representable in f32" but the epsilon field is of type f64, and the implementation uses f64 for calculations. This could be misleading. Consider either updating the comment to reference f64, or clarifying that the comment refers to backend float precision rather than the config type. For example: "Should be small enough to allow high PSNR values (default: 1e-12)."
| /// Should be small enough to allow high PSNR values but representable in f32 (default: 1e-12). | |
| /// Should be small enough to allow high PSNR values (default: 1e-12). |
| pub fn with_config(config: PSNRConfig) -> Self { | ||
| let name = MetricName::new(format!("PSNR ({}D)", D)); | ||
| Self { | ||
| name, | ||
| state: NumericMetricState::default(), | ||
| _b: PhantomData, | ||
| config, | ||
| } | ||
| } |
Copilot
AI
Jan 25, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dimension validation (D >= 2) is missing from the with_config constructor, but is present in PSNRInput::new(). For consistency with DiceMetric (see dice.rs:107), which validates dimensions in its with_config method, consider adding the assertion here as well. This would provide earlier feedback if an invalid dimension is used and matches the pattern established in dice.rs.
Checklist
Related Issues/PRs
Image quality metrics #4312
Changes
Implemented Peak Signal-to-Noise Ratio (PSNR) metric in
crates/burn-train/src/metric/vision/psnr.rsfor evaluating image quality.The metric computes PSNR using the formula
10 * log10(MAX² / MSE)where MAX is the maximum pixel value and MSE is the mean squared error.The implementation supports tensors with 2 or more dimensions (default 4D for images), where dimension 0 is treated as batch and dimensions 1..D-1 are reduced to compute per-image MSE.
The metric includes configurable
max_val(default 1.0 for normalized images, 255.0 for 8-bit) andepsilon(default 1e-12) for numerical stability.Testing
Added 12 test cases covering perfect matches, small errors, large errors, batch processing, custom max_val, shape validation, 3D inputs, zero MSE handling, monotonicity property, mixed batches, running averages, and state clearing.