Skip to content

Commit

Permalink
Update train.jl: p -> u.p in OptimizationState variable
Browse files Browse the repository at this point in the history
  • Loading branch information
facusapienza21 authored Nov 13, 2024
1 parent b7f5773 commit 6171c4a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ function train(data::AD,
callback = function (p, l)
push!(losses, l)
if length(losses) % 100 == 0
_, l_dict = f_loss(p)
_, l_dict = f_loss(p.u)
@printf "Iteration: [%5d / %5d] \t Loss: %.9f = Empirical: %.9f + Regularization: %.9f \n" length(losses) (params.niter_ADAM+params.niter_LBFGS) sum(values(l_dict)) l_dict["Empirical"] (sum(values(l_dict))-l_dict["Empirical"])
end
if params.train_initial_condition
p.u0 ./= norm(p.u0)
p.u.u0 ./= norm(p.u.u0)
end
return false
end
Expand Down Expand Up @@ -130,4 +130,4 @@ function train(data::AD,
return Results=θ_trained, u0=u0_trained, U=U, st=st,
fit_times=fit_times, fit_directions=fit_directions,
fit_rotations=fit_rotations, losses=losses)
end
end

0 comments on commit 6171c4a

Please sign in to comment.