Skip to content

Commit

Permalink
june update
Browse files Browse the repository at this point in the history
  • Loading branch information
mv-lab committed Jun 6, 2023
1 parent 05d3612 commit b5e59ed
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 20 deletions.
98 changes: 78 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,59 +1,117 @@
# AI Image Signal Processing and ISPs
# AI Image Signal Processing and Computational Photography
## Deep learning for low-level computer vision and imaging

[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2201.03210)
[![isp](https://img.shields.io/badge/ISP-paper-lightgreen)](https://arxiv.org/abs/2201.03210)
[![lpienet](https://img.shields.io/badge/LPIENet-paper-lightpink)](https://arxiv.org/abs/2210.13552)
[![bokeh](https://img.shields.io/badge/Bokeh-paper-9cf)](https://openaccess.thecvf.com/content/CVPR2023W/NTIRE/papers/Seizinger_Efficient_Multi-Lens_Bokeh_Effect_Rendering_and_Transformation_CVPRW_2023_paper.pdf)
[![ntire23](https://img.shields.io/badge/NTIRE-CVPR23-lightcyan)](https://cvlai.net/ntire/2023/)
![visitors](https://visitor-badge.glitch.me/badge?page_id=mv-lab/AISP)


[Marcos V. Conde](https://scholar.google.com/citations?user=NtB1kjYAAAAJ&hl=en), [Radu Timofte](https://scholar.google.com/citations?user=u3MwH5kAAAAJ&hl=en)
**[Marcos V. Conde](https://scholar.google.com/citations?user=NtB1kjYAAAAJ&hl=en), [Radu Timofte](https://scholar.google.com/citations?user=u3MwH5kAAAAJ&hl=en)**

[Computer Vision Lab, CAIDAS, University of Würzburg](https://www.informatik.uni-wuerzburg.de/computervision/home/)
[Computer Vision Lab, CAIDAS, University of Würzburg](https://www.informatik.uni-wuerzburg.de/computervision/home/)

---------------------------------------------------

> **Topics** This repository contains material for RAW image processing, RAW image reconstruction and synthesis, learned Image Signal Processing (ISP), Image Enhancement and Restoration (denoising, deblurring), Multi-lense Bokeh effect rendering, and much more! 📷
<br>

#### Official repository for the following works:

1. **[Efficient Multi-Lens Bokeh Effect Rendering and Transformation](https://openaccess.thecvf.com/content/CVPR2023W/NTIRE/papers/Seizinger_Efficient_Multi-Lens_Bokeh_Effect_Rendering_and_Transformation_CVPRW_2023_paper.pdf)** at **CVPR NTIRE 2023**.
1. **[Perceptual Image Enhancement for Smartphone Real-Time Applications](https://arxiv.org/abs/2210.13552) (LPIENet) at WACV 2023.**
1. **[Reversed Image Signal Processing and RAW Reconstruction. AIM 2022 Challenge Report](aim22-reverseisp/) ECCV, AIM 2022**
1. **[Model-Based Image Signal Processors via Learnable Dictionaries](https://ojs.aaai.org/index.php/AAAI/article/view/19926) AAAI 2022 Oral**
1. **[Model-Based Image Signal Processors via Learnable Dictionaries](https://arxiv.org/abs/2201.03210) AAAI 2022 Oral**
1. [MAI 2022 Learned ISP Challenge](#mai-2022-learned-isp-challenge) Complete Baseline solution
1. [Citation and Acknowledgement](#citation-and-acknowledgement) | [Contact](#contact)
1. [Citation and Acknowledgement](#citation-and-acknowledgement) | [Contact](#contact) for any inquiries.

**News 🚀🚀**

- [11/2022] LPIENet release soon!
- will try to keep the repo updated on a monthly basis ✏️
- [06/2023] Lens-to-lens bokeh effect transformation and NTIRE 2023 material coming soon.
- [01/202] LPIENet material is out
- [10/2022] Reversed ISP and RAW Reconstruction material presented at AIM workshop ECCV 2022 is now available! [check here](aim22-reverseisp/)

---------------------------------------------------
| | | | |
|:--- |:--- |:--- |:---|
| <a href="https://openaccess.thecvf.com/content/CVPR2023W/NTIRE/papers/Seizinger_Efficient_Multi-Lens_Bokeh_Effect_Rendering_and_Transformation_CVPRW_2023_paper.pdf"><img src="media/papers/bokeh-ntire23.png" width="300" border="0"></a> | <a href="https://arxiv.org/abs/2210.13552"><img src="media/papers/lpienet-wacv23.png" width="300" border="0"></a> | <a href="https://arxiv.org/abs/2210.11153"><img src="media/papers/reisp-aim22.png" width="255" border="0"></a> | <a href="https://arxiv.org/abs/2201.03210"><img src="media/papers/isp-aaai22.png" width="300" border="0"></a>
| | | | |

## [AIM 2022 Reversed ISP Challenge](aim22-reverseisp/)
------

### [Track 1 - S7](https://codalab.lisn.upsaclay.fr/competitions/5079) | [Track 2 - P20](https://codalab.lisn.upsaclay.fr/competitions/5080)
## [Perceptual Image Enhancement for Smartphone Real-Time Applications](https://arxiv.org/abs/2210.13552) (WACV '23)

<a href="https://data.vision.ee.ethz.ch/cvl/aim22/"><img src="https://i.ibb.co/VJ7SSQj/aim-challenge-teaser.png" alt="aim-challenge-teaser" width="400" border="0"></a>
*This work was presented at the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV) 2023.*

In this challenge, we look for solutions to recover RAW readings from the camera using only the corresponding RGB images processed by the in-camera ISP. Successful solutions should generate plausible RAW images, and by doing this, other downstream tasks like Denoising, Super-resolution or Colour Constancy can benefit from such synthetic data generation. Click [here to read more information](aim22-reverseisp/README.md) about the challenge.
> Recent advances in camera designs and imaging pipelines allow us to capture high-quality images using smartphones. However, due to the small size and lens limitations of the smartphone cameras, we commonly find artifacts or degradation in the processed images e.g., noise, diffraction artifacts, blur, and HDR overexposure.
We propose LPIENet, a lightweight network for perceptual image enhancement, with the focus on deploying it on smartphones.

### Starter guide and code 🔥
The code is available at **[lpienet](lpienet/)** including versions in Pytorch and Tensorflow. We also include the model conversion to TFLite, so you can generate the corresponding `.tflite` file and run the model using the `AI Benchmark` app on android devices.
In *[lpienet-tflite.ipynb](lpienet/lpienet-tflite.ipynb)* you can find a complete tutorial to transform the model to tflite.

- **[aim-starter-code.ipynb](aim22-reverseisp/official-starter-code.ipynb)** - Simple dataloading and visualization of RGB-RAW pairs + other utils.
- **[aim-baseline.ipynb](aim22-reverseisp/official-baseline.ipynb)** - End-to-end guide to load the data, train a simple UNet model and make your first submission!
**Contributions**
- The model can process 4K images under 1s on commercial smartphones.
- We achieve competitive results in comparison to SOTA methods in relevant benchmarks for denoising, deblurring and HDR correction. For example the SIDD benchmark.
- We reduce NAFNet number of MACs (or FLOPs) by 50 times.

<details>
<summary>Click here to read the abstract</summary>
<p>Recent advances in camera designs and imaging pipelines allow us to capture high-quality images using smartphones. However, due to the small size and lens limitations of the smartphone cameras, we commonly find artifacts or degradation in the processed images. The most common unpleasant effects are noise artifacts, diffraction artifacts, blur, and HDR overexposure. Deep learning methods for image restoration can successfully remove these artifacts. However, most approaches are not suitable for real-time applications on mobile devices due to their heavy computation and memory requirements.

In this paper, we propose LPIENet, a lightweight network for perceptual image enhancement, with the focus on deploying it on smartphones. Our experiments show that, with much fewer parameters and operations, our model can deal with the mentioned artifacts and achieve competitive performance compared with state-of-the-art methods on standard benchmarks. Moreover, to prove the efficiency and reliability of our approach, we deployed the model directly on commercial smartphones and evaluated its performance. Our model can process 2K resolution images under 1 second in mid-level commercial smartphones.
<br>
</p>
</details>
<br>



<a href="https://arxiv.org/abs/2210.13552"><img src="media/lpienet.png" alt="lpienet" width="800" border="0"></a>


| | |
| :--- | :--- |
| <img src="lpienet/lpienet-app.png" width="300" border="0"> | <img src="lpienet/lpienet-plot.png" width="450" border="0"> |
| | |

<br>

------

## [Model-Based Image Signal Processors via Learnable Dictionaries](https://ojs.aaai.org/index.php/AAAI/article/view/19926) (AAAI '22 Oral)
## [Model-Based Image Signal Processors via Learnable Dictionaries](https://mv-lab.github.io/model-isp22/) (AAAI '22 Oral)

*This work was presented at the 36th AAAI Conference on Artificial Intelligence, Spotlight (15%)*

[Project website](https://mv-lab.github.io/model-isp22/) where you can find the poster, presentation and more information.

> Hybrid model-based and data-driven approach for modelling ISPs using learnable dictionaries. We explore RAW image reconstruction and improve downstream tasks like RAW Image Denoising via raw data augmentation-synthesis.

<img src="mbispld/mbispld.png" alt="mbdlisp" width="600" border="0">
<a href="https://ojs.aaai.org/index.php/AAAI/article/view/19926/19685"><img src="mbispld/mbispld.png" alt="mbdlisp" width="800" border="0"></a>


If you have implementation questions or you need qualitative samples for comparison, please contact me. You can download the figure/illustration of our method in [mbispld](mbispld/mbispld.pdf).

<br>

------

## [AIM 2022 Reversed ISP Challenge](aim22-reverseisp/)

This work was presented at the European Conference on Computer Vision (ECCV) 2022, AIM workshop.

### [Track 1 - S7](https://codalab.lisn.upsaclay.fr/competitions/5079) | [Track 2 - P20](https://codalab.lisn.upsaclay.fr/competitions/5080)

<a href="https://data.vision.ee.ethz.ch/cvl/aim22/"><img src="https://i.ibb.co/VJ7SSQj/aim-challenge-teaser.png" alt="aim-challenge-teaser" width="500" border="0"></a>

In this challenge, we look for solutions to recover RAW readings from the camera using only the corresponding RGB images processed by the in-camera ISP. Successful solutions should generate plausible RAW images, and by doing this, other downstream tasks like Denoising, Super-resolution or Colour Constancy can benefit from such synthetic data generation. Click [here to read more information](aim22-reverseisp/README.md) about the challenge.

The code will be released soon. If you have implementation questions or you need qualitative samples for comparison, please contact me.
### Starter guide and code 🔥

We provide the figure/illustration of our method in [mbispld](mbispld/mbispld.pdf).
- **[aim-starter-code.ipynb](aim22-reverseisp/official-starter-code.ipynb)** - Simple dataloading and visualization of RGB-RAW pairs + other utils.
- **[aim-baseline.ipynb](aim22-reverseisp/official-baseline.ipynb)** - End-to-end guide to load the data, train a simple UNet model and make your first submission!

------

Expand Down Expand Up @@ -94,4 +152,4 @@ We test the model on AI Benchmark. The model average latency is 60ms using a inp

## Contact

Marcos Conde (marcos.conde-osorio@uni-wuerzburg.de) and Radu Timofte ([email protected]) are the contact persons and direct managers of the AIM challenge. Please add in the email subject "AIM22 Reverse ISP Challenge" or "AISP"
Marcos Conde ([email protected]) is the contact persons and co-organizer of NTIRE and AIM challenges.
Binary file added lpienet/lpienet-app.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added lpienet/lpienet-plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
175 changes: 175 additions & 0 deletions lpienet/lpienet-pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""
Experiment options:
- Clip input range?!
- Sequential or parallel attention, which order?
- Spatial attention options (see CBAM paper)
- Which down and up sampling method? Pool, Conv, Shuffle, Interpolation
- Add vs. concat skips
- Add FMEN-like Unshuffle/Shuffle
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List


class AttentionBlock(nn.Module):
def __init__(self, dim: int):
super(AttentionBlock, self).__init__()
self._spatial_attention_conv = nn.Conv2d(2, dim, kernel_size=3, padding=1)

# Channel attention MLP
self._channel_attention_conv0 = nn.Conv2d(1, dim, kernel_size=1, padding=0)
self._channel_attention_conv1 = nn.Conv2d(dim, dim, kernel_size=1, padding=0)

self._out_conv = nn.Conv2d(2 * dim, dim, kernel_size=1, padding=0)

def forward(self, x: torch.Tensor):
if len(x.shape) != 4:
raise ValueError(f"Expected [B, C, H, W] input, got {x.shape}.")

# Spatial attention
mean = torch.mean(x, dim=1, keepdim=True) # Mean/Max on C axis
max, _ = torch.max(x, dim=1, keepdim=True)
spatial_attention = torch.cat([mean, max], dim=1) # [B, 2, H, W]
spatial_attention = self._spatial_attention_conv(spatial_attention)
spatial_attention = torch.sigmoid(spatial_attention) * x

# Channel attention. TODO: Correct that it only uses average pool contrary to CBAM?
# NOTE/TODO: This differs from CBAM as it uses Channel pooling, not spatial pooling!
# In a way, this is 2x spatial attention
channel_attention = torch.relu(self._channel_attention_conv0(mean))
channel_attention = self._channel_attention_conv1(channel_attention)
channel_attention = torch.sigmoid(channel_attention) * x

attention = torch.cat([spatial_attention, channel_attention], dim=1) # [B, 2*dim, H, W]
attention = self._out_conv(attention)
return x + attention


# TODO: This is not named in the paper right?
# It is sort of the InverseResidualBlock but w/o the Channel and Spatial Attentions and without another Conv after ReLU
class InverseBlock(nn.Module):
def __init__(self, input_channels: int, channels: int):
super(InverseBlock, self).__init__()

self._conv0 = nn.Conv2d(input_channels, channels, kernel_size=1)
self._dw_conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1, groups=channels)
self._conv1 = nn.Conv2d(channels, channels, kernel_size=1)
self._conv2 = nn.Conv2d(input_channels, channels, kernel_size=1)

def forward(self, x: torch.Tensor):
features = self._conv0(x)
features = F.elu(self._dw_conv(features)) # TODO: Paper is ReLU, authors do ELU
features = self._conv1(features)

# TODO: The BaseBlock has residuals and one path of convolutions, not 2 separate paths - is this different on purpose?
x = torch.relu(self._conv2(x))
return x + features


class BaseBlock(nn.Module):
def __init__(self, channels: int):
super(BaseBlock, self).__init__()

self._conv0 = nn.Conv2d(channels, channels, kernel_size=1)
self._dw_conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1, groups=channels)
self._conv1 = nn.Conv2d(channels, channels, kernel_size=1)

self._conv2 = nn.Conv2d(channels, channels, kernel_size=1)
self._conv3 = nn.Conv2d(channels, channels, kernel_size=1)

def forward(self, x: torch.Tensor):
features = self._conv0(x)
features = F.elu(self._dw_conv(features)) # TODO: ELU or ReLU?
features = self._conv1(features)
x = x + features

features = F.elu(self._conv2(x))
features = self._conv3(features)
return x + features


class AttentionTail(nn.Module):
def __init__(self, channels: int):
super(AttentionTail, self).__init__()

self._conv0 = nn.Conv2d(channels, channels, kernel_size=7, padding=3)
self._conv1 = nn.Conv2d(channels, channels, kernel_size=5, padding=2)
self._conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

def forward(self, x: torch.Tensor):
attention = torch.relu(self._conv0(x))
attention = torch.relu(self._conv1(attention))
attention = torch.sigmoid(self._conv2(attention))
return x * attention


class LPIENet(nn.Module):
def __init__(self, input_channels: int, output_channels: int, encoder_dims: List[int], decoder_dims: List[int]):
super(LPIENet, self).__init__()

if len(encoder_dims) != len(decoder_dims) + 1 or len(decoder_dims) < 1:
raise ValueError(f"Unexpected encoder and decoder dims: {encoder_dims}, {decoder_dims}.")

if input_channels != output_channels:
raise NotImplementedError()

# TODO: We will need an explicit decoder head, consider Unshuffle & Shuffle

encoders = []
for i, encoder_dim in enumerate(encoder_dims):
input_dim = input_channels if i == 0 else encoder_dims[i - 1]
encoders.append(
nn.Sequential(
nn.Conv2d(input_dim, encoder_dim, kernel_size=3, padding=1),
BaseBlock(encoder_dim), # TODO: one or two base blocks?
BaseBlock(encoder_dim),
AttentionBlock(encoder_dim),
)
)
self._encoders = nn.ModuleList(encoders)

decoders = []
for i, decoder_dim in enumerate(decoder_dims):
input_dim = encoder_dims[-1] if i == 0 else decoder_dims[i - 1] + encoder_dims[-i - 1]
decoders.append(
nn.Sequential(
nn.Conv2d(input_dim, decoder_dim, kernel_size=3, padding=1),
BaseBlock(decoder_dim),
BaseBlock(decoder_dim),
AttentionBlock(decoder_dim),
)
)
self._decoders = nn.ModuleList(decoders)

self._inverse_bock = InverseBlock(encoder_dims[0] + decoder_dims[-1], output_channels)
self._attention_tail = AttentionTail(output_channels)

def forward(self, x: torch.Tensor):
if len(x.shape) != 4:
raise ValueError(f"Expected [B, C, H, W] input, got {x.shape}.")
global_residual = x

encoder_outputs = []
for i, encoder in enumerate(self._encoders):
x = encoder(x)
if i != len(self._encoders) - 1:
encoder_outputs.append(x)
x = F.max_pool2d(x, kernel_size=2)

for i, decoder in enumerate(self._decoders):
x = decoder(x)
x = F.interpolate(x, scale_factor=2, mode="bilinear")
x = torch.cat([x, encoder_outputs.pop()], dim=1)

x = self._inverse_bock(x)
x = self._attention_tail(x)
return x + global_residual


model = LPIENet(3, 3, [4, 8, 16], [8, 4])
x = torch.rand(1, 3, 16, 16)
out = model(x)
print(out.shape)
1 change: 1 addition & 0 deletions lpienet/lpienet-tflite.ipynb

Large diffs are not rendered by default.

Binary file added media/lpienet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added media/papers/bokeh-ntire23.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added media/papers/isp-aaai22.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added media/papers/lpienet-wacv23.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added media/papers/reisp-aim22.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit b5e59ed

Please sign in to comment.