Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikola committed Jul 11, 2023
1 parent 2fbffae commit e73665d
Show file tree
Hide file tree
Showing 8 changed files with 303 additions and 9 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2023 [email protected]
Copyright (c) 2023 [email protected]

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
27 changes: 23 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,32 @@
name = "SSIMLoss"
uuid = "4d144c37-fac8-4848-9cd1-d34bc60114c1"
authors = ["[email protected]"]
version = "1.0.0-DEV"
authors = ["[email protected]"]
version = "1.0.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[extensions]
SSIMLossCUDAExt = ["CUDA", "cuDNN"]

[compat]
julia = "1"
julia = "1.9"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ImageQualityIndexes = "2996bd0c-7a13-11e9-2da2-2f5ce47296a9"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["Test"]
test = ["Test", "CUDA", "cuDNN", "Zygote", "Images", "TestImages", "ImageQualityIndexes"]
7 changes: 7 additions & 0 deletions ext/SSIMLossCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module SSIMLossCUDAExt

using SSIMLoss, CUDA, cuDNN

SSIMLoss.ssim_kernel(x::AnyCuArray{T, N}) where {T, N} = cu(ssim_kernel(T, N))

end
105 changes: 104 additions & 1 deletion src/SSIMLoss.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,108 @@
module SSIMLoss

# Write your package code here.
using NNlib, MLUtils, Statistics
using ChainRulesCore

include("utils.jl")
export ssim_kernel

"""
ssim(x, y, kernel=ssim_kernel(x); peakval=1, crop=true, dims=:)
Return the [structural similarity index
measure](https://en.wikipedia.org/wiki/Structural_similarity) (SSIM) between
two signals. SSIM is computed via the mean of a sliding window of
statistics computed between the two signals. By default, the sliding window is
a Gaussian with side-length 11 in each signal dimension and σ=1.5. `crop=false` will pad `x` and `y`
such that the sliding window computes statistics centered at every pixel of the input (via same-size convolution).
`ssim` computes statistics independently over channel and batch dimensions.
`x` and `y` may be 3D/4D/5D tensors with channel and batch-dimensions.
`peakval=1` is the standard for image comparisons, but in practice should be
set to the maximum value of your signal type.
`dims` determines which dimensions to average the computed statistics over. If
`dims=1:ndims(x)-1`, SSIM will be computed for each batch-element separately.
The results of `ssim` are matched against those of
[ImageQualityIndexes](https://github.com/JuliaImages/ImageQualityIndexes.jl)
for grayscale and RGB images (i.e. x, y both of size (N1, N2, 1, B) and (N1, N2, 3, B) for grayscale and color images, resp.).
See also [`ssim_loss`](@ref), [`ssim_loss_fast`](@ref).
"""
function ssim(x::AbstractArray{T,N}, y::AbstractArray{T,N}, kernel=ssim_kernel(x); peakval=T(1.0), crop=true, dims=:) where {T,N}
_check_sizes(x, y)

# apply same kernel on each channel dimension separately via groups=in_channels
groups = size(x, N-1)
kernel = repeat(kernel, ones(Int, N-1)..., groups)

# constants to avoid division by zero
SSIM_K = (0.01, 0.03)
C₁, C₂ = @. T(peakval * SSIM_K)^2

# crop==true -> valid-sized conv (do nothing),
# otherwise, pad for same-sized conv
if !crop
# from Flux.jl:src/layers/conv.jl (calc_padding)
padding = Tuple(mapfoldl(i -> [cld(i, 2), fld(i,2)], vcat, size(kernel)[1:N-2] .- 1))
x = pad_symmetric(x, padding)
y = pad_symmetric(y, padding)
end

μx = conv(x, kernel; groups=groups)
μy = conv(y, kernel; groups=groups)
μx² = μx.^2
μy² = μy.^2
μxy = μx.*μy
σx² = conv(x.^2, kernel; groups=groups) .- μx²
σy² = conv(y.^2, kernel; groups=groups) .- μy²
σxy = conv(x.*y, kernel; groups=groups) .- μxy

ssim_map = @. (2μxy + C₁)*(2σxy + C₂)/((μx² + μy² + C₁)*(σx² + σy² + C₂))
return mean(ssim_map, dims=dims)
end

