1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <string>
17#include <utility>
18
19#include "absl/types/optional.h"
20#include "llvm/ADT/SetVector.h"
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Support/FormatVariadic.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
24#include "mlir/IR/Attributes.h" // from @llvm-project
25#include "mlir/IR/Builders.h" // from @llvm-project
26#include "mlir/IR/BuiltinOps.h" // from @llvm-project
27#include "mlir/IR/Diagnostics.h" // from @llvm-project
28#include "mlir/IR/Operation.h" // from @llvm-project
29#include "mlir/IR/Value.h" // from @llvm-project
30#include "mlir/IR/Visitors.h" // from @llvm-project
31#include "mlir/Pass/Pass.h" // from @llvm-project
32#include "mlir/Pass/PassManager.h" // from @llvm-project
33#include "mlir/Support/LogicalResult.h" // from @llvm-project
34#include "mlir/Transforms/Passes.h" // from @llvm-project
35#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
36#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
37#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
38#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
39#include "tensorflow/dtensor/cc/constants.h"
40#include "tensorflow/dtensor/cc/tensor_layout.h"
41#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
42#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
43#include "tensorflow/dtensor/mlir/layout_parsing.h"
44#include "tensorflow/dtensor/mlir/op_utils.h"
45#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
46
47namespace tensorflow {
48namespace dtensor {
49
50namespace {
51#define GEN_PASS_DEF_DTENSORMESHPROPAGATION
52#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
53
54// Extracts mesh of `block_arg` by parsing function argument attributes of it's
55// enclosing function. Mesh is inferred either using `tf._layout` or `tf._mesh`
56// attributes.
57mlir::LogicalResult ExtractMeshFromBlockArgument(mlir::BlockArgument block_arg,
58 absl::optional<Mesh>* out) {
59 auto func_op = mlir::dyn_cast_or_null<mlir::func::FuncOp>(
60 block_arg.getOwner()->getParentOp());
61 if (!func_op) {
62 return block_arg.getOwner()->getParentOp()->emitOpError(
63 "must be enclosed by a function");
64 }
65 auto layout_or_status = ExtractLayoutFromOperand(block_arg);
66 if (!layout_or_status.ok())
67 return func_op.emitOpError(layout_or_status.status().error_message());
68
69 if (layout_or_status->has_value()) {
70 out->emplace(layout_or_status->value().mesh());
71 return mlir::success();
72 }
73
74 auto mesh_attr = func_op.getArgAttrOfType<mlir::StringAttr>(
75 block_arg.getArgNumber(), kCustomDeviceMeshAttr);
76 if (!mesh_attr) return mlir::success();
77
78 auto mesh_from_block_arg_or_status =
79 Mesh::FromString(mesh_attr.getValue().str());
80 if (!mesh_from_block_arg_or_status.ok()) {
81 return func_op.emitOpError(
82 "Failed during mesh propagation. Op operand has invalid serialized "
83 "mesh");
84 }
85
86 out->emplace(mesh_from_block_arg_or_status.value());
87 return mlir::success();
88}
89
90// Extracts mesh of operation that produces `value`.
91mlir::LogicalResult ExtractMeshFromOpOutput(mlir::Value value,
92 absl::optional<Mesh>* out) {
93 auto input_op = value.getDefiningOp();
94 if (!input_op) return mlir::success();
95
96 auto operand_cluster =
97 llvm::dyn_cast<mlir::tf_device::ClusterOp>(value.getDefiningOp());
98 if (!operand_cluster) {
99 return mlir::emitError(value.getLoc())
100 << "operand must be from different device cluster.";
101 }
102
103 auto mesh_or_status = ExtractDeviceMeshFromOp(operand_cluster);
104 if (!mesh_or_status.ok())
105 return operand_cluster.emitOpError(
106 llvm::formatv("Failed during mesh propagation. {0}",
107 mesh_or_status.status().error_message()));
108
109 auto extracted_mesh = mesh_or_status.value();
110 if (extracted_mesh) *out = extracted_mesh.value();
111 return mlir::success();
112}
113
114// Extracts mesh configuration from `operand`. If operand is a function
115// argument, then mesh config is extracted from "tf._mesh" arg attribute of the
116// corresponding func op. If operand is from a preceding op, then mesh
117// configuration is extracted from the enclosing tf_device.Cluster op.
118mlir::LogicalResult ExtractMeshFromOperand(
119 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
120 mlir::OpOperand* operand, absl::optional<Mesh>* out) {
121 mlir::Value operand_value = operand->get();
122
123 const auto check_and_assign_mesh =
124 [](mlir::Location loc, absl::optional<Mesh>& mesh,
125 absl::optional<Mesh>& operand_mesh) -> mlir::LogicalResult {
126 if (mesh && !operand_mesh) {
127 operand_mesh.swap(mesh);
128 } else if (mesh && operand_mesh && mesh != operand_mesh) {
129 return mlir::emitError(
130 loc,
131 "Error during mesh propagation. Found inconsistent mesh "
132 "while inferring mesh from operands.");
133 }
134 return mlir::success();
135 };
136
137 // If `operand` is a block argument then extract mesh from `tf._mesh`
138 // attribute of the corresponding function argument.
139 if (auto block_arg = operand_value.dyn_cast<mlir::BlockArgument>()) {
140 if (mlir::failed(ExtractMeshFromBlockArgument(block_arg, out)))
141 return mlir::failure();
142
143 if (!out->has_value()) {
144 auto it = producers.find(operand);
145 if (it != producers.end()) {
146 auto producer_values = it->getSecond();
147 absl::optional<Mesh> operand_mesh;
148 for (mlir::Value producer_value : producer_values) {
149 if (auto arg = producer_value.dyn_cast<mlir::BlockArgument>()) {
150 absl::optional<Mesh> mesh;
151 if (mlir::failed(ExtractMeshFromBlockArgument(arg, &mesh)))
152 return mlir::failure();
153
154 if (mlir::failed(check_and_assign_mesh(
155 operand->getOwner()->getLoc(), mesh, operand_mesh)))
156 return mlir::failure();
157 } else {
158 auto input_cluster =
159 producer_value.getDefiningOp()
160 ->getParentOfType<mlir::tf_device::ClusterOp>();
161 auto output_from_producing_op = input_cluster.getResult(
162 producer_value.cast<mlir::OpResult>().getResultNumber());
163
164 absl::optional<Mesh> mesh;
165 if (mlir::failed(
166 ExtractMeshFromOpOutput(output_from_producing_op, &mesh)))
167 return mlir::failure();
168
169 if (mlir::failed(check_and_assign_mesh(
170 operand->getOwner()->getLoc(), mesh, operand_mesh)))
171 return mlir::failure();
172 }
173 }
174 *out = operand_mesh;
175 }
176 }
177 return mlir::success();
178 }
179
180 // If `operand` is from another operation, extract mesh from enclosing
181 // tf_device.cluster op of the input operation.
182 if (mlir::failed(ExtractMeshFromOpOutput(operand_value, out)))
183 return mlir::failure();
184
185 return mlir::success();
186}
187
188// Infers mesh of `cluster` from it's operands. If mesh can be inferred, all
189// operands must have same mesh.
190mlir::LogicalResult InferMeshFromInputs(
191 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
192 mlir::tf_device::ClusterOp cluster, absl::optional<Mesh>* mesh,
193 llvm::SmallVector<mlir::OpOperand*, 8>* inputs_with_inferred_mesh) {
194 auto result = mlir::success();
195
196 // If `cluster` wraps a `tf.CopyToMesh` op, do not infer mesh from it's
197 // inputs. `tf.CopyToMesh` specifies that all operations following the
198 // operation is executed on target device mesh cluster specified by
199 // `tf.CopyToMesh`.
200 if (llvm::isa<mlir::TF::CopyToMeshOp>(&cluster.GetBody().front()))
201 return result;
202
203 mlir::visitUsedValuesDefinedAbove(
204 cluster.getBody(), cluster.getBody(), [&](mlir::OpOperand* operand) {
205 if (mlir::failed(result)) return;
206 absl::optional<Mesh> extracted_config;
207
208 // If inputs to mesh is from DTensorLayout op, then use the mesh
209 // extracted from the DTensorLayout op to infer the mesh of the cluster.
210 if (auto layout_op =
211 llvm::dyn_cast<mlir::TF::DTensorLayout>(operand->getOwner())) {
212 auto mesh = layout_op.layout().mesh();
213 extracted_config.emplace(mesh);
214 } else {
215 auto extract_result =
216 ExtractMeshFromOperand(producers, operand, &extracted_config);
217 if (mlir::failed(extract_result)) {
218 result = extract_result;
219 return;
220 }
221 }
222
223 // DTensorDevice may create a graph with resource arguments with an
224 // empty layout. These layouts of the resource values will be added
225 // after layout is inferred from resource update ops. Therefore, ignore
226 // DTensorLayout ops will empty layouts.
227 if (!extracted_config || extracted_config->IsEmpty()) return;
228
229 inputs_with_inferred_mesh->emplace_back(operand);
230 if (mesh->has_value() && extracted_config != mesh->value()) {
231 result = cluster.emitOpError(
232 "failed during mesh propagation. All inputs to "
233 "`tf_device.Cluster` must have same mesh configuration.");
234 }
235
236 if (!mesh->has_value()) mesh->emplace(extracted_config.value());
237 });
238
239 return result;
240}
241
242// Extracts mesh from function return attributes. If `tf._default_layout`
243// attribute exists, mesh from the default layout is used. If not, mesh from
244// `tf._mesh` attribute is used.
245StatusOr<absl::optional<Mesh>> ExtractMeshFromFuctionOutput(
246 const int output_index, mlir::func::FuncOp function) {
247 absl::optional<Mesh> function_mesh;
248 auto terminator = llvm::cast<mlir::func::ReturnOp>(
249 function.getBody().front().getTerminator());
250 TF_ASSIGN_OR_RETURN(auto layout, ExtractLayoutFromFunctionReturnAttr(
251 terminator, output_index));
252
253 if (layout) {
254 function_mesh.emplace(layout->mesh());
255 return function_mesh;
256 }
257
258 auto output_mesh_attr = function.getResultAttrOfType<mlir::StringAttr>(
259 output_index, kCustomDeviceMeshAttr);
260 if (output_mesh_attr) {
261 TF_ASSIGN_OR_RETURN(auto mesh,
262 Mesh::FromString(output_mesh_attr.getValue().str()));
263 function_mesh.emplace(std::move(mesh));
264 }
265 return function_mesh;
266}
267
268// Infers mesh from users of `cluster` and records the usages that were used to
269// infer mesh configuration in `consumers_with_mesh`.
270mlir::LogicalResult InferMeshFromConsumers(
271 mlir::tf_device::ClusterOp cluster, absl::optional<Mesh>* mesh,
272 llvm::SmallVector<mlir::OpOperand*, 8>* consumers_with_mesh) {
273 for (auto& use_value : cluster.getOperation()->getUses()) {
274 mlir::Operation* consumer = use_value.getOwner();
275
276 // `tf.CopyToMesh` specifies that all operations following the
277 // operation are executed on target device mesh cluster specified by
278 // `tf.CopyToMesh`. Therefore, if `consumer` operation is `tf.CopyToMesh`
279 // do not propagate mesh backwards to `cluster`.
280 if (llvm::isa<mlir::TF::CopyToMeshOp>(consumer)) continue;
281
282 Mesh extracted_mesh;
283
284 // If `cluster` output is output value of a function, then infer mesh using
285 // function return value attribute, if it exists.
286 if (auto return_op = llvm::dyn_cast<mlir::func::ReturnOp>(consumer)) {
287 auto status_or_mesh = ExtractMeshFromFuctionOutput(
288 use_value.getOperandNumber(),
289 return_op->getParentOfType<mlir::func::FuncOp>());
290 if (!status_or_mesh.ok())
291 return cluster.emitOpError(status_or_mesh.status().ToString());
292
293 auto mesh = status_or_mesh.value();
294 if (mesh) extracted_mesh = *mesh;
295 } else {
296 // If `cluster` output is input to another cluster/op then infer mesh from
297 // the consumer operation.
298 auto consumer_cluster =
299 consumer->getParentOfType<mlir::tf_device::ClusterOp>();
300 if (!consumer_cluster) {
301 return cluster.emitOpError(
302 "failed to propagate mesh information. All operations must be "
303 "enclosed inside a tf_device.cluster op.");
304 }
305
306 auto mesh_or_status = ExtractDeviceMeshFromOp(consumer_cluster);
307 if (!mesh_or_status.ok())
308 return cluster.emitOpError(mesh_or_status.status().error_message());
309
310 auto consumer_mesh = mesh_or_status.value();
311 if (!consumer_mesh) continue;
312
313 extracted_mesh = consumer_mesh.value();
314 }
315
316 if (extracted_mesh.IsEmpty()) continue;
317
318 if (mesh->has_value() && extracted_mesh != mesh->value()) {
319 return cluster.emitOpError(
320 "failed to propagate mesh information. Mesh for op is ambiguous as "
321 "consumers have different mesh attributes");
322 }
323
324 consumers_with_mesh->emplace_back(&use_value);
325 if (!mesh->has_value()) mesh->emplace(std::move(extracted_mesh));
326 }
327 return mlir::success();
328}
329
330// Infers default mesh of function given it's inputs and outputs. Function has a
331// default mesh if all its inputs/outputs have valus assigned to the same mesh.
332mlir::LogicalResult InferFunctionDefaultMesh(
333 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
334 mlir::func::FuncOp function, mlir::OpBuilder* builder,
335 absl::optional<mlir::StringAttr>* inferred_default_mesh) {
336 auto terminator = function.getCallableRegion()->front().getTerminator();
337 for (auto& result_value : terminator->getOpOperands()) {
338 auto result_defining_op = result_value.get().getDefiningOp();
339 if (!result_defining_op) continue;
340
341 auto result_cluster =
342 llvm::cast<mlir::tf_device::ClusterOp>(result_defining_op);
343 auto result_mesh =
344 result_cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr);
345 if (!result_mesh) continue;
346
347 if (inferred_default_mesh->has_value() &&
348 inferred_default_mesh->value() != result_mesh) {
349 inferred_default_mesh->reset();
350 return mlir::success();
351 }
352 inferred_default_mesh->emplace(result_mesh);
353 }
354
355 absl::optional<Mesh> inferred_mesh_from_args;
356 for (auto function_arg : function.getArguments()) {
357 auto uses = function_arg.getUses();
358 if (uses.empty()) {
359 if (mlir::failed(ExtractMeshFromBlockArgument(function_arg,
360 &inferred_mesh_from_args)))
361 return mlir::failure();
362 } else {
363 auto operand = uses.begin().getOperand();
364 if (mlir::failed(ExtractMeshFromOperand(producers, operand,
365 &inferred_mesh_from_args)))
366 return mlir::failure();
367 }
368 if (!inferred_mesh_from_args) continue;
369
370 std::string mesh_str = inferred_mesh_from_args->ToString();
371 if (inferred_default_mesh->has_value() &&
372 inferred_default_mesh->value().getValue().str() != mesh_str) {
373 inferred_default_mesh->reset();
374 return mlir::success();
375 }
376
377 inferred_default_mesh->emplace(builder->getStringAttr(std::move(mesh_str)));
378 }
379 return mlir::success();
380}
381
382// Annotates `tf._mesh` attribute to argument of `function` with
383// string of `mesh`.
384void AnnotateFunctionArgumentsWithMeshInformation(
385 const Mesh& mesh,
386 const llvm::SmallVector<mlir::OpOperand*, 8>& input_values_from_mesh,
387 mlir::func::FuncOp function, mlir::OpBuilder* builder) {
388 for (auto value : input_values_from_mesh) {
389 function.setArgAttr(value->getOperandNumber(), kCustomDeviceMeshAttr,
390 builder->getStringAttr(mesh.ToString()));
391 }
392}
393
394// Annotates return value attributes of `function_to_annotate` with mesh
395// information parsed from usages of the function. `callsite_operation` is
396// callable op whose function definition is `function_to_annotate`.
397mlir::LogicalResult AnnotateFunctionReturnValuesWithMeshInformation(
398 const llvm::SmallVector<mlir::OpOperand*, 8>& return_values_from_mesh,
399 mlir::Operation* callsite_operation,
400 mlir::func::FuncOp function_to_annotate, mlir::OpBuilder* builder) {
401 for (auto value : return_values_from_mesh) {
402 absl::optional<mlir::StringAttr> result_mesh_attribute;
403 if (llvm::isa<mlir::func::ReturnOp>(value->getOwner())) {
404 auto parent_function =
405 callsite_operation->getParentOfType<mlir::func::FuncOp>();
406 auto function_result_layout =
407 parent_function.getResultAttrOfType<mlir::StringAttr>(
408 value->getOperandNumber(), kCustomDefaultLayoutAttr);
409 if (function_result_layout) {
410 auto layout_or_status =
411 Layout::FromString(function_result_layout.getValue().str());
412 if (!layout_or_status.ok())
413 return parent_function.emitOpError(
414 layout_or_status.status().error_message());
415
416 result_mesh_attribute.emplace(
417 builder->getStringAttr(layout_or_status->mesh().ToString()));
418 } else {
419 auto function_result_mesh =
420 parent_function.getResultAttrOfType<mlir::StringAttr>(
421 value->getOperandNumber(), kCustomDeviceMeshAttr);
422 if (function_result_mesh)
423 result_mesh_attribute.emplace(function_result_mesh);
424 }
425 } else {
426 auto op_mesh =
427 value->getOwner()->getAttrOfType<mlir::StringAttr>(kMeshAttr);
428 if (op_mesh) result_mesh_attribute.emplace(std::move(op_mesh));
429 }
430
431 if (result_mesh_attribute)
432 function_to_annotate.setResultAttr(
433 value->get().cast<mlir::OpResult>().getResultNumber(),
434 kCustomDeviceMeshAttr, result_mesh_attribute.value());
435 }
436 return mlir::success();
437}
438
439// MLIR pass that propagates mesh information to tf_device.Cluster ops.
440struct DTensorMeshPropagation
441 : public impl::DTensorMeshPropagationBase<DTensorMeshPropagation> {
442 void runOnOperation() override {
443 mlir::MLIRContext& context = getContext();
444 mlir::OpBuilder builder(&context);
445 auto module = getOperation();
446 mlir::func::FuncOp main_func =
447 module.lookupSymbol<mlir::func::FuncOp>("main");
448 if (!main_func) return;
449
450 mlir::Dialect* tf_dialect =
451 context.getLoadedDialect<mlir::TF::TensorFlowDialect>();
452
453 // This maps from OpResults to a list of OpOperands that consume this.
454 // Note that this will pass over/through
455 // (Stateful)PartitionedCall and other control flow, directly connecting
456 // producing ops to their consumers in the function. I.e. it presents
457 // flattened/inlined view of the flow of data.
458 llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>> consumers;
459 // Maintain a reverse mapping. Note that for controlflow operations like
460 // tf.If op, there may be multiple producers for a mlir::Value.
461 llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>> producers;
462
463 // Create consumers and producers maps.
464 if (mlir::failed(
465 PopulateConsumersFromModule(&module, tf_dialect, consumers)))
466 return signalPassFailure();
467
468 for (auto& consumer : consumers) {
469 for (auto* operand : consumer.second) {
470 producers[operand].emplace_back(consumer.first);
471 }
472 }
473
474 bool mesh_changed = true;
475 while (mesh_changed) {
476 mesh_changed = false;
477 if (mlir::failed(
478 PropagateMesh(producers, main_func, &builder, &mesh_changed)))
479 return signalPassFailure();
480 }
481 }
482
483 // Propagates and sets `_mesh` attributes to all clusters inside `function` if
484 // possible.
485 mlir::LogicalResult PropagateMesh(
486 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>&
487 producers,
488 mlir::func::FuncOp, mlir::OpBuilder* builder, bool* mesh_changed);
489
490 // Infers mesh of `cluster` from its input operations.
491 mlir::LogicalResult PropagateMeshFromInputs(
492 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>&
493 producers,
494 mlir::tf_device::ClusterOp cluster, mlir::OpBuilder* builder,
495 bool* mesh_changed);
496
497 // Infers mesh of `cluster` from its consuming operations.
498 mlir::LogicalResult PropagateMeshFromConsumers(
499 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>&
500 producers,
501 mlir::tf_device::ClusterOp cluster, mlir::OpBuilder* builder,
502 bool* mesh_changed);
503
504 // Assigns function default mesh to clusters with no mesh specified. Note that
505 // function has default mesh if all its dtensor inputs/outputs are assigned to
506 // a single mesh.
507 mlir::LogicalResult PropagateDefaultMeshToUnAssignedClusters(
508 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>&
509 producers,
510 mlir::func::FuncOp, mlir::OpBuilder* builder, bool* mesh_changed);
511};
512
513mlir::LogicalResult
514DTensorMeshPropagation::PropagateDefaultMeshToUnAssignedClusters(
515 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
516 mlir::func::FuncOp function, mlir::OpBuilder* builder, bool* mesh_changed) {
517 absl::optional<mlir::StringAttr> mesh;
518 if (mlir::failed(
519 InferFunctionDefaultMesh(producers, function, builder, &mesh)))
520 return mlir::failure();
521
522 llvm::SmallVector<mlir::tf_device::ClusterOp, 4> clusters_without_mesh;
523 auto walk_result = function.walk([&](mlir::tf_device::ClusterOp cluster) {
524 auto mesh_or_status = ExtractDeviceMeshFromOp(cluster);
525 if (!mesh_or_status.ok()) {
526 cluster.GetBody().front().emitOpError(
527 mesh_or_status.status().error_message());
528 return mlir::WalkResult::interrupt();
529 }
530
531 const auto& mesh = mesh_or_status.value();
532 if (mesh.has_value()) return mlir::WalkResult::advance();
533
534 clusters_without_mesh.emplace_back(cluster);
535 return mlir::WalkResult::advance();
536 });
537
538 if (walk_result.wasInterrupted()) return mlir::failure();
539
540 if (!mesh.has_value()) return mlir::success();
541
542 // Set function default mesh to cluster with unspecified mesh.
543 for (auto cluster_without_mesh : clusters_without_mesh) {
544 *mesh_changed = true;
545 cluster_without_mesh->setAttr(kMeshAttr, mesh.value());
546 }
547
548 return mlir::success();
549}
550
551mlir::LogicalResult DTensorMeshPropagation::PropagateMeshFromInputs(
552 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
553 mlir::tf_device::ClusterOp cluster, mlir::OpBuilder* builder,
554 bool* mesh_changed) {
555 // If operation inside a mesh cluster is not a callable operation and
556 // mesh is already specified on a cluster, do nothing.
557 auto inner_func = MaybeFindFunction(&cluster.GetBody().front());
558 auto cluster_mesh = cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr);
559 if (!inner_func && cluster_mesh) return mlir::success();
560
561 // If mesh of `cluster` is not specified, infer mesh using inputs of mesh
562 // cluster.
563 absl::optional<Mesh> extracted_mesh;
564 llvm::SmallVector<mlir::OpOperand*, 8> inputs_with_inferred_mesh;
565 if (failed(InferMeshFromInputs(producers, cluster, &extracted_mesh,
566 &inputs_with_inferred_mesh))) {
567 return mlir::failure();
568 }
569
570 // If operation include 'cluster` is a function call, annotate input and
571 // output mesh of `cluster` using function argument and return value
572 // attributes, then recursively propagate mesh of the function definition.
573 if (inner_func) {
574 // All inputs to cluster must be from the same mesh. If input mesh to
575 // callable operation is inferred, then annotated the input mesh to
576 // function argument attribute so that this information can be used to
577 // infer mesh of ops inside `inner_func`.
578 if (extracted_mesh.has_value()) {
579 AnnotateFunctionArgumentsWithMeshInformation(extracted_mesh.value(),
580 inputs_with_inferred_mesh,
581 inner_func.value(), builder);
582 }
583
584 // Recursively propagate mesh to clusters in function definition of
585 // `inner_func`.
586 if (mlir::failed(PropagateMesh(producers, inner_func.value(), builder,
587 mesh_changed)))
588 return mlir::failure();
589
590 // Once all clusters inside `inner_func` callable has been set, now we can
591 // infer mesh of `cluster`. That is, mesh of call site operation is equal
592 // to mesh of return values of the function.
593 absl::optional<mlir::StringAttr> function_mesh;
594 if (mlir::failed(InferFunctionDefaultMesh(producers, inner_func.value(),
595 builder, &function_mesh)))
596 return mlir::failure();
597
598 if (function_mesh && !cluster_mesh) {
599 *mesh_changed = true;
600 cluster->setAttr(kMeshAttr, function_mesh.value());
601 }
602 } else if (!cluster_mesh && extracted_mesh.has_value()) {
603 *mesh_changed = true;
604 cluster->setAttr(kMeshAttr,
605 builder->getStringAttr(extracted_mesh->ToString()));
606 }
607 return mlir::success();
608}
609
610// Set mesh of `cluster`, inferring mesh from consumer operations of `cluster`.
611mlir::LogicalResult DTensorMeshPropagation::PropagateMeshFromConsumers(
612 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
613 mlir::tf_device::ClusterOp cluster, mlir::OpBuilder* builder,
614 bool* mesh_changed) {
615 mlir::Operation* op_inside_cluster = &cluster.GetBody().front();
616 auto inner_func = MaybeFindFunction(op_inside_cluster);
617 auto cluster_mesh = cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr);
618 // If mesh is already set, then do nothing.
619 if (!inner_func && cluster_mesh) return mlir::success();
620
621 // Infer mesh of `cluster` from its output usages.
622 absl::optional<Mesh> extracted_mesh_from_consumers;
623 llvm::SmallVector<mlir::OpOperand*, 8> consumers_with_mesh_information;
624 if (failed(InferMeshFromConsumers(cluster, &extracted_mesh_from_consumers,
625 &consumers_with_mesh_information)))
626 return mlir::failure();
627
628 // If operation inside mesh cluster is a function callsite operation,
629 // then propagate mesh of the function recursively.
630 if (inner_func) {
631 if (mlir::failed(AnnotateFunctionReturnValuesWithMeshInformation(
632 consumers_with_mesh_information, op_inside_cluster,
633 inner_func.value(), builder)))
634 return mlir::failure();
635
636 if (mlir::failed(PropagateMesh(producers, inner_func.value(), builder,
637 mesh_changed)))
638 return mlir::failure();
639
640 absl::optional<mlir::StringAttr> function_mesh;
641 if (mlir::failed(InferFunctionDefaultMesh(producers, inner_func.value(),
642 builder, &function_mesh)))
643 return mlir::failure();
644
645 if (function_mesh && !cluster_mesh) {
646 *mesh_changed = true;
647 cluster->setAttr(kMeshAttr, function_mesh.value());
648 }
649 } else if (extracted_mesh_from_consumers && !cluster_mesh) {
650 *mesh_changed = true;
651 cluster->setAttr(kMeshAttr, builder->getStringAttr(
652 extracted_mesh_from_consumers->ToString()));
653 }
654 return mlir::success();
655}
656
657// Propagates mesh information to all `tf_device.Cluster` ops in `function`. If
658// `function` includes callable ops, then recursively traverse the function
659// definition to propagate mesh information using input operands and consuming
660// result ops. Note that at current stage of graph optimization,
661// tf_device.cluster ops are enclosing a single operation.
662mlir::LogicalResult DTensorMeshPropagation::PropagateMesh(
663 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
664 mlir::func::FuncOp function, mlir::OpBuilder* builder, bool* mesh_changed) {
665 // Iterate clusters in topological order propagating mesh from operations'
666 // inputs.
667 llvm::SmallVector<mlir::tf_device::ClusterOp, 8> cluster_ops;
668 for (auto cluster : function.getOps<mlir::tf_device::ClusterOp>()) {
669 cluster_ops.emplace_back(cluster);
670
671 if (mlir::failed(
672 PropagateMeshFromInputs(producers, cluster, builder, mesh_changed)))
673 return mlir::failure();
674 }
675
676 // Iterate clusters in reverse topological order and propagate mesh from
677 // consumers.
678 for (auto cluster : llvm::reverse(cluster_ops)) {
679 if (mlir::failed(PropagateMeshFromConsumers(producers, cluster, builder,
680 mesh_changed)))
681 return mlir::failure();
682 }
683
684 if (mlir::failed(PropagateDefaultMeshToUnAssignedClusters(
685 producers, function, builder, mesh_changed)))
686 return mlir::failure();
687
688 return mlir::success();
689}
690
691} // namespace
692
693std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
694CreateDTensorMeshPropagationPass() {
695 return std::make_unique<DTensorMeshPropagation>();
696}
697
698} // namespace dtensor
699} // namespace tensorflow
700