Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Less memory/communication-intensive pmap adjoint #1188

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

gaurav-arya
Copy link

@gaurav-arya gaurav-arya commented Mar 21, 2022

The current pmap adjoint collects all the pullbacks from the workers on the root process, which is very memory-intensive. Instead, we'd like to persist the pullbacks on the workers, and use them in the backward pass. This requires the backward pass to use the same matching of workers to array indices as the forward pass, as opposed to load balancing.

Note: I had originally made a PR to ChainRules (JuliaDiff/ChainRules.jl#566), but there were a lot of tricky details so I've decided to instead PR directly to Zygote. This will hopefully be easier to review because there is an already existing pmap rule which we can directly compare to; hopefully I can re-open the ChainRules PR in the future once this has been tried and tested in Zygote.

What this rule does

In the forward pass, pmap dynamically assigns workers to array elements, where each worker evaluates the function f (the stuff usually done by pmap is shown in black in the image), and sends the value back to the root. In the old rule, this process was simply modified by replacing f with _pullback, so that both the primal and the pullback were sent back to the root. However, since backward differentiation is memory-intensive, the pullback can potentially be much larger than the primal if the function f has large intermediate values.

This can cause order-of-magnitude slowdown in some cases, see e.g.
benchmark (4).pdf (a toy example; I ran this benchmark with the ChainRules PR but will update the benchmark for this ported Zygote rule) and SciML/SciMLSensitivity.jl#612.

Thus, this rule instead persists the pullbacks on the local processors without communicating them back to the root. To use these pullbacks in the backward pass, we must first remember the matching that pmap made in the forward pass. The unique_IDs and positions arrays are used to save this matching, as shown in blue in the image. Then, we must explicitly use the exact same matching in the backward pass, sending each worker all of the dual computations that it needs to do (note that these are all sent together, so batch_size > 1 is efficiently handled).

IMG_0325 (1)

Disadvantages

  • Less robust. If a worker dies between the forward and backward pass, the pullback is lost. Also, in the backward pass, tasks aren't retried if they error, whereas pmap has an option for that in the forward pass.
  • Possible overhead. The logic here can cause some overhead. Would welcome ideas to reduce the overhead / do different things in special cases. Perhaps we'd want to somehow provide an option to use the old pmap rule in some cases. On the other hand, I think that one normally does a pmap only when the function f is very expensive, enough to justify communicating back and forth between processors.

Use of DistributedArrays

In my understanding, worker processes don't know anything about the scope on the global processor, so to store the pullbacks locally one needs to reinvent scope by e.g. using a global dictionary with a counter. I saw that this same logic was done in DistributedArrays, so I thought it be nicer to add that as a dependency rather than write that messy logic here.

@ToucheSir
Copy link
Member

In the interests of not having a substantial amount of work like this sitting around, do you have some numbers on the performance benefits? If they're good, I don't see any major objections to getting it merged :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants