Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
SCiarella committed Dec 4, 2024
1 parent bedb0b8 commit 95ddf5e
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ test/test_data/*.jld2

# Benchmarking
simulations/Benchmark/output
simulations/Benchmark/Manifest.toml
8 changes: 4 additions & 4 deletions simulations/Benchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ Accessors = "0.1"
Adapt = "4"
CUDA = "5"
CairoMakie = "0.12"
CoupledNODE = "0.0.2"
CoupledNODE = "0.0"
Dates = "1"
DocStringExtensions = "0.9"
EnumX = "1"
FFTW = "1"
IncompressibleNavierStokes = "2"
IncompressibleNavierStokes = "^2.0.1"
JLD2 = "0.5"
LaTeXStrings = "1"
LinearAlgebra = "1"
Expand All @@ -49,9 +49,9 @@ Lux = "1"
LuxCUDA = "0.3"
MLUtils = "0.4"
NNlib = "0.9"
NeuralClosure = "1"
NeuralClosure = "1.0.0"
Observables = "0.5"
Optimisers = "0.3, 0.4"
ParameterSchedulers = "0.4"
SparseArrays = "1"
julia = "1.9"
julia = "1.10"
3 changes: 2 additions & 1 deletion simulations/Benchmark/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ params = (;
Re = T(6e3),
tburn = T(0.5),
tsim = T(2),
savefreq = 50,
savefreq = 100,
ndns = 64,
nles = [32,],
filters = (FaceAverage(),),
Expand All @@ -137,6 +137,7 @@ params = (;
bodyforce = (dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y),
issteadybodyforce = true,
processors = (; log = timelogger(; nupdate = 100)),
Δt = T(1e-3),
)

# DNS seeds
Expand Down
28 changes: 12 additions & 16 deletions simulations/Benchmark/src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ createdata(; params, seeds, outdir, taskid) =
ispath(datadir) || mkpath(datadir)
push!(filenames, f)
end
data = create_les_data(; params..., rng = Xoshiro(seed), filenames)
data = create_les_data(; params..., rng = Xoshiro(seed), filenames, Δt = params.Δt)
@info(
"Trajectory info:",
data[1].comptime / 60,
Expand Down Expand Up @@ -83,11 +83,15 @@ function trainprior(;
figfile = joinpath(figdir, splitext(basename(priorfile))[1] * ".pdf")
checkfile = join(splitext(priorfile), "_checkpoint")
batchseed, validseed = splitseed(priorseed, 2) # Same seed for all training setups
setup = getsetup(; params, nles)
#data_train =
# map(s -> namedtupleload(getdatafile(outdir, nles, Φ, s)), dns_seeds_train)
#data_valid =
# map(s -> namedtupleload(getdatafile(outdir, nles, Φ, s)), dns_seeds_valid)

# Read the data in the format expected by the CoupledNODE
T = eltype(params.Re)
setup = []
for nl in nles
x = ntuple-> LinRange(T(0.0), T(1.0), nl + 1), params.D)
push!(setup, Setup(; x = x, Re = params.Re))
end

# Read the data in the format expected by the CoupledNODE
data_train = []
for s in dns_seeds_train
Expand All @@ -99,18 +103,10 @@ function trainprior(;
data_i = namedtupleload(getdatafile(outdir, nles, Φ, s))
push!(data_valid, hcat(data_i))
end
@show length(data_train)
@show typeof(data_train)
@assert false
@show size(data_train[1].u)
@show size(data_train[1])
io_train = CoupledNODE.NavierStokes.create_io_arrays_priori(data_train, setup)
@assert false
# io_valid = CoupledNODE.NavierStokes.create_io_arrays_priori(data_valid, setup)
io_valid = CoupledNODE.NavierStokes.create_io_arrays_priori(data_valid, setup)
θ = device(θ_start)
dataloader_prior = CoupledNODE.NavierStokes.create_dataloader_prior(data_train; batchsize = batchsize,rng=dns_seeds_train)
@info dataloader_prior()
@assert false
dataloader_prior = CoupledNODE.NavierStokes.create_dataloader_prior(io_train[1]; batchsize = batchsize,rng=Random.Xoshiro(dns_seeds_train[1]))
train_data_priori = dataloader_prior()
loss = loss_priori_lux(closure, θ, st, train_data_priori)
optstate = Optimisers.setup(opt, θ)
Expand Down

0 comments on commit 95ddf5e

Please sign in to comment.