1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
17
18#include <algorithm>
19#include <atomic>
20#include <iterator>
21#include <string>
22#include <vector>
23
24#include "absl/strings/str_cat.h"
25#include "absl/strings/string_view.h"
26#include "llvm/ADT/SmallPtrSet.h"
27#include "llvm/Support/FormatVariadic.h"
28#include "llvm/Support/raw_ostream.h"
29#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
30#include "mlir/IR/Builders.h" // from @llvm-project
31#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
32#include "mlir/IR/BuiltinOps.h" // from @llvm-project
33#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
34#include "mlir/IR/Location.h" // from @llvm-project
35#include "mlir/IR/MLIRContext.h" // from @llvm-project
36#include "mlir/IR/OperationSupport.h" // from @llvm-project
37#include "mlir/IR/Value.h" // from @llvm-project
38#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
39#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
40#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41#include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
42#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
43#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/convert_op_folder.h"
44#include "tensorflow/core/platform/errors.h"
45#include "tensorflow/dtensor/cc/constants.h"
46#include "tensorflow/dtensor/cc/tensor_layout.h"
47#include "tensorflow/dtensor/mlir/device_utils.h"
48#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
49#include "tensorflow/dtensor/mlir/layout_parsing.h"
50#include "tensorflow/dtensor/mlir/op_utils.h"
51#include "tensorflow/dtensor/mlir/shape_utils.h"
52#include "tensorflow/dtensor/mlir/value_utils.h"
53
54namespace tensorflow {
55namespace dtensor {
56
57// Checks that all layouts are fully replicated
58bool AllReplicated(const std::vector<Layout>& layouts) {
59 for (const Layout& layout : layouts) {
60 if (!layout.IsFullyReplicated()) return false;
61 }
62 return true;
63}
64
65StatusOr<mlir::TensorType> LocalTypeFromGlobalType(
66 const Layout& layout, const mlir::TensorType& original_type) {
67 if (!original_type.hasRank()) {
68 return original_type;
69 }
70 auto shape = llvm::to_vector<4>(original_type.getShape());
71 auto shard_values = layout.num_shards();
72 for (int output_axis = 0; output_axis < shape.size(); ++output_axis) {
73 if (shape[output_axis] != mlir::ShapedType::kDynamicSize) {
74 if (shape[output_axis] % shard_values[output_axis] != 0) {
75 return errors::InvalidArgument(
76 "The sharding spec for axis ", output_axis, " splits among ",
77 shard_values[output_axis],
78 " values, which does not evenly divide the length of that axis "
79 "(",
80 shape[output_axis], "). The full requested layout is ",
81 layout.ToString(), ".");
82 }
83 shape[output_axis] /= shard_values[output_axis];
84 }
85 }
86 mlir::RankedTensorType new_output_type =
87 mlir::RankedTensorType::get(shape, original_type.getElementType());
88 return new_output_type;
89}
90
91StatusOr<mlir::TensorType> GlobalTypeFromLocalType(
92 const Layout& layout, const mlir::TensorType& original_type) {
93 if (!original_type.hasRank()) {
94 return original_type;
95 }
96 auto shape = llvm::to_vector<4>(original_type.getShape());
97 auto shard_values = layout.num_shards();
98 for (int output_axis = 0; output_axis < shape.size(); ++output_axis)
99 if (shape[output_axis] != mlir::ShapedType::kDynamicSize)
100 shape[output_axis] *= shard_values[output_axis];
101 mlir::RankedTensorType new_output_type =
102 mlir::RankedTensorType::get(shape, original_type.getElementType());
103 return new_output_type;
104}
105
106Status CreateSplitOp(const int num_split, const int split_dimension,
107 const mlir::Location location, mlir::Value src_input,
108 mlir::OpBuilder* builder, mlir::TF::SplitOp* split_op) {
109 // Creates a const op to hold split dimension value.
110 auto split_dim_type =
111 mlir::RankedTensorType::get({}, builder->getIntegerType(32));
112 auto split_dimension_attr =
113 mlir::DenseElementsAttr::get(split_dim_type, split_dimension);
114 auto split_dimension_op = builder->create<mlir::TF::ConstOp>(
115 location, split_dim_type, split_dimension_attr);
116
117 // Correctly set output shapes of split op output if input shape is statically
118 // known.
119 mlir::Type output_type;
120 auto input_type = src_input.getType().cast<mlir::TensorType>();
121
122 if (input_type.hasRank()) {
123 if (input_type.getShape()[split_dimension] ==
124 mlir::ShapedType::kDynamicSize) {
125 output_type = input_type;
126 } else {
127 auto shape = llvm::to_vector<4>(input_type.getShape());
128 if (shape[split_dimension] % num_split != 0) {
129 return errors::InvalidArgument(
130 llvm::formatv(
131 "incorrect input sharding configuration received. "
132 "{0}-th dimension of the input must be evenly divisible by {1}",
133 split_dimension, num_split)
134 .str());
135 }
136
137 shape[split_dimension] = shape[split_dimension] / num_split;
138 output_type =
139 mlir::RankedTensorType::get(shape, input_type.getElementType());
140 }
141 } else {
142 output_type = input_type;
143 }
144
145 // Creates a split op that splits |src_input| along |split_dimension|.
146 llvm::SmallVector<mlir::Type, 4> output_types(num_split, output_type);
147 *split_op = builder->create<mlir::TF::SplitOp>(
148 location, output_types, split_dimension_op.output(), src_input);
149 return OkStatus();
150}
151
152// Given layouts + shapes, determines if the two are broadcasting compatible.
153// When broadcasting we effectively line up the shapes and layouts by the end.
154// The input with lower rank can be thought of as having abs(rank_a-rank_b)
155// replicated dims of size 1 prepended to it.
156//
157// Returns the broadcast layout and the splits in the two inputs needed to run
158// an elementwise op efficiently.
159//
160// Checks that a given mesh dimension is not used in different tensor dimensions
161// in the two input layouts.
162// E.g. a layout like (unsharded,x,unsharded) is not compatible with
163// (unsharded,x) or (x,unsharded,unsharded) but is compatible with
164// (x,unsharded), (unsharded,unsharded) or (unsharded,x,unsharded).
165// (Note that due to broadcasting, we compare the dimensions from the end).
166//
167// If dims_to_ignore is > 0, then we ignore when a mesh dimension is used in
168// different tensor dimensions when those dimensions are both in the last
169// dims_to_ignore tensor dimensions of each input.
170// E.g. If dims_to_ignore = 2, then (unsharded,x,unsharded) is now compatible
171// with (unsharded,x) and it not compatible with (x,unsharded,unsharded).
172//
173// The output layout will be of rank max(layout_a.rank(), layout_b.rank()) -
174// dims_to_ignore and will be replicated on a dimension if either one of the
175// input layouts is replicated on that dimension. Once again recall due to
176// broadcasting, layouts are aligned by their ends and not their beginnings.
177// E.g. if dims_to_ignore is zero, the output layout for the inputs
178// (unsharded,x,unsharded) and (unsharded,y) is (unsharded,x,y).
179// If dims_to_ignore is two, the output for (y,x,unsharded) and
180// (unsharded,x) is just (y).
181//
182// In the case that one tensor is sharded and the other is not on a given
183// dimension, element wise operations *may* need to split the unsharded tensor
184// along the same mesh dimension that the other input is split on. Note that
185// the split is *not* needed if the unsharded tensor has dimension of size 1,
186// due to broadcasting.
187//
188// To help with the needed splittings, the vectors to_split_* are resized to the
189// rank of each input and if that dimension of the tensor needs to be split for
190// and elementwise op, we record the mesh dimension it should be split along in
191// the vector.
192// E.g. in the case of input layouts (unsharded,x,unsharded) and
193// (unsharded,unsharded) with dimensions (10,10,10) and (10,10),
194// to_split_a = {"unsharded", "unsharded", "unsharded"} and to_split_b =
195// {"x", "unsharded"}.
196// If the shapes were (10,10,10) and (1,10), then to_split_a = {"unsharded",
197// "unsharded", "unsharded"} and to_split_b = {"unsharded", "unsharded"}.
198//
199// Note that "unsharded" == Layout::kUnshardedDim.
200// NOTE: shape_a and shape_b are *global* shapes.
201StatusOr<Layout> GetBroadcastLayoutForElementWise(
202 const Layout& layout_a, const Layout& layout_b,
203 mlir::ArrayRef<int64_t> shape_a, mlir::ArrayRef<int64_t> shape_b,
204 int64_t dims_to_ignore, std::vector<std::string>& to_split_a,
205 std::vector<std::string>& to_split_b) {
206 if (layout_a.mesh() != layout_b.mesh())
207 return errors::InvalidArgument(
208 "layout_a and layout_b cannot be broadcast as they are on different "
209 "meshes.");
210
211 const int rank_a = layout_a.rank();
212 const int rank_b = layout_b.rank();
213 const int rank_offset_a = std::max(0, rank_b - rank_a);
214 const int rank_offset_b = std::max(0, rank_a - rank_b);
215 absl::flat_hash_map<std::string, int> mesh_dim_map_a;
216 absl::flat_hash_map<std::string, int> mesh_dim_map_b;
217 std::vector<string> output_layout_specs;
218
219 auto unsharded_specs = [](const int new_size) -> std::vector<std::string> {
220 std::vector<std::string> spec_strs(new_size, Layout::kUnshardedDim);
221 return spec_strs;
222 };
223
224 to_split_a = unsharded_specs(rank_a - dims_to_ignore);
225 to_split_b = unsharded_specs(rank_b - dims_to_ignore);
226
227 // Note that we record ranks over all dimensions even ones we ignore.
228 // We will check that a non-ignored dimension of a tensor does not use a
229 // mesh dimension that is used by an ignored dimension in the other tensor.
230 for (int i = 0; i < rank_a; ++i)
231 if (!Layout::IsUnshardedDimension(layout_a.sharding_spec(i)))
232 mesh_dim_map_a[layout_a.sharding_spec(i)] = i;
233 for (int i = 0; i < rank_b; ++i)
234 if (!Layout::IsUnshardedDimension(layout_b.sharding_spec(i)))
235 mesh_dim_map_b[layout_b.sharding_spec(i)] = i;
236
237 for (int i = 0; i < std::max(rank_a, rank_b) - dims_to_ignore; ++i) {
238 const int dim_a = i - rank_offset_a;
239 const int dim_b = i - rank_offset_b;
240 // When ranks are not equal we treat the first rank_offset_* dims of the
241 // shorter layout as not sharded.
242 const std::string mesh_dim_a =
243 dim_a >= 0 ? layout_a.sharding_spec(dim_a) : Layout::kUnshardedDim;
244 const std::string mesh_dim_b =
245 dim_b >= 0 ? layout_b.sharding_spec(dim_b) : Layout::kUnshardedDim;
246 // When ranks are not equal, we treat the first rank_offset_* dims of the
247 // shorter shape as if they were 1.
248 const int64_t tensor_dim_a = dim_a >= 0 ? shape_a[dim_a] : 1;
249 const int64_t tensor_dim_b = dim_b >= 0 ? shape_b[dim_b] : 1;
250
251 // Check for conflicted dimensions. If occurred, chose unsharded as merged
252 // result, if generate_unsharded_dim_for_conflicts is set by call site.
253 bool have_conflicted_dim = false;
254 if (!Layout::IsUnshardedDimension(mesh_dim_a) &&
255 mesh_dim_map_b.contains(mesh_dim_a) &&
256 mesh_dim_map_b[mesh_dim_a] != dim_b)
257 have_conflicted_dim = true;
258
259 if (!Layout::IsUnshardedDimension(mesh_dim_b) &&
260 mesh_dim_map_a.contains(mesh_dim_b) &&
261 mesh_dim_map_a[mesh_dim_b] != dim_a)
262 have_conflicted_dim = true;
263
264 // If both dimensions are sharded, we have already verified that they are
265 // sharded on the same mesh dim.
266 if (have_conflicted_dim) {
267 output_layout_specs.emplace_back(Layout::kUnshardedDim);
268 } else {
269 output_layout_specs.emplace_back(
270 Layout::IsUnshardedDimension(mesh_dim_a) ? mesh_dim_b : mesh_dim_a);
271 }
272 if (dim_a >= 0 && tensor_dim_a > 1 &&
273 Layout::IsUnshardedDimension(mesh_dim_a) &&
274 !Layout::IsUnshardedDimension(mesh_dim_b)) {
275 to_split_a[dim_a] = mesh_dim_b;
276 }
277 if (dim_b >= 0 && tensor_dim_b > 1 &&
278 Layout::IsUnshardedDimension(mesh_dim_b) &&
279 !Layout::IsUnshardedDimension(mesh_dim_a)) {
280 to_split_b[dim_b] = mesh_dim_a;
281 }
282 }
283 return Layout::GetLayout(output_layout_specs, layout_a.mesh());
284}
285
286StatusOr<absl::optional<Layout>> GetMergedOperandLayout(
287 const llvm::DenseMap<int, Layout>& operand_layouts, mlir::Operation* op) {
288 // Represents list of Layouts and it's operand index where layout value is
289 // defined (i.e. layout is not absl::nullopt).
290 llvm::SmallVector<std::pair<const Layout&, llvm::ArrayRef<int64_t>>, 4>
291 filtered_preferred_operand_layouts;
292 filtered_preferred_operand_layouts.reserve(op->getNumOperands());
293
294 for (const auto& index_and_layout : operand_layouts) {
295 TF_ASSIGN_OR_RETURN(
296 llvm::ArrayRef<int64_t> shape_to_merge,
297 GetShapeOfValue(op->getOperand(index_and_layout.first)));
298 filtered_preferred_operand_layouts.emplace_back(index_and_layout.second,
299 shape_to_merge);
300 }
301
302 if (filtered_preferred_operand_layouts.empty())
303 return absl::optional<Layout>();
304
305 // Merged all operands and it's layouts to a single broadcasted layout.
306 Layout merged_operand_layout = filtered_preferred_operand_layouts[0].first;
307 llvm::ArrayRef<int64_t> merged_shape =
308 filtered_preferred_operand_layouts[0].second;
309
310 // Statically analyze merged input operands layouts. Broadcasting is allowed
311 // but no cross device communication should be incurred.
312 for (int i = 1; i < filtered_preferred_operand_layouts.size(); ++i) {
313 const auto& operand_index_and_layout_to_merge =
314 filtered_preferred_operand_layouts[i];
315 const Layout& layout_to_merge = operand_index_and_layout_to_merge.first;
316 llvm::ArrayRef<int64_t> shape_to_merge =
317 operand_index_and_layout_to_merge.second;
318
319 std::vector<std::string> left_splits;
320 std::vector<std::string> right_splits;
321 TF_ASSIGN_OR_RETURN(merged_operand_layout,
322 GetBroadcastLayoutForElementWise(
323 merged_operand_layout, layout_to_merge,
324 merged_shape, shape_to_merge,
325 /*dims_to_ignore=*/0, left_splits, right_splits));
326 }
327 return absl::optional<Layout>(merged_operand_layout);
328}
329
330mlir::Value GetForwardedDTensorLayoutInput(mlir::Value value) {
331 auto layout_op =
332 llvm::dyn_cast_or_null<mlir::TF::DTensorLayout>(value.getDefiningOp());
333 if (!layout_op) return value;
334
335 return layout_op.input();
336}
337
338// Takes an operand and traces its use across function call and
339// tf_device.cluster boundaries. Note that this may turn one operand into many.
340// TODO(bfontain): Assumes that a function is only called once. This is checked
341// when creating func_to_caller.
342llvm::SmallVector<mlir::OpOperand*, 4> TraceUseToNextTFOp(
343 mlir::OpOperand* operand,
344 const llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller,
345 llvm::SmallVector<mlir::Value, 4>* skipped_values) {
346 mlir::Operation* owner = operand->getOwner();
347 llvm::SmallVector<mlir::Value, 4> values;
348 if (mlir::isa<mlir::TF::PartitionedCallOp>(owner) ||
349 mlir::isa<mlir::TF::StatefulPartitionedCallOp>(owner)) {
350 mlir::func::FuncOp func;
351 if (mlir::isa<mlir::TF::PartitionedCallOp>(owner))
352 func = mlir::cast<mlir::TF::PartitionedCallOp>(owner).func();
353 else
354 func = mlir::cast<mlir::TF::StatefulPartitionedCallOp>(owner).func();
355 values.emplace_back(func.getArgument(operand->getOperandNumber()));
356 } else if (mlir::isa<mlir::tf_device::ReturnOp>(owner)) {
357 auto device_return = mlir::cast<mlir::tf_device::ReturnOp>(owner);
358 auto enclosing_cluster =
359 device_return->getParentOfType<mlir::tf_device::ClusterOp>();
360 values.emplace_back(
361 enclosing_cluster.getResult(operand->getOperandNumber()));
362 } else if (mlir::isa<mlir::func::ReturnOp>(owner)) {
363 auto func = mlir::cast<mlir::func::ReturnOp>(owner)
364 ->getParentOfType<mlir::func::FuncOp>();
365 // The one function we don't have a caller for is the main function.
366 // In this case return the empty list as there are no consumers.
367 auto caller = func_to_caller.find(func.getName());
368 if (caller != func_to_caller.end())
369 values.emplace_back(
370 caller->second->getOpResult(operand->getOperandNumber()));
371 } else if (auto yield = mlir::dyn_cast<mlir::TF::YieldOp>(owner)) {
372 if (auto if_op = owner->getParentOfType<mlir::TF::IfRegionOp>()) {
373 values.emplace_back(if_op.getResult(operand->getOperandNumber()));
374 } else if (auto while_op =
375 owner->getParentOfType<mlir::TF::WhileRegionOp>()) {
376 if (while_op && !while_op.cond().isAncestor(yield->getParentRegion()))
377 values.emplace_back(while_op.getResult(operand->getOperandNumber()));
378 } else {
379 LOG(WARNING)
380 << "Found terminator op for unsupported controlflow operations.";
381 }
382 } else if (mlir::isa<mlir::TF::DTensorLayout>(owner)) {
383 auto dtensor_layout = mlir::cast<mlir::TF::DTensorLayout>(owner);
384 values.emplace_back(dtensor_layout.output());
385 } else if (auto while_op = mlir::dyn_cast<mlir::TF::WhileRegionOp>(owner)) {
386 // Handle loop variant inputs of while op.
387 mlir::Region& cond = while_op.cond();
388 mlir::Region& body = while_op.body();
389 const int operand_index = operand->getOperandNumber();
390 values.emplace_back(cond.front().getArgument(operand_index));
391 values.emplace_back(body.front().getArgument(operand_index));
392 } else {
393 return {operand};
394 }
395 llvm::SmallVector<mlir::OpOperand*, 4> ret;
396 for (mlir::Value value : values) {
397 if (skipped_values != nullptr) skipped_values->emplace_back(value);
398 for (mlir::OpOperand& use : value.getUses()) {
399 // TODO(bfontain): Remove recursion here.
400 const auto& traced_operands =
401 TraceUseToNextTFOp(&use, func_to_caller, skipped_values);
402 ret.append(traced_operands.begin(), traced_operands.end());
403 }
404 }
405
406 return ret;
407}
408
409mlir::LogicalResult GetFuncToCaller(
410 mlir::ModuleOp module,
411 llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller) {
412 // For now this is a 1:1 mapping and we will error out if a function is called
413 // by more than one op. The layout code assumes there is 1:many relationship
414 // between producers and consumers. If we allow a function to be called
415 // multiple times, then its consumers consume from multiple producers, which
416 // breaks this assumption.
417 // TODO(bfontain): Fix this, possibly by duplicating all functions in order to
418 // make this mapping 1:1 in truth.
419 auto result = module->walk([&](mlir::Operation* op) -> mlir::WalkResult {
420 mlir::StringRef func;
421 if (mlir::TF::PartitionedCallOp call_op =
422 mlir::dyn_cast<mlir::TF::PartitionedCallOp>(op))
423 func = call_op.func().getName();
424 else if (mlir::TF::StatefulPartitionedCallOp call_op =
425 mlir::dyn_cast<mlir::TF::StatefulPartitionedCallOp>(op))
426 func = call_op.func().getName();
427 else
428 return mlir::WalkResult::advance();
429 if (func_to_caller.find(func) != func_to_caller.end())
430 return op->emitOpError()
431 << "multiple calls found to " << func << " found.";
432 func_to_caller[func] = op;
433 return mlir::WalkResult::advance();
434 });
435 return mlir::failure(result.wasInterrupted());
436}
437
438mlir::LogicalResult PopulateConsumersFromModule(
439 mlir::ModuleOp* module, mlir::Dialect* tf_dialect,
440 llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers) {
441 mlir::func::FuncOp main_func =
442 module->lookupSymbol<mlir::func::FuncOp>("main");
443 llvm::DenseMap<llvm::StringRef, mlir::Operation*> func_to_caller;
444
445 if (mlir::failed(GetFuncToCaller(*module, func_to_caller)))
446 return mlir::failure();
447
448 module->walk([&](mlir::Operation* op) {
449 if (op->getDialect() != tf_dialect) return;
450
451 if (mlir::isa<mlir::TF::PartitionedCallOp>(op) ||
452 mlir::isa<mlir::TF::StatefulPartitionedCallOp>(op) ||
453 mlir::isa<mlir::TF::WhileRegionOp>(op) ||
454 mlir::isa<mlir::TF::IfRegionOp>(op) ||
455 mlir::isa<mlir::TF::DTensorLayout>(op))
456 return;
457
458 for (const auto& value : op->getOpResults()) {
459 // Call clear so that value is in consumers (with an empty vector)even if
460 // there are no 'uses'. This should only happen for ops whose outputs are
461 // directly to main return, e.g. eagerly executed ops.
462 consumers[value].clear();
463 for (auto& operand : value.getUses())
464 for (auto& traced_operand :
465 TraceUseToNextTFOp(&operand, func_to_caller))
466 consumers[value].emplace_back(traced_operand);
467 }
468 });
469
470 // Note that we need to add in the inputs from the main function (otherwise
471 // we won't have any layouts to propagate!).
472 for (auto& value : main_func.getArguments())
473 for (auto& operand : value.getUses())
474 for (auto* traced_operand : TraceUseToNextTFOp(&operand, func_to_caller))
475 consumers[value].emplace_back(traced_operand);
476 return mlir::success();
477}
478
479// Compute the mesh coordinates from a device id + the current cluster.
480//
481// If the mesh shape is [a, b, c, d], then the mesh coordinates are
482// [device_id/b/c/d, device_id/c/d%b, device_id/d%c, device_id%d]
483// for convenience, since device_id < a*b*c*d, we can apply %a on the first
484// coordinate as well for simplicity's sake.
485// Thus we can decompose this calculation into the following tf ops:
486// tf.FloorMod(tf.Div(device_id, [b*c*d, c*d, d, 1]), [a, b, c, d]) where
487// [a, b, c, d] and [b*c*d, c*d, d, 1] are simply precomputed constants.
488//
489// Note that this returns a tensor of shape [1, mesh.rank()], suitable for
490// using with MatMul.
491StatusOr<mlir::Value> GetMeshCoordinatesFromCluster(
492 mlir::tf_device::ClusterOp cluster) {
493 // First try to find a FloorMod op with kMeshCoordinatesAttr attribute that
494 // has the given mesh in it. If it exists, simply return that op's value.
495 TF_ASSIGN_OR_RETURN(const auto mesh, ExtractDeviceMeshFromOp(cluster));
496 if (!mesh) return errors::InvalidArgument("missing mesh on cluster");
497 string serialized_mesh = mesh->ToString();
498 mlir::Value ret_val;
499 auto result = cluster.walk([&](mlir::TF::FloorModOp op) -> mlir::WalkResult {
500 if (op->hasAttrOfType<mlir::StringAttr>(kMeshCoordinatesAttr) &&
501 op->getAttrOfType<mlir::StringAttr>(kMeshCoordinatesAttr)
502 .getValue()
503 .str() == serialized_mesh) {
504 ret_val = op.z();
505 return mlir::WalkResult::interrupt();
506 }
507 return mlir::WalkResult::advance();
508 });
509 if (result.wasInterrupted()) return ret_val;
510
511 // We didn't find a FloorModOp for the given mesh, so we must produce the
512 // FloorModOp and add the attr so we can find it on next call.
513 std::vector<int32> mesh_shape(mesh->rank());
514 for (int i = 0; i < mesh->rank(); ++i) mesh_shape[i] = mesh->dim(i).size;
515
516 // This product represents the [b*c*d, c*d, d, 1] from the function
517 // documentation.
518 std::vector<int32> running_product(mesh->rank());
519 running_product[mesh->rank() - 1] = 1;
520 for (int i = mesh->rank() - 1; i > 0; --i)
521 running_product[i - 1] = running_product[i] * mesh_shape[i];
522
523 mlir::OpBuilder builder(cluster.getContext());
524 builder.setInsertionPointToStart(&cluster.GetBody());
525
526 auto mesh_shape_type = mlir::RankedTensorType::get(
527 {1, mesh->rank()}, builder.getIntegerType(32));
528 mlir::Attribute mesh_shape_attr =
529 mlir::DenseIntElementsAttr::get(mesh_shape_type, mesh_shape);
530 auto mesh_shape_value =
531 builder.create<mlir::TF::ConstOp>(cluster.getLoc(), mesh_shape_attr)
532 .getResult();
533
534 auto running_product_value =
535 IntConst(builder, cluster.getLoc(), running_product);
536
537 TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(cluster));
538
539 auto div_op = builder.create<mlir::TF::DivOp>(cluster.getLoc(), device_id,
540 running_product_value);
541
542 auto mod_op = builder.create<mlir::TF::FloorModOp>(
543 cluster.getLoc(), div_op.z(), mesh_shape_value);
544
545 mod_op->setAttr(kMeshCoordinatesAttr, builder.getStringAttr(serialized_mesh));
546 return mod_op.z();
547}
548
549mlir::LogicalResult ValidateMetadataAttributes(mlir::Operation* op) {
550 // If cluster function has attributes containing inferred layout of resource
551 // handle arguments, then add the attributes to the newly created
552 // StatefulPartitonedCallOp.
553 auto inferred_resource_handle_indices =
554 op->getAttrOfType<mlir::DenseIntElementsAttr>(kNewResourceLayoutIndices);
555 auto inferred_resource_handle_layouts =
556 op->getAttrOfType<mlir::ArrayAttr>(kNewResourceArgLayouts);
557 if (inferred_resource_handle_indices || inferred_resource_handle_layouts) {
558 if (!inferred_resource_handle_indices ||
559 !inferred_resource_handle_layouts ||
560 inferred_resource_handle_indices.getNumElements() !=
561 inferred_resource_handle_layouts.size())
562 return op->emitOpError(
563 "inferred layout args doesn't match. indices size: ")
564 << (inferred_resource_handle_indices
565 ? inferred_resource_handle_indices.getNumElements()
566 : 0)
567 << ", layouts size : "
568 << (inferred_resource_handle_layouts
569 ? inferred_resource_handle_layouts.size()
570 : 0);
571 }
572
573 auto shape_layouts = op->getAttrOfType<mlir::ArrayAttr>(kShapeOpInputLayout);
574 auto shape_op_indices =
575 op->getAttrOfType<mlir::DenseIntElementsAttr>(kShapeOpInputLayoutIndices);
576 if (shape_op_indices || shape_layouts) {
577 if (!shape_op_indices || !shape_layouts ||
578 shape_op_indices.getNumElements() != shape_layouts.size())
579 return op->emitOpError("shape layout args doesn't match. indices size: ")
580 << (shape_op_indices ? shape_op_indices.getNumElements() : 0)
581 << ", layouts size : "
582 << (shape_layouts ? shape_layouts.size() : 0);
583 }
584 return mlir::success();
585}
586
587void RemoveUnusedClusterResults(mlir::tf_device::ClusterOp cluster) {
588 llvm::SmallVector<mlir::OpResult, 4> new_result_values;
589 llvm::SmallVector<mlir::Value, 4> result_producing_values;
590 new_result_values.reserve(cluster->getNumResults());
591 result_producing_values.reserve(cluster->getNumResults());
592 for (mlir::OpResult result : cluster.getResults()) {
593 if (!result.use_empty()) {
594 new_result_values.emplace_back(result);
595 result_producing_values.emplace_back(
596 cluster.GetBody().getTerminator()->getOperand(
597 result.getResultNumber()));
598 }
599 }
600
601 if (new_result_values.size() == cluster.getNumResults()) return;
602
603 llvm::SmallVector<mlir::Type, 4> new_result_types;
604 llvm::transform(new_result_values, std::back_inserter(new_result_types),
605 [](mlir::Value v) { return v.getType(); });
606
607 mlir::OpBuilder builder(cluster);
608 auto new_cluster = builder.create<mlir::tf_device::ClusterOp>(
609 cluster.getLoc(), new_result_types);
610 new_cluster->setAttr(kMeshAttr,
611 cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr));
612 new_cluster.getBody().push_back(new mlir::Block);
613
614 auto& cluster_body = cluster.GetBody().getOperations();
615 new_cluster.GetBody().getOperations().splice(
616 new_cluster.GetBody().end(), cluster_body, cluster_body.begin(),
617 std::prev(cluster_body.end()));
618
619 builder.setInsertionPointToEnd(&new_cluster.GetBody());
620 builder.create<mlir::tf_device::ReturnOp>(cluster.getLoc(),
621 result_producing_values);
622
623 assert(new_cluster.getNumResults() == new_result_values.size());
624 for (auto it : llvm::zip(new_result_values, new_cluster.getResults())) {
625 mlir::Value value_to_replace = std::get<0>(it);
626 mlir::Value new_result = std::get<1>(it);
627 value_to_replace.replaceAllUsesWith(new_result);
628 }
629 cluster.erase();
630}
631
632namespace {
633
634// Keeps track of number of functions added to the global graph for adding
635// control flows. When converting regional control flow to functional control
636// flow ops, function names may collide if non-unique branch function names are
637// used. In order to ensure that all branch functions of TF control flow ops are
638// unique, we keep track of atomic counter for each control flow functions.
639// See b/174253694 for more details.
640std::atomic<int32> dtensor_controlflow_function_counter{0};
641
642} // namespace
643
644mlir::StringAttr GetUniqueControlflowFnName(const std::string& prefix,
645 mlir::OpBuilder& builder) {
646 int32 unique_id = dtensor_controlflow_function_counter++;
647 return builder.getStringAttr(
648 absl::StrCat(prefix, "_dtensor_function_", unique_id));
649}
650
651Status SetBuilderInsertionAfterValue(mlir::Value value,
652 mlir::OpBuilder& builder) {
653 if (value.isa<mlir::OpResult>()) {
654 builder.setInsertionPointAfterValue(value);
655 return OkStatus();
656 }
657 mlir::tf_device::ClusterOp cluster;
658 for (mlir::Operation* op : value.getUsers()) {
659 mlir::tf_device::ClusterOp new_cluster =
660 op->getParentOfType<mlir::tf_device::ClusterOp>();
661 if (!new_cluster) continue;
662 if (!cluster) cluster = new_cluster;
663 if (cluster != new_cluster)
664 return errors::Internal("value has multiple uses in different clusters");
665 }
666 if (!cluster) return errors::Internal("value not used in any cluster");
667
668 builder.setInsertionPointToStart(cluster.SingleBlock::getBody());
669 return OkStatus();
670}
671
672Status PrintTensor(mlir::Value value, const std::string& format_string = "%s") {
673 mlir::OpBuilder builder(value.getContext());
674 builder.setInsertionPointAfterValue(value);
675 TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(value));
676 std::string all_format = absl::StrCat("Core %s: ", format_string);
677 // Scalar string type
678 mlir::RankedTensorType scalar_string =
679 mlir::RankedTensorType::get({}, builder.getType<mlir::TF::StringType>());
680 mlir::TF::StringFormatOp format = builder.create<mlir::TF::StringFormatOp>(
681 value.getLoc(), scalar_string, mlir::ValueRange({device_id, value}));
682 format->setAttr("template", builder.getStringAttr(all_format));
683 builder.create<mlir::TF::PrintV2Op>(value.getLoc(), format.output(),
684 /*output_stream=*/"log(info)",
685 /*end=*/"\n");
686 return OkStatus();
687}
688
689Status ExtractConstStringVectorFromValue(
690 mlir::Value value, llvm::SmallVectorImpl<std::string>& out_vector) {
691 value = GetForwardedDTensorLayoutInput(value);
692 if (value.isa<mlir::BlockArgument>())
693 return errors::Internal("Unable get constant value from block argument.");
694 mlir::DenseStringElementsAttr attr;
695 if (!matchPattern(value, m_Constant(&attr))) {
696 return errors::Internal(
697 llvm::formatv("failed to extract constant string vector from : {0}",
698 value)
699 .str());
700 }
701 for (const auto& str : attr.getRawStringData()) {
702 out_vector.push_back(str.str());
703 }
704 return OkStatus();
705}
706
707StatusOr<std::string> ExtractConstScalarStringFromValue(mlir::Value value) {
708 value = GetForwardedDTensorLayoutInput(value);
709 if (value.isa<mlir::BlockArgument>())
710 return errors::Internal("Unable get constant value from block argument.");
711 mlir::DenseStringElementsAttr attr;
712 if (!matchPattern(value, m_Constant(&attr))) {
713 return errors::Internal(absl::StrCat("required constant value for ",
714 OpName(value.getDefiningOp())));
715 }
716 if (attr.size() != 1) {
717 return errors::Internal(absl::StrCat("expected 1 element, got ",
718 attr.size(), " for ",
719 OpName(value.getDefiningOp())));
720 }
721 return std::string(*attr.getRawStringData().begin());
722}
723
724TopologicalIterator::TopologicalIterator(mlir::func::FuncOp main_func)
725 : ops_to_visit_{&main_func.front().front()} {
726 funcs_visited_.insert(main_func.getName());
727 funcs_visited_in_call_stack_.insert(main_func.getName());
728}
729
730mlir::Operation* TopologicalIterator::next() {
731 if (!hasNext()) return nullptr;
732
733 auto* op = ops_to_visit_.pop_back_val();
734 auto* next_op = op->getNextNode();
735 if (next_op) ops_to_visit_.push_back(next_op);
736
737 // If this is a function call op, push the first op of the function body so
738 // that the function body is converted before the call site.
739 absl::optional<mlir::func::FuncOp> func = MaybeFindFunction(op);
740 if (func.has_value()) {
741 mlir::StringRef func_name = func->getName();
742
743 if (funcs_visited_.contains(func_name)) return next();
744
745 ops_to_visit_.push_back(&(func->front().front()));
746 funcs_visited_.insert(func_name);
747 }
748
749 // If we have reached the end of a function body, remove the function from
750 // our active set.
751 if (!next_op && !funcs_visited_in_call_stack_.empty())
752 if (auto func = op->getParentOfType<mlir::func::FuncOp>())
753 funcs_visited_in_call_stack_.erase(func.getName());
754
755 if (auto cluster_op = mlir::dyn_cast<mlir::tf_device::ClusterOp>(op))
756 ops_to_visit_.push_back(&cluster_op.GetBody().front());
757
758 if (auto while_op = mlir::dyn_cast<mlir::TF::WhileRegionOp>(op)) {
759 ops_to_visit_.push_back(&while_op.cond().front().front());
760 ops_to_visit_.push_back(&while_op.body().front().front());
761 }
762
763 if (auto if_op = mlir::dyn_cast<mlir::TF::IfRegionOp>(op)) {
764 ops_to_visit_.push_back(&if_op.then_branch().front().front());
765 ops_to_visit_.push_back(&if_op.else_branch().front().front());
766 }
767 return op;
768}
769
770bool TopologicalIterator::hasNext() { return !ops_to_visit_.empty(); }
771
772} // namespace dtensor
773} // namespace tensorflow
774