Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shortcut to find common axes when the operation is unary and only one tensor is incompatible individually. #329

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions shardy/dialect/sdy/ir/axis_list_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ class AxisListRef {
/*isTailIterated=*/true, tailAxisRef);
}

// Clears this AxisListRef.
void clear();

friend struct AxisListRefInfo;

private:
Expand Down Expand Up @@ -178,8 +181,6 @@ class AxisListRef {
// `newSizeExcludingNewTail`.
void trim(int64_t newSizeExcludingNewTail,
std::optional<AxisRefAttr> newTailAxisRef);
// Clears this AxisListRef.
void clear();

// The axes that this FactorAxesPair holds is defined by `axisRefs` and
// `tailAxisRef` together as the concatantion of the two. If `tailAxisRef` is
Expand Down
118 changes: 102 additions & 16 deletions shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <cstdint>
#include <optional>

#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
Expand Down Expand Up @@ -60,19 +61,47 @@ bool hasOverflowAxes(const ShardingProjection& projection) {
return false;
}

// Checks if each factor sharding is compatible individually, that is, it
// satisfies:
// 1. Factors that need replication are unsharded.
//
// Assumes factor shardings do not have overflow axes.
// TODO(enver): Handle the case when some factor shardings have overflow axes.
SmallVector<int64_t> findTensorIndicesIncompatibleTo(
const ShardingProjection& projection, OpShardingRuleAttr shardingRule) {
// Factors that need replication should be unsharded across all operands and
// results in order for it to have a compatible sharding.
llvm::SmallDenseSet<int64_t> factorIndicesThatNeedReplication;
for (int64_t factorIndex : shardingRule.getNeedReplicationFactors()) {
factorIndicesThatNeedReplication.insert(factorIndex);
}
SmallVector<int64_t> tensorIndicesIncompatibleTo;
for (const auto& [tensorIndex, tensorFactorSharding] :
llvm::enumerate(llvm::concat<const TensorFactorShardings>(
projection.getOperands(), projection.getResults()))) {
for (const auto& [factorIndex, factorSharding] :
tensorFactorSharding.factorIndexToSharding) {
if (factorIndicesThatNeedReplication.contains(factorIndex) &&
!factorSharding.axisRefs.empty()) {
tensorIndicesIncompatibleTo.push_back(tensorIndex);
break;
}
}
}
return tensorIndicesIncompatibleTo;
}

// Checks if factor sharding is compatible, that is, it satisfies:
// 1. Factors are sharded the same way across operands and results.
// 2. Factors that need replication are unsharded.
// 2. Each factor is compatible individually.
//
// Assumes factor shardings do not have overflow axes.
// TODO(enver): Handle the case when some factor shardings have overflow axes.
bool hasCompatibleFactorShardings(const ShardingProjection& projection,
OpShardingRuleAttr shardingRule) {
FactorIndexToSharding factorIndexToCommonSharding;
// Factors that need replication should be unsharded across all operands and
// results in order for it to have a compatible sharding.
for (int64_t factorIndex : shardingRule.getNeedReplicationFactors()) {
factorIndexToCommonSharding[factorIndex] = FactorSharding{};
if (!findTensorIndicesIncompatibleTo(projection, shardingRule).empty()) {
return false;
}
for (const TensorFactorShardings& tensorFactorSharding :
llvm::concat<const TensorFactorShardings>(projection.getOperands(),
Expand Down Expand Up @@ -458,12 +487,63 @@ SmallVector<AxisListRef> findCommonAxesUsingMajorityVoteHeuristic(
return factorAxisRefs;
}

SmallVector<AxisListRef> findCommonAxesIgnoringOneTensor(
const int64_t tensorIndexToIgnore, const ShardingProjection& projection,
OpShardingRuleAttr shardingRule) {
SmallVector<AxisListRef> factorAxisRefs(shardingRule.getNumFactors());
for (const auto& [factorIndex, factorSharding] :
projection.getTensor(tensorIndexToIgnore).factorIndexToSharding) {
if (!factorSharding.axisRefs.empty()) {
factorAxisRefs[factorIndex] = AxisListRef(factorSharding.axisRefs);
}
}
for (const auto& [tensorIndex, tensorFactorSharding] :
llvm::enumerate(llvm::concat<const TensorFactorShardings>(
projection.getOperands(), projection.getResults()))) {
if (tensorIndex == tensorIndexToIgnore) {
continue;
}
for (const auto& [factorIndex, factorSharding] :
tensorFactorSharding.factorIndexToSharding) {
if (factorSharding.axisRefs.empty()) {
factorAxisRefs[factorIndex].clear();
continue;
}
factorAxisRefs[factorIndex] = AxisListRef(factorSharding.axisRefs);
}
}
for (int64_t factorIndex : shardingRule.getNeedReplicationFactors()) {
factorAxisRefs[factorIndex] = AxisListRef();
}
return factorAxisRefs;
}

// TODO(enver): Relax being unary to those with two non-scalar tensors.
bool isUnaryOperation(OpShardingRuleAttr shardingRule) {
return shardingRule.getNumOperands() == 1 &&
shardingRule.getNumResults() == 1;
}

SmallVector<AxisListRef> findCommonAxes(const ShardingProjection& projection,
int64_t numFactors,
ArrayRef<int64_t> tensorSizes,
OpShardingRuleAttr shardingRule,
MeshAttr mesh) {
return findCommonAxesUsingMajorityVoteHeuristic(projection, numFactors,
tensorSizes, mesh);
// Find common axes if it is a unary operation and only one tensor is
// incompatible individually.
if (isUnaryOperation(shardingRule)) {
if (auto tensorIndices =
findTensorIndicesIncompatibleTo(projection, shardingRule);
tensorIndices.size() == 1) {
return findCommonAxesIgnoringOneTensor(tensorIndices[0], projection,
shardingRule);
}
}

// TODO(enver): Support cases that factors need replication, for which, empty
// axes should be fixed for those factors during the heuristic of finding
// common axes.
return findCommonAxesUsingMajorityVoteHeuristic(
projection, shardingRule.getNumFactors(), shardingRule.getTensorSizes(),
mesh);
}

struct InsertExplicitReshardsPass
Expand Down Expand Up @@ -566,20 +646,26 @@ struct InsertExplicitReshardsPass
}

// Return without inserting reshards for operations with factors that need
// replication.
// TODO(enver): Insert explicit reshards also for the case that the
// factors that need replication are sharded.
if (isa<stablehlo::CholeskyOp, stablehlo::BitcastConvertOp,
stablehlo::ConcatenateOp, stablehlo::SortOp,
// replication and may be non-unary operation.
// TODO(enver): Drop this special check and handle its resharding.
if (isa<stablehlo::ConcatenateOp, stablehlo::SortOp,
stablehlo::TriangularSolveOp>(op)) {
return;
}

// Return without inserting reshards for operations with more than one
// tensors that are incompatible individually, which happens if there are
// multiple tensors have sharded factors that needs replication.
// TODO(enver): Drop this special check and handle its resharding.
if (findTensorIndicesIncompatibleTo(shardingProjection, shardingRule)
.size() > 1) {
return;
}

UpdateTensorShardings updateTensorShardings(shardingRule.getNumOperands(),
shardingRule.getNumResults());
for (const auto& [index, axes] : llvm::enumerate(
findCommonAxes(shardingProjection, shardingRule.getNumFactors(),
shardingRule.getTensorSizes(), mesh))) {
findCommonAxes(shardingProjection, shardingRule, mesh))) {
// TODO(enver): Add unit tests to test overflow axes are cleared after
// handling the case that some factors have overflow axes.
updateTensorShardings |= shardingProjection.updateSharding(
Expand Down
Loading