"""
ssim_loss(x, y, kernel=ssim_kernel(x); peakval=1, crop=true, dims=:)
Computes `1 - ssim(x, y)`, suitable for use as a loss function with gradient descent.
For faster training, it is recommended to store a kernel and reuse it, ex.,
```julia
kernel = ssim_kernel(Float32, 2) |> gpu
# or alternatively for faster computation
# kernel = ones(Float32, 5, 5, 1, num_channels) |> gpu
for (x, y) in dataloader
x, y = (x, y) .|> gpu
grads = gradient(model) do m
x̂ = m(y)
ssim_loss(x, x̂, kernel)
end
# update the model ...
end
```
See [`ssim`](@ref) for a detailed description of SSIM and the above arguments.
See also [`ssim_loss_fast`](@ref).
"""
ssim_loss(x::AbstractArray{T}, args...; kws...) where T = one(T) - ssim(x, args...; kws...)

"""
ssim_loss_fast(x, y; kernel_length=5, peakval=1, crop=true, dims=:)
Computes `ssim_loss` with an averaging kernel instead of a large Gaussian
kernel for faster computation. `kernel_length` specifies the averaging kernel
side-length in each signal dimension of x, y. See [`ssim`](@ref) for a
detailed description of SSIM and the above arguments.
See also [`ssim_loss`](@ref).
"""
function ssim_loss_fast(x::AbstractArray{T, N}, y::AbstractArray{T, N}; kernel_length=5, kws...) where {T, N}
kernel = ones_like(x, (kernel_length*ones(Int, N-2)..., 1, 1))
kernel = kernel ./ sum(kernel)
return ssim_loss(x, y, kernel; kws...)
end

export ssim, ssim_loss, ssim_loss_fast

end
53 changes: 53 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Gaussian kernel std=1.5, length=11
const SSIM_KERNEL =
[0.00102838008447911,
0.007598758135239185,
0.03600077212843083,
0.10936068950970002,
0.2130055377112537,
0.26601172486179436,
0.2130055377112537,
0.10936068950970002,
0.03600077212843083,
0.007598758135239185,
0.00102838008447911]

"""
ssim_kernel(T, N)
Return Gaussian kernel with σ=1.5 and side-length 11 for use in [`ssim`](@ref).
Returned kernel will be `N-2` dimensional of type `T`.
"""
function ssim_kernel(T::Type, N::Integer)
if N-2 == 1
kernel = SSIM_KERNEL
elseif N-2 == 2
kernel = SSIM_KERNEL*SSIM_KERNEL'
elseif N-2 == 3
ks = length(SSIM_KERNEL)
kernel = reshape(SSIM_KERNEL*SSIM_KERNEL', 1, ks, ks).*SSIM_KERNEL
else
throw("SSIM is only implemented for 3D/4D/5D inputs, dimension=$N provided.")
end
return reshape(T.(kernel), size(kernel)..., 1, 1)
end
ChainRulesCore.@non_differentiable ssim_kernel(T::Any, N::Any)

"""
ssim_kernel(x::AbstractArray{T, N}) where {T, N}
Return Gaussian kernel with σ=1.5 and side-length 11 for use in [`ssim`](@ref).
Returned array will be on the same device as `x`.
"""
ssim_kernel(x::Array{T, N}) where {T, N} = ssim_kernel(T, N)
ChainRulesCore.@non_differentiable ssim_kernel(x::Any)

function _check_sizes(x::AbstractArray, y::AbstractArray)
for d in 1:max(ndims(x), ndims(y))
size(x,d) == size(y,d) || throw(DimensionMismatch(
"loss function expects size(ŷ) = $(size(ŷ)) to match size(y) = $(size(y))"
))
end
end
_check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1
ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any)
36 changes: 36 additions & 0 deletions test/cuda.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
CUDA.allowscalar(false)

@testset "type $T, ndims=$N, $loss" for T=(Float64, Float32, Float16), N=1:3, loss in (ssim, ssim_loss, ssim_loss_fast)
# see https://github.com/FluxML/NNlib.jl/issues/505
# Float16 conv is broken for 5D tensors
if T==Float16 && (N==3 || loss in (ssim, ssim_loss))
continue
end

x_cpu = rand(T, 16*ones(Int, N)..., 2, 2)
y_cpu = rand(T, 16*ones(Int, N)..., 2, 2)
x_gpu = cu(x_cpu)
y_gpu = cu(y_cpu)

