Skip to content

Commit

Permalink
bench priori done
Browse files Browse the repository at this point in the history
  • Loading branch information
SCiarella committed Dec 5, 2024
1 parent 95ddf5e commit c669ae8
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 189 deletions.
4 changes: 4 additions & 0 deletions simulations/Benchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
CoupledNODE = "88291d29-22ea-41b1-bc0b-03785bffce48"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Expand All @@ -35,8 +37,10 @@ Accessors = "0.1"
Adapt = "4"
CUDA = "5"
CairoMakie = "0.12"
ComponentArrays = "0.15.19"
CoupledNODE = "0.0"
Dates = "1"
DifferentialEquations = "7.15.0"
DocStringExtensions = "0.9"
EnumX = "1"
FFTW = "1"
Expand Down
64 changes: 20 additions & 44 deletions simulations/Benchmark/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ params = (;
lims = (T(0), T(1)),
Re = T(6e3),
tburn = T(0.5),
tsim = T(2),
tsim = T(5),
savefreq = 100,
ndns = 64,
nles = [32,],
ndns = 128,
nles = [64,],
filters = (FaceAverage(),),
backend,
icfunc = (setup, psolver, rng) -> random_field(setup, T(0); kp = 20, psolver, rng),
Expand All @@ -141,18 +141,18 @@ params = (;
)

# DNS seeds
ntrajectory = 5
ntrajectory = 8
dns_seeds = splitseed(seeds.dns, ntrajectory)
dns_seeds_train = dns_seeds[1:ntrajectory-2]
dns_seeds_valid = dns_seeds[ntrajectory-1:ntrajectory-1]
dns_seeds_test = dns_seeds[ntrajectory:ntrajectory]

# Create data
docreatedata = true
docreatedata = false
docreatedata && createdata(; params, seeds = dns_seeds, outdir, taskid)

# Computational time
docomp = true
docomp = false
docomp && let
comptime, datasize = 0.0, 0.0
for seed in dns_seeds
Expand Down Expand Up @@ -184,10 +184,10 @@ closure, θ_start, st = CoupledNODE.cnn(;
T = T,
D = params.D,
data_ch = params.D,
radii = [2, 2],
channels = [2, 2],
activations = [tanh, identity],
use_bias = [false, false],
radii = [2, 2, 2, 2],
channels = [8,8,8, 2],
activations = [tanh,tanh,tanh, identity],
use_bias = [true, true,true, false],
rng = Xoshiro(seeds.θ_start),
)

Expand Down Expand Up @@ -221,9 +221,8 @@ end

# Train
let
dotrain = true
nepoch = 10
niter = 20
dotrain = false
nepoch = 300
dotrain && trainprior(;
params,
priorseed = seeds.prior,
Expand All @@ -236,16 +235,8 @@ let
θ_start,
st,
opt = Adam(T(1.0e-3)),
λ = T(5.0e-5),
scheduler = CosAnneal(; l0 = T(1e-6), l1 = T(1e-3), period = nepoch),
nvalid = 64,
batchsize = 16,
displayref = true,
displayupdates = true, # Set to `true` if using CairoMakie
nupdate_callback = 20,
loadcheckpoint = false,
batchsize = 32,
nepoch,
niter,
)
end

Expand Down Expand Up @@ -277,20 +268,14 @@ with_theme(; palette) do
ylims!(-0.05, 1.05)
lines!(
ax,
[Point2f(0, 1), Point2f(priortraining[ig, 1].hist[end][1], 1)];
[Point2f(0, 1), Point2f(priortraining[ig, 1].lhist_val[end][1], 1)];
label = "No closure",
linestyle = :dash,
)
for (ifil, Φ) in enumerate(params.filters)
label = Φ isa FaceAverage ? "FA" : "VA"
lines!(ax, priortraining[ig, ifil].hist; label)
lines!(ax, priortraining[ig, ifil].lhist_val; label)
end
# lines!(
# ax,
# [Point2f(0, 0), Point2f(priortraining[ig, 1].hist[end][1], 0)];
# label = "DNS",
# linestyle = :dash,
# )
end
axes = filter(x -> x isa Axis, fig.content)
linkaxes!(axes...)
Expand All @@ -317,9 +302,8 @@ projectorders = ProjectOrder.First, ProjectOrder.Last

# Train
let
dotrain = false
dotrain = true
nepoch = 10
niter = 10
dotrain && trainpost(;
params,
projectorders,
Expand All @@ -329,21 +313,13 @@ let
postseed = seeds.post,
dns_seeds_train,
dns_seeds_valid,
nsubstep = 5,
nunroll = 10,
ntrajectory = 5,
nunroll = 5,
closure,
θ_start = θ_cnn_prior,
st,
opt = Adam(T(1e-4)),
λ = T(5e-8),
scheduler = CosAnneal(; l0 = T(1e-6), l1 = T(1e-4), period = nepoch),
nunroll_valid = 50,
nupdate_callback = 10,
displayref = false,
displayupdates = true,
loadcheckpoint = false,
nepoch,
niter,
nunroll_valid = 10,
nepoch
)
end

Expand Down
5 changes: 5 additions & 0 deletions simulations/Benchmark/src/Benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ module Benchmark

using Accessors
using Adapt
using ComponentArrays
using CoupledNODE
using CoupledNODE: loss_priori_lux, create_loss_post_lux
using CoupledNODE.NavierStokes: create_right_hand_side_with_closure
using Dates
using DifferentialEquations
using DocStringExtensions
using EnumX
using LinearAlgebra
Expand Down
Loading

0 comments on commit c669ae8

Please sign in to comment.