1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <string>
17#include <unordered_set>
18#include <utility>
19#include <vector>
20
21#include "absl/container/flat_hash_set.h"
22#include "llvm/ADT/SmallVector.h"
23#include "llvm/ADT/StringRef.h"
24#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
25#include "mlir/IR/Attributes.h" // from @llvm-project
26#include "mlir/IR/Builders.h" // from @llvm-project
27#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
28#include "mlir/IR/BuiltinOps.h" // from @llvm-project
29#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30#include "mlir/IR/Operation.h" // from @llvm-project
31#include "mlir/IR/SymbolTable.h" // from @llvm-project
32#include "mlir/IR/Visitors.h" // from @llvm-project
33#include "mlir/Pass/Pass.h" // from @llvm-project
34#include "mlir/Pass/PassManager.h" // from @llvm-project
35#include "mlir/Support/LogicalResult.h" // from @llvm-project
36#include "mlir/Transforms/Passes.h" // from @llvm-project
37#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
38#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
39#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
40#include "tensorflow/dtensor/cc/constants.h"
41#include "tensorflow/dtensor/mlir/device_utils.h"
42#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
43#include "tensorflow/dtensor/mlir/op_utils.h"
44#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
45#include "tensorflow/dtensor/mlir/value_utils.h"
46
47namespace tensorflow {
48namespace dtensor {
49
50namespace {
51#define GEN_PASS_DEF_DTENSORSPARSETENSORTODENSETENSOR
52#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
53
54constexpr char kEntryFuncAttr[] = "tf.entry_function";
55constexpr char kSparseIndicesStr[] = "op_input_sparse_indices";
56constexpr char kSparseDenseShapesStr[] = "op_input_sparse_dense_shapes";
57constexpr char kSparseValuesStr[] = "op_input_sparse_values";
58
59typedef struct SparseTensorToComponentInfo {
60 mlir::RankedTensorType indices;
61 mlir::RankedTensorType values;
62 mlir::RankedTensorType dense_shapes;
63 unsigned int func_op_arg_index;
64} SparseTensorToComponentInfo;
65
66void UpdateFunctionSignature(mlir::func::FuncOp function,
67 mlir::OpBuilder& builder) {
68 function.setType(mlir::FunctionType::get(
69 builder.getContext(),
70 llvm::to_vector<4>(function.front().getArgumentTypes()),
71 function.getFunctionType().getResults()));
72}
73
74// Add input attributes for new sparsetensor components and remove the
75// old sparsetensor value input attributes.
76//
77// TF has a list of comma separated input names within `kEntryFuncAttr`
78// attribute, under 'inputs'. Update this comma separated list of input names
79// by correctly deleting the sparse tensor input name and replacing it with
80// three new sparse component input names.
81//
82// Without this update, MLIR conversion to GraphDef will fail since
83// the number of input names will not match with the FuncOp num arguments.
84//
85// e.g. "op_input_1" should become
86// "op_input_sparse_indices_0,op_input_sparse_dense_shapes_0,
87// "op_input_sparse_values_0"
88mlir::LogicalResult UpdateFunctionInputAttributes(
89 mlir::MLIRContext& context, mlir::func::FuncOp main_func,
90 mlir::OpBuilder& builder,
91 const std::vector<SparseTensorToComponentInfo>& sparse_tensor_components) {
92 llvm::SmallVector<llvm::StringRef, 2> input_names;
93
94 auto dict_attr =
95 main_func->getAttrOfType<mlir::DictionaryAttr>(kEntryFuncAttr);
96 if (dict_attr) {
97 if (!dict_attr.get("inputs").isa<mlir::StringAttr>())
98 return main_func.emitOpError("Missing attribute inputs in main FuncOp.");
99
100 dict_attr.get("inputs").cast<mlir::StringAttr>().getValue().split(
101 input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
102
103 llvm::SmallVector<std::string, 2> new_input_names;
104
105 absl::flat_hash_set<int> skip_indices;
106 for (const auto component : sparse_tensor_components) {
107 skip_indices.insert(component.func_op_arg_index);
108 }
109
110 for (auto i = 0; i < input_names.size(); ++i) {
111 if (skip_indices.find(i) == skip_indices.end()) {
112 new_input_names.push_back(input_names[i].str());
113 }
114 }
115
116 for (const auto component : sparse_tensor_components) {
117 int arg_index = component.func_op_arg_index;
118 new_input_names.push_back(
119 absl::StrCat(kSparseIndicesStr, "_", arg_index));
120 new_input_names.push_back(
121 absl::StrCat(kSparseDenseShapesStr, "_", arg_index));
122 new_input_names.push_back(absl::StrCat(kSparseValuesStr, "_", arg_index));
123 }
124
125 mlir::NamedAttrList attributes(dict_attr);
126 attributes.set(
127 "inputs",
128 mlir::StringAttr::get(&context, absl::StrJoin(new_input_names, ",")));
129 main_func->setAttr(kEntryFuncAttr, attributes.getDictionary(&context));
130 }
131 UpdateFunctionSignature(main_func, builder);
132 return mlir::success();
133}
134
135// For each SparseTensor block argument of the main FuncOp, create
136// three of the component tensors, `indices`, `values`, and `dense_shapes`
137// and add it to `sparse_tensor_components`.
138void CreateComponentTensorsFromSparseTensors(
139 mlir::func::FuncOp main_func, mlir::OpBuilder& builder,
140 std::vector<SparseTensorToComponentInfo>* sparse_tensor_components) {
141 for (const auto block_arg : main_func.getArguments()) {
142 const auto is_sparse = main_func.getArgAttrOfType<mlir::BoolAttr>(
143 block_arg.getArgNumber(), kSparseValue);
144 if (is_sparse) {
145 sparse_tensor_components->push_back(SparseTensorToComponentInfo{
146 /*indices=*/mlir::RankedTensorType::get({-1, ValueRank(block_arg)},
147 builder.getI64Type()),
148 /*values=*/
149 mlir::RankedTensorType::get({-1},
150 block_arg.getType()
151 .dyn_cast<mlir::RankedTensorType>()
152 .getElementType()),
153 /*dense_shapes=*/
154 mlir::RankedTensorType::get({ValueRank(block_arg)},
155 builder.getI64Type()),
156 /*func_op_arg_index=*/block_arg.getArgNumber()});
157 }
158 }
159}
160
161// Inserts SparseTensor components `components` into `main_func` at the end
162// of block arguments list.
163void UpdateFunctionWithSparseTensorComponents(
164 mlir::MLIRContext& context, mlir::func::FuncOp main_func,
165 mlir::OpBuilder& builder, const SparseTensorToComponentInfo& component) {
166 main_func.front().addArgument(component.indices, main_func.getLoc());
167 main_func.front().addArgument(component.dense_shapes, main_func.getLoc());
168 main_func.front().addArgument(component.values, main_func.getLoc());
169 UpdateFunctionSignature(main_func, builder);
170}
171
172struct DTensorSparseTensorToDenseTensor
173 : public impl::DTensorSparseTensorToDenseTensorBase<
174 DTensorSparseTensorToDenseTensor> {
175 void runOnOperation() override {
176 mlir::MLIRContext& context = getContext();
177 auto module = getOperation();
178 mlir::OpBuilder builder(&context);
179
180 mlir::func::FuncOp main_func =
181 module.lookupSymbol<mlir::func::FuncOp>("main");
182
183 // Save Arg Attributes for each argument for later use, this will be
184 // reset and reordered after we insert sparse tensor components arguments.
185 llvm::DenseMap<mlir::Value, llvm::ArrayRef<mlir::NamedAttribute>>
186 arg_attribute_map;
187 for (auto block_arg : main_func.getArguments()) {
188 arg_attribute_map.insert(std::make_pair(
189 block_arg, main_func.getArgAttrs(block_arg.getArgNumber())));
190 }
191
192 std::vector<SparseTensorToComponentInfo> sparse_tensor_components;
193 CreateComponentTensorsFromSparseTensors(main_func, builder,
194 &sparse_tensor_components);
195
196 // Update func arguments in place by replacing SparseTensors with their
197 // components and emitting a SparseToDenseOp before all ops that consume
198 // a SparseTensor.
199 for (const SparseTensorToComponentInfo& components :
200 sparse_tensor_components) {
201 // Insert SparseTensor component into the main function's block
202 // arguments.
203 mlir::Value sparse_tensor_value =
204 main_func.getArgument(components.func_op_arg_index);
205
206 UpdateFunctionWithSparseTensorComponents(context, main_func, builder,
207 components);
208 mlir::Operation* front_op = &main_func.front().front();
209 builder.setInsertionPoint(front_op);
210
211 // Emit a SparseToDenseOp and replace the SparseTensor with the result of
212 // this new op.
213 auto zero_scalar = CreateZeroScalarConst(builder, front_op->getLoc(),
214 sparse_tensor_value.getType()
215 .cast<mlir::TensorType>()
216 .getElementType());
217 if (!zero_scalar.has_value()) return signalPassFailure();
218 mlir::TF::SparseToDenseOp sparse_to_dense_op =
219 builder.create<mlir::TF::SparseToDenseOp>(
220 front_op->getLoc(), sparse_tensor_value.getType(),
221 mlir::ValueRange(
222 {main_func.getArgument(main_func.getNumArguments() - 3),
223 main_func.getArgument(main_func.getNumArguments() - 2),
224 main_func.getArgument(main_func.getNumArguments() - 1),
225 zero_scalar.value()}));
226
227 sparse_tensor_value.replaceAllUsesWith(sparse_to_dense_op);
228 if (!sparse_tensor_value.use_empty()) return signalPassFailure();
229 }
230
231 // Erase sparse tensor arguments now that we converted all of them.
232 for (int i = 0; i < sparse_tensor_components.size(); ++i)
233 main_func.front().eraseArgument(
234 sparse_tensor_components[i].func_op_arg_index - i);
235
236 // Reset block argument attributes since they are likely mixed up
237 // due to change in ordering of arguments.
238 for (auto block_arg : main_func.getArguments()) {
239 if (arg_attribute_map.find(block_arg) == arg_attribute_map.end()) {
240 main_func.setArgAttrs(block_arg.getArgNumber(),
241 llvm::ArrayRef<mlir::NamedAttribute>{});
242 } else {
243 main_func.setArgAttrs(block_arg.getArgNumber(),
244 arg_attribute_map[block_arg]);
245 }
246 }
247 if (mlir::failed(UpdateFunctionInputAttributes(context, main_func, builder,
248 sparse_tensor_components)))
249 return signalPassFailure();
250 };
251};
252
253} // namespace
254
255std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
256CreateDTensorSparseTensorToDenseTensor() {
257 return std::make_unique<DTensorSparseTensorToDenseTensor>();
258}
259
260} // namespace dtensor
261} // namespace tensorflow
262