@testset "sanity check" begin
@test ssim(x_gpu, x_gpu) 1
@test ssim_loss(x_gpu, x_gpu) 0
@test ssim_loss_fast(x_gpu, x_gpu) 0
end

@testset "cpu == gpu" loss(x_cpu, y_cpu) loss(x_gpu, y_gpu)

@testset "grad cpu == gpu" begin
out_cpu, back_cpu = pullback((x, y) -> loss(x, y), x_cpu, y_cpu)
c = randn(T)
gs_cpu = back_cpu(c)

out_gpu, back_gpu = pullback((x, y) -> loss(x, y), x_gpu, y_gpu)
gs_gpu = back_gpu(c)

@test collect(out_cpu) collect(out_gpu)
for (g_cpu, g_gpu) in zip(gs_cpu, gs_gpu)
@test collect(g_cpu) collect(g_gpu)
end
end
end
23 changes: 20 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
using SSIMLoss
using Test
using SSIMLoss, CUDA, cuDNN

@testset "SSIMLoss.jl" begin
# Write your tests here.
using Zygote
using Random
using ChainRulesTestUtils

using MLUtils
using Images, TestImages, ImageQualityIndexes

Random.seed!(0)

@testset "SSIMLoss" begin
include("ssim.jl")

@testset "CUDA" begin
if CUDA.functional()
include("cuda.jl")
else
@warn "CUDA unavailable, not testing GPU support"
end
end
end
59 changes: 59 additions & 0 deletions test/ssim.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# monarch_color_256 and fabio_color_256 testimages
# used to obtain below numbers.
# true/false denote `assess_ssim(...; crop=true/false)`
const iqi_rgb_true = 0.1299260389807608
const iqi_gry_true = 0.13380159790218638
const iqi_rgb_false = 0.13683875886675542
const iqi_gry_false = 0.14181793989104552

@testset "IQI consistency" begin
# color-image testing
# ssim values for monarch-fabio
@test SSIMLoss.SSIM_KERNEL == ImageQualityIndexes.SSIM_KERNEL.parent

# get reference images
imx_rgb = testimage("monarch_color_256")
imy_rgb = testimage("fabio_color_256")
imx_gry = Gray.(imx_rgb)
imy_gry = Gray.(imy_rgb)
x_rgb = permutedims(channelview(imx_rgb), (2, 3, 1)) .|> Float64 |> unsqueeze(dims=4)
y_rgb = permutedims(channelview(imy_rgb), (2, 3, 1)) .|> Float64 |> unsqueeze(dims=4)
x_gry = imx_gry .|> Float64 |> unsqueeze(dims=3) |> unsqueeze(dims=4)
y_gry = imy_gry .|> Float64 |> unsqueeze(dims=3) |> unsqueeze(dims=4)

# 8 tests enumerating rgb/gray, crop/nocrop, iqi/flux vs. ref
for (ssim_iqi, crop) in
zip(((iqi_rgb_true, iqi_gry_true), (iqi_rgb_false, iqi_gry_false)), (true, false))

for (imx, imy, x, y, ssim_ref) in
zip((imx_rgb, imx_gry), (imy_rgb, imy_gry), (x_rgb, x_gry), (y_rgb, y_gry), ssim_iqi)

color = eltype(imx) <: RGB ? "RGB" : "Gray"
@testset "crop=$crop, color=$color" begin
# make sure IQI is same
@test assess_ssim(imx, imy; crop=crop) ssim_ref
# test flux against IQI
@test ssim(x, y; crop=crop) ssim_ref atol=1e-6
end
end
end
end

@testset "$T, ndims=$N" for T in (Float64, Float32, Float16), N=1:3
x = rand(T, 16*ones(Int, N)..., 2, 2)
y = rand(T, 16*ones(Int, N)..., 2, 2)

@testset "sanity check" begin
@test ssim(x, x) 1
@test ssim_loss(x, x) 0
@test ssim_loss_fast(x, x) 0
end

@testset "$f" for f in (ssim, ssim_loss, ssim_loss_fast)
@testset "no spurious promotions" begin
fwd, back = pullback(f, x, y)
@test fwd isa T
@test eltype(back(one(T))[1]) == T
end
end
end

0 comments on commit e73665d

Please sign in to comment.