Less memory/communication-intensive pmap adjoint #1188
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The current pmap adjoint collects all the pullbacks from the workers on the root process, which is very memory-intensive. Instead, we'd like to persist the pullbacks on the workers, and use them in the backward pass. This requires the backward pass to use the same matching of workers to array indices as the forward pass, as opposed to load balancing.
Note: I had originally made a PR to ChainRules (JuliaDiff/ChainRules.jl#566), but there were a lot of tricky details so I've decided to instead PR directly to Zygote. This will hopefully be easier to review because there is an already existing
pmap
rule which we can directly compare to; hopefully I can re-open the ChainRules PR in the future once this has been tried and tested in Zygote.What this rule does
In the forward pass,
pmap
dynamically assigns workers to array elements, where each worker evaluates the functionf
(the stuff usually done bypmap
is shown in black in the image), and sends the value back to the root. In the old rule, this process was simply modified by replacingf
with_pullback
, so that both the primal and the pullback were sent back to the root. However, since backward differentiation is memory-intensive, the pullback can potentially be much larger than the primal if the functionf
has large intermediate values.This can cause order-of-magnitude slowdown in some cases, see e.g.
benchmark (4).pdf (a toy example; I ran this benchmark with the ChainRules PR but will update the benchmark for this ported Zygote rule) and SciML/SciMLSensitivity.jl#612.
Thus, this rule instead persists the pullbacks on the local processors without communicating them back to the root. To use these pullbacks in the backward pass, we must first remember the matching that
pmap
made in the forward pass. Theunique_IDs
andpositions
arrays are used to save this matching, as shown in blue in the image. Then, we must explicitly use the exact same matching in the backward pass, sending each worker all of the dual computations that it needs to do (note that these are all sent together, sobatch_size > 1
is efficiently handled).Disadvantages
pmap
has an option for that in the forward pass.pmap
rule in some cases. On the other hand, I think that one normally does apmap
only when the functionf
is very expensive, enough to justify communicating back and forth between processors.Use of DistributedArrays
In my understanding, worker processes don't know anything about the scope on the global processor, so to store the pullbacks locally one needs to reinvent scope by e.g. using a global dictionary with a counter. I saw that this same logic was done in DistributedArrays, so I thought it be nicer to add that as a dependency rather than write that messy logic here.