-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Nikola
committed
Jul 11, 2023
1 parent
2fbffae
commit e73665d
Showing
8 changed files
with
303 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
MIT License | ||
|
||
Copyright (c) 2023 [email protected] | ||
Copyright (c) 2023 [email protected] | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,32 @@ | ||
name = "SSIMLoss" | ||
uuid = "4d144c37-fac8-4848-9cd1-d34bc60114c1" | ||
authors = ["[email protected]"] | ||
version = "1.0.0-DEV" | ||
authors = ["[email protected]"] | ||
version = "1.0.0" | ||
|
||
[deps] | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" | ||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
|
||
[weakdeps] | ||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" | ||
|
||
[extensions] | ||
SSIMLossCUDAExt = ["CUDA", "cuDNN"] | ||
|
||
[compat] | ||
julia = "1" | ||
julia = "1.9" | ||
|
||
[extras] | ||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
ImageQualityIndexes = "2996bd0c-7a13-11e9-2da2-2f5ce47296a9" | ||
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" | ||
|
||
[targets] | ||
test = ["Test"] | ||
test = ["Test", "CUDA", "cuDNN", "Zygote", "Images", "TestImages", "ImageQualityIndexes"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
module SSIMLossCUDAExt | ||
|
||
using SSIMLoss, CUDA, cuDNN | ||
|
||
SSIMLoss.ssim_kernel(x::AnyCuArray{T, N}) where {T, N} = cu(ssim_kernel(T, N)) | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,108 @@ | ||
module SSIMLoss | ||
|
||
# Write your package code here. | ||
using NNlib, MLUtils, Statistics | ||
using ChainRulesCore | ||
|
||
include("utils.jl") | ||
export ssim_kernel | ||
|
||
""" | ||
ssim(x, y, kernel=ssim_kernel(x); peakval=1, crop=true, dims=:) | ||
Return the [structural similarity index | ||
measure](https://en.wikipedia.org/wiki/Structural_similarity) (SSIM) between | ||
two signals. SSIM is computed via the mean of a sliding window of | ||
statistics computed between the two signals. By default, the sliding window is | ||
a Gaussian with side-length 11 in each signal dimension and σ=1.5. `crop=false` will pad `x` and `y` | ||
such that the sliding window computes statistics centered at every pixel of the input (via same-size convolution). | ||
`ssim` computes statistics independently over channel and batch dimensions. | ||
`x` and `y` may be 3D/4D/5D tensors with channel and batch-dimensions. | ||
`peakval=1` is the standard for image comparisons, but in practice should be | ||
set to the maximum value of your signal type. | ||
`dims` determines which dimensions to average the computed statistics over. If | ||
`dims=1:ndims(x)-1`, SSIM will be computed for each batch-element separately. | ||
The results of `ssim` are matched against those of | ||
[ImageQualityIndexes](https://github.com/JuliaImages/ImageQualityIndexes.jl) | ||
for grayscale and RGB images (i.e. x, y both of size (N1, N2, 1, B) and (N1, N2, 3, B) for grayscale and color images, resp.). | ||
See also [`ssim_loss`](@ref), [`ssim_loss_fast`](@ref). | ||
""" | ||
function ssim(x::AbstractArray{T,N}, y::AbstractArray{T,N}, kernel=ssim_kernel(x); peakval=T(1.0), crop=true, dims=:) where {T,N} | ||
_check_sizes(x, y) | ||
|
||
# apply same kernel on each channel dimension separately via groups=in_channels | ||
groups = size(x, N-1) | ||
kernel = repeat(kernel, ones(Int, N-1)..., groups) | ||
|
||
# constants to avoid division by zero | ||
SSIM_K = (0.01, 0.03) | ||
C₁, C₂ = @. T(peakval * SSIM_K)^2 | ||
|
||
# crop==true -> valid-sized conv (do nothing), | ||
# otherwise, pad for same-sized conv | ||
if !crop | ||
# from Flux.jl:src/layers/conv.jl (calc_padding) | ||
padding = Tuple(mapfoldl(i -> [cld(i, 2), fld(i,2)], vcat, size(kernel)[1:N-2] .- 1)) | ||
x = pad_symmetric(x, padding) | ||
y = pad_symmetric(y, padding) | ||
end | ||
|
||
μx = conv(x, kernel; groups=groups) | ||
μy = conv(y, kernel; groups=groups) | ||
μx² = μx.^2 | ||
μy² = μy.^2 | ||
μxy = μx.*μy | ||
σx² = conv(x.^2, kernel; groups=groups) .- μx² | ||
σy² = conv(y.^2, kernel; groups=groups) .- μy² | ||
σxy = conv(x.*y, kernel; groups=groups) .- μxy | ||
|
||
ssim_map = @. (2μxy + C₁)*(2σxy + C₂)/((μx² + μy² + C₁)*(σx² + σy² + C₂)) | ||
return mean(ssim_map, dims=dims) | ||
end | ||
|
||
""" | ||
ssim_loss(x, y, kernel=ssim_kernel(x); peakval=1, crop=true, dims=:) | ||
Computes `1 - ssim(x, y)`, suitable for use as a loss function with gradient descent. | ||
For faster training, it is recommended to store a kernel and reuse it, ex., | ||
```julia | ||
kernel = ssim_kernel(Float32, 2) |> gpu | ||
# or alternatively for faster computation | ||
# kernel = ones(Float32, 5, 5, 1, num_channels) |> gpu | ||
for (x, y) in dataloader | ||
x, y = (x, y) .|> gpu | ||
grads = gradient(model) do m | ||
x̂ = m(y) | ||
ssim_loss(x, x̂, kernel) | ||
end | ||
# update the model ... | ||
end | ||
``` | ||
See [`ssim`](@ref) for a detailed description of SSIM and the above arguments. | ||
See also [`ssim_loss_fast`](@ref). | ||
""" | ||
ssim_loss(x::AbstractArray{T}, args...; kws...) where T = one(T) - ssim(x, args...; kws...) | ||
|
||
""" | ||
ssim_loss_fast(x, y; kernel_length=5, peakval=1, crop=true, dims=:) | ||
Computes `ssim_loss` with an averaging kernel instead of a large Gaussian | ||
kernel for faster computation. `kernel_length` specifies the averaging kernel | ||
side-length in each signal dimension of x, y. See [`ssim`](@ref) for a | ||
detailed description of SSIM and the above arguments. | ||
See also [`ssim_loss`](@ref). | ||
""" | ||
function ssim_loss_fast(x::AbstractArray{T, N}, y::AbstractArray{T, N}; kernel_length=5, kws...) where {T, N} | ||
kernel = ones_like(x, (kernel_length*ones(Int, N-2)..., 1, 1)) | ||
kernel = kernel ./ sum(kernel) | ||
return ssim_loss(x, y, kernel; kws...) | ||
end | ||
|
||
export ssim, ssim_loss, ssim_loss_fast | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Gaussian kernel std=1.5, length=11 | ||
const SSIM_KERNEL = | ||
[0.00102838008447911, | ||
0.007598758135239185, | ||
0.03600077212843083, | ||
0.10936068950970002, | ||
0.2130055377112537, | ||
0.26601172486179436, | ||
0.2130055377112537, | ||
0.10936068950970002, | ||
0.03600077212843083, | ||
0.007598758135239185, | ||
0.00102838008447911] | ||
|
||
""" | ||
ssim_kernel(T, N) | ||
Return Gaussian kernel with σ=1.5 and side-length 11 for use in [`ssim`](@ref). | ||
Returned kernel will be `N-2` dimensional of type `T`. | ||
""" | ||
function ssim_kernel(T::Type, N::Integer) | ||
if N-2 == 1 | ||
kernel = SSIM_KERNEL | ||
elseif N-2 == 2 | ||
kernel = SSIM_KERNEL*SSIM_KERNEL' | ||
elseif N-2 == 3 | ||
ks = length(SSIM_KERNEL) | ||
kernel = reshape(SSIM_KERNEL*SSIM_KERNEL', 1, ks, ks).*SSIM_KERNEL | ||
else | ||
throw("SSIM is only implemented for 3D/4D/5D inputs, dimension=$N provided.") | ||
end | ||
return reshape(T.(kernel), size(kernel)..., 1, 1) | ||
end | ||
ChainRulesCore.@non_differentiable ssim_kernel(T::Any, N::Any) | ||
|
||
""" | ||
ssim_kernel(x::AbstractArray{T, N}) where {T, N} | ||
Return Gaussian kernel with σ=1.5 and side-length 11 for use in [`ssim`](@ref). | ||
Returned array will be on the same device as `x`. | ||
""" | ||
ssim_kernel(x::Array{T, N}) where {T, N} = ssim_kernel(T, N) | ||
ChainRulesCore.@non_differentiable ssim_kernel(x::Any) | ||
|
||
function _check_sizes(x::AbstractArray, y::AbstractArray) | ||
for d in 1:max(ndims(x), ndims(y)) | ||
size(x,d) == size(y,d) || throw(DimensionMismatch( | ||
"loss function expects size(ŷ) = $(size(ŷ)) to match size(y) = $(size(y))" | ||
)) | ||
end | ||
end | ||
_check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1 | ||
ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
CUDA.allowscalar(false) | ||
|
||
@testset "type $T, ndims=$N, $loss" for T=(Float64, Float32, Float16), N=1:3, loss in (ssim, ssim_loss, ssim_loss_fast) | ||
# see https://github.com/FluxML/NNlib.jl/issues/505 | ||
# Float16 conv is broken for 5D tensors | ||
if T==Float16 && (N==3 || loss in (ssim, ssim_loss)) | ||
continue | ||
end | ||
|
||
x_cpu = rand(T, 16*ones(Int, N)..., 2, 2) | ||
y_cpu = rand(T, 16*ones(Int, N)..., 2, 2) | ||
x_gpu = cu(x_cpu) | ||
y_gpu = cu(y_cpu) | ||
|
||
@testset "sanity check" begin | ||
@test ssim(x_gpu, x_gpu) ≈ 1 | ||
@test ssim_loss(x_gpu, x_gpu) ≈ 0 | ||
@test ssim_loss_fast(x_gpu, x_gpu) ≈ 0 | ||
end | ||
|
||
@testset "cpu == gpu" loss(x_cpu, y_cpu) ≈ loss(x_gpu, y_gpu) | ||
|
||
@testset "grad cpu == gpu" begin | ||
out_cpu, back_cpu = pullback((x, y) -> loss(x, y), x_cpu, y_cpu) | ||
c = randn(T) | ||
gs_cpu = back_cpu(c) | ||
|
||
out_gpu, back_gpu = pullback((x, y) -> loss(x, y), x_gpu, y_gpu) | ||
gs_gpu = back_gpu(c) | ||
|
||
@test collect(out_cpu) ≈ collect(out_gpu) | ||
for (g_cpu, g_gpu) in zip(gs_cpu, gs_gpu) | ||
@test collect(g_cpu) ≈ collect(g_gpu) | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,23 @@ | ||
using SSIMLoss | ||
using Test | ||
using SSIMLoss, CUDA, cuDNN | ||
|
||
@testset "SSIMLoss.jl" begin | ||
# Write your tests here. | ||
using Zygote | ||
using Random | ||
using ChainRulesTestUtils | ||
|
||
using MLUtils | ||
using Images, TestImages, ImageQualityIndexes | ||
|
||
Random.seed!(0) | ||
|
||
@testset "SSIMLoss" begin | ||
include("ssim.jl") | ||
|
||
@testset "CUDA" begin | ||
if CUDA.functional() | ||
include("cuda.jl") | ||
else | ||
@warn "CUDA unavailable, not testing GPU support" | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# monarch_color_256 and fabio_color_256 testimages | ||
# used to obtain below numbers. | ||
# true/false denote `assess_ssim(...; crop=true/false)` | ||
const iqi_rgb_true = 0.1299260389807608 | ||
const iqi_gry_true = 0.13380159790218638 | ||
const iqi_rgb_false = 0.13683875886675542 | ||
const iqi_gry_false = 0.14181793989104552 | ||
|
||
@testset "IQI consistency" begin | ||
# color-image testing | ||
# ssim values for monarch-fabio | ||
@test SSIMLoss.SSIM_KERNEL == ImageQualityIndexes.SSIM_KERNEL.parent | ||
|
||
# get reference images | ||
imx_rgb = testimage("monarch_color_256") | ||
imy_rgb = testimage("fabio_color_256") | ||
imx_gry = Gray.(imx_rgb) | ||
imy_gry = Gray.(imy_rgb) | ||
x_rgb = permutedims(channelview(imx_rgb), (2, 3, 1)) .|> Float64 |> unsqueeze(dims=4) | ||
y_rgb = permutedims(channelview(imy_rgb), (2, 3, 1)) .|> Float64 |> unsqueeze(dims=4) | ||
x_gry = imx_gry .|> Float64 |> unsqueeze(dims=3) |> unsqueeze(dims=4) | ||
y_gry = imy_gry .|> Float64 |> unsqueeze(dims=3) |> unsqueeze(dims=4) | ||
|
||
# 8 tests enumerating rgb/gray, crop/nocrop, iqi/flux vs. ref | ||
for (ssim_iqi, crop) in | ||
zip(((iqi_rgb_true, iqi_gry_true), (iqi_rgb_false, iqi_gry_false)), (true, false)) | ||
|
||
for (imx, imy, x, y, ssim_ref) in | ||
zip((imx_rgb, imx_gry), (imy_rgb, imy_gry), (x_rgb, x_gry), (y_rgb, y_gry), ssim_iqi) | ||
|
||
color = eltype(imx) <: RGB ? "RGB" : "Gray" | ||
@testset "crop=$crop, color=$color" begin | ||
# make sure IQI is same | ||
@test assess_ssim(imx, imy; crop=crop) ≈ ssim_ref | ||
# test flux against IQI | ||
@test ssim(x, y; crop=crop) ≈ ssim_ref atol=1e-6 | ||
end | ||
end | ||
end | ||
end | ||
|
||
@testset "$T, ndims=$N" for T in (Float64, Float32, Float16), N=1:3 | ||
x = rand(T, 16*ones(Int, N)..., 2, 2) | ||
y = rand(T, 16*ones(Int, N)..., 2, 2) | ||
|
||
@testset "sanity check" begin | ||
@test ssim(x, x) ≈ 1 | ||
@test ssim_loss(x, x) ≈ 0 | ||
@test ssim_loss_fast(x, x) ≈ 0 | ||
end | ||
|
||
@testset "$f" for f in (ssim, ssim_loss, ssim_loss_fast) | ||
@testset "no spurious promotions" begin | ||
fwd, back = pullback(f, x, y) | ||
@test fwd isa T | ||
@test eltype(back(one(T))[1]) == T | ||
end | ||
end | ||
end |