Skip to content

Commit

Permalink
Add support for int32_t indices in TBE training (2I/N)
Browse files Browse the repository at this point in the history
Summary: - Add `int21_t` support to `::internal::csr2csc`, for eventual `int32_t` indices support in TBE CPU

Reviewed By: jianyuh

Differential Revision: D67920539
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 7, 2025
1 parent 8bf5bbf commit b5fea5d
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 63 deletions.
108 changes: 55 additions & 53 deletions fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,18 +380,18 @@ namespace internal {

namespace {

template <typename scalar_t, bool IS_VALUE_PAIR>
template <typename index_t, typename scalar_t, bool IS_VALUE_PAIR>
void csr2csc_template_(
HyperCompressedSparseColumn& csc,
int B,
const at::TensorAccessor<int64_t, 1>& csr_offsets,
const at::TensorAccessor<int64_t, 1>& csr_indices,
const at::TensorAccessor<index_t, 1>& csr_offsets,
const at::TensorAccessor<index_t, 1>& csr_indices,
const at::TensorAccessor<scalar_t, 1>& csr_weights,
int64_t pooling_mode,
const int* table_to_feature_offset,
int64_t num_embeddings) {
csc.num_non_zero_columns = 0;
int64_t nnz = csr_offsets[table_to_feature_offset[1] * B] -
const auto nnz = csr_offsets[table_to_feature_offset[1] * B] -
csr_offsets[table_to_feature_offset[0] * B];
if (nnz == 0) {
return;
Expand All @@ -407,7 +407,7 @@ void csr2csc_template_(
[[maybe_unused]] int column_ptr_curr = 0;
bool is_shared_table =
table_to_feature_offset[1] > table_to_feature_offset[0] + 1;
auto NS = csr_offsets[table_to_feature_offset[1] * B] -
const auto NS = csr_offsets[(size_t)table_to_feature_offset[1] * B] -
csr_offsets[table_to_feature_offset[0] * B];

using pair_t = std::pair<int, scalar_t>;
Expand All @@ -432,9 +432,9 @@ void csr2csc_template_(
#pragma omp parallel for
for (int b = 0; b < B; ++b) {
const auto FBb = feature * B + b;
int64_t pool_begin = csr_offsets[FBb];
int64_t pool_end = csr_offsets[FBb + 1];
int64_t L = pool_end - pool_begin;
const auto pool_begin = csr_offsets[FBb];
const auto pool_end = csr_offsets[FBb + 1];
const auto L = pool_end - pool_begin;
// MEAN pooling will not work with indice_weights!
double scale_factor =
(static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN &&
Expand Down Expand Up @@ -581,47 +581,48 @@ void csr2csc_template_(
assert(column_ptr_curr == nnz);
}

#define INSTANTIATE_BATCHED_CSR2CSC(SCALAR_T) \
template void csr2csc_template_<SCALAR_T, true>( \
HyperCompressedSparseColumn & csc, \
int B, \
const at::TensorAccessor<int64_t, 1>& csr_offsets, \
const at::TensorAccessor<int64_t, 1>& csr_indices, \
const at::TensorAccessor<SCALAR_T, 1>& csr_weights, \
int64_t pooling_mode, \
const int* table_to_feature_offset, \
int64_t num_embeddings); \
\
template void csr2csc_template_<SCALAR_T, false>( \
HyperCompressedSparseColumn & csc, \
int B, \
const at::TensorAccessor<int64_t, 1>& csr_offsets, \
const at::TensorAccessor<int64_t, 1>& csr_indices, \
const at::TensorAccessor<SCALAR_T, 1>& csr_weights, \
int64_t pooling_mode, \
const int* table_to_feature_offset, \
#define INSTANTIATE_CSR2CSC_TEMPLATE_0(index_t, scalar_t, is_value_pair) \
template void csr2csc_template_<index_t, scalar_t, is_value_pair>( \
HyperCompressedSparseColumn & csc, \
int B, \
const at::TensorAccessor<index_t, 1>& csr_offsets, \
const at::TensorAccessor<index_t, 1>& csr_indices, \
const at::TensorAccessor<scalar_t, 1>& csr_weights, \
int64_t pooling_mode, \
const int* table_to_feature_offset, \
int64_t num_embeddings);

INSTANTIATE_BATCHED_CSR2CSC(float)
INSTANTIATE_BATCHED_CSR2CSC(double)
#undef INSTANTIATE_BATCHED_CSR2CSC
#define INSTANTIATE_CSR2CSC_TEMPLATE_1(index_t, scalar_t) \
INSTANTIATE_CSR2CSC_TEMPLATE_0(index_t, scalar_t, true); \
INSTANTIATE_CSR2CSC_TEMPLATE_0(index_t, scalar_t, false);

#define INSTANTIATE_CSR2CSC_TEMPLATE_2(index_t) \
INSTANTIATE_CSR2CSC_TEMPLATE_1(index_t, float); \
INSTANTIATE_CSR2CSC_TEMPLATE_1(index_t, double);

INSTANTIATE_CSR2CSC_TEMPLATE_2(int32_t);
INSTANTIATE_CSR2CSC_TEMPLATE_2(int64_t);

#undef INSTANTIATE_CSR2CSC_TEMPLATE_2
#undef INSTANTIATE_CSR2CSC_TEMPLATE_1
#undef INSTANTIATE_CSR2CSC_TEMPLATE_0

} // namespace

template <typename scalar_t>
template <typename index_t, typename scalar_t>
void csr2csc(
HyperCompressedSparseColumn& csc,
int B,
const at::TensorAccessor<int64_t, 1>& csr_offsets,
const at::TensorAccessor<int64_t, 1>& csr_indices,
const at::TensorAccessor<index_t, 1>& csr_offsets,
const at::TensorAccessor<index_t, 1>& csr_indices,
const at::TensorAccessor<scalar_t, 1>& csr_weights,
int64_t pooling_mode,
const int* table_to_feature_offset,
int64_t num_embeddings) {
bool has_weights = csr_weights.data() != nullptr;
if (has_weights ||
static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN) {
csr2csc_template_<scalar_t, /*IS_VALUE_PAIR=*/true>(
csr2csc_template_<index_t, scalar_t, /*IS_VALUE_PAIR=*/true>(
csc,
B,
csr_offsets,
Expand All @@ -631,7 +632,7 @@ void csr2csc(
table_to_feature_offset,
num_embeddings);
} else {
csr2csc_template_<scalar_t, /*IS_VALUE_PAIR=*/false>(
csr2csc_template_<index_t, scalar_t, /*IS_VALUE_PAIR=*/false>(
csc,
B,
csr_offsets,
Expand All @@ -643,25 +644,26 @@ void csr2csc(
}
}

template void csr2csc<float>(
HyperCompressedSparseColumn& csc,
int B,
const at::TensorAccessor<int64_t, 1>& csr_offsets,
const at::TensorAccessor<int64_t, 1>& csr_indices,
const at::TensorAccessor<float, 1>& csr_weights,
int64_t pooling_mode,
const int* table_to_feature_offset,
int64_t num_embeddings);
#define INSTANTIATE_CSR2CSC_0(index_t, scalar_t) \
template void csr2csc<index_t, scalar_t>( \
HyperCompressedSparseColumn & csc, \
int B, \
const at::TensorAccessor<index_t, 1>& csr_offsets, \
const at::TensorAccessor<index_t, 1>& csr_indices, \
const at::TensorAccessor<scalar_t, 1>& csr_weights, \
int64_t pooling_mode, \
const int* table_to_feature_offset, \
int64_t num_embeddings);

template void csr2csc<double>(
HyperCompressedSparseColumn& csc,
int B,
const at::TensorAccessor<int64_t, 1>& csr_offsets,
const at::TensorAccessor<int64_t, 1>& csr_indices,
const at::TensorAccessor<double, 1>& csr_weights,
int64_t pooling_mode,
const int* table_to_feature_offset,
int64_t num_embeddings);
#define INSTANTIATE_CSR2CSC_1(index_t) \
INSTANTIATE_CSR2CSC_0(index_t, float); \
INSTANTIATE_CSR2CSC_0(index_t, double);

INSTANTIATE_CSR2CSC_1(int32_t);
INSTANTIATE_CSR2CSC_1(int64_t);

#undef INSTANTIATE_CSR2CSC_1
#undef INSTANTIATE_CSR2CSC_0

} // namespace internal

Expand Down
6 changes: 3 additions & 3 deletions fbgemm_gpu/include/fbgemm_gpu/embedding_forward_split_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ struct HyperCompressedSparseColumn {
}
};

template <typename scalar_t>
template <typename index_t, typename scalar_t>
void csr2csc(
HyperCompressedSparseColumn& csc,
int B,
const at::TensorAccessor<int64_t, 1>& csr_offsets,
const at::TensorAccessor<int64_t, 1>& csr_indices,
const at::TensorAccessor<index_t, 1>& csr_offsets,
const at::TensorAccessor<index_t, 1>& csr_indices,
const at::TensorAccessor<scalar_t, 1>& csr_weights,
int64_t pooling_mode,
const int* table_to_feature_offset,
Expand Down
25 changes: 18 additions & 7 deletions fbgemm_gpu/test/tbe/utils/cpu_kernel_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,23 @@
#include "fbgemm_gpu/embedding_forward_split_cpu.h"
#include "torch/types.h" // @manual=//caffe2:torch-cpp-cpu

TEST(CpuKernelTest, csr2csc_test) {
template <c10::ScalarType DType, typename T>
void test_csr2csc() {
internal::HyperCompressedSparseColumn csc;
int B = 2;
at::Tensor offsets = torch::tensor({0, 4, 8});
at::Tensor indices = torch::tensor({1, 2, 4, 5, 4, 3, 2, 9});
at::Tensor offsets =
torch::tensor({0, 4, 8}, torch::TensorOptions().dtype(DType));
at::Tensor indices = torch::tensor(
{1, 2, 4, 5, 4, 3, 2, 9}, torch::TensorOptions().dtype(DType));
int64_t pooling_mode = (int64_t)fbgemm_gpu::PoolingMode::SUM;
int table_to_feature_offset[2] = {0, 1};
int num_embeddings = 10;

::internal::csr2csc(
csc,
B,
offsets.accessor<int64_t, 1>(),
indices.accessor<int64_t, 1>(),
offsets.accessor<T, 1>(),
indices.accessor<T, 1>(),
at::TensorAccessor<at::acc_type<float, true>, 1>(
nullptr, nullptr, nullptr), // no weights
pooling_mode,
Expand Down Expand Up @@ -61,8 +64,8 @@ TEST(CpuKernelTest, csr2csc_test) {
::internal::csr2csc(
csc_weighted,
B,
offsets.accessor<int64_t, 1>(),
indices.accessor<int64_t, 1>(),
offsets.accessor<T, 1>(),
indices.accessor<T, 1>(),
indice_weights.accessor<at::acc_type<float, true>, 1>(),
pooling_mode,
table_to_feature_offset,
Expand All @@ -88,3 +91,11 @@ TEST(CpuKernelTest, csr2csc_test) {
EXPECT_EQ(expect_weights[i], csc_weighted.weights[i]);
}
}

TEST(CpuKernelTest, csr2csc_test_int32) {
test_csr2csc<torch::kInt32, int32_t>();
}

TEST(CpuKernelTest, csr2csc_test_int64) {
test_csr2csc<torch::kInt64, int64_t>();
}

0 comments on commit b5fea5d

Please sign in to comment.