-
-
Notifications
You must be signed in to change notification settings - Fork 222
ImplictiEulerExtrapolation parallel #872
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,10 +70,12 @@ end | |
@cache mutable struct ImplicitEulerExtrapolationCache{uType,rateType,arrayType,dtType,JType,WType,F,JCType,GCType,uNoUnitsType,TFType,UFType} <: OrdinaryDiffEqMutableCache | ||
uprev::uType | ||
u_tmp::uType | ||
u_tmps::Array{uType,1} | ||
utilde::uType | ||
tmp::uType | ||
atmp::uNoUnitsType | ||
k_tmp::rateType | ||
saurabhkgp21 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
k_tmps::Array{rateType,1} | ||
dtpropose::dtType | ||
T::arrayType | ||
cur_order::Int | ||
|
@@ -89,7 +91,8 @@ end | |
tf::TFType | ||
uf::UFType | ||
linsolve_tmp::rateType | ||
saurabhkgp21 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
linsolve::F | ||
linsolve_tmps::Array{rateType,1} | ||
linsolve::Array{F,1} | ||
jac_config::JCType | ||
grad_config::GCType | ||
end | ||
|
@@ -121,10 +124,21 @@ end | |
|
||
function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}}) | ||
u_tmp = similar(u) | ||
u_tmps = Array{typeof(u_tmp),1}(undef, Threads.nthreads()) | ||
|
||
for i=1:Threads.nthreads() | ||
u_tmps[i] = zero(u_tmp) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one improvement that here is just to make the first ones in the arrays equal, i.e. |
||
end | ||
|
||
utilde = similar(u) | ||
tmp = similar(u) | ||
k = zero(rate_prototype) | ||
k_tmp = zero(rate_prototype) | ||
k_tmps = Array{typeof(k_tmp),1}(undef, Threads.nthreads()) | ||
|
||
for i=1:Threads.nthreads() | ||
k_tmps[i] = zero(rate_prototype) | ||
end | ||
|
||
cur_order = max(alg.init_order, alg.min_order) | ||
dtpropose = zero(dt) | ||
T = Array{typeof(u),2}(undef, alg.max_order, alg.max_order) | ||
|
@@ -143,22 +157,36 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni | |
du2 = zero(rate_prototype) | ||
|
||
if DiffEqBase.has_jac(f) && !DiffEqBase.has_Wfact(f) && f.jac_prototype !== nothing | ||
W = WOperator(f, dt, true) | ||
W_el = WOperator(f, dt, true) | ||
J = nothing # is J = W.J better? | ||
else | ||
J = false .* vec(rate_prototype) .* vec(rate_prototype)' # uEltype? | ||
W = similar(J) | ||
W_el = similar(J) | ||
end | ||
W = Array{typeof(W_el),1}(undef, Threads.nthreads()) | ||
for i=1:Threads.nthreads() | ||
W[i] = zero(W_el) | ||
end | ||
tf = DiffEqDiffTools.TimeGradientWrapper(f,uprev,p) | ||
uf = DiffEqDiffTools.UJacobianWrapper(f,t,p) | ||
linsolve_tmp = zero(rate_prototype) | ||
linsolve = alg.linsolve(Val{:init},uf,u) | ||
linsolve_tmps = Array{typeof(linsolve_tmp),1}(undef, Threads.nthreads()) | ||
|
||
for i=1:Threads.nthreads() | ||
linsolve_tmps[i] = zero(rate_prototype) | ||
end | ||
|
||
linsolve_el = alg.linsolve(Val{:init},uf,u) | ||
linsolve = Array{typeof(linsolve_el),1}(undef, Threads.nthreads()) | ||
for i=1:Threads.nthreads() | ||
linsolve[i] = alg.linsolve(Val{:init},uf,u) | ||
end | ||
grad_config = build_grad_config(alg,f,tf,du1,t) | ||
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,du1,du2) | ||
|
||
|
||
ImplicitEulerExtrapolationCache(uprev,u_tmp,utilde,tmp,atmp,k_tmp,dtpropose,T,cur_order,work,A,step_no, | ||
du1,du2,J,W,tf,uf,linsolve_tmp,linsolve,jac_config,grad_config) | ||
ImplicitEulerExtrapolationCache(uprev,u_tmp,u_tmps,utilde,tmp,atmp,k_tmp,k_tmps,dtpropose,T,cur_order,work,A,step_no, | ||
du1,du2,J,W,tf,uf,linsolve_tmp,linsolve_tmps,linsolve,jac_config,grad_config) | ||
end | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -392,6 +392,57 @@ function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_ | |
return nothing | ||
end | ||
|
||
function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_step, W_index::Int, W_transform=false) | ||
@unpack t,dt,uprev,u,f,p = integrator | ||
@unpack J,W = cache | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is just silly. Why not just pass in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kanav99 @huanglangwen this is something we should follow up on. It's because our original There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it mean There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it should have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But it won't work on OOP. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For OOP it will need to return W. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should have a pseudo-nlsolver kind of thing for algorithms which don't have an actual nlsolver like Rosenbrocks, the struct should have aliases to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that algorithms with nlsolver struct are already well handled for parallel applications. |
||
alg = unwrap_alg(integrator, true) | ||
mass_matrix = integrator.f.mass_matrix | ||
is_compos = integrator.alg isa CompositeAlgorithm | ||
isnewton = alg isa NewtonAlgorithm | ||
|
||
if W_transform && DiffEqBase.has_Wfact_t(f) | ||
f.Wfact_t(W[W_index], u, p, dtgamma, t) | ||
is_compos && (integrator.eigen_est = opnorm(LowerTriangular(W[W_index]), Inf) + inv(dtgamma)) # TODO: better estimate | ||
return nothing | ||
elseif !W_transform && DiffEqBase.has_Wfact(f) | ||
f.Wfact(W[W_index], u, p, dtgamma, t) | ||
if is_compos | ||
opn = opnorm(LowerTriangular(W[W_index]), Inf) | ||
integrator.eigen_est = (opn + one(opn)) / dtgamma # TODO: better estimate | ||
end | ||
return nothing | ||
end | ||
|
||
# fast pass | ||
# we only want to factorize the linear operator once | ||
new_jac = true | ||
new_W = true | ||
if (f isa ODEFunction && islinear(f.f)) || (integrator.alg isa SplitAlgorithms && f isa SplitFunction && islinear(f.f1.f)) | ||
new_jac = false | ||
@goto J2W # Jump to W calculation directly, because we already have J | ||
end | ||
|
||
# check if we need to update J or W | ||
W_dt = isnewton ? cache.nlsolver.cache.W_dt : dt # TODO: RosW | ||
new_jac = isnewton ? do_newJ(integrator, alg, cache, repeat_step) : true | ||
new_W = isnewton ? do_newW(integrator, cache.nlsolver, new_jac, W_dt) : true | ||
|
||
# calculate W | ||
if DiffEqBase.has_jac(f) && f.jac_prototype !== nothing && !ArrayInterface.isstructured(f.jac_prototype) | ||
isnewton || DiffEqBase.update_coefficients!(W[W_index],uprev,p,t) # we will call `update_coefficients!` in NLNewton | ||
@label J2W | ||
W[W_index].transform = W_transform; set_gamma!(W[W_index], dtgamma) | ||
else # concrete W using jacobian from `calc_J!` | ||
new_jac && calc_J!(integrator, cache, is_compos) | ||
new_W && jacobian2W!(W[W_index], mass_matrix, dtgamma, J, W_transform) | ||
end | ||
if isnewton | ||
set_new_W!(cache.nlsolver, new_W) && DiffEqBase.set_W_dt!(cache.nlsolver, dt) | ||
end | ||
new_W && (integrator.destats.nw += 1) | ||
return nothing | ||
end | ||
|
||
function calc_W!(nlsolver, integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_step, W_transform=false) | ||
@unpack t,dt,uprev,u,f,p = integrator | ||
@unpack J,W = nlsolver.cache | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -249,27 +249,39 @@ function perform_step!(integrator,cache::ImplicitEulerExtrapolationCache,repeat_ | |
@unpack t,dt,uprev,u,f,p = integrator | ||
@unpack u_tmp,k_tmp,T,utilde,atmp,dtpropose,cur_order,A = cache | ||
@unpack J,W,uf,tf,linsolve_tmp,jac_config = cache | ||
@unpack u_tmps, k_tmps, linsolve_tmps = cache | ||
|
||
max_order = min(size(T)[1],cur_order+1) | ||
|
||
for i in 1:max_order | ||
dt_temp = dt/(2^(i-1)) # Romberg sequence | ||
calc_W!(integrator, cache, dt_temp, repeat_step) | ||
k_tmp = copy(integrator.fsalfirst) | ||
u_tmp = copy(uprev) | ||
for j in 1:2^(i-1) | ||
linsolve_tmp = dt_temp*k_tmp | ||
cache.linsolve(vec(k_tmp), W, vec(linsolve_tmp), !repeat_step) | ||
@.. k_tmp = -k_tmp | ||
@.. u_tmp = u_tmp + k_tmp | ||
f(k_tmp, u_tmp,p,t+j*dt_temp) | ||
end | ||
let max_order=max_order, uprev=uprev, dt=dt, p=p, t=t, T=T, W=W, | ||
integrator=integrator, cache=cache, repeat_step = repeat_step, | ||
k_tmps=k_tmps, u_tmps=u_tmps | ||
Threads.@threads for i in 1:2 | ||
startIndex = (i == 1) ? 1 : max_order | ||
endIndex = (i == 1) ? max_order - 1 : max_order | ||
for index in startIndex:endIndex | ||
dt_temp = dt/(2^(index-1)) # Romberg sequence | ||
calc_W!(integrator, cache, dt_temp, repeat_step, Threads.threadid()) | ||
k_tmps[Threads.threadid()] = copy(integrator.fsalfirst) | ||
u_tmps[Threads.threadid()] = copy(uprev) | ||
for j in 1:2^(index-1) | ||
@.. linsolve_tmps[Threads.threadid()] = dt_temp*k_tmps[Threads.threadid()] | ||
cache.linsolve[Threads.threadid()](vec(k_tmps[Threads.threadid()]), W[Threads.threadid()], vec(linsolve_tmps[Threads.threadid()]), !repeat_step) | ||
@.. k_tmps[Threads.threadid()] = -k_tmps[Threads.threadid()] | ||
@.. u_tmps[Threads.threadid()] = u_tmps[Threads.threadid()] + k_tmps[Threads.threadid()] | ||
f(k_tmps[Threads.threadid()], u_tmps[Threads.threadid()],p,t+j*dt_temp) | ||
end | ||
|
||
@.. T[i,1] = u_tmp | ||
for j in 2:i | ||
@.. T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1) | ||
@.. T[index,1] = u_tmps[Threads.threadid()] | ||
end | ||
end | ||
for i in 2:max_order | ||
for j in 2:i | ||
@.. T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1) | ||
end | ||
end | ||
end | ||
|
||
integrator.dt = dt | ||
|
||
if integrator.opts.adaptive | ||
|
@@ -332,50 +344,29 @@ function perform_step!(integrator,cache::ImplicitEulerExtrapolationConstantCache | |
|
||
max_order = min(size(T)[1], cur_order+1) | ||
|
||
if integrator.alg.threading == false | ||
for i in 1:max_order | ||
dt_temp = dt/(2^(i-1)) # Romberg sequence | ||
W = calc_W!(integrator, cache, dt_temp, repeat_step) | ||
k_copy = integrator.fsalfirst | ||
u_tmp = uprev | ||
for j in 1:2^(i-1) | ||
k = _reshape(W\-_vec(dt_temp*k_copy), axes(uprev)) | ||
integrator.destats.nsolve += 1 | ||
u_tmp = u_tmp + k | ||
k_copy = f(u_tmp, p, t+j*dt_temp) | ||
end | ||
T[i,1] = u_tmp | ||
# Richardson Extrapolation | ||
for j in 2:i | ||
T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1) | ||
end | ||
end | ||
else | ||
let max_order=max_order, dt=dt, integrator=integrator, cache=cache, repeat_step=repeat_step, | ||
uprev=uprev, T=T | ||
Threads.@threads for i in 1:2 | ||
println(Threads.threadid()) | ||
startIndex = (i==1) ? 1 : max_order | ||
endIndex = (i==1) ? max_order-1 : max_order | ||
for index in startIndex:endIndex | ||
dt_temp = dt/(2^(index-1)) # Romberg sequence | ||
W = calc_W!(integrator, cache, dt_temp, repeat_step) | ||
k_copy = integrator.fsalfirst | ||
u_tmp = uprev | ||
for j in 1:2^(index-1) | ||
k = _reshape(W\-_vec(dt_temp*k_copy), axes(uprev)) | ||
integrator.destats.nsolve += 1 | ||
u_tmp = u_tmp + k | ||
k_copy = f(u_tmp, p, t+j*dt_temp) | ||
end | ||
T[index,1] = u_tmp | ||
let max_order=max_order, dt=dt, integrator=integrator, cache=cache, repeat_step=repeat_step, | ||
uprev=uprev, T=T | ||
Threads.@threads for i in 1:2 | ||
startIndex = (i==1) ? 1 : max_order | ||
endIndex = (i==1) ? max_order-1 : max_order | ||
for index in startIndex:endIndex | ||
dt_temp = dt/(2^(index-1)) # Romberg sequence | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why Romberg? |
||
W = calc_W!(integrator, cache, dt_temp, repeat_step) | ||
k_copy = integrator.fsalfirst | ||
u_tmp = uprev | ||
for j in 1:2^(index-1) | ||
k = _reshape(W\-_vec(dt_temp*k_copy), axes(uprev)) | ||
integrator.destats.nsolve += 1 | ||
u_tmp = u_tmp + k | ||
k_copy = f(u_tmp, p, t+j*dt_temp) | ||
end | ||
T[index,1] = u_tmp | ||
end | ||
end | ||
|
||
for i=2:max_order | ||
for j=2:i | ||
T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1) | ||
end | ||
for i=2:max_order | ||
for j=2:i | ||
T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1) | ||
end | ||
end | ||
end | ||
|
Uh oh!
There was an error while loading. Please reload this page.