Skip to content

Commit

Permalink
[xla:cpu] Optimize Thunk::OkExecuteEvent
Browse files Browse the repository at this point in the history
name                                     old cpu/op   new cpu/op   delta
BM_SelectAndScatterF32/128/process_time   385µs ± 2%   378µs ± 4%  -1.82%
BM_SelectAndScatterF32/256/process_time  1.58ms ± 2%  1.56ms ± 2%  -1.77%
BM_SelectAndScatterF32/512/process_time  7.24ms ± 4%  7.07ms ± 6%  -2.39%

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32
PiperOrigin-RevId: 657437409
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Jul 31, 2024
1 parent 5843411 commit e477c1a
Show file tree
Hide file tree
Showing 30 changed files with 148 additions and 113 deletions.
15 changes: 0 additions & 15 deletions third_party/shardy/temporary.patch
Original file line number Diff line number Diff line change
@@ -1,15 +0,0 @@
diff --git i/third_party/llvm/workspace.bzl w/third_party/llvm/workspace.bzl
index 76a13a4..9345d8d 100644
--- i/third_party/llvm/workspace.bzl
+++ w/third_party/llvm/workspace.bzl
@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")

def repo(name):
"""Imports LLVM."""
- LLVM_COMMIT = "51681409aeb081c8dfe241e0d8e8c71f8bf0a4f4"
- LLVM_SHA256 = "347cc44fc5bba17b2a6ac26a253803434790a2996b77e8b6fbbeee9b8a367ec8"
+ LLVM_COMMIT = "d92a484e6f5c9063d82ca79405bb3557d88ad575"
+ LLVM_SHA256 = "0e6cce920f7344248ed747443fc16c316faf398e33f6a7f9f11f41ede861f824"

