1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <string> |
17 | #include <unordered_set> |
18 | #include <utility> |
19 | #include <vector> |
20 | |
21 | #include "absl/container/flat_hash_set.h" |
22 | #include "llvm/ADT/SmallVector.h" |
23 | #include "llvm/ADT/StringRef.h" |
24 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
25 | #include "mlir/IR/Attributes.h" // from @llvm-project |
26 | #include "mlir/IR/Builders.h" // from @llvm-project |
27 | #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
28 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
29 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
30 | #include "mlir/IR/Operation.h" // from @llvm-project |
31 | #include "mlir/IR/SymbolTable.h" // from @llvm-project |
32 | #include "mlir/IR/Visitors.h" // from @llvm-project |
33 | #include "mlir/Pass/Pass.h" // from @llvm-project |
34 | #include "mlir/Pass/PassManager.h" // from @llvm-project |
35 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
36 | #include "mlir/Transforms/Passes.h" // from @llvm-project |
37 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
38 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" |
39 | #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" |
40 | #include "tensorflow/dtensor/cc/constants.h" |
41 | #include "tensorflow/dtensor/mlir/device_utils.h" |
42 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
43 | #include "tensorflow/dtensor/mlir/op_utils.h" |
44 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
45 | #include "tensorflow/dtensor/mlir/value_utils.h" |
46 | |
47 | namespace tensorflow { |
48 | namespace dtensor { |
49 | |
50 | namespace { |
51 | #define GEN_PASS_DEF_DTENSORSPARSETENSORTODENSETENSOR |
52 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
53 | |
54 | constexpr char kEntryFuncAttr[] = "tf.entry_function" ; |
55 | constexpr char kSparseIndicesStr[] = "op_input_sparse_indices" ; |
56 | constexpr char kSparseDenseShapesStr[] = "op_input_sparse_dense_shapes" ; |
57 | constexpr char kSparseValuesStr[] = "op_input_sparse_values" ; |
58 | |
59 | typedef struct SparseTensorToComponentInfo { |
60 | mlir::RankedTensorType indices; |
61 | mlir::RankedTensorType values; |
62 | mlir::RankedTensorType dense_shapes; |
63 | unsigned int func_op_arg_index; |
64 | } SparseTensorToComponentInfo; |
65 | |
66 | void UpdateFunctionSignature(mlir::func::FuncOp function, |
67 | mlir::OpBuilder& builder) { |
68 | function.setType(mlir::FunctionType::get( |
69 | builder.getContext(), |
70 | llvm::to_vector<4>(function.front().getArgumentTypes()), |
71 | function.getFunctionType().getResults())); |
72 | } |
73 | |
74 | // Add input attributes for new sparsetensor components and remove the |
75 | // old sparsetensor value input attributes. |
76 | // |
77 | // TF has a list of comma separated input names within `kEntryFuncAttr` |
78 | // attribute, under 'inputs'. Update this comma separated list of input names |
79 | // by correctly deleting the sparse tensor input name and replacing it with |
80 | // three new sparse component input names. |
81 | // |
82 | // Without this update, MLIR conversion to GraphDef will fail since |
83 | // the number of input names will not match with the FuncOp num arguments. |
84 | // |
85 | // e.g. "op_input_1" should become |
86 | // "op_input_sparse_indices_0,op_input_sparse_dense_shapes_0, |
87 | // "op_input_sparse_values_0" |
88 | mlir::LogicalResult UpdateFunctionInputAttributes( |
89 | mlir::MLIRContext& context, mlir::func::FuncOp main_func, |
90 | mlir::OpBuilder& builder, |
91 | const std::vector<SparseTensorToComponentInfo>& sparse_tensor_components) { |
92 | llvm::SmallVector<llvm::StringRef, 2> input_names; |
93 | |
94 | auto dict_attr = |
95 | main_func->getAttrOfType<mlir::DictionaryAttr>(kEntryFuncAttr); |
96 | if (dict_attr) { |
97 | if (!dict_attr.get("inputs" ).isa<mlir::StringAttr>()) |
98 | return main_func.emitOpError("Missing attribute inputs in main FuncOp." ); |
99 | |
100 | dict_attr.get("inputs" ).cast<mlir::StringAttr>().getValue().split( |
101 | input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); |
102 | |
103 | llvm::SmallVector<std::string, 2> new_input_names; |
104 | |
105 | absl::flat_hash_set<int> skip_indices; |
106 | for (const auto component : sparse_tensor_components) { |
107 | skip_indices.insert(component.func_op_arg_index); |
108 | } |
109 | |
110 | for (auto i = 0; i < input_names.size(); ++i) { |
111 | if (skip_indices.find(i) == skip_indices.end()) { |
112 | new_input_names.push_back(input_names[i].str()); |
113 | } |
114 | } |
115 | |
116 | for (const auto component : sparse_tensor_components) { |
117 | int arg_index = component.func_op_arg_index; |
118 | new_input_names.push_back( |
119 | absl::StrCat(kSparseIndicesStr, "_" , arg_index)); |
120 | new_input_names.push_back( |
121 | absl::StrCat(kSparseDenseShapesStr, "_" , arg_index)); |
122 | new_input_names.push_back(absl::StrCat(kSparseValuesStr, "_" , arg_index)); |
123 | } |
124 | |
125 | mlir::NamedAttrList attributes(dict_attr); |
126 | attributes.set( |
127 | "inputs" , |
128 | mlir::StringAttr::get(&context, absl::StrJoin(new_input_names, "," ))); |
129 | main_func->setAttr(kEntryFuncAttr, attributes.getDictionary(&context)); |
130 | } |
131 | UpdateFunctionSignature(main_func, builder); |
132 | return mlir::success(); |
133 | } |
134 | |
135 | // For each SparseTensor block argument of the main FuncOp, create |
136 | // three of the component tensors, `indices`, `values`, and `dense_shapes` |
137 | // and add it to `sparse_tensor_components`. |
138 | void CreateComponentTensorsFromSparseTensors( |
139 | mlir::func::FuncOp main_func, mlir::OpBuilder& builder, |
140 | std::vector<SparseTensorToComponentInfo>* sparse_tensor_components) { |
141 | for (const auto block_arg : main_func.getArguments()) { |
142 | const auto is_sparse = main_func.getArgAttrOfType<mlir::BoolAttr>( |
143 | block_arg.getArgNumber(), kSparseValue); |
144 | if (is_sparse) { |
145 | sparse_tensor_components->push_back(SparseTensorToComponentInfo{ |
146 | /*indices=*/mlir::RankedTensorType::get({-1, ValueRank(block_arg)}, |
147 | builder.getI64Type()), |
148 | /*values=*/ |
149 | mlir::RankedTensorType::get({-1}, |
150 | block_arg.getType() |
151 | .dyn_cast<mlir::RankedTensorType>() |
152 | .getElementType()), |
153 | /*dense_shapes=*/ |
154 | mlir::RankedTensorType::get({ValueRank(block_arg)}, |
155 | builder.getI64Type()), |
156 | /*func_op_arg_index=*/block_arg.getArgNumber()}); |
157 | } |
158 | } |
159 | } |
160 | |
161 | // Inserts SparseTensor components `components` into `main_func` at the end |
162 | // of block arguments list. |
163 | void UpdateFunctionWithSparseTensorComponents( |
164 | mlir::MLIRContext& context, mlir::func::FuncOp main_func, |
165 | mlir::OpBuilder& builder, const SparseTensorToComponentInfo& component) { |
166 | main_func.front().addArgument(component.indices, main_func.getLoc()); |
167 | main_func.front().addArgument(component.dense_shapes, main_func.getLoc()); |
168 | main_func.front().addArgument(component.values, main_func.getLoc()); |
169 | UpdateFunctionSignature(main_func, builder); |
170 | } |
171 | |
172 | struct DTensorSparseTensorToDenseTensor |
173 | : public impl::DTensorSparseTensorToDenseTensorBase< |
174 | DTensorSparseTensorToDenseTensor> { |
175 | void runOnOperation() override { |
176 | mlir::MLIRContext& context = getContext(); |
177 | auto module = getOperation(); |
178 | mlir::OpBuilder builder(&context); |
179 | |
180 | mlir::func::FuncOp main_func = |
181 | module.lookupSymbol<mlir::func::FuncOp>("main" ); |
182 | |
183 | // Save Arg Attributes for each argument for later use, this will be |
184 | // reset and reordered after we insert sparse tensor components arguments. |
185 | llvm::DenseMap<mlir::Value, llvm::ArrayRef<mlir::NamedAttribute>> |
186 | arg_attribute_map; |
187 | for (auto block_arg : main_func.getArguments()) { |
188 | arg_attribute_map.insert(std::make_pair( |
189 | block_arg, main_func.getArgAttrs(block_arg.getArgNumber()))); |
190 | } |
191 | |
192 | std::vector<SparseTensorToComponentInfo> sparse_tensor_components; |
193 | CreateComponentTensorsFromSparseTensors(main_func, builder, |
194 | &sparse_tensor_components); |
195 | |
196 | // Update func arguments in place by replacing SparseTensors with their |
197 | // components and emitting a SparseToDenseOp before all ops that consume |
198 | // a SparseTensor. |
199 | for (const SparseTensorToComponentInfo& components : |
200 | sparse_tensor_components) { |
201 | // Insert SparseTensor component into the main function's block |
202 | // arguments. |
203 | mlir::Value sparse_tensor_value = |
204 | main_func.getArgument(components.func_op_arg_index); |
205 | |
206 | UpdateFunctionWithSparseTensorComponents(context, main_func, builder, |
207 | components); |
208 | mlir::Operation* front_op = &main_func.front().front(); |
209 | builder.setInsertionPoint(front_op); |
210 | |
211 | // Emit a SparseToDenseOp and replace the SparseTensor with the result of |
212 | // this new op. |
213 | auto zero_scalar = CreateZeroScalarConst(builder, front_op->getLoc(), |
214 | sparse_tensor_value.getType() |
215 | .cast<mlir::TensorType>() |
216 | .getElementType()); |
217 | if (!zero_scalar.has_value()) return signalPassFailure(); |
218 | mlir::TF::SparseToDenseOp sparse_to_dense_op = |
219 | builder.create<mlir::TF::SparseToDenseOp>( |
220 | front_op->getLoc(), sparse_tensor_value.getType(), |
221 | mlir::ValueRange( |
222 | {main_func.getArgument(main_func.getNumArguments() - 3), |
223 | main_func.getArgument(main_func.getNumArguments() - 2), |
224 | main_func.getArgument(main_func.getNumArguments() - 1), |
225 | zero_scalar.value()})); |
226 | |
227 | sparse_tensor_value.replaceAllUsesWith(sparse_to_dense_op); |
228 | if (!sparse_tensor_value.use_empty()) return signalPassFailure(); |
229 | } |
230 | |
231 | // Erase sparse tensor arguments now that we converted all of them. |
232 | for (int i = 0; i < sparse_tensor_components.size(); ++i) |
233 | main_func.front().eraseArgument( |
234 | sparse_tensor_components[i].func_op_arg_index - i); |
235 | |
236 | // Reset block argument attributes since they are likely mixed up |
237 | // due to change in ordering of arguments. |
238 | for (auto block_arg : main_func.getArguments()) { |
239 | if (arg_attribute_map.find(block_arg) == arg_attribute_map.end()) { |
240 | main_func.setArgAttrs(block_arg.getArgNumber(), |
241 | llvm::ArrayRef<mlir::NamedAttribute>{}); |
242 | } else { |
243 | main_func.setArgAttrs(block_arg.getArgNumber(), |
244 | arg_attribute_map[block_arg]); |
245 | } |
246 | } |
247 | if (mlir::failed(UpdateFunctionInputAttributes(context, main_func, builder, |
248 | sparse_tensor_components))) |
249 | return signalPassFailure(); |
250 | }; |
251 | }; |
252 | |
253 | } // namespace |
254 | |
255 | std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> |
256 | CreateDTensorSparseTensorToDenseTensor() { |
257 | return std::make_unique<DTensorSparseTensorToDenseTensor>(); |
258 | } |
259 | |
260 | } // namespace dtensor |
261 | } // namespace tensorflow |
262 | |