Skip to content

Commit

Permalink
add weight and gradient logging (#35)
Browse files Browse the repository at this point in the history
* WIP: add weight and gradient logging

* update to `log_arrays!` and add doctests

* make test more robust

* add tests

* add `step_logger!` call

* bump version

* bump minimum Julia version to 1.6

* Update src/LighthouseFlux.jl
  • Loading branch information
ericphanson authored May 3, 2022
1 parent 9e500f2 commit 0af2901
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 23 deletions.
17 changes: 9 additions & 8 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.5'
- '1.6'
- '1'
os:
- ubuntu-latest
Expand Down Expand Up @@ -48,13 +48,14 @@ jobs:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: '1'
- run: |
julia --project=docs -e '
using Pkg
Pkg.develop(PackageSpec(path=pwd()))
Pkg.instantiate()'
- run: julia --project=docs docs/make.jl
version: 1.6
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-docdeploy@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
# Get codecov for doctests
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v2
with:
file: lcov.info
11 changes: 7 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
name = "LighthouseFlux"
uuid = "56a5d6c5-c9a8-4db3-ae3d-7c3fdb50c563"
authors = ["Beacon Biosignals, Inc."]
version = "0.3.4"
version = "0.3.5"

[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Lighthouse = "ac2c24cd-07f0-4848-96b2-1b82c3ea0e59"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CairoMakie = "0.6.2"
CairoMakie = "0.6.2, 0.7"
Flux = "0.10.4, 0.11, 0.12"
Lighthouse = "0.12, 0.13"
Functors = "0.2.8"
Lighthouse = "0.14.6"
StableRNGs = "1.0"
Zygote = "0.4.13, 0.5, 0.6"
julia = "1.5"
julia = "1.6"

[extras]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Expand Down
3 changes: 2 additions & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LighthouseFlux = "56a5d6c5-c9a8-4db3-ae3d-7c3fdb50c563"

[compat]
Documenter = "0.25"
Documenter = "0.27"
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Documenter

makedocs(; modules=[LighthouseFlux], sitename="LighthouseFlux",
authors="Beacon Biosignals and other contributors",
pages=["API Documentation" => "index.md"])
pages=["API Documentation" => "index.md"],
strict=true)

deploydocs(; repo="github.com/beacon-biosignals/LighthouseFlux.jl.git", devbranch="main")
7 changes: 7 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,10 @@ LighthouseFlux.loss
LighthouseFlux.loss_and_prediction
LighthouseFlux.evaluate_chain_in_debug_mode
```

## Internal functions

```@docs
LighthouseFlux.gather_weights_gradients
LighthouseFlux.fforeach_pairs
```
97 changes: 94 additions & 3 deletions src/LighthouseFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ module LighthouseFlux

using Zygote: Zygote
using Flux: Flux
using Lighthouse: Lighthouse, classes, log_resource_info!, log_value!
using Lighthouse: Lighthouse, classes, log_resource_info!, log_values!, log_arrays!, step_logger!
using Functors
using Statistics

export FluxClassifier

Expand Down Expand Up @@ -81,6 +83,92 @@ function loss_and_prediction(model, input_batch, other_batch_arguments...)
return (loss(model, input_batch, other_batch_arguments...), model(input_batch))
end

# Modified from `Functors.fmap`
"""
fforeach_pairs(F, x, keys=(); exclude=Functors.isleaf, cache=IdDict(),
prune=Functors.NoKeyword(), combine=(ks, k) -> (ks..., k))
Walks the Functors.jl-compatible graph `x` (by calling `pairs ∘ Functors.children`), applying
`F(parent_key, child)` at each step along the way. Here `parent_key` is the `key` part of a
key-value pair returned from `pairs ∘ Functors.children`, combined with the previous `parent_key`
by `combine`.
## Example
```jldoctest ex
julia> using Functors, LighthouseFlux
julia> struct Foo; x; y; end
julia> @functor Foo
julia> struct Bar; x; end
julia> @functor Bar
julia> m = Foo(Bar([1,2,3]), (4, 5, Bar(Foo(6, 7))));
julia> LighthouseFlux.fforeach_pairs((k,v) -> @show((k, v)), m)
(k, v) = ((:x,), Bar([1, 2, 3]))
(k, v) = ((:x, :x), [1, 2, 3])
(k, v) = ((:y,), (4, 5, Bar(Foo(6, 7))))
(k, v) = ((:y, 1), 4)
(k, v) = ((:y, 2), 5)
(k, v) = ((:y, 3), Bar(Foo(6, 7)))
(k, v) = ((:y, 3, :x), Foo(6, 7))
(k, v) = ((:y, 3, :x, :x), 6)
(k, v) = ((:y, 3, :x, :y), 7)
```
The `combine` argument can be used to customize how the keys are combined. For example
```jldoctest ex
julia> LighthouseFlux.fforeach_pairs((k,v) -> @show((k, v)), m, ""; combine=(ks, k) -> string(ks, "/", k))
(k, v) = ("/x", Bar([1, 2, 3]))
(k, v) = ("/x/x", [1, 2, 3])
(k, v) = ("/y", (4, 5, Bar(Foo(6, 7))))
(k, v) = ("/y/1", 4)
(k, v) = ("/y/2", 5)
(k, v) = ("/y/3", Bar(Foo(6, 7)))
(k, v) = ("/y/3/x", Foo(6, 7))
(k, v) = ("/y/3/x/x", 6)
(k, v) = ("/y/3/x/y", 7)
```
"""
function fforeach_pairs(F, x, keys=(); exclude=Functors.isleaf, cache=IdDict(),
prune=Functors.NoKeyword(), combine=(ks, k) -> (ks..., k))
walk = (f, x) -> for (k, v) in pairs(Functors.children(x))
F(combine(keys, k), v)
f(k, v)
end
haskey(cache, x) && return prune isa Functors.NoKeyword ? cache[x] : prune
cache[x] = exclude(x) ? (keys, x) :
walk((k, x) -> fforeach_pairs(F, x, combine(keys, k); combine, exclude,
cache, prune), x)
return nothing
end

"""
gather_weights_gradients(classifier, gradients)
Collects the weights and gradients from `classifier` into a `Dict`.
"""
function gather_weights_gradients(classifier, gradients)
values = Dict{String, Any}()
fforeach_pairs(classifier.model, "";
combine=(ks, k) -> string(ks, "/", k)) do k, v
if haskey(gradients, v)
values[string("train/gradients", k)] = gradients[v]
end
if v isa AbstractArray
values[string("train/weights", k)] = v
end
end
return values
end

#####
##### Lighthouse `AbstractClassifier` Interface
#####
Expand All @@ -98,20 +186,23 @@ Lighthouse.onecold(classifier::FluxClassifier, label) = classifier.onecold(label
function Lighthouse.train!(classifier::FluxClassifier, batches, logger)
Flux.trainmode!(classifier.model)
weights = Zygote.Params(classifier.params)
for batch in batches
for (i, batch) in enumerate(batches)
train_loss, back = log_resource_info!(logger, "train/forward_pass";
suffix="_per_batch") do
f = () -> loss(classifier.model, batch...)
return Zygote.pullback(f, weights)
end
log_value!(logger, "train/loss_per_batch", train_loss)
log_values!(logger, ("train/loss_per_batch" => train_loss,
"train/batch_index" => i))
gradients = log_resource_info!(logger, "train/reverse_pass"; suffix="_per_batch") do
return back(Zygote.sensitivity(train_loss))
end
log_resource_info!(logger, "train/update"; suffix="_per_batch") do
Flux.Optimise.update!(classifier.optimiser, weights, gradients)
return nothing
end
log_arrays!(logger, gather_weights_gradients(classifier, gradients))
step_logger!(logger)
end
Flux.testmode!(classifier.model)
return nothing
Expand Down
23 changes: 17 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test, StableRNGs
using LighthouseFlux, Lighthouse, Flux, Random
using Statistics

using CairoMakie
CairoMakie.activate!(type="png")
Expand Down Expand Up @@ -36,19 +37,18 @@ end
logger = Lighthouse.LearnLogger(joinpath(tmpdir, "logs"), "test_run")
limit = 4
let counted = 0
upon_loss_decrease = Lighthouse.upon(logger,
"test_set_prediction/mean_loss_per_epoch";
condition=<, initial=Inf)
# Every epoch has the same number of batches, namely 100
upon_batch_index_same = Lighthouse.upon(logger,
"train/batch_index";
condition=(==), initial=100)
callback = n -> begin
upon_loss_decrease() do _
upon_batch_index_same() do _
counted += n
end
end
elected = majority.((rng,), eachrow(votes), (1:length(classes),))
Lighthouse.learn!(classifier, logger, () -> train_batches, () -> test_batches,
votes, elected; epoch_limit=limit, post_epoch_callback=callback)
# NOTE: the RNG chosen above just happens to allow this to work every time,
# since the loss happens to actually "improve" on the random data each epoch
@test counted == sum(1:limit)
end
for key in [
Expand All @@ -65,6 +65,14 @@ end
"train/update/gc_time_in_seconds_per_batch"
"train/update/allocations_per_batch"
"train/update/memory_in_mb_per_batch"
"train/gradients/chain/1/weight"
"train/gradients/chain/2/bias"
"train/gradients/chain/2/weight"
"train/gradients/chain/1/bias"
"train/weights/chain/1/weight"
"train/weights/chain/2/bias"
"train/weights/chain/2/weight"
"train/weights/chain/1/bias"
]
@test length(logger.logged[key]) == length(train_batches) * limit
end
Expand Down Expand Up @@ -98,6 +106,9 @@ end
onehot=(x -> fill(x, length(classes))), onecold=sum)
@test Lighthouse.onehot(classifier, 3) == fill(3, length(classes))
@test Lighthouse.onecold(classifier, [0.31, 0.43, 0.13]) == 0.87

@test mean(classifier.model.chain[1].weight) logger.logged["train/weights/chain/1/weight"][end]
@test mean(classifier.model.chain[2].weight) logger.logged["train/weights/chain/2/weight"][end]
end
end

Expand Down

2 comments on commit 0af2901

@ericphanson
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/59599

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.5 -m "<description of version>" 0af2901ca206f8f2325c0eb0f77a8c4edb769655
git push origin v0.3.5

Please sign in to comment.