forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSparseBinaryOpIntersectionKernel.cpp
107 lines (91 loc) · 3.83 KB
/
SparseBinaryOpIntersectionKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/sparse/SparseStubs.h>
#include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/TensorIterator.h>
namespace at {
namespace native {
namespace {
template <typename func_t>
struct CPUKernelLauncher {
static void launch(TensorIteratorBase& iter, const func_t& f) {
cpu_kernel(iter, f);
}
};
struct MulOp {
template <typename scalar_t>
static scalar_t apply(scalar_t a, scalar_t b) {
return a * b;
}
};
template <>
bool MulOp::apply(bool a, bool b) {
return a && b;
}
template <typename binary_op_t>
struct CPUValueSelectionIntersectionKernel {
static Tensor apply(
const Tensor& lhs_values,
const Tensor& lhs_select_idx,
const Tensor& rhs_values,
const Tensor& rhs_select_idx) {
auto iter = make_value_selection_intersection_iter(
lhs_values,
lhs_select_idx,
rhs_values,
rhs_select_idx);
auto res_values = iter.tensor(0);
auto lhs_nnz_stride = lhs_values.stride(0);
auto rhs_nnz_stride = rhs_values.stride(0);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, res_values.scalar_type(),
"binary_op_intersection_cpu", [&] {
AT_DISPATCH_INDEX_TYPES(lhs_select_idx.scalar_type(),
"binary_op_intersection_cpu", [&] {
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
auto* ptr_res_values_bytes = data[0];
const auto* ptr_lhs_values_bytes = data[1];
const auto* ptr_lhs_select_idx_bytes = data[2];
const auto* ptr_rhs_values_bytes = data[3];
const auto* ptr_rhs_select_idx_bytes = data[4];
for (int64_t i = 0; i < n; ++i) {
// Exctract data
auto* RESTRICT ptr_res_values = reinterpret_cast<scalar_t*>(ptr_res_values_bytes);
const auto* ptr_lhs_values = reinterpret_cast<const scalar_t*>(ptr_lhs_values_bytes);
const auto lhs_nnz_idx = *reinterpret_cast<const index_t*>(ptr_lhs_select_idx_bytes);
const auto* ptr_rhs_values = reinterpret_cast<const scalar_t*>(ptr_rhs_values_bytes);
const auto rhs_nnz_idx = *reinterpret_cast<const index_t*>(ptr_rhs_select_idx_bytes);
// Apply op
*ptr_res_values = binary_op_t::apply(
*(ptr_lhs_values + lhs_nnz_idx * lhs_nnz_stride),
*(ptr_rhs_values + rhs_nnz_idx * rhs_nnz_stride));
// Advance
ptr_res_values_bytes += strides[0];
ptr_lhs_values_bytes += strides[1];
ptr_lhs_select_idx_bytes += strides[2];
ptr_rhs_values_bytes += strides[3];
ptr_rhs_select_idx_bytes += strides[4];
}
};
iter.for_each(loop, at::internal::GRAIN_SIZE);
});
});
return res_values;
}
};
void mul_sparse_sparse_out_cpu_kernel(
Tensor& result,
const Tensor& x,
const Tensor& y) {
using CPUValueSelectionMulKernel = CPUValueSelectionIntersectionKernel<MulOp>;
_sparse_binary_op_intersection_kernel_out<CPUKernelLauncher, CPUValueSelectionMulKernel>(
result, x, y
);
}
}
REGISTER_ARCH_DISPATCH(mul_sparse_sparse_out_stub, DEFAULT, &mul_sparse_sparse_out_cpu_kernel);
REGISTER_AVX512_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
REGISTER_AVX2_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
REGISTER_VSX_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
}}