Skip to content

Commit

Permalink
slip in value residual learning for pairformer stack
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 2, 2024
1 parent 9a6f2da commit 01a4ab2
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 12 deletions.
41 changes: 30 additions & 11 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,10 @@ def forward(
self,
x: Float['... n d'],
**kwargs
) -> Float['... n d']:
) -> (
Float['... n d'] |
tuple[Float['... n d'] | Any]
):

x = self.norm(x)
return self.fn(x, **kwargs)
Expand Down Expand Up @@ -1395,6 +1398,7 @@ def __init__(
dropout_row_prob = 0.25,
num_register_tokens = 0,
checkpoint = False,
add_value_residual = False,
pairwise_block_kwargs: dict = dict(),
pair_bias_attn_kwargs: dict = dict()
):
Expand Down Expand Up @@ -1430,6 +1434,8 @@ def __init__(

self.layers = layers

self.add_value_residual = add_value_residual

# checkpointing

self.checkpoint = checkpoint
Expand Down Expand Up @@ -1458,6 +1464,8 @@ def to_layers(

) -> Tuple[Float['b n ds'], Float['b n n dp']]:

value_residual = None

for _ in range(self.recurrent_depth):
for (
pairwise_block,
Expand All @@ -1467,7 +1475,13 @@ def to_layers(

pairwise_repr = pairwise_block(pairwise_repr = pairwise_repr, mask = mask)

single_repr = pair_bias_attn(single_repr, pairwise_repr = pairwise_repr, mask = mask) + single_repr
attn_out, attn_values = pair_bias_attn(single_repr, pairwise_repr = pairwise_repr, mask = mask, return_values = True, value_residual = value_residual)

single_repr = single_repr + attn_out

if self.add_value_residual:
value_residual = default(value_residual, attn_values)

single_repr = single_transition(single_repr) + single_repr

return single_repr, pairwise_repr
Expand All @@ -1482,30 +1496,35 @@ def to_checkpointed_layers(

) -> Tuple[Float['b n ds'], Float['b n n dp']]:

inputs = (single_repr, pairwise_repr, mask)
inputs = (single_repr, pairwise_repr, mask, None)

def pairwise_block_wrapper(layer):
@wraps(layer)
def inner(inputs, *args, **kwargs):
single_repr, pairwise_repr, mask = inputs
single_repr, pairwise_repr, mask, maybe_value_residual = inputs
pairwise_repr = layer(pairwise_repr = pairwise_repr, mask = mask)
return single_repr, pairwise_repr, mask
return single_repr, pairwise_repr, mask, maybe_value_residual
return inner

def pair_bias_attn_wrapper(layer):
@wraps(layer)
def inner(inputs, *args, **kwargs):
single_repr, pairwise_repr, mask = inputs
single_repr = layer(single_repr, pairwise_repr = pairwise_repr, mask = mask) + single_repr
return single_repr, pairwise_repr, mask
single_repr, pairwise_repr, mask, maybe_value_residual = inputs
attn_out, attn_values = layer(single_repr, pairwise_repr = pairwise_repr, mask = mask, return_values = True, value_residual = maybe_value_residual)
single_repr = single_repr + attn_out

if self.add_value_residual:
maybe_value_residual = default(maybe_value_residual, attn_values)

return single_repr, pairwise_repr, mask, maybe_value_residual
return inner

def single_transition_wrapper(layer):
@wraps(layer)
def inner(inputs, *args, **kwargs):
single_repr, pairwise_repr, mask = inputs
single_repr, pairwise_repr, mask, maybe_value_residual = inputs
single_repr = layer(single_repr) + single_repr
return single_repr, pairwise_repr, mask
return single_repr, pairwise_repr, mask, maybe_value_residual
return inner

wrapped_layers = []
Expand All @@ -1524,7 +1543,7 @@ def inner(inputs, *args, **kwargs):
for layer in wrapped_layers:
inputs = checkpoint(layer, inputs)

single_repr, pairwise_repr, _ = inputs
single_repr, pairwise_repr, *_ = inputs
return single_repr, pairwise_repr

@typecheck
Expand Down
5 changes: 4 additions & 1 deletion tests/test_af3.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,12 @@ def test_centre_random_augmentation():
@pytest.mark.parametrize('checkpoint', (True, False))
@pytest.mark.parametrize('recurrent_depth', (1, 2))
@pytest.mark.parametrize('enable_attn_softclamp', (True, False))
@pytest.mark.parametrize('add_value_residual', (True, False))
def test_pairformer(
checkpoint,
recurrent_depth,
enable_attn_softclamp
enable_attn_softclamp,
add_value_residual
):
single = torch.randn(2, 16, 384).requires_grad_()
pairwise = torch.randn(2, 16, 16, 128).requires_grad_()
Expand All @@ -316,6 +318,7 @@ def test_pairformer(
num_register_tokens = 4,
recurrent_depth = recurrent_depth,
checkpoint = checkpoint,
add_value_residual = add_value_residual,
pair_bias_attn_kwargs = dict(
enable_attn_softclamp = enable_attn_softclamp
)
Expand Down

0 comments on commit 01a4ab2

Please sign in to comment.