diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl index 3bed6a453a..229ca559f2 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl @@ -206,19 +206,24 @@ function build_nlsolver( if nlalg isa NonlinearSolveAlg α = tTypeNoUnits(α) dt = tTypeNoUnits(dt) - if isdae - nlf = (ztmp, z, p) -> begin - tmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f = p - _compute_rhs!(tmp, ztmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f, z)[1] + if isnothing(f.nlfunc) + if isdae + nlf = (ztmp, z, p) -> begin + tmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f = p + _compute_rhs!(tmp, ztmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f, z)[1] + end + nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, p, dt, f) + else + nlf = (ztmp, z, p) -> begin + tmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f = p + _compute_rhs!( + tmp, ztmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f, z)[1] + end + nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, DIRK, p, dt, f) end - nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, p, dt, f) else - nlf = (ztmp, z, p) -> begin - tmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f = p - _compute_rhs!( - tmp, ztmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f, z)[1] - end - nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, DIRK, p, dt, f) + nlf = f.nlfunc + nlp_params =() end prob = NonlinearProblem(NonlinearFunction(nlf), ztmp, nlp_params) cache = init(prob, nlalg.alg)