-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add non-causal attention mask to pallas attention #16936
base: main
Are you sure you want to change the base?
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! This might be a reasonable approach for short sequences, but ultimately the whole point of flash attention is to avoid materializing any values that are quadratic in sequence length, and a full mask is just that. IMO having arguments for segment ids and constructing the mask on the fly would be a better solution.
…into pallas_attention_mask
Ah good point - updated the PR to take in a pad mask and seg ids and compute the attention mask on the fly. |
Thanks for the change! FYI we have a similar PR here: jax-ml/jax-triton#193. It's a little bit out of date since Pallas has since moved here but I think the logic is generally good. Perhaps you could work with @wang12tao to get some combination in? |
Of course, happy to collab with @wang12tao! |
@wang12tao Bumping this |
Andy, thanks for the heads up, I will update jax-ml/jax-triton#193 by the end of this week and then you could decide if it meets your requirement |
Hi Andy, 5a578cb is added for segment_ids support. You could consider add padding mask support, thanks a lot! |
Not sure if this is actually wanted (or maybe people would prefer to see segment ids as an additional argument so the attention mask is also computed in the kernel).