Skip to content

Commit

Permalink
Add weighted mode support
Browse files Browse the repository at this point in the history
  • Loading branch information
avbokovoy committed Jan 24, 2025
1 parent bd507dc commit 53a0d13
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no

// Define {{ emb_weight_type }} kernel invocation macro
#define X(DeviceOnly, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \
{%-if is_rocm and not nobag and not weighted %}
{%-if is_rocm and not nobag %}
const int32_t num_uint4_loads_per_row = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, SparseType::{{ emb_weight_type }}, row_alignment), sizeof(uint4)); \
constexpr int32_t NumUint4LoadsPerRow = MaxNum128BRows * 128 / sizeof(uint4); \
const int32_t num_packed_bags = NumUint4LoadsPerRow > num_uint4_loads_per_row && !std::is_same_v<output_t, uint8_t> && SparseType::{{ emb_weight_type }} != SparseType::FP32 ? NumUint4LoadsPerRow / num_uint4_loads_per_row : 1; \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
__shared__ AllBuffers buffers;

{% if weighted %}
typedef float AllIndiceWeights[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight];
typedef float AllIndiceWeights[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][NumUint4LoadsPerRow];
__shared__ AllIndiceWeights buffers_indice_weights;
{% endif %}

Expand Down Expand Up @@ -187,7 +187,7 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
uint4 data = valid ? row_data_v[inner_i] : zeros;
buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_idx] = data;
{% if weighted %}
buffers_indice_weights[warp_idx][i][input_row_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0;
buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0;
{% endif %}
}
}
Expand Down Expand Up @@ -218,7 +218,7 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
cp_async_zfill_cg<sizeof(uint4)>(&buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_idx], &row[row_load_idx], valid);

{% if weighted %}
buffers_indice_weights[warp_idx][i][input_row_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0;
buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0;
{% endif %}
}
{%- if is_rocm %}
Expand Down Expand Up @@ -253,7 +253,7 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
{% endif %}

{% if weighted %}
float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx];
float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_idx];
{% endif %}

using scalar_t = {{ emb_weight_type.cpp_type_name }};
Expand Down

0 comments on commit 53a0d13

Please sign in to comment.