Skip to content

Commit 280bddd

Browse files
Merge pull request #872 from saurabhkgp21/implicitEulerParaller
ImplictiEulerExtrapolation parallel
2 parents b4de315 + 62195d2 commit 280bddd

File tree

4 files changed

+142
-40
lines changed

4 files changed

+142
-40
lines changed

src/algorithms.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,15 @@ struct ImplicitEulerExtrapolation{CS,AD,F,F2} <: OrdinaryDiffEqImplicitExtrapola
5353
max_order::Int
5454
min_order::Int
5555
init_order::Int
56+
threading::Bool
5657
end
5758

5859

5960
ImplicitEulerExtrapolation(;chunk_size=0,autodiff=true,diff_type=Val{:forward},
6061
linsolve=DEFAULT_LINSOLVE,
61-
max_order=10,min_order=1,init_order=5) =
62+
max_order=10,min_order=1,init_order=5,threading=true) =
6263
ImplicitEulerExtrapolation{chunk_size,autodiff,
63-
typeof(linsolve),typeof(diff_type)}(linsolve,max_order,min_order,init_order)
64+
typeof(linsolve),typeof(diff_type)}(linsolve,max_order,min_order,init_order,threading)
6465

6566
struct ExtrapolationMidpointDeuflhard <: OrdinaryDiffEqExtrapolationVarOrderVarStepAlgorithm
6667
n_min::Int # Minimal extrapolation order

src/caches/extrapolation_caches.jl

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,12 @@ end
7070
@cache mutable struct ImplicitEulerExtrapolationCache{uType,rateType,arrayType,dtType,JType,WType,F,JCType,GCType,uNoUnitsType,TFType,UFType} <: OrdinaryDiffEqMutableCache
7171
uprev::uType
7272
u_tmp::uType
73+
u_tmps::Array{uType,1}
7374
utilde::uType
7475
tmp::uType
7576
atmp::uNoUnitsType
7677
k_tmp::rateType
78+
k_tmps::Array{rateType,1}
7779
dtpropose::dtType
7880
T::arrayType
7981
cur_order::Int
@@ -89,7 +91,8 @@ end
8991
tf::TFType
9092
uf::UFType
9193
linsolve_tmp::rateType
92-
linsolve::F
94+
linsolve_tmps::Array{rateType,1}
95+
linsolve::Array{F,1}
9396
jac_config::JCType
9497
grad_config::GCType
9598
end
@@ -121,10 +124,21 @@ end
121124

122125
function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
123126
u_tmp = similar(u)
127+
u_tmps = Array{typeof(u_tmp),1}(undef, Threads.nthreads())
128+
129+
for i=1:Threads.nthreads()
130+
u_tmps[i] = zero(u_tmp)
131+
end
132+
124133
utilde = similar(u)
125134
tmp = similar(u)
126-
k = zero(rate_prototype)
127135
k_tmp = zero(rate_prototype)
136+
k_tmps = Array{typeof(k_tmp),1}(undef, Threads.nthreads())
137+
138+
for i=1:Threads.nthreads()
139+
k_tmps[i] = zero(rate_prototype)
140+
end
141+
128142
cur_order = max(alg.init_order, alg.min_order)
129143
dtpropose = zero(dt)
130144
T = Array{typeof(u),2}(undef, alg.max_order, alg.max_order)
@@ -143,22 +157,36 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni
143157
du2 = zero(rate_prototype)
144158

145159
if DiffEqBase.has_jac(f) && !DiffEqBase.has_Wfact(f) && f.jac_prototype !== nothing
146-
W = WOperator(f, dt, true)
160+
W_el = WOperator(f, dt, true)
147161
J = nothing # is J = W.J better?
148162
else
149163
J = false .* vec(rate_prototype) .* vec(rate_prototype)' # uEltype?
150-
W = similar(J)
164+
W_el = similar(J)
165+
end
166+
W = Array{typeof(W_el),1}(undef, Threads.nthreads())
167+
for i=1:Threads.nthreads()
168+
W[i] = zero(W_el)
151169
end
152170
tf = DiffEqDiffTools.TimeGradientWrapper(f,uprev,p)
153171
uf = DiffEqDiffTools.UJacobianWrapper(f,t,p)
154172
linsolve_tmp = zero(rate_prototype)
155-
linsolve = alg.linsolve(Val{:init},uf,u)
173+
linsolve_tmps = Array{typeof(linsolve_tmp),1}(undef, Threads.nthreads())
174+
175+
for i=1:Threads.nthreads()
176+
linsolve_tmps[i] = zero(rate_prototype)
177+
end
178+
179+
linsolve_el = alg.linsolve(Val{:init},uf,u)
180+
linsolve = Array{typeof(linsolve_el),1}(undef, Threads.nthreads())
181+
for i=1:Threads.nthreads()
182+
linsolve[i] = alg.linsolve(Val{:init},uf,u)
183+
end
156184
grad_config = build_grad_config(alg,f,tf,du1,t)
157185
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,du1,du2)
158186

