Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add non-causal attention mask to pallas attention #16936

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

andyehrenberg
Copy link

Not sure if this is actually wanted (or maybe people would prefer to see segment ids as an additional argument so the attention mask is also computed in the kernel).

@google-cla
Copy link

google-cla bot commented Aug 2, 2023

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Collaborator

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! This might be a reasonable approach for short sequences, but ultimately the whole point of flash attention is to avoid materializing any values that are quadratic in sequence length, and a full mask is just that. IMO having arguments for segment ids and constructing the mask on the fly would be a better solution.

@andyehrenberg
Copy link
Author

Thanks for the PR! This might be a reasonable approach for short sequences, but ultimately the whole point of flash attention is to avoid materializing any values that are quadratic in sequence length, and a full mask is just that. IMO having arguments for segment ids and constructing the mask on the fly would be a better solution.

Ah good point - updated the PR to take in a pad mask and seg ids and compute the attention mask on the fly.

@sharadmv
Copy link
Collaborator

sharadmv commented Aug 4, 2023

Thanks for the change! FYI we have a similar PR here: jax-ml/jax-triton#193. It's a little bit out of date since Pallas has since moved here but I think the logic is generally good. Perhaps you could work with @wang12tao to get some combination in?

@andyehrenberg andyehrenberg marked this pull request as ready for review August 9, 2023 12:28
@andyehrenberg
Copy link
Author

Thanks for the change! FYI we have a similar PR here: jax-ml/jax-triton#193. It's a little bit out of date since Pallas has since moved here but I think the logic is generally good. Perhaps you could work with @wang12tao to get some combination in?

Of course, happy to collab with @wang12tao!

@andyehrenberg
Copy link
Author

@wang12tao Bumping this

@wang12tao
Copy link
Contributor

Andy, thanks for the heads up, I will update jax-ml/jax-triton#193 by the end of this week and then you could decide if it meets your requirement

@wang12tao
Copy link
Contributor

Hi Andy, 5a578cb is added for segment_ids support. You could consider add padding mask support, thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants