forked from LittleLittleCloud/Torchsharp-stable-diffusion-2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
AttentionMaskConverter.cs
124 lines (107 loc) · 4.28 KB
/
AttentionMaskConverter.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
using static TorchSharp.torch.nn;
using static TorchSharp.torch;
using TorchSharp.Modules;
using TorchSharp;
namespace SD;
public class AttentionMaskConverter
{
private readonly bool is_casual;
private readonly int? sliding_window;
public AttentionMaskConverter(bool is_casual, int? sliding_window)
{
this.is_casual = is_casual;
this.sliding_window = sliding_window;
}
public Tensor? ToCasual4D(
int batch_size,
int query_length,
int key_value_length,
ScalarType dtype,
Device device)
{
if (!is_casual)
{
throw new ArgumentException("This is not a casual mask");
}
long[] input_shape = [batch_size, query_length];
var past_key_values_length = key_value_length - query_length;
// create causal mask
// [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
Tensor? casual_4d_mask = null;
if (query_length > 1 || this.sliding_window is int window)
{
casual_4d_mask = MakeCasualMask(input_shape, dtype, device, past_key_values_length, this.sliding_window);
}
return casual_4d_mask;
}
public static Tensor MakeCasualMask(
long[] input_ids_shape,
ScalarType dtype,
Device device,
int past_key_values_length = 0,
int? sliding_window = null)
{
// Make causal mask used for bi-directional self-attention.
var bsz = input_ids_shape[0];
var tgt_len = input_ids_shape[1];
var min = dtype switch
{
ScalarType.Float32 => torch.finfo(dtype).min,
ScalarType.Float64 => torch.finfo(dtype).min,
ScalarType.Float16 => -65504.0,
_ => throw new ArgumentException("Invalid dtype"),
};
var mask = torch.full([tgt_len, tgt_len], min, dtype: dtype, device: device);
var mask_cond = torch.arange(tgt_len, device: device);
mask.masked_fill_(mask_cond < (mask_cond + 1).view(tgt_len, 1), 0);
mask = mask.to(dtype);
if (past_key_values_length > 0)
{
mask = torch.cat([torch.zeros([tgt_len, past_key_values_length], dtype: dtype, device: device), mask], dim: -1);
}
if (sliding_window is int window)
{
var diagonal = past_key_values_length - window - 1;
var context_mask = torch.tril(torch.ones([tgt_len, tgt_len], dtype: ScalarType.Bool, device: device), diagonal: diagonal);
mask = mask.masked_fill(context_mask, min);
}
// return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
return mask.unsqueeze(0).unsqueeze(0).expand(bsz, 1, tgt_len, tgt_len + past_key_values_length);
}
/// <summary>
/// Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
/// </summary>
/// <param name="input_shape">The input shape should be a tuple that defines `(batch_size, query_length)`.</param>
public static Tensor? Create4DCasualAttentionMask(
long[] input_shape,
ScalarType dtype,
Device device,
int past_key_values_length = 0,
int? sliding_window = null)
{
var batch_size = (int)input_shape[0];
var query_length = (int)input_shape[1];
var converter = new AttentionMaskConverter(is_casual: true, sliding_window: sliding_window);
var key_value_length = past_key_values_length + query_length;
return converter.ToCasual4D(batch_size, query_length, key_value_length, dtype, device);
}
public static Tensor ExpandMask(
Tensor mask,
ScalarType dtype,
int? tgt_len = null)
{
var bsz = (int)mask.shape[0];
var src_len = (int)mask.shape[1];
tgt_len = tgt_len ?? src_len;
var expanded_mask = mask.unsqueeze(1).unsqueeze(1).expand(bsz, 1, tgt_len.Value, src_len).to(dtype);
var inverted_mask = 1.0 - expanded_mask;
var min = dtype switch
{
ScalarType.Float32 => torch.finfo(dtype).min,
ScalarType.Float64 => torch.finfo(dtype).min,
ScalarType.Float16 => -65504.0,
_ => throw new ArgumentException("Invalid dtype"),
};
return inverted_mask.masked_fill(inverted_mask.to(ScalarType.Bool), min);
}
}