159187

160-
ImplicitEulerExtrapolationCache(uprev,u_tmp,utilde,tmp,atmp,k_tmp,dtpropose,T,cur_order,work,A,step_no,
161-
du1,du2,J,W,tf,uf,linsolve_tmp,linsolve,jac_config,grad_config)
188+
ImplicitEulerExtrapolationCache(uprev,u_tmp,u_tmps,utilde,tmp,atmp,k_tmp,k_tmps,dtpropose,T,cur_order,work,A,step_no,
189+
du1,du2,J,W,tf,uf,linsolve_tmp,linsolve_tmps,linsolve,jac_config,grad_config)
162190
end
163191

164192

src/derivative_utils.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,57 @@ function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_
392392
return nothing
393393
end
394394

395+
function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_step, W_index::Int, W_transform=false)
396+
@unpack t,dt,uprev,u,f,p = integrator
397+
@unpack J,W = cache
398+
alg = unwrap_alg(integrator, true)
399+
mass_matrix = integrator.f.mass_matrix
400+
is_compos = integrator.alg isa CompositeAlgorithm
401+
isnewton = alg isa NewtonAlgorithm
402+
403+
if W_transform && DiffEqBase.has_Wfact_t(f)
404+
f.Wfact_t(W[W_index], u, p, dtgamma, t)
405+
is_compos && (integrator.eigen_est = opnorm(LowerTriangular(W[W_index]), Inf) + inv(dtgamma)) # TODO: better estimate
406+
return nothing
407+
elseif !W_transform && DiffEqBase.has_Wfact(f)
408+
f.Wfact(W[W_index], u, p, dtgamma, t)
409+
if is_compos
410+
opn = opnorm(LowerTriangular(W[W_index]), Inf)
411+
integrator.eigen_est = (opn + one(opn)) / dtgamma # TODO: better estimate
412+
end
413+
return nothing
414+
end
415+
416+
# fast pass
417+
# we only want to factorize the linear operator once
418+
new_jac = true
419+
new_W = true
420+
if (f isa ODEFunction && islinear(f.f)) || (integrator.alg isa SplitAlgorithms && f isa SplitFunction && islinear(f.f1.f))
421+
new_jac = false
422+
@goto J2W # Jump to W calculation directly, because we already have J
423+
end
424+
425+
# check if we need to update J or W
426+
W_dt = isnewton ? cache.nlsolver.cache.W_dt : dt # TODO: RosW
427+
new_jac = isnewton ? do_newJ(integrator, alg, cache, repeat_step) : true
428+
new_W = isnewton ? do_newW(integrator, cache.nlsolver, new_jac, W_dt) : true
429+
430+
# calculate W
431+
if DiffEqBase.has_jac(f) && f.jac_prototype !== nothing && !ArrayInterface.isstructured(f.jac_prototype)
432+
isnewton || DiffEqBase.update_coefficients!(W[W_index],uprev,p,t) # we will call `update_coefficients!` in NLNewton
433+
@label J2W
434+
W[W_index].transform = W_transform; set_gamma!(W[W_index], dtgamma)
435+
else # concrete W using jacobian from `calc_J!`
436+
new_jac && calc_J!(integrator, cache, is_compos)
437+
new_W && jacobian2W!(W[W_index], mass_matrix, dtgamma, J, W_transform)
438+
end
439+
if isnewton
440+
set_new_W!(cache.nlsolver, new_W) && DiffEqBase.set_W_dt!(cache.nlsolver, dt)
441+
end
442+
new_W && (integrator.destats.nw += 1)
443+
return nothing
444+
end
445+
395446
function calc_W!(nlsolver, integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_step, W_transform=false)
396447
@unpack t,dt,uprev,u,f,p = integrator
397448
@unpack J,W = nlsolver.cache

