Skip to content

Commit

Permalink
Add support for sliced attention
Browse files Browse the repository at this point in the history
  • Loading branch information
daemon committed Nov 30, 2022
1 parent 0b03e03 commit 54c1285
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions daam/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,40 @@ def _unravel_attn(self, x):
maps = torch.stack(maps, 0) # shape: (tokens, heads, height, width)
return maps.permute(1, 0, 2, 3).contiguous() # shape: (heads, tokens, height, width)

def _hooked_sliced_attention(hk_self, self, query, key, value, sequence_length, dim):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = torch.baddbmm(
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query[start_idx:end_idx],
key[start_idx:end_idx].transpose(-1, -2),
beta=0,
alpha=self.scale,
)
attn_slice = attn_slice.softmax(dim=-1)
factor = int(math.sqrt(hk_self.latent_hw // attn_slice.shape[1]))

if attn_slice.shape[-1] == hk_self.context_size:
# shape: (batch_size, 64 // factor, 64 // factor, 77)
maps = hk_self._unravel_attn(attn_slice)

for head_idx, heatmap in enumerate(maps):
hk_self.heat_maps.update(factor, hk_self.layer_idx, head_idx, heatmap)

attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])

hidden_states[start_idx:end_idx] = attn_slice

# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states

def _hooked_attention(hk_self, self, query, key, value):
"""
Monkey-patched version of :py:func:`.CrossAttention._attention` to capture attentions and aggregate them.
Expand Down Expand Up @@ -264,6 +298,7 @@ def _hooked_attention(hk_self, self, query, key, value):

def _hook_impl(self):
self.monkey_patch('_attention', self._hooked_attention)
self.monkey_patch('_sliced_attention', self._hooked_sliced_attention)

@property
def num_heat_maps(self):
Expand Down

0 comments on commit 54c1285

Please sign in to comment.