1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <memory>
17#include <optional>
18
19#include "llvm/ADT/SmallVector.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
21#include "mlir/IR/BuiltinOps.h" // from @llvm-project
22#include "mlir/Pass/Pass.h" // from @llvm-project
23#include "mlir/Support/LogicalResult.h" // from @llvm-project
24#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
25#include "tensorflow/dtensor/mlir/sparse_expander.h"
26#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
27
28namespace tensorflow {
29namespace dtensor {
30
31namespace {
32#define GEN_PASS_DEF_DTENSORSPARSEEXPANSION
33#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
34
35constexpr char kMainFunctionName[] = "main";
36
37// Expand every op that consumes SparseTensor operands in topological order.
38mlir::LogicalResult ConductSparseExpansion(mlir::ModuleOp module) {
39 auto main_func = module.lookupSymbol<mlir::func::FuncOp>(kMainFunctionName);
40 if (!main_func)
41 return module.emitOpError(
42 "could not find `main` function in module for SPMD expansion.");
43
44 TopologicalIterator iterator(main_func);
45 while (iterator.hasNext()) {
46 mlir::Operation* op = iterator.next();
47
48 mlir::Operation* expanded_op = nullptr;
49 auto status = RunSparseExpansion(op, &expanded_op);
50 if (!status.ok() || expanded_op == nullptr) {
51 // Sometimes op may been erased and expanded_op set.
52 // In this case we should emit the error on the expanded op.
53 mlir::Operation* emit_op = op;
54 if (expanded_op != nullptr) emit_op = expanded_op;
55 return emit_op->emitError(WithContext(status, __FILE__, __LINE__,
56 "While computing Sparse expansion")
57 .error_message());
58 }
59 }
60 return mlir::success();
61}
62
63// After Sparse Expansion pass, there may be unused SparseToDenseOps due to
64// expanded ops possibly taking the operands of the SparseToDenseOps instead
65// of the output of the SparseToDenseOps. So remove unused SparseToDenseOps
66// and its corresponding dependent ops like DTensorLayout and Const ops.
67void RemoveUnusedSparseToDenseOps(mlir::ModuleOp module) {
68 llvm::SmallVector<mlir::TF::SparseToDenseOp, 4> sparse_ops_to_erase;
69 llvm::SmallVector<mlir::TF::DTensorLayout, 4> layout_ops_to_erase;
70
71 module.walk([&](mlir::TF::SparseToDenseOp op) {
72 // Delete this op if it either has no consuming ops or the only consuming
73 // op is a DTensorLayout op that also has no consuming ops.
74 if (op->use_empty()) {
75 sparse_ops_to_erase.emplace_back(op);
76 } else if (op->hasOneUse()) {
77 if (auto layout_op = mlir::dyn_cast<mlir::TF::DTensorLayout>(
78 op->getOpResult(0).getUses().begin().getUser())) {
79 if (layout_op.use_empty()) {
80 layout_ops_to_erase.emplace_back(layout_op);
81 sparse_ops_to_erase.emplace_back(op);
82 }
83 }
84 }
85 });
86
87 // First delete Layout ops and then delete SparseToDense ops.
88 for (auto op : layout_ops_to_erase) op.erase();
89 for (auto op : sparse_ops_to_erase) {
90 // Also delete the corresponding Const ops that are no longer used
91 // attached to the SparseToDense ops.
92 auto const_op = op.getOperand(3).getDefiningOp();
93 op.erase();
94 if (const_op->use_empty()) const_op->erase();
95 }
96}
97
98struct DTensorSparseExpansion
99 : public impl::DTensorSparseExpansionBase<DTensorSparseExpansion> {
100 void runOnOperation() override {
101 auto module = getOperation();
102 if (failed(ConductSparseExpansion(module))) return signalPassFailure();
103
104 // After Sparse Expansion, we may no longer use any SparseToDenseOp outputs,
105 // so remove them if they are not used.
106 RemoveUnusedSparseToDenseOps(module);
107 };
108};
109
110} // namespace
111
112std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
113CreateDTensorSparseExpansion() {
114 return std::make_unique<DTensorSparseExpansion>();
115}
116
117} // namespace dtensor
118} // namespace tensorflow
119