src/perform_step/extrapolation_perform_step.jl

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -249,27 +249,39 @@ function perform_step!(integrator,cache::ImplicitEulerExtrapolationCache,repeat_
249249
@unpack t,dt,uprev,u,f,p = integrator
250250
@unpack u_tmp,k_tmp,T,utilde,atmp,dtpropose,cur_order,A = cache
251251
@unpack J,W,uf,tf,linsolve_tmp,jac_config = cache
252+
@unpack u_tmps, k_tmps, linsolve_tmps = cache
252253

253254
max_order = min(size(T)[1],cur_order+1)
254255

255-
for i in 1:max_order
256-
dt_temp = dt/(2^(i-1)) # Romberg sequence
257-
calc_W!(integrator, cache, dt_temp, repeat_step)
258-
k_tmp = copy(integrator.fsalfirst)
259-
u_tmp = copy(uprev)
260-
for j in 1:2^(i-1)
261-
linsolve_tmp = dt_temp*k_tmp
262-
cache.linsolve(vec(k_tmp), W, vec(linsolve_tmp), !repeat_step)
263-
@.. k_tmp = -k_tmp
264-
@.. u_tmp = u_tmp + k_tmp
265-
f(k_tmp, u_tmp,p,t+j*dt_temp)
266-
end
256+
let max_order=max_order, uprev=uprev, dt=dt, p=p, t=t, T=T, W=W,
257+
integrator=integrator, cache=cache, repeat_step = repeat_step,
258+
k_tmps=k_tmps, u_tmps=u_tmps
259+
Threads.@threads for i in 1:2
260+
startIndex = (i == 1) ? 1 : max_order
261+
endIndex = (i == 1) ? max_order - 1 : max_order
262+
for index in startIndex:endIndex
263+
dt_temp = dt/(2^(index-1)) # Romberg sequence
264+
calc_W!(integrator, cache, dt_temp, repeat_step, Threads.threadid())
265+
k_tmps[Threads.threadid()] = copy(integrator.fsalfirst)
266+
u_tmps[Threads.threadid()] = copy(uprev)
267+
for j in 1:2^(index-1)
268+
@.. linsolve_tmps[Threads.threadid()] = dt_temp*k_tmps[Threads.threadid()]
269+
cache.linsolve[Threads.threadid()](vec(k_tmps[Threads.threadid()]), W[Threads.threadid()], vec(linsolve_tmps[Threads.threadid()]), !repeat_step)
270+
@.. k_tmps[Threads.threadid()] = -k_tmps[Threads.threadid()]
271+
@.. u_tmps[Threads.threadid()] = u_tmps[Threads.threadid()] + k_tmps[Threads.threadid()]
272+
f(k_tmps[Threads.threadid()], u_tmps[Threads.threadid()],p,t+j*dt_temp)
273+
end
267274

268-
@.. T[i,1] = u_tmp
269-
for j in 2:i
270-
@.. T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1)
275+
@.. T[index,1] = u_tmps[Threads.threadid()]
276+
end
277+
end
278+
for i in 2:max_order
279+
for j in 2:i
280+
@.. T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1)
281+
end
271282
end
272283
end
284+
273285
integrator.dt = dt
274286

275287
if integrator.opts.adaptive
@@ -332,23 +344,33 @@ function perform_step!(integrator,cache::ImplicitEulerExtrapolationConstantCache
332344

333345
max_order = min(size(T)[1], cur_order+1)
334346

335-
for i in 1:max_order
336-
dt_temp = dt/(2^(i-1)) # Romberg sequence
337-
W = calc_W!(integrator, cache, dt_temp, repeat_step)
338-
k_copy = integrator.fsalfirst
339-
u_tmp = uprev
340-
for j in 1:2^(i-1)
341-
k = _reshape(W\-_vec(dt_temp*k_copy), axes(uprev))
342-
integrator.destats.nsolve += 1
343-
u_tmp = u_tmp + k
344-
k_copy = f(u_tmp, p, t+j*dt_temp)
347+
let max_order=max_order, dt=dt, integrator=integrator, cache=cache, repeat_step=repeat_step,
348+
uprev=uprev, T=T
349+
Threads.@threads for i in 1:2
350+
startIndex = (i==1) ? 1 : max_order
351+
endIndex = (i==1) ? max_order-1 : max_order
352+
for index in startIndex:endIndex
353+
dt_temp = dt/(2^(index-1)) # Romberg sequence
354+
W = calc_W!(integrator, cache, dt_temp, repeat_step)
355+
k_copy = integrator.fsalfirst
356+
u_tmp = uprev
357+
for j in 1:2^(index-1)
358+
k = _reshape(W\-_vec(dt_temp*k_copy), axes(uprev))
359+
integrator.destats.nsolve += 1
360+
u_tmp = u_tmp + k
361+
k_copy = f(u_tmp, p, t+j*dt_temp)
362+
end
363+
T[index,1] = u_tmp
364+
end
345365
end
346-
T[i,1] = u_tmp
347-
# Richardson Extrapolation
348-
for j in 2:i
349-
T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1)
366+
367+
for i=2:max_order
368+
for j=2:i
369+
T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1)
370+
end
350371
end
351372
end
373+
352374
integrator.destats.nf += 2^(max_order) - 1
353375
integrator.dt = dt
354376

@@ -391,9 +413,9 @@ function perform_step!(integrator,cache::ImplicitEulerExtrapolationConstantCache
391413

392414
# Use extrapolated value of u
393415
integrator.u = T[cache.cur_order, cache.cur_order]
394-
k = f(integrator.u, p, t+dt)
416+
k_temp = f(integrator.u, p, t+dt)
395417
integrator.destats.nf += 1
396-
integrator.fsallast = k
418+
integrator.fsallast = k_temp
397419
integrator.k[1] = integrator.fsalfirst
398420
integrator.k[2] = integrator.fsallast
399421
end

0 commit comments

Comments
 (0)