A memory-efficient implementation of attention mechanism combining ring buffers with hierarchical tiled computation. This implementation achieves efficient memory usage and computation through specialized optimizations for encoder-decoder transformers.
The implementation uses three ring buffers for efficient memory management:
- Forward KV Cache: For inference and non-gradient operations
- Backward KV Cache: For gradient computation during training
- Decoder Feedback Buffer: For decoder self-attention output feedback
Each ring buffer operates with:
- Read and write pointers for active data window management
- Cache line aligned data width for efficient prefetching
- Sequential access patterns for optimal memory bandwidth
- Fixed size buffer with automatic pointer wraparound
-
Hardware Efficiency
- Cache-aligned data access
- Efficient data prefetching
- Sequential memory access patterns
- Fixed memory footprint
-
KV Cache Management
- Simple pointer-based management
- Linear memory addressing
- Automatic wraparound handling
- Efficient encoder-decoder interaction
-
Decoder Optimizations
- Selective matrix multiplication masking
- Skip unnecessary data loading
- Efficient right-shift operation
- Sequential data arrangement
Inspired by FlashAttention, our implementation uses a two-level tiling hierarchy:
-
Upper Level (Global Management)
- Tile scheduling and coordination
- Memory access pattern optimization
- Inter-tile dependency handling
-
Lower Level (Processing)
- Dense tiles: Full computation for valid regions
- Triangle tiles: Special handling for diagonal blocks
- Skip tiles: Automatic future token skipping
-
Tile Processing Orders
- Q-major: Optimizes query tile reuse
- K-major: Optimizes key/value tile reuse
config = RingBufferConfig(
# Required xFormers fields
num_heads=8,
head_dim=64,
dropout=0.1,
# Ring buffer configuration
buffer_size=2048, # Size of ring buffers
block_size=128, # Processing block size
tile_size=32, # Tile size for tiled attention
head_tile_size=4, # Number of heads to process together
prefill_size=512, # Size of prefill cache for skip calculation
causal=True # Whether to use causal masking
)
- Model Initialization
# Create model and move to GPU
model = RingBufferAttention(config).cuda()
# Set tile processing order
model.tiling_processor.tile_order = TileOrder.Q_MAJOR # or K_MAJOR
- Encoder Phase
encoder_output, buffer_state = model(
q=query,
k=key,
v=value,
requires_grad=True
)
- Decoder Phase
decoder_output = model(
q=decoder_query,
k=decoder_key,
v=decoder_value,
position=current_pos,
is_decoder=True,
att_mask=attention_mask, # Optional attention mask
key_padding_mask=padding_mask, # Optional padding mask
needs_weights=False, # Whether to return attention weights
output_attentions=False # Whether to return attention outputs
)
- Q-major Ordering (Default)
model.tiling_processor.tile_order = TileOrder.Q_MAJOR
- Optimizes query tile reuse
- Preferred when query tiles > key tiles
- Reduces query access bandwidth
- K-major Ordering
model.tiling_processor.tile_order = TileOrder.K_MAJOR
- Optimizes key/value tile reuse
- Preferred when key tiles > query tiles
- Reduces key/value access bandwidth
- O(buffer_size * head_dim) per ring buffer
- Constant memory usage independent of sequence length
- Additional workspace proportional to tile sizes
- O(n * d) for n tokens and d head dimension
- Reduced memory bandwidth via tiling
- Efficient future token skipping
- Optional attention weight computation
- PyTorch >= 1.8.0
- xFormers library
- CUDA-capable GPU (recommended)
- Registered as standard attention module
- Follows xFormers factory patterns
- Compatible with xFormers configuration system
# Initialize model for training
model = RingBufferAttention(RingBufferConfig())
model.tiling_processor.tile_order = TileOrder.Q_MAJOR
model.train()
# Forward pass with gradient computation
output, state = model(query, key, value, requires_grad=True)
# Decoder feedback step
decoder_output = model(
dec_query, dec_key, dec_value,
position=pos,
is_decoder=True,
buffer_state=state
)
# Initialize model for inference
model = RingBufferAttention(RingBufferConfig())
model.eval()
# Forward pass without gradient computation
with torch.no_grad():
output, state = model(query, key, value, requires_grad=False)
Areas for potential improvements:
- Advanced tile processing strategies
- Hardware-specific optimizations
- Auto-tuning systems
- Performance profiling tools
- Memory access pattern analysis
This implementation builds upon:
- FlashAttention paper
- xFormers library
- Transformer optimization techniques
- BSD 3-Clause License - See LICENSE file for details
For questions and support, please open an issue on the GitHub repository.