1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <memory> |
17 | #include <optional> |
18 | |
19 | #include "llvm/ADT/SmallVector.h" |
20 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
21 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
22 | #include "mlir/Pass/Pass.h" // from @llvm-project |
23 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
24 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
25 | #include "tensorflow/dtensor/mlir/sparse_expander.h" |
26 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace dtensor { |
30 | |
31 | namespace { |
32 | #define GEN_PASS_DEF_DTENSORSPARSEEXPANSION |
33 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
34 | |
35 | constexpr char kMainFunctionName[] = "main" ; |
36 | |
37 | // Expand every op that consumes SparseTensor operands in topological order. |
38 | mlir::LogicalResult ConductSparseExpansion(mlir::ModuleOp module) { |
39 | auto main_func = module.lookupSymbol<mlir::func::FuncOp>(kMainFunctionName); |
40 | if (!main_func) |
41 | return module.emitOpError( |
42 | "could not find `main` function in module for SPMD expansion." ); |
43 | |
44 | TopologicalIterator iterator(main_func); |
45 | while (iterator.hasNext()) { |
46 | mlir::Operation* op = iterator.next(); |
47 | |
48 | mlir::Operation* expanded_op = nullptr; |
49 | auto status = RunSparseExpansion(op, &expanded_op); |
50 | if (!status.ok() || expanded_op == nullptr) { |
51 | // Sometimes op may been erased and expanded_op set. |
52 | // In this case we should emit the error on the expanded op. |
53 | mlir::Operation* emit_op = op; |
54 | if (expanded_op != nullptr) emit_op = expanded_op; |
55 | return emit_op->emitError(WithContext(status, __FILE__, __LINE__, |
56 | "While computing Sparse expansion" ) |
57 | .error_message()); |
58 | } |
59 | } |
60 | return mlir::success(); |
61 | } |
62 | |
63 | // After Sparse Expansion pass, there may be unused SparseToDenseOps due to |
64 | // expanded ops possibly taking the operands of the SparseToDenseOps instead |
65 | // of the output of the SparseToDenseOps. So remove unused SparseToDenseOps |
66 | // and its corresponding dependent ops like DTensorLayout and Const ops. |
67 | void RemoveUnusedSparseToDenseOps(mlir::ModuleOp module) { |
68 | llvm::SmallVector<mlir::TF::SparseToDenseOp, 4> sparse_ops_to_erase; |
69 | llvm::SmallVector<mlir::TF::DTensorLayout, 4> layout_ops_to_erase; |
70 | |
71 | module.walk([&](mlir::TF::SparseToDenseOp op) { |
72 | // Delete this op if it either has no consuming ops or the only consuming |
73 | // op is a DTensorLayout op that also has no consuming ops. |
74 | if (op->use_empty()) { |
75 | sparse_ops_to_erase.emplace_back(op); |
76 | } else if (op->hasOneUse()) { |
77 | if (auto layout_op = mlir::dyn_cast<mlir::TF::DTensorLayout>( |
78 | op->getOpResult(0).getUses().begin().getUser())) { |
79 | if (layout_op.use_empty()) { |
80 | layout_ops_to_erase.emplace_back(layout_op); |
81 | sparse_ops_to_erase.emplace_back(op); |
82 | } |
83 | } |
84 | } |
85 | }); |
86 | |
87 | // First delete Layout ops and then delete SparseToDense ops. |
88 | for (auto op : layout_ops_to_erase) op.erase(); |
89 | for (auto op : sparse_ops_to_erase) { |
90 | // Also delete the corresponding Const ops that are no longer used |
91 | // attached to the SparseToDense ops. |
92 | auto const_op = op.getOperand(3).getDefiningOp(); |
93 | op.erase(); |
94 | if (const_op->use_empty()) const_op->erase(); |
95 | } |
96 | } |
97 | |
98 | struct DTensorSparseExpansion |
99 | : public impl::DTensorSparseExpansionBase<DTensorSparseExpansion> { |
100 | void runOnOperation() override { |
101 | auto module = getOperation(); |
102 | if (failed(ConductSparseExpansion(module))) return signalPassFailure(); |
103 | |
104 | // After Sparse Expansion, we may no longer use any SparseToDenseOp outputs, |
105 | // so remove them if they are not used. |
106 | RemoveUnusedSparseToDenseOps(module); |
107 | }; |
108 | }; |
109 | |
110 | } // namespace |
111 | |
112 | std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> |
113 | CreateDTensorSparseExpansion() { |
114 | return std::make_unique<DTensorSparseExpansion>(); |
115 | } |
116 | |
117 | } // namespace dtensor |
118 | } // namespace tensorflow |
119 | |