tf_http_archive(
name = name,
4 changes: 2 additions & 2 deletions third_party/shardy/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
SHARDY_COMMIT = "c87ce5b404305927c7a169b305ba0dc1c304e4ce"
SHARDY_SHA256 = "2fa411cfb31f351f2cdad997db0ccb8f9898bad3421f2a78889703bb75bd054c"
SHARDY_COMMIT = "df54e37427b0007e6527b62616ed1f66a68dda4a"
SHARDY_SHA256 = "2ebf03fd73c4578e721c539ad05b33d5fbfae6838abbb58b944e12f1eafbd9b2"

tf_http_archive(
name = "shardy",
Expand Down
15 changes: 0 additions & 15 deletions third_party/xla/third_party/shardy/temporary.patch
Original file line number Diff line number Diff line change
@@ -1,15 +0,0 @@
diff --git i/third_party/llvm/workspace.bzl w/third_party/llvm/workspace.bzl
index 76a13a4..9345d8d 100644
--- i/third_party/llvm/workspace.bzl
+++ w/third_party/llvm/workspace.bzl
@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")

def repo(name):
"""Imports LLVM."""
- LLVM_COMMIT = "51681409aeb081c8dfe241e0d8e8c71f8bf0a4f4"
- LLVM_SHA256 = "347cc44fc5bba17b2a6ac26a253803434790a2996b77e8b6fbbeee9b8a367ec8"
+ LLVM_COMMIT = "d92a484e6f5c9063d82ca79405bb3557d88ad575"
+ LLVM_SHA256 = "0e6cce920f7344248ed747443fc16c316faf398e33f6a7f9f11f41ede861f824"

tf_http_archive(
name = name,
4 changes: 2 additions & 2 deletions third_party/xla/third_party/shardy/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
SHARDY_COMMIT = "c87ce5b404305927c7a169b305ba0dc1c304e4ce"
SHARDY_SHA256 = "2fa411cfb31f351f2cdad997db0ccb8f9898bad3421f2a78889703bb75bd054c"
SHARDY_COMMIT = "df54e37427b0007e6527b62616ed1f66a68dda4a"
SHARDY_SHA256 = "2ebf03fd73c4578e721c539ad05b33d5fbfae6838abbb58b944e12f1eafbd9b2"

tf_http_archive(
name = "shardy",
Expand Down
24 changes: 12 additions & 12 deletions third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,18 @@ void AddGraphTraceActivityEvent(CuptiEventCollectorDelegate &collector,
AnnotationMap::AnnotationInfo info = collector.annotation_map.LookUp(
graph_trace->deviceId, graph_trace->correlationId);
collector.receive(CuptiTracerEvent{
.type = CuptiTracerEventType::CudaGraph,
.source = CuptiTracerEventSource::Activity,
.name = absl::StrCat("CudaGraphExec:", graph_trace->graphId),
.annotation = info.annotation,
.nvtx_range = info.nvtx_range,
.start_time_ns = graph_trace->start,
.end_time_ns = graph_trace->end,
.device_id = graph_trace->deviceId,
.correlation_id = graph_trace->correlationId,
.context_id = graph_trace->contextId,
.stream_id = graph_trace->streamId,
.graph_id = graph_trace->graphId,
/* .type = */ CuptiTracerEventType::CudaGraph,
/* .source = */ CuptiTracerEventSource::Activity,
/* .name = */ absl::StrCat("CudaGraphExec:", graph_trace->graphId),
/* .annotation = */ info.annotation,
/* .nvtx_range = */ info.nvtx_range,
/* .start_time_ns = */ graph_trace->start,
/* .end_time_ns = */ graph_trace->end,
/* .device_id = */ graph_trace->deviceId,
/* .correlation_id = */ graph_trace->correlationId,
/* .context_id = */ graph_trace->contextId,
/* .stream_id = */ graph_trace->streamId,
/* .graph_id = */ graph_trace->graphId,
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ struct MemcpyDetails {
int8_t dst_mem_kind;

// ID of the hardware channel on which this operation ran.
uint32_t channel_id = -1;
uint32_t channel_id = static_cast<uint32_t>(-1);
// CUpti_ChannelType of the channel above.
int8_t channel_type = 0; // CUPTI_CHANNEL_TYPE_INVALID
};
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,9 @@ std::optional<DynamicOrStaticInteger> EvaluateWhileLoopParamInitValue(

namespace internal {

#if !defined(_MSC_VER)
constexpr absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl";
#endif

std::optional<EvalErrorDetail> ParseEvalErrorDetail(const absl::Status& error) {
auto error_detail = error.GetPayload(kEvalErrorDetailUrl);
Expand Down
4 changes: 4 additions & 0 deletions third_party/xla/xla/hlo/evaluator/hlo_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,11 @@ enum class EvalErrorDetail : uint32_t {
kDynamicValueDependence = 0,
};

#if defined(_MSC_VER)
extern const absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl";
#else
extern const absl::string_view kEvalErrorDetailUrl;
#endif

std::optional<EvalErrorDetail> ParseEvalErrorDetail(const absl::Status& error);

Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2129,7 +2129,7 @@ PJRT_Error* PJRT_Layouts_MemoryLayout_Serialize(
PJRT_Layouts_MemoryLayout_Serialize_Args_STRUCT_SIZE, args->struct_size));

PJRT_Layouts_SerializedLayout* s_layout = new PJRT_Layouts_SerializedLayout{
.serialized = args->layout->layout->Serialize()};
/* .serialized = */ args->layout->layout->Serialize()};
args->serialized_layout = s_layout;
args->serialized_bytes = s_layout->serialized.data();
args->serialized_bytes_size = s_layout->serialized.size();
Expand Down
14 changes: 8 additions & 6 deletions third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,15 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options,
#endif
}

STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, {
PjRtRegisterCompiler(
#if TENSORFLOW_USE_ROCM
RocmName(),
STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, {
PjRtRegisterCompiler(RocmName(),
std::make_unique<StreamExecutorGpuCompiler>());
});
#else
CudaName(),
#endif
std::make_unique<StreamExecutorGpuCompiler>());
STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, {
PjRtRegisterCompiler(CudaName(),
std::make_unique<StreamExecutorGpuCompiler>());
});
#endif
} // namespace xla
8 changes: 4 additions & 4 deletions third_party/xla/xla/service/cpu/runtime/conv_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void EigenConv2DImpl(
Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation,
Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation,
Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count,
std::optional<std::function<void()>> done_callback = std::nullopt) {
std::optional<std::function<void()>> done_callback) {
const Eigen::TensorMap<Eigen::Tensor<const ScalarType, 4, Eigen::RowMajor>,
Eigen::Aligned>
input(lhs, input_batch, input_x, input_y, input_channels);
Expand Down Expand Up @@ -129,7 +129,7 @@ void EigenConv3DImpl(
Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation,
Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation,
Eigen::Index feature_group_count,
std::optional<std::function<void()>> done_callback = std::nullopt) {
std::optional<std::function<void()>> done_callback) {
using ConstTType =
Eigen::TensorMap<Eigen::Tensor<const ScalarType, 5, Eigen::RowMajor>,
Eigen::Aligned>;
Expand Down Expand Up @@ -223,7 +223,7 @@ void EigenConv3DImpl(
Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, \
Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, \
Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count, \
std::optional<std::function<void()>> done_callback = std::nullopt)
std::optional<std::function<void()>> done_callback)

CONV2D_EXTERN_TEMPLATE(Eigen::DefaultDevice, Eigen::half);
CONV2D_EXTERN_TEMPLATE(Eigen::DefaultDevice, float);
Expand All @@ -249,7 +249,7 @@ CONV2D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, float);
Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation, \
Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation, \
Eigen::Index feature_group_count, \
std::optional<std::function<void()>> done_callback = std::nullopt)
std::optional<std::function<void()>> done_callback)

CONV3D_EXTERN_TEMPLATE(Eigen::DefaultDevice, Eigen::half);
CONV3D_EXTERN_TEMPLATE(Eigen::DefaultDevice, float);
Expand Down
10 changes: 7 additions & 3 deletions third_party/xla/xla/service/cpu/runtime/thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ std::string_view Thunk::KindToString(Kind kind) {
return "while";
}
}
Thunk::Thunk(Kind kind, Info info)
: kind_(kind),
info_(std::move(info)),
ok_event_(OkExecuteEventSingleton()) {}

absl::StatusOr<Thunk::CollectiveExecuteParams>
Thunk::CollectiveExecuteParams::Create(
Expand Down Expand Up @@ -150,13 +154,13 @@ Thunk::CustomCallExecuteParams::CustomCallExecuteParams(
allocator(allocator),
ffi_execution_context(ffi_execution_context) {}

const tsl::AsyncValueOwningRef<Thunk::ExecuteEvent>* Thunk::OkEvent() {
static tsl::AsyncValueOwningRef<ExecuteEvent>* owner = [] {
tsl::AsyncValueRef<Thunk::ExecuteEvent> Thunk::OkExecuteEventSingleton() {
static tsl::AsyncValueOwningRef<ExecuteEvent>* singleton = [] {
auto* storage = new tsl::internal::AsyncValueStorage<ExecuteEvent>();
return new tsl::AsyncValueOwningRef<ExecuteEvent>(
tsl::MakeAvailableAsyncValueRef<ExecuteEvent>(*storage));
}();
return owner;
return singleton->AsRef();
}

Thunk::ExecuteState::ExecuteState(int64_t num_tasks)
Expand Down
26 changes: 14 additions & 12 deletions third_party/xla/xla/service/cpu/runtime/thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class Thunk {
using Task = std::function<void()>;
using TaskRunner = absl::AnyInvocable<void(Task)>;

Thunk(Kind kind, Info info) : kind_(kind), info_(std::move(info)) {}
Thunk(Kind kind, Info info);

Thunk(const Thunk&) = delete;
Thunk& operator=(const Thunk&) = delete;
Expand Down Expand Up @@ -286,18 +286,20 @@ class Thunk {
// An execute event that becomes ready when all tasks are completed.
using ExecuteEvent = tsl::Chain;

// Returns non-reference-counted async value ref for thunks executed in the
// caller thread to avoid reference counting overhead.
static tsl::AsyncValueRef<ExecuteEvent> OkExecuteEvent() {
return OkEvent()->AsRef();
}
// Returns non-reference-counted async value ref in constructed state.
// Returned async value is a per-process singleton stored in a storage with a
// static duration, and can be safely compared using pointer equality.
static tsl::AsyncValueRef<ExecuteEvent> OkExecuteEventSingleton();

// Returns `OkExecuteEventSingleton()` cached by this thunk instance.
tsl::AsyncValueRef<ExecuteEvent> OkExecuteEvent() const { return ok_event_; }

static bool IsOkExecuteEvent(tsl::AsyncValuePtr<ExecuteEvent> event) {
return event == OkEvent()->AsPtr();
bool IsOkExecuteEvent(const tsl::AsyncValueRef<ExecuteEvent>& event) const {
return event == ok_event_;
}

static bool IsOkExecuteEvent(const tsl::AsyncValueRef<ExecuteEvent>& event) {
return IsOkExecuteEvent(event.AsPtr());
bool IsOkExecuteEvent(tsl::AsyncValuePtr<ExecuteEvent> event) const {
return event == ok_event_.AsPtr();
}

// Thunk execution must be asynchronous and never block the caller thread,
Expand Down Expand Up @@ -339,10 +341,10 @@ class Thunk {
}

private:
static const tsl::AsyncValueOwningRef<Thunk::ExecuteEvent>* OkEvent();

Kind kind_;
Info info_;

tsl::AsyncValueRef<ExecuteEvent> ok_event_;
};

std::ostream& operator<<(std::ostream& os, Thunk::Kind kind);
Expand Down
10 changes: 5 additions & 5 deletions third_party/xla/xla/service/cpu/runtime/thunk_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ tsl::AsyncValueRef<ThunkExecutor::ExecuteEvent> ThunkExecutor::Execute(
const Thunk::ExecuteParams& params) {
// Short-circuit execution of trivial thunk sequences.
if (ABSL_PREDICT_FALSE(thunk_sequence_.empty())) {
return Thunk::OkExecuteEvent();
return Thunk::OkExecuteEventSingleton();
}
if (ABSL_PREDICT_FALSE(thunk_sequence_.size() == 1)) {
return thunk_sequence_[0]->Execute(params);
Expand Down Expand Up @@ -181,7 +181,7 @@ ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) {
auto execute_event = thunk.Execute(params);

// Fast path for thunks executed inline and returned OkExecuteEvent.
if (ABSL_PREDICT_TRUE(Thunk::IsOkExecuteEvent(execute_event))) {
if (ABSL_PREDICT_TRUE(thunk.IsOkExecuteEvent(execute_event))) {
continue;
}

Expand All @@ -207,7 +207,7 @@ ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) {

// If we got to the end of the sequence it means that all thunks have
// succeeded.
return Thunk::OkExecuteEvent();
return Thunk::OkExecuteEventSingleton();
}

void ThunkExecutor::ResumeExecuteSequential(
Expand All @@ -218,7 +218,7 @@ void ThunkExecutor::ResumeExecuteSequential(
auto execute_event = thunk.Execute(params);

// Fast path for thunks executed inline and returned OkExecuteEvent.
if (ABSL_PREDICT_TRUE(Thunk::IsOkExecuteEvent(execute_event))) {
if (ABSL_PREDICT_TRUE(thunk.IsOkExecuteEvent(execute_event))) {
continue;
}

Expand Down Expand Up @@ -281,7 +281,7 @@ void ThunkExecutor::Execute(ExecuteState* state,
Thunk& thunk = *state->executor->thunk_sequence_[id];
tsl::AsyncValueRef<ExecuteEvent> execute_event =
ABSL_PREDICT_FALSE(state->abort.load(std::memory_order_relaxed))
? Thunk::OkExecuteEvent()
? Thunk::OkExecuteEventSingleton()
: thunk.Execute(params);

if (ABSL_PREDICT_TRUE(execute_event.IsAvailable())) {
Expand Down
6 changes: 4 additions & 2 deletions third_party/xla/xla/service/cpu/runtime_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "xla/service/cpu/runtime_conv2d.h"

#include <optional>

#define EIGEN_USE_THREADS

#include "absl/base/dynamic_annotations.h"
Expand All @@ -41,7 +43,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF32(
kernel_channels, kernel_filters, output_rows, output_cols, row_stride,
col_stride, padding_top, padding_bottom, padding_left, padding_right,
lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation,
feature_group_count);
feature_group_count, std::nullopt);
}

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF16(
Expand All @@ -63,5 +65,5 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF16(
kernel_channels, kernel_filters, output_rows, output_cols, row_stride,
col_stride, padding_top, padding_bottom, padding_left, padding_right,
lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation,
feature_group_count);
feature_group_count, std::nullopt);
}
6 changes: 4 additions & 2 deletions third_party/xla/xla/service/cpu/runtime_conv3d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "xla/service/cpu/runtime_conv3d.h"

#include <optional>

#define EIGEN_USE_THREADS

#include "absl/base/dynamic_annotations.h"
Expand Down Expand Up @@ -44,7 +46,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF32(
y_stride, z_stride, padding_x_before, padding_x_after, padding_y_before,
padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation,
lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation,
rhs_z_dilation, feature_group_count);
rhs_z_dilation, feature_group_count, std::nullopt);
}

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF16(
Expand All @@ -69,5 +71,5 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF16(
y_stride, z_stride, padding_x_before, padding_x_after, padding_y_before,
padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation,
lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation,
rhs_z_dilation, feature_group_count);
rhs_z_dilation, feature_group_count, std::nullopt);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "xla/service/cpu/runtime_single_threaded_conv2d.h"

#include <optional>

#include "absl/base/dynamic_annotations.h"
#include "xla/service/cpu/runtime/conv_impl.h"

Expand All @@ -35,7 +37,7 @@ __xla_cpu_runtime_EigenSingleThreadedConv2DF16(
kernel_filters, output_rows, output_cols, row_stride, col_stride,
padding_top, padding_bottom, padding_left, padding_right,
lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation,
feature_group_count);
feature_group_count, std::nullopt);
}

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void
Expand All @@ -55,5 +57,5 @@ __xla_cpu_runtime_EigenSingleThreadedConv2DF32(
kernel_filters, output_rows, output_cols, row_stride, col_stride,
padding_top, padding_bottom, padding_left, padding_right,
lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation,
feature_group_count);
feature_group_count, std::nullopt);
}
Loading

0 comments on commit e477c1a

Please sign in to comment.