Skip to content

Commit

Permalink
Revert "Revert D65620886 (pytorch#3582)"
Browse files Browse the repository at this point in the history
This reverts commit 87db593.
  • Loading branch information
avbokovoy committed Jan 22, 2025
1 parent b858408 commit 668dea2
Showing 1 changed file with 311 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,159 @@ using namespace fbgemm_gpu;
{%- endif %}
{%- endmacro %}

{#-/*
Splitted version of load_and_accumulate macro. This code chunk describes
the weights load in forward kernel. Set up the WeightRow and load quantization
parameters. Shortcut store for nobag mode.
The main difference is in whether the slices are loaded from the embedding
table or cache.
NOTE: The decision was made to define this code chunk as a Jinja macro
instead of inline C++ function, since the compiler might not be able to
inline the code.
In-code variables that are defined outside:
emb_t, cache_t, cache_t
idx_j
inner_j
D_emb
lxu_cache_weights
{{ locs_or_addrs_idx }}_j
idx_weight_j
VEC_WIDTH
D
kThreadGroupSize
output_j
*/#}
{%- macro load_weights(from_cache) %}
{%- if from_cache %}
const cache_t* cache_weights;
{%- if ssd %}
cache_weights = reinterpret_cast<const cache_t*>(
*reinterpret_cast<uint64_t*>(&{{ locs_or_addrs_idx }}_j));
{%- else %}
cache_weights = reinterpret_cast<const cache_t*>(
&lxu_cache_weights[{{ locs_or_addrs_idx }}_j][0]);
{%- endif %}
{%- endif %}
{#-/* Set the weights row */#}
{%- if is_rocm %}
const auto weights_row = rocm::WeightRowAccessorVec2
{%- else %}
const auto weights_row = WeightRowAccessor
{%- endif %}
<
emb_t,
cache_t,
cache_t,
{%- if from_cache %}
true
{%- else %}
false
{%- endif %}
>(
{%- if from_cache %}
// Pass nullptr to avoid calling &weights[idx_j * D_emb], which loads
// memory into the registers as a side effect
nullptr,
// Load from the cache
cache_weights,
{%- else %}
// Load from the embedding table
&weights[idx_j * D_emb],
// Pass nullptr bc we are loading from the embedding table
nullptr,
{%- endif %}
D);

{#-/* Set the quantization params */#}
{%- if from_cache %}
// Assume cache is FP16/FP32, which doesn't require quantization params
const auto qparams = make_float2(0.0f, 0.0f);
{%- else %}
// Load the quantization params from the embedding table row if emb_t == uint8_t
const auto qparams = weights_row.load_qparams();
{%- endif %}

{%- if not nobag %}
// Iterate over the row in the weights table, in 4-element strides
#pragma unroll kMaxVecsPerThread
for (int32_t i = 0; i < kMaxVecsPerThread; ++i)
{
// Load the slice of the weights
int32_t d = (i * kThreadGroupSize + threadIdx.x) * VEC_WIDTH;
d = (d < D) ? d : 0;
const auto weights_slice = weights_row.load(d, qparams);
vals[inner_j * kMaxVecsPerThread + i] = weights_slice;
}

{%- else %}
for (int32_t i = 0; i < D; i += kThreadGroupSize * VEC_WIDTH) {
const int32_t d = i + threadIdx.x * VEC_WIDTH;
if (d < D) {
// Since there is no pooling, simply copy the weights to output
const auto weights_slice = weights_row.load(d, qparams);
{%- if is_index_select %}
// output is 1D (because the stride can be irregular)
weights_slice.store(&output[output_offset + output_j * output_stride + d]);
{%- else %}
// output is 2D
weights_slice.store(&output[output_j][d]);
{%- endif %}
}
}
{%- endif %}
{%- endmacro %}

{#-/*
Splitted version of load_and_accumulate macro. This code chunk
describes the weights accumulate step in the forward kernel.
Accumulate the slices of values from the row. Does nothing for
nobag mode assuming all the work is done in load() macro.
The main difference is in whether the slices are loaded from the embedding
table or cache.
NOTE: The decision was made to define this code chunk as a Jinja macro
instead of inline C++ function, since the compiler might not be able to
inline the code.
In-code variables that are defined outside:
emb_t, cache_t, cache_t
idx_j
inner_j
D_emb
lxu_cache_weights
cache_idx_j
idx_weight_j
VEC_WIDTH
D
kThreadGroupSize
output_j
*/#}
{%- macro accumulate_and_store(from_cache) %}
{%- if not nobag %}
// Iterate over the row in the weights table, in 4-element strides
#pragma unroll kMaxVecsPerThread
for (int32_t i = 0;
i < kMaxVecsPerThread && (i * kThreadGroupSize + threadIdx.x) * VEC_WIDTH < D;
++i) {
{%- if is_gwd_kernel %}
// Scale weights with global weight decay
vals[inner_j * kMaxVecsPerThread + i].mul_(global_weight_decay_j);
{%- endif %}
{%- if weighted %}
// Accumulate the weights * positional weight
accumulators[i].fma_(vals[inner_j * kMaxVecsPerThread + i], idx_weight_j);
{%- else %}
// Accumulate the weights
accumulators[i].add_(vals[inner_j * kMaxVecsPerThread + i]);
{%- endif %}
}
{%- endif %}
{%- endmacro %}

{#-/*
This code chunk contains the implementation body of the kernel, and is
defined as a Jinja macro to be copy-pasted directly into the kernel as
Expand Down Expand Up @@ -203,8 +356,162 @@ using namespace fbgemm_gpu;
at::acc_type<cache_t, true> idx_weight = l < L ? indice_weights[indices_start + l] : 0;
{%- endif %}

{%- if is_rocm %}
{%- if not nobag %}
rocm::Vec2T<cache_t> vals[kManualUnrollLength * kMaxVecsPerThread];
{%- endif %}
// Iterate over kThreadGroupSize indices
for (auto outer_j = 0; outer_j < kThreadGroupSize && l_start + outer_j < L - L % kManualUnrollLength; outer_j += kManualUnrollLength)
{
{%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %}
// Load index from thread j in the group
[[maybe_unused]] int64_t idx_j_[kManualUnrollLength];
for (auto inner_j = 0; inner_j < kManualUnrollLength; ++inner_j)
{
idx_j_[inner_j] = SHFL_SYNC(idx, outer_j + inner_j);
}
{%- endif %}
{%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %}
// Load cache's index from thread j in the group
[[maybe_unused]] int32_t {{ locs_or_addrs_idx }}_j_[kManualUnrollLength];
for (auto inner_j = 0; inner_j < kManualUnrollLength; ++inner_j)
{
{{ locs_or_addrs_idx }}_j_[inner_j] = use_lxu_cache ? SHFL_SYNC({{ locs_or_addrs_idx }}, outer_j + inner_j) : 0;
}
{%- endif %}

{%- if weighted %}
// Load positional weight index from thread j in the group
at::acc_type<cache_t, true> idx_weight_j_[kManualUnrollLength];
for (auto inner_j = 0; inner_j < kManualUnrollLength; ++inner_j)
{
idx_weight_j_[inner_j] = SHFL_SYNC(idx_weight, outer_j + inner_j);
}
{%- endif %}


for (auto inner_j = 0; inner_j < kManualUnrollLength; ++inner_j)
{
auto j = outer_j + inner_j;
{%- if is_index_select %}
int64_t output_j = L_start + l_start + j;
{%- elif nobag %}
int64_t output_j = indices_start + l_start + j;
{%- endif %}

{%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %}
[[maybe_unused]] int64_t idx_j = idx_j_[inner_j];
{%- endif %}
{%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %}
[[maybe_unused]] {{ locs_or_addrs_type }} {{ locs_or_addrs_idx }}_j
= use_lxu_cache ? {{ locs_or_addrs_idx }}_j_[inner_j] : 0;

{%- endif %}
{%- if weighted %}
at::acc_type<cache_t, true> idx_weight_j = idx_weight_j_[inner_j];
{%- endif %}



{#/**************************************************************/#}
{#-/*
This is the main switch that determines how we are to load and
accumulate weights, and is determined by Jinja-time, compile-time,
and run-time variables.
*/#}

{%- if dense %}
{#-/* If it's dense, cache is not supported, so load from the embedding table */#}
{{- load_weights(false) }}

{%- elif lxu_miss_rate == "cache_conflict_miss_rate::all" %}
{#-/* Else if we know we have a 100% miss rate, then always fetch from the embedding table */#}
{{- load_weights(false) }}

{%- elif lxu_miss_rate == "cache_conflict_miss_rate::zero" %}
{#-/* Else if we know we have a 0% miss rate, then always fetch from the cache */#}
{{ load_weights(true) }}
{%- else %}
{#-/* Else we defer to run-time selection */#}
if (placement == PlacementType::MANAGED_CACHING
&& {{ locs_or_addrs_idx }}_j != kCacheLocationMissing
) {
{#-/* If the row is available in the cache, fetch from the cache */#}
{{ load_weights(true) }}
} else {
{#-/* Else fetch from the embedding table */#}
{{ load_weights(false) }}
}

{%- endif %}
{#/**************************************************************/#}
}
{%- if not nobag %}
for (auto inner_j = 0; inner_j < kManualUnrollLength; ++inner_j)
{
auto j = outer_j + inner_j;

{%- if is_index_select %}
int64_t output_j = L_start + l_start + j;
{%- elif nobag %}
int64_t output_j = indices_start + l_start + j;
{%- endif %}

{%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %}
[[maybe_unused]] int64_t idx_j = idx_j_[inner_j];
{%- endif %}
{%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %}
[[maybe_unused]] int32_t {{ locs_or_addrs_idx }}_j = {{ locs_or_addrs_idx }}_j_[inner_j];
{%- endif %}
{%- if weighted %}
at::acc_type<cache_t, true> idx_weight_j = idx_weight_j_[inner_j];
{%- endif %}
{%- if is_gwd_kernel %}
const auto global_weight_decay_j = SHFL_SYNC(global_weight_decay, j);
{%- endif %}

{#/**************************************************************/#}
{#-/*
This is the main switch that determines how we are to load and
accumulate weights, and is determined by Jinja-time, compile-time,
and run-time variables.
*/#}

{%- if dense %}
{#-/* If it's dense, cache is not supported, so load from the embedding table */#}
{{- accumulate_and_store(false) }}

{%- elif lxu_miss_rate == "cache_conflict_miss_rate::all" %}
{#-/* Else if we know we have a 100% miss rate, then always fetch from the embedding table */#}
{{- accumulate_and_store(false) }}

{%- elif lxu_miss_rate == "cache_conflict_miss_rate::zero" %}
{#-/* Else if we know we have a 0% miss rate, then always fetch from the cache */#}
{{ accumulate_and_store(true) }}
{%- else %}
{#-/* Else we defer to run-time selection */#}
if (placement == PlacementType::MANAGED_CACHING
&& {{ locs_or_addrs_idx }}_j != kCacheLocationMissing) {
{#-/* If the row is available in the cache, fetch from the cache */#}
{{ accumulate_and_store(true) }}
} else {
{#-/* Else fetch from the embedding table */#}
{{ accumulate_and_store(false) }}
}

{%- endif %}
{#/**************************************************************/#}
}
{%- endif %}
}
{%- endif %}

{%- if is_rocm %}
for(auto j = L - L % kManualUnrollLength; j < kThreadGroupSize && l_start + j < L; ++j) {
{%- else %}
// Iterate over kThreadGroupSize indices
for (auto j = 0; j < kThreadGroupSize && l_start + j < L; ++j) {
{%- endif %}
{%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %}
// Load index from thread j in the group
[[maybe_unused]] int64_t idx_j = SHFL_SYNC(idx, j);
Expand Down Expand Up @@ -370,6 +677,10 @@ batch_index_select_dim0_codegen_forward_kernel(
{%- else %}
constexpr int VEC_WIDTH = 4;
{%- endif %}
{%- if is_rocm %}
// Unroll factor for ROCm devices
constexpr int kManualUnrollLength = 4;
{%- endif %}

// Determine the linearized warp ID, and exit early if needed
int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;
Expand Down

0 comments on commit 668dea2

Please sign in to comment.