1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
17 | |
18 | #include <algorithm> |
19 | #include <atomic> |
20 | #include <iterator> |
21 | #include <string> |
22 | #include <vector> |
23 | |
24 | #include "absl/strings/str_cat.h" |
25 | #include "absl/strings/string_view.h" |
26 | #include "llvm/ADT/SmallPtrSet.h" |
27 | #include "llvm/Support/FormatVariadic.h" |
28 | #include "llvm/Support/raw_ostream.h" |
29 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
30 | #include "mlir/IR/Builders.h" // from @llvm-project |
31 | #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
32 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
33 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
34 | #include "mlir/IR/Location.h" // from @llvm-project |
35 | #include "mlir/IR/MLIRContext.h" // from @llvm-project |
36 | #include "mlir/IR/OperationSupport.h" // from @llvm-project |
37 | #include "mlir/IR/Value.h" // from @llvm-project |
38 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" |
39 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
40 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
41 | #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" |
42 | #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" |
43 | #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/convert_op_folder.h" |
44 | #include "tensorflow/core/platform/errors.h" |
45 | #include "tensorflow/dtensor/cc/constants.h" |
46 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
47 | #include "tensorflow/dtensor/mlir/device_utils.h" |
48 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
49 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
50 | #include "tensorflow/dtensor/mlir/op_utils.h" |
51 | #include "tensorflow/dtensor/mlir/shape_utils.h" |
52 | #include "tensorflow/dtensor/mlir/value_utils.h" |
53 | |
54 | namespace tensorflow { |
55 | namespace dtensor { |
56 | |
57 | // Checks that all layouts are fully replicated |
58 | bool AllReplicated(const std::vector<Layout>& layouts) { |
59 | for (const Layout& layout : layouts) { |
60 | if (!layout.IsFullyReplicated()) return false; |
61 | } |
62 | return true; |
63 | } |
64 | |
65 | StatusOr<mlir::TensorType> LocalTypeFromGlobalType( |
66 | const Layout& layout, const mlir::TensorType& original_type) { |
67 | if (!original_type.hasRank()) { |
68 | return original_type; |
69 | } |
70 | auto shape = llvm::to_vector<4>(original_type.getShape()); |
71 | auto shard_values = layout.num_shards(); |
72 | for (int output_axis = 0; output_axis < shape.size(); ++output_axis) { |
73 | if (shape[output_axis] != mlir::ShapedType::kDynamicSize) { |
74 | if (shape[output_axis] % shard_values[output_axis] != 0) { |
75 | return errors::InvalidArgument( |
76 | "The sharding spec for axis " , output_axis, " splits among " , |
77 | shard_values[output_axis], |
78 | " values, which does not evenly divide the length of that axis " |
79 | "(" , |
80 | shape[output_axis], "). The full requested layout is " , |
81 | layout.ToString(), "." ); |
82 | } |
83 | shape[output_axis] /= shard_values[output_axis]; |
84 | } |
85 | } |
86 | mlir::RankedTensorType new_output_type = |
87 | mlir::RankedTensorType::get(shape, original_type.getElementType()); |
88 | return new_output_type; |
89 | } |
90 | |
91 | StatusOr<mlir::TensorType> GlobalTypeFromLocalType( |
92 | const Layout& layout, const mlir::TensorType& original_type) { |
93 | if (!original_type.hasRank()) { |
94 | return original_type; |
95 | } |
96 | auto shape = llvm::to_vector<4>(original_type.getShape()); |
97 | auto shard_values = layout.num_shards(); |
98 | for (int output_axis = 0; output_axis < shape.size(); ++output_axis) |
99 | if (shape[output_axis] != mlir::ShapedType::kDynamicSize) |
100 | shape[output_axis] *= shard_values[output_axis]; |
101 | mlir::RankedTensorType new_output_type = |
102 | mlir::RankedTensorType::get(shape, original_type.getElementType()); |
103 | return new_output_type; |
104 | } |
105 | |
106 | Status CreateSplitOp(const int num_split, const int split_dimension, |
107 | const mlir::Location location, mlir::Value src_input, |
108 | mlir::OpBuilder* builder, mlir::TF::SplitOp* split_op) { |
109 | // Creates a const op to hold split dimension value. |
110 | auto split_dim_type = |
111 | mlir::RankedTensorType::get({}, builder->getIntegerType(32)); |
112 | auto split_dimension_attr = |
113 | mlir::DenseElementsAttr::get(split_dim_type, split_dimension); |
114 | auto split_dimension_op = builder->create<mlir::TF::ConstOp>( |
115 | location, split_dim_type, split_dimension_attr); |
116 | |
117 | // Correctly set output shapes of split op output if input shape is statically |
118 | // known. |
119 | mlir::Type output_type; |
120 | auto input_type = src_input.getType().cast<mlir::TensorType>(); |
121 | |
122 | if (input_type.hasRank()) { |
123 | if (input_type.getShape()[split_dimension] == |
124 | mlir::ShapedType::kDynamicSize) { |
125 | output_type = input_type; |
126 | } else { |
127 | auto shape = llvm::to_vector<4>(input_type.getShape()); |
128 | if (shape[split_dimension] % num_split != 0) { |
129 | return errors::InvalidArgument( |
130 | llvm::formatv( |
131 | "incorrect input sharding configuration received. " |
132 | "{0}-th dimension of the input must be evenly divisible by {1}" , |
133 | split_dimension, num_split) |
134 | .str()); |
135 | } |
136 | |
137 | shape[split_dimension] = shape[split_dimension] / num_split; |
138 | output_type = |
139 | mlir::RankedTensorType::get(shape, input_type.getElementType()); |
140 | } |
141 | } else { |
142 | output_type = input_type; |
143 | } |
144 | |
145 | // Creates a split op that splits |src_input| along |split_dimension|. |
146 | llvm::SmallVector<mlir::Type, 4> output_types(num_split, output_type); |
147 | *split_op = builder->create<mlir::TF::SplitOp>( |
148 | location, output_types, split_dimension_op.output(), src_input); |
149 | return OkStatus(); |
150 | } |
151 | |
152 | // Given layouts + shapes, determines if the two are broadcasting compatible. |
153 | // When broadcasting we effectively line up the shapes and layouts by the end. |
154 | // The input with lower rank can be thought of as having abs(rank_a-rank_b) |
155 | // replicated dims of size 1 prepended to it. |
156 | // |
157 | // Returns the broadcast layout and the splits in the two inputs needed to run |
158 | // an elementwise op efficiently. |
159 | // |
160 | // Checks that a given mesh dimension is not used in different tensor dimensions |
161 | // in the two input layouts. |
162 | // E.g. a layout like (unsharded,x,unsharded) is not compatible with |
163 | // (unsharded,x) or (x,unsharded,unsharded) but is compatible with |
164 | // (x,unsharded), (unsharded,unsharded) or (unsharded,x,unsharded). |
165 | // (Note that due to broadcasting, we compare the dimensions from the end). |
166 | // |
167 | // If dims_to_ignore is > 0, then we ignore when a mesh dimension is used in |
168 | // different tensor dimensions when those dimensions are both in the last |
169 | // dims_to_ignore tensor dimensions of each input. |
170 | // E.g. If dims_to_ignore = 2, then (unsharded,x,unsharded) is now compatible |
171 | // with (unsharded,x) and it not compatible with (x,unsharded,unsharded). |
172 | // |
173 | // The output layout will be of rank max(layout_a.rank(), layout_b.rank()) - |
174 | // dims_to_ignore and will be replicated on a dimension if either one of the |
175 | // input layouts is replicated on that dimension. Once again recall due to |
176 | // broadcasting, layouts are aligned by their ends and not their beginnings. |
177 | // E.g. if dims_to_ignore is zero, the output layout for the inputs |
178 | // (unsharded,x,unsharded) and (unsharded,y) is (unsharded,x,y). |
179 | // If dims_to_ignore is two, the output for (y,x,unsharded) and |
180 | // (unsharded,x) is just (y). |
181 | // |
182 | // In the case that one tensor is sharded and the other is not on a given |
183 | // dimension, element wise operations *may* need to split the unsharded tensor |
184 | // along the same mesh dimension that the other input is split on. Note that |
185 | // the split is *not* needed if the unsharded tensor has dimension of size 1, |
186 | // due to broadcasting. |
187 | // |
188 | // To help with the needed splittings, the vectors to_split_* are resized to the |
189 | // rank of each input and if that dimension of the tensor needs to be split for |
190 | // and elementwise op, we record the mesh dimension it should be split along in |
191 | // the vector. |
192 | // E.g. in the case of input layouts (unsharded,x,unsharded) and |
193 | // (unsharded,unsharded) with dimensions (10,10,10) and (10,10), |
194 | // to_split_a = {"unsharded", "unsharded", "unsharded"} and to_split_b = |
195 | // {"x", "unsharded"}. |
196 | // If the shapes were (10,10,10) and (1,10), then to_split_a = {"unsharded", |
197 | // "unsharded", "unsharded"} and to_split_b = {"unsharded", "unsharded"}. |
198 | // |
199 | // Note that "unsharded" == Layout::kUnshardedDim. |
200 | // NOTE: shape_a and shape_b are *global* shapes. |
201 | StatusOr<Layout> GetBroadcastLayoutForElementWise( |
202 | const Layout& layout_a, const Layout& layout_b, |
203 | mlir::ArrayRef<int64_t> shape_a, mlir::ArrayRef<int64_t> shape_b, |
204 | int64_t dims_to_ignore, std::vector<std::string>& to_split_a, |
205 | std::vector<std::string>& to_split_b) { |
206 | if (layout_a.mesh() != layout_b.mesh()) |
207 | return errors::InvalidArgument( |
208 | "layout_a and layout_b cannot be broadcast as they are on different " |
209 | "meshes." ); |
210 | |
211 | const int rank_a = layout_a.rank(); |
212 | const int rank_b = layout_b.rank(); |
213 | const int rank_offset_a = std::max(0, rank_b - rank_a); |
214 | const int rank_offset_b = std::max(0, rank_a - rank_b); |
215 | absl::flat_hash_map<std::string, int> mesh_dim_map_a; |
216 | absl::flat_hash_map<std::string, int> mesh_dim_map_b; |
217 | std::vector<string> output_layout_specs; |
218 | |
219 | auto unsharded_specs = [](const int new_size) -> std::vector<std::string> { |
220 | std::vector<std::string> spec_strs(new_size, Layout::kUnshardedDim); |
221 | return spec_strs; |
222 | }; |
223 | |
224 | to_split_a = unsharded_specs(rank_a - dims_to_ignore); |
225 | to_split_b = unsharded_specs(rank_b - dims_to_ignore); |
226 | |
227 | // Note that we record ranks over all dimensions even ones we ignore. |
228 | // We will check that a non-ignored dimension of a tensor does not use a |
229 | // mesh dimension that is used by an ignored dimension in the other tensor. |
230 | for (int i = 0; i < rank_a; ++i) |
231 | if (!Layout::IsUnshardedDimension(layout_a.sharding_spec(i))) |
232 | mesh_dim_map_a[layout_a.sharding_spec(i)] = i; |
233 | for (int i = 0; i < rank_b; ++i) |
234 | if (!Layout::IsUnshardedDimension(layout_b.sharding_spec(i))) |
235 | mesh_dim_map_b[layout_b.sharding_spec(i)] = i; |
236 | |
237 | for (int i = 0; i < std::max(rank_a, rank_b) - dims_to_ignore; ++i) { |
238 | const int dim_a = i - rank_offset_a; |
239 | const int dim_b = i - rank_offset_b; |
240 | // When ranks are not equal we treat the first rank_offset_* dims of the |
241 | // shorter layout as not sharded. |
242 | const std::string mesh_dim_a = |
243 | dim_a >= 0 ? layout_a.sharding_spec(dim_a) : Layout::kUnshardedDim; |
244 | const std::string mesh_dim_b = |
245 | dim_b >= 0 ? layout_b.sharding_spec(dim_b) : Layout::kUnshardedDim; |
246 | // When ranks are not equal, we treat the first rank_offset_* dims of the |
247 | // shorter shape as if they were 1. |
248 | const int64_t tensor_dim_a = dim_a >= 0 ? shape_a[dim_a] : 1; |
249 | const int64_t tensor_dim_b = dim_b >= 0 ? shape_b[dim_b] : 1; |
250 | |
251 | // Check for conflicted dimensions. If occurred, chose unsharded as merged |
252 | // result, if generate_unsharded_dim_for_conflicts is set by call site. |
253 | bool have_conflicted_dim = false; |
254 | if (!Layout::IsUnshardedDimension(mesh_dim_a) && |
255 | mesh_dim_map_b.contains(mesh_dim_a) && |
256 | mesh_dim_map_b[mesh_dim_a] != dim_b) |
257 | have_conflicted_dim = true; |
258 | |
259 | if (!Layout::IsUnshardedDimension(mesh_dim_b) && |
260 | mesh_dim_map_a.contains(mesh_dim_b) && |
261 | mesh_dim_map_a[mesh_dim_b] != dim_a) |
262 | have_conflicted_dim = true; |
263 | |
264 | // If both dimensions are sharded, we have already verified that they are |
265 | // sharded on the same mesh dim. |
266 | if (have_conflicted_dim) { |
267 | output_layout_specs.emplace_back(Layout::kUnshardedDim); |
268 | } else { |
269 | output_layout_specs.emplace_back( |
270 | Layout::IsUnshardedDimension(mesh_dim_a) ? mesh_dim_b : mesh_dim_a); |
271 | } |
272 | if (dim_a >= 0 && tensor_dim_a > 1 && |
273 | Layout::IsUnshardedDimension(mesh_dim_a) && |
274 | !Layout::IsUnshardedDimension(mesh_dim_b)) { |
275 | to_split_a[dim_a] = mesh_dim_b; |
276 | } |
277 | if (dim_b >= 0 && tensor_dim_b > 1 && |
278 | Layout::IsUnshardedDimension(mesh_dim_b) && |
279 | !Layout::IsUnshardedDimension(mesh_dim_a)) { |
280 | to_split_b[dim_b] = mesh_dim_a; |
281 | } |
282 | } |
283 | return Layout::GetLayout(output_layout_specs, layout_a.mesh()); |
284 | } |
285 | |
286 | StatusOr<absl::optional<Layout>> GetMergedOperandLayout( |
287 | const llvm::DenseMap<int, Layout>& operand_layouts, mlir::Operation* op) { |
288 | // Represents list of Layouts and it's operand index where layout value is |
289 | // defined (i.e. layout is not absl::nullopt). |
290 | llvm::SmallVector<std::pair<const Layout&, llvm::ArrayRef<int64_t>>, 4> |
291 | filtered_preferred_operand_layouts; |
292 | filtered_preferred_operand_layouts.reserve(op->getNumOperands()); |
293 | |
294 | for (const auto& index_and_layout : operand_layouts) { |
295 | TF_ASSIGN_OR_RETURN( |
296 | llvm::ArrayRef<int64_t> shape_to_merge, |
297 | GetShapeOfValue(op->getOperand(index_and_layout.first))); |
298 | filtered_preferred_operand_layouts.emplace_back(index_and_layout.second, |
299 | shape_to_merge); |
300 | } |
301 | |
302 | if (filtered_preferred_operand_layouts.empty()) |
303 | return absl::optional<Layout>(); |
304 | |
305 | // Merged all operands and it's layouts to a single broadcasted layout. |
306 | Layout merged_operand_layout = filtered_preferred_operand_layouts[0].first; |
307 | llvm::ArrayRef<int64_t> merged_shape = |
308 | filtered_preferred_operand_layouts[0].second; |
309 | |
310 | // Statically analyze merged input operands layouts. Broadcasting is allowed |
311 | // but no cross device communication should be incurred. |
312 | for (int i = 1; i < filtered_preferred_operand_layouts.size(); ++i) { |
313 | const auto& operand_index_and_layout_to_merge = |
314 | filtered_preferred_operand_layouts[i]; |
315 | const Layout& layout_to_merge = operand_index_and_layout_to_merge.first; |
316 | llvm::ArrayRef<int64_t> shape_to_merge = |
317 | operand_index_and_layout_to_merge.second; |
318 | |
319 | std::vector<std::string> left_splits; |
320 | std::vector<std::string> right_splits; |
321 | TF_ASSIGN_OR_RETURN(merged_operand_layout, |
322 | GetBroadcastLayoutForElementWise( |
323 | merged_operand_layout, layout_to_merge, |
324 | merged_shape, shape_to_merge, |
325 | /*dims_to_ignore=*/0, left_splits, right_splits)); |
326 | } |
327 | return absl::optional<Layout>(merged_operand_layout); |
328 | } |
329 | |
330 | mlir::Value GetForwardedDTensorLayoutInput(mlir::Value value) { |
331 | auto layout_op = |
332 | llvm::dyn_cast_or_null<mlir::TF::DTensorLayout>(value.getDefiningOp()); |
333 | if (!layout_op) return value; |
334 | |
335 | return layout_op.input(); |
336 | } |
337 | |
338 | // Takes an operand and traces its use across function call and |
339 | // tf_device.cluster boundaries. Note that this may turn one operand into many. |
340 | // TODO(bfontain): Assumes that a function is only called once. This is checked |
341 | // when creating func_to_caller. |
342 | llvm::SmallVector<mlir::OpOperand*, 4> TraceUseToNextTFOp( |
343 | mlir::OpOperand* operand, |
344 | const llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller, |
345 | llvm::SmallVector<mlir::Value, 4>* skipped_values) { |
346 | mlir::Operation* owner = operand->getOwner(); |
347 | llvm::SmallVector<mlir::Value, 4> values; |
348 | if (mlir::isa<mlir::TF::PartitionedCallOp>(owner) || |
349 | mlir::isa<mlir::TF::StatefulPartitionedCallOp>(owner)) { |
350 | mlir::func::FuncOp func; |
351 | if (mlir::isa<mlir::TF::PartitionedCallOp>(owner)) |
352 | func = mlir::cast<mlir::TF::PartitionedCallOp>(owner).func(); |
353 | else |
354 | func = mlir::cast<mlir::TF::StatefulPartitionedCallOp>(owner).func(); |
355 | values.emplace_back(func.getArgument(operand->getOperandNumber())); |
356 | } else if (mlir::isa<mlir::tf_device::ReturnOp>(owner)) { |
357 | auto device_return = mlir::cast<mlir::tf_device::ReturnOp>(owner); |
358 | auto enclosing_cluster = |
359 | device_return->getParentOfType<mlir::tf_device::ClusterOp>(); |
360 | values.emplace_back( |
361 | enclosing_cluster.getResult(operand->getOperandNumber())); |
362 | } else if (mlir::isa<mlir::func::ReturnOp>(owner)) { |
363 | auto func = mlir::cast<mlir::func::ReturnOp>(owner) |
364 | ->getParentOfType<mlir::func::FuncOp>(); |
365 | // The one function we don't have a caller for is the main function. |
366 | // In this case return the empty list as there are no consumers. |
367 | auto caller = func_to_caller.find(func.getName()); |
368 | if (caller != func_to_caller.end()) |
369 | values.emplace_back( |
370 | caller->second->getOpResult(operand->getOperandNumber())); |
371 | } else if (auto yield = mlir::dyn_cast<mlir::TF::YieldOp>(owner)) { |
372 | if (auto if_op = owner->getParentOfType<mlir::TF::IfRegionOp>()) { |
373 | values.emplace_back(if_op.getResult(operand->getOperandNumber())); |
374 | } else if (auto while_op = |
375 | owner->getParentOfType<mlir::TF::WhileRegionOp>()) { |
376 | if (while_op && !while_op.cond().isAncestor(yield->getParentRegion())) |
377 | values.emplace_back(while_op.getResult(operand->getOperandNumber())); |
378 | } else { |
379 | LOG(WARNING) |
380 | << "Found terminator op for unsupported controlflow operations." ; |
381 | } |
382 | } else if (mlir::isa<mlir::TF::DTensorLayout>(owner)) { |
383 | auto dtensor_layout = mlir::cast<mlir::TF::DTensorLayout>(owner); |
384 | values.emplace_back(dtensor_layout.output()); |
385 | } else if (auto while_op = mlir::dyn_cast<mlir::TF::WhileRegionOp>(owner)) { |
386 | // Handle loop variant inputs of while op. |
387 | mlir::Region& cond = while_op.cond(); |
388 | mlir::Region& body = while_op.body(); |
389 | const int operand_index = operand->getOperandNumber(); |
390 | values.emplace_back(cond.front().getArgument(operand_index)); |
391 | values.emplace_back(body.front().getArgument(operand_index)); |
392 | } else { |
393 | return {operand}; |
394 | } |
395 | llvm::SmallVector<mlir::OpOperand*, 4> ret; |
396 | for (mlir::Value value : values) { |
397 | if (skipped_values != nullptr) skipped_values->emplace_back(value); |
398 | for (mlir::OpOperand& use : value.getUses()) { |
399 | // TODO(bfontain): Remove recursion here. |
400 | const auto& traced_operands = |
401 | TraceUseToNextTFOp(&use, func_to_caller, skipped_values); |
402 | ret.append(traced_operands.begin(), traced_operands.end()); |
403 | } |
404 | } |
405 | |
406 | return ret; |
407 | } |
408 | |
409 | mlir::LogicalResult GetFuncToCaller( |
410 | mlir::ModuleOp module, |
411 | llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller) { |
412 | // For now this is a 1:1 mapping and we will error out if a function is called |
413 | // by more than one op. The layout code assumes there is 1:many relationship |
414 | // between producers and consumers. If we allow a function to be called |
415 | // multiple times, then its consumers consume from multiple producers, which |
416 | // breaks this assumption. |
417 | // TODO(bfontain): Fix this, possibly by duplicating all functions in order to |
418 | // make this mapping 1:1 in truth. |
419 | auto result = module->walk([&](mlir::Operation* op) -> mlir::WalkResult { |
420 | mlir::StringRef func; |
421 | if (mlir::TF::PartitionedCallOp call_op = |
422 | mlir::dyn_cast<mlir::TF::PartitionedCallOp>(op)) |
423 | func = call_op.func().getName(); |
424 | else if (mlir::TF::StatefulPartitionedCallOp call_op = |
425 | mlir::dyn_cast<mlir::TF::StatefulPartitionedCallOp>(op)) |
426 | func = call_op.func().getName(); |
427 | else |
428 | return mlir::WalkResult::advance(); |
429 | if (func_to_caller.find(func) != func_to_caller.end()) |
430 | return op->emitOpError() |
431 | << "multiple calls found to " << func << " found." ; |
432 | func_to_caller[func] = op; |
433 | return mlir::WalkResult::advance(); |
434 | }); |
435 | return mlir::failure(result.wasInterrupted()); |
436 | } |
437 | |
438 | mlir::LogicalResult PopulateConsumersFromModule( |
439 | mlir::ModuleOp* module, mlir::Dialect* tf_dialect, |
440 | llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers) { |
441 | mlir::func::FuncOp main_func = |
442 | module->lookupSymbol<mlir::func::FuncOp>("main" ); |
443 | llvm::DenseMap<llvm::StringRef, mlir::Operation*> func_to_caller; |
444 | |
445 | if (mlir::failed(GetFuncToCaller(*module, func_to_caller))) |
446 | return mlir::failure(); |
447 | |
448 | module->walk([&](mlir::Operation* op) { |
449 | if (op->getDialect() != tf_dialect) return; |
450 | |
451 | if (mlir::isa<mlir::TF::PartitionedCallOp>(op) || |
452 | mlir::isa<mlir::TF::StatefulPartitionedCallOp>(op) || |
453 | mlir::isa<mlir::TF::WhileRegionOp>(op) || |
454 | mlir::isa<mlir::TF::IfRegionOp>(op) || |
455 | mlir::isa<mlir::TF::DTensorLayout>(op)) |
456 | return; |
457 | |
458 | for (const auto& value : op->getOpResults()) { |
459 | // Call clear so that value is in consumers (with an empty vector)even if |
460 | // there are no 'uses'. This should only happen for ops whose outputs are |
461 | // directly to main return, e.g. eagerly executed ops. |
462 | consumers[value].clear(); |
463 | for (auto& operand : value.getUses()) |
464 | for (auto& traced_operand : |
465 | TraceUseToNextTFOp(&operand, func_to_caller)) |
466 | consumers[value].emplace_back(traced_operand); |
467 | } |
468 | }); |
469 | |
470 | // Note that we need to add in the inputs from the main function (otherwise |
471 | // we won't have any layouts to propagate!). |
472 | for (auto& value : main_func.getArguments()) |
473 | for (auto& operand : value.getUses()) |
474 | for (auto* traced_operand : TraceUseToNextTFOp(&operand, func_to_caller)) |
475 | consumers[value].emplace_back(traced_operand); |
476 | return mlir::success(); |
477 | } |
478 | |
479 | // Compute the mesh coordinates from a device id + the current cluster. |
480 | // |
481 | // If the mesh shape is [a, b, c, d], then the mesh coordinates are |
482 | // [device_id/b/c/d, device_id/c/d%b, device_id/d%c, device_id%d] |
483 | // for convenience, since device_id < a*b*c*d, we can apply %a on the first |
484 | // coordinate as well for simplicity's sake. |
485 | // Thus we can decompose this calculation into the following tf ops: |
486 | // tf.FloorMod(tf.Div(device_id, [b*c*d, c*d, d, 1]), [a, b, c, d]) where |
487 | // [a, b, c, d] and [b*c*d, c*d, d, 1] are simply precomputed constants. |
488 | // |
489 | // Note that this returns a tensor of shape [1, mesh.rank()], suitable for |
490 | // using with MatMul. |
491 | StatusOr<mlir::Value> GetMeshCoordinatesFromCluster( |
492 | mlir::tf_device::ClusterOp cluster) { |
493 | // First try to find a FloorMod op with kMeshCoordinatesAttr attribute that |
494 | // has the given mesh in it. If it exists, simply return that op's value. |
495 | TF_ASSIGN_OR_RETURN(const auto mesh, ExtractDeviceMeshFromOp(cluster)); |
496 | if (!mesh) return errors::InvalidArgument("missing mesh on cluster" ); |
497 | string serialized_mesh = mesh->ToString(); |
498 | mlir::Value ret_val; |
499 | auto result = cluster.walk([&](mlir::TF::FloorModOp op) -> mlir::WalkResult { |
500 | if (op->hasAttrOfType<mlir::StringAttr>(kMeshCoordinatesAttr) && |
501 | op->getAttrOfType<mlir::StringAttr>(kMeshCoordinatesAttr) |
502 | .getValue() |
503 | .str() == serialized_mesh) { |
504 | ret_val = op.z(); |
505 | return mlir::WalkResult::interrupt(); |
506 | } |
507 | return mlir::WalkResult::advance(); |
508 | }); |
509 | if (result.wasInterrupted()) return ret_val; |
510 | |
511 | // We didn't find a FloorModOp for the given mesh, so we must produce the |
512 | // FloorModOp and add the attr so we can find it on next call. |
513 | std::vector<int32> mesh_shape(mesh->rank()); |
514 | for (int i = 0; i < mesh->rank(); ++i) mesh_shape[i] = mesh->dim(i).size; |
515 | |
516 | // This product represents the [b*c*d, c*d, d, 1] from the function |
517 | // documentation. |
518 | std::vector<int32> running_product(mesh->rank()); |
519 | running_product[mesh->rank() - 1] = 1; |
520 | for (int i = mesh->rank() - 1; i > 0; --i) |
521 | running_product[i - 1] = running_product[i] * mesh_shape[i]; |
522 | |
523 | mlir::OpBuilder builder(cluster.getContext()); |
524 | builder.setInsertionPointToStart(&cluster.GetBody()); |
525 | |
526 | auto mesh_shape_type = mlir::RankedTensorType::get( |
527 | {1, mesh->rank()}, builder.getIntegerType(32)); |
528 | mlir::Attribute mesh_shape_attr = |
529 | mlir::DenseIntElementsAttr::get(mesh_shape_type, mesh_shape); |
530 | auto mesh_shape_value = |
531 | builder.create<mlir::TF::ConstOp>(cluster.getLoc(), mesh_shape_attr) |
532 | .getResult(); |
533 | |
534 | auto running_product_value = |
535 | IntConst(builder, cluster.getLoc(), running_product); |
536 | |
537 | TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(cluster)); |
538 | |
539 | auto div_op = builder.create<mlir::TF::DivOp>(cluster.getLoc(), device_id, |
540 | running_product_value); |
541 | |
542 | auto mod_op = builder.create<mlir::TF::FloorModOp>( |
543 | cluster.getLoc(), div_op.z(), mesh_shape_value); |
544 | |
545 | mod_op->setAttr(kMeshCoordinatesAttr, builder.getStringAttr(serialized_mesh)); |
546 | return mod_op.z(); |
547 | } |
548 | |
549 | mlir::LogicalResult ValidateMetadataAttributes(mlir::Operation* op) { |
550 | // If cluster function has attributes containing inferred layout of resource |
551 | // handle arguments, then add the attributes to the newly created |
552 | // StatefulPartitonedCallOp. |
553 | auto inferred_resource_handle_indices = |
554 | op->getAttrOfType<mlir::DenseIntElementsAttr>(kNewResourceLayoutIndices); |
555 | auto inferred_resource_handle_layouts = |
556 | op->getAttrOfType<mlir::ArrayAttr>(kNewResourceArgLayouts); |
557 | if (inferred_resource_handle_indices || inferred_resource_handle_layouts) { |
558 | if (!inferred_resource_handle_indices || |
559 | !inferred_resource_handle_layouts || |
560 | inferred_resource_handle_indices.getNumElements() != |
561 | inferred_resource_handle_layouts.size()) |
562 | return op->emitOpError( |
563 | "inferred layout args doesn't match. indices size: " ) |
564 | << (inferred_resource_handle_indices |
565 | ? inferred_resource_handle_indices.getNumElements() |
566 | : 0) |
567 | << ", layouts size : " |
568 | << (inferred_resource_handle_layouts |
569 | ? inferred_resource_handle_layouts.size() |
570 | : 0); |
571 | } |
572 | |
573 | auto shape_layouts = op->getAttrOfType<mlir::ArrayAttr>(kShapeOpInputLayout); |
574 | auto shape_op_indices = |
575 | op->getAttrOfType<mlir::DenseIntElementsAttr>(kShapeOpInputLayoutIndices); |
576 | if (shape_op_indices || shape_layouts) { |
577 | if (!shape_op_indices || !shape_layouts || |
578 | shape_op_indices.getNumElements() != shape_layouts.size()) |
579 | return op->emitOpError("shape layout args doesn't match. indices size: " ) |
580 | << (shape_op_indices ? shape_op_indices.getNumElements() : 0) |
581 | << ", layouts size : " |
582 | << (shape_layouts ? shape_layouts.size() : 0); |
583 | } |
584 | return mlir::success(); |
585 | } |
586 | |
587 | void RemoveUnusedClusterResults(mlir::tf_device::ClusterOp cluster) { |
588 | llvm::SmallVector<mlir::OpResult, 4> new_result_values; |
589 | llvm::SmallVector<mlir::Value, 4> result_producing_values; |
590 | new_result_values.reserve(cluster->getNumResults()); |
591 | result_producing_values.reserve(cluster->getNumResults()); |
592 | for (mlir::OpResult result : cluster.getResults()) { |
593 | if (!result.use_empty()) { |
594 | new_result_values.emplace_back(result); |
595 | result_producing_values.emplace_back( |
596 | cluster.GetBody().getTerminator()->getOperand( |
597 | result.getResultNumber())); |
598 | } |
599 | } |
600 | |
601 | if (new_result_values.size() == cluster.getNumResults()) return; |
602 | |
603 | llvm::SmallVector<mlir::Type, 4> new_result_types; |
604 | llvm::transform(new_result_values, std::back_inserter(new_result_types), |
605 | [](mlir::Value v) { return v.getType(); }); |
606 | |
607 | mlir::OpBuilder builder(cluster); |
608 | auto new_cluster = builder.create<mlir::tf_device::ClusterOp>( |
609 | cluster.getLoc(), new_result_types); |
610 | new_cluster->setAttr(kMeshAttr, |
611 | cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr)); |
612 | new_cluster.getBody().push_back(new mlir::Block); |
613 | |
614 | auto& cluster_body = cluster.GetBody().getOperations(); |
615 | new_cluster.GetBody().getOperations().splice( |
616 | new_cluster.GetBody().end(), cluster_body, cluster_body.begin(), |
617 | std::prev(cluster_body.end())); |
618 | |
619 | builder.setInsertionPointToEnd(&new_cluster.GetBody()); |
620 | builder.create<mlir::tf_device::ReturnOp>(cluster.getLoc(), |
621 | result_producing_values); |
622 | |
623 | assert(new_cluster.getNumResults() == new_result_values.size()); |
624 | for (auto it : llvm::zip(new_result_values, new_cluster.getResults())) { |
625 | mlir::Value value_to_replace = std::get<0>(it); |
626 | mlir::Value new_result = std::get<1>(it); |
627 | value_to_replace.replaceAllUsesWith(new_result); |
628 | } |
629 | cluster.erase(); |
630 | } |
631 | |
632 | namespace { |
633 | |
634 | // Keeps track of number of functions added to the global graph for adding |
635 | // control flows. When converting regional control flow to functional control |
636 | // flow ops, function names may collide if non-unique branch function names are |
637 | // used. In order to ensure that all branch functions of TF control flow ops are |
638 | // unique, we keep track of atomic counter for each control flow functions. |
639 | // See b/174253694 for more details. |
640 | std::atomic<int32> dtensor_controlflow_function_counter{0}; |
641 | |
642 | } // namespace |
643 | |
644 | mlir::StringAttr GetUniqueControlflowFnName(const std::string& prefix, |
645 | mlir::OpBuilder& builder) { |
646 | int32 unique_id = dtensor_controlflow_function_counter++; |
647 | return builder.getStringAttr( |
648 | absl::StrCat(prefix, "_dtensor_function_" , unique_id)); |
649 | } |
650 | |
651 | Status SetBuilderInsertionAfterValue(mlir::Value value, |
652 | mlir::OpBuilder& builder) { |
653 | if (value.isa<mlir::OpResult>()) { |
654 | builder.setInsertionPointAfterValue(value); |
655 | return OkStatus(); |
656 | } |
657 | mlir::tf_device::ClusterOp cluster; |
658 | for (mlir::Operation* op : value.getUsers()) { |
659 | mlir::tf_device::ClusterOp new_cluster = |
660 | op->getParentOfType<mlir::tf_device::ClusterOp>(); |
661 | if (!new_cluster) continue; |
662 | if (!cluster) cluster = new_cluster; |
663 | if (cluster != new_cluster) |
664 | return errors::Internal("value has multiple uses in different clusters" ); |
665 | } |
666 | if (!cluster) return errors::Internal("value not used in any cluster" ); |
667 | |
668 | builder.setInsertionPointToStart(cluster.SingleBlock::getBody()); |
669 | return OkStatus(); |
670 | } |
671 | |
672 | Status PrintTensor(mlir::Value value, const std::string& format_string = "%s" ) { |
673 | mlir::OpBuilder builder(value.getContext()); |
674 | builder.setInsertionPointAfterValue(value); |
675 | TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(value)); |
676 | std::string all_format = absl::StrCat("Core %s: " , format_string); |
677 | // Scalar string type |
678 | mlir::RankedTensorType scalar_string = |
679 | mlir::RankedTensorType::get({}, builder.getType<mlir::TF::StringType>()); |
680 | mlir::TF::StringFormatOp format = builder.create<mlir::TF::StringFormatOp>( |
681 | value.getLoc(), scalar_string, mlir::ValueRange({device_id, value})); |
682 | format->setAttr("template" , builder.getStringAttr(all_format)); |
683 | builder.create<mlir::TF::PrintV2Op>(value.getLoc(), format.output(), |
684 | /*output_stream=*/"log(info)" , |
685 | /*end=*/"\n" ); |
686 | return OkStatus(); |
687 | } |
688 | |
689 | Status ( |
690 | mlir::Value value, llvm::SmallVectorImpl<std::string>& out_vector) { |
691 | value = GetForwardedDTensorLayoutInput(value); |
692 | if (value.isa<mlir::BlockArgument>()) |
693 | return errors::Internal("Unable get constant value from block argument." ); |
694 | mlir::DenseStringElementsAttr attr; |
695 | if (!matchPattern(value, m_Constant(&attr))) { |
696 | return errors::Internal( |
697 | llvm::formatv("failed to extract constant string vector from : {0}" , |
698 | value) |
699 | .str()); |
700 | } |
701 | for (const auto& str : attr.getRawStringData()) { |
702 | out_vector.push_back(str.str()); |
703 | } |
704 | return OkStatus(); |
705 | } |
706 | |
707 | StatusOr<std::string> (mlir::Value value) { |
708 | value = GetForwardedDTensorLayoutInput(value); |
709 | if (value.isa<mlir::BlockArgument>()) |
710 | return errors::Internal("Unable get constant value from block argument." ); |
711 | mlir::DenseStringElementsAttr attr; |
712 | if (!matchPattern(value, m_Constant(&attr))) { |
713 | return errors::Internal(absl::StrCat("required constant value for " , |
714 | OpName(value.getDefiningOp()))); |
715 | } |
716 | if (attr.size() != 1) { |
717 | return errors::Internal(absl::StrCat("expected 1 element, got " , |
718 | attr.size(), " for " , |
719 | OpName(value.getDefiningOp()))); |
720 | } |
721 | return std::string(*attr.getRawStringData().begin()); |
722 | } |
723 | |
724 | TopologicalIterator::TopologicalIterator(mlir::func::FuncOp main_func) |
725 | : ops_to_visit_{&main_func.front().front()} { |
726 | funcs_visited_.insert(main_func.getName()); |
727 | funcs_visited_in_call_stack_.insert(main_func.getName()); |
728 | } |
729 | |
730 | mlir::Operation* TopologicalIterator::next() { |
731 | if (!hasNext()) return nullptr; |
732 | |
733 | auto* op = ops_to_visit_.pop_back_val(); |
734 | auto* next_op = op->getNextNode(); |
735 | if (next_op) ops_to_visit_.push_back(next_op); |
736 | |
737 | // If this is a function call op, push the first op of the function body so |
738 | // that the function body is converted before the call site. |
739 | absl::optional<mlir::func::FuncOp> func = MaybeFindFunction(op); |
740 | if (func.has_value()) { |
741 | mlir::StringRef func_name = func->getName(); |
742 | |
743 | if (funcs_visited_.contains(func_name)) return next(); |
744 | |
745 | ops_to_visit_.push_back(&(func->front().front())); |
746 | funcs_visited_.insert(func_name); |
747 | } |
748 | |
749 | // If we have reached the end of a function body, remove the function from |
750 | // our active set. |
751 | if (!next_op && !funcs_visited_in_call_stack_.empty()) |
752 | if (auto func = op->getParentOfType<mlir::func::FuncOp>()) |
753 | funcs_visited_in_call_stack_.erase(func.getName()); |
754 | |
755 | if (auto cluster_op = mlir::dyn_cast<mlir::tf_device::ClusterOp>(op)) |
756 | ops_to_visit_.push_back(&cluster_op.GetBody().front()); |
757 | |
758 | if (auto while_op = mlir::dyn_cast<mlir::TF::WhileRegionOp>(op)) { |
759 | ops_to_visit_.push_back(&while_op.cond().front().front()); |
760 | ops_to_visit_.push_back(&while_op.body().front().front()); |
761 | } |
762 | |
763 | if (auto if_op = mlir::dyn_cast<mlir::TF::IfRegionOp>(op)) { |
764 | ops_to_visit_.push_back(&if_op.then_branch().front().front()); |
765 | ops_to_visit_.push_back(&if_op.else_branch().front().front()); |
766 | } |
767 | return op; |
768 | } |
769 | |
770 | bool TopologicalIterator::hasNext() { return !ops_to_visit_.empty(); } |
771 | |
772 | } // namespace dtensor |
773 | } // namespace tensorflow |
774 | |