1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17#include <iterator>
18#include <queue>
19#include <string>
20
21#include "absl/container/flat_hash_set.h"
22#include "absl/types/optional.h"
23#include "llvm/ADT/DenseMap.h"
24#include "llvm/ADT/DenseSet.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallVector.h"
27#include "llvm/ADT/StringRef.h"
28#include "llvm/Support/FormatVariadic.h"
29#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
30#include "mlir/IR/Attributes.h" // from @llvm-project
31#include "mlir/IR/Builders.h" // from @llvm-project
32#include "mlir/IR/BuiltinOps.h" // from @llvm-project
33#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
34#include "mlir/IR/MLIRContext.h" // from @llvm-project
35#include "mlir/IR/OpImplementation.h" // from @llvm-project
36#include "mlir/IR/Operation.h" // from @llvm-project
37#include "mlir/IR/TypeUtilities.h" // from @llvm-project
38#include "mlir/IR/Types.h" // from @llvm-project
39#include "mlir/IR/Value.h" // from @llvm-project
40#include "mlir/IR/Visitors.h" // from @llvm-project
41#include "mlir/Support/LogicalResult.h" // from @llvm-project
42#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
43#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
44#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
45#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
46#include "tensorflow/compiler/mlir/utils/name_utils.h"
47#include "tensorflow/dtensor/cc/constants.h"
48#include "tensorflow/dtensor/cc/dtensor_utils.h"
49#include "tensorflow/dtensor/cc/tensor_layout.h"
50#include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
51#include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dtensor_attributes.h"
52#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
53#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
54#include "tensorflow/dtensor/mlir/layout_parsing.h"
55#include "tensorflow/dtensor/mlir/op_utils.h"
56#include "tensorflow/dtensor/mlir/spmd_expander.h"
57#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
58#include "tensorflow/dtensor/mlir/value_utils.h"
59
60namespace tensorflow {
61namespace dtensor {
62
63namespace {
64#define GEN_PASS_DEF_DTENSORLAYOUTPROPAGATIONV2
65#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
66
67// This value dictates how many times during layout propagation we allow
68// fixing of oscillatory behaviors.
69constexpr int kLayoutPropagationMaxStages = 3;
70
71bool AllOpResultsHaveLayouts(
72 mlir::ModuleOp* module, mlir::Dialect* tf_dialect,
73 const llvm::DenseMap<mlir::Value, Layout>& layouts) {
74 const auto& result = module->walk([&](mlir::Operation* op) {
75 if (op->getDialect() != tf_dialect ||
76 mlir::isa<mlir::TF::DTensorLayout>(op))
77 return mlir::WalkResult::advance();
78 for (const auto& result : op->getOpResults()) {
79 if (layouts.find(result) == layouts.end()) {
80 op->emitOpError() << "missing layout for result "
81 << result.getResultNumber();
82 return mlir::WalkResult::interrupt();
83 }
84 }
85 return mlir::WalkResult::advance();
86 });
87 return !result.wasInterrupted();
88}
89
90void UpdateLayoutForSkippedOps(
91 mlir::OpOperand& operand,
92 const llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller,
93 const Layout& layout_to_copy,
94 llvm::DenseMap<mlir::Value, Layout>& layouts) {
95 llvm::SmallVector<mlir::Value, 4> skipped_values;
96 TraceUseToNextTFOp(&operand, func_to_caller, &skipped_values);
97 for (const mlir::Value& skipped_value : skipped_values)
98 if ((!skipped_value.isa<mlir::OpResult>() ||
99 !mlir::isa<mlir::TF::DTensorLayout, mlir::tf_device::ClusterOp>(
100 skipped_value.getDefiningOp())) &&
101 layouts.find(skipped_value) == layouts.end())
102 // TraceUseToNextTFOp's contract is that it only skips over ops that
103 // act like the identity (such as function calls, returns, yields,
104 // controlflow, DTensorLayouts, etc). This means that operand layout
105 // that we came from is the layout we want for this value.
106 layouts[skipped_value] = layout_to_copy;
107}
108
109// Some ops, which are skipped by TraceUseToNextTFOp, will not have layouts
110// for their mlir::OpResults.
111// E.g. during the creation of the consumers map, we skip the input and output
112// of the WhileRegion op. In particular if we have:
113//
114// %b = tf.WhileRegion(%a) ({
115// %bb0(%arg0): # Cond
116// %c = tf.A(%arg0)
117// tf.Yield(%c)
118// }, {
119// %bb0(%arg0): # Body
120// %d = tf.B(%arg0)
121// tf.Yield(%d)
122// }
123// }
124// %e = tf.C(%b)
125//
126// Then the consumers map would directly connect the mlir::Value %a to input 0
127// of tf.A and tf.B, bypassing the WhileRegion and the mlir::Value of %arg1.
128// Similarly it would connect the mlir::Value of %d directly to input 0 of
129// tf.C bypassing the mlir::Value of %b.
130// This means that at the end of layout propagation the skipped values would not
131// have an assigned layout. But this layout can be derived by taking the known
132// layout of %a and propagating to each mlir::Value that was skipped while
133// connecting %a to the input 0 of tf.A and tf.B. Similarly we derive the layout
134// for %b from %d.
135//
136// To get layouts we
137// 1) Iterate over all ops that have layouts for their OpResults and call
138// TraceUseToNextTFOp to get the skipped mlir::Values.
139// 2) If any skipped mlir::Value doesn't have a layout set, then we set the
140// layout.
141mlir::LogicalResult CopyLayoutsForSkippedOps(
142 mlir::ModuleOp module, mlir::Dialect* tf_dialect,
143 llvm::DenseMap<mlir::Value, Layout>& layouts) {
144 llvm::DenseMap<llvm::StringRef, mlir::Operation*> func_to_caller;
145
146 if (mlir::failed(GetFuncToCaller(module, func_to_caller)))
147 return mlir::failure();
148
149 // Update layouts derived from ops.
150 module->walk([&](mlir::Operation* op) {
151 for (mlir::OpOperand& operand : op->getOpOperands()) {
152 if (layouts.find(operand.get()) == layouts.end()) continue;
153 const Layout layout = layouts[operand.get()];
154 UpdateLayoutForSkippedOps(operand, func_to_caller, layout, layouts);
155 }
156 });
157
158 // Update layouts derived from inputs
159 mlir::func::FuncOp main_func =
160 module.lookupSymbol<mlir::func::FuncOp>("main");
161 if (!main_func) return mlir::success();
162
163 for (auto& value : main_func.getArguments()) {
164 if (layouts.find(value) == layouts.end()) continue;
165 const Layout layout = layouts[value];
166
167 for (mlir::OpOperand& operand : value.getUses())
168 UpdateLayoutForSkippedOps(operand, func_to_caller, layout, layouts);
169 }
170
171 return mlir::success();
172}
173
174namespace {
175void FilterkAnySpecs(std::vector<std::string>& proposed_specs) {
176 for (auto& spec : proposed_specs) {
177 if (spec == Layout::kAny) spec = Layout::kUnshardedDim;
178 }
179}
180} // namespace
181
182// Merges the producer and consumer layouts into a single layout.
183// Assumes that all layouts are of the same rank.
184// Consumers are first merged together so that we have the layout which is
185// sharded in a tensor dim if and only if all consumers are sharded in the same
186// sharding_spec.
187// If producer layout is present, we merge the consumer layouts into the layout
188// of the producer: if the consumer wants a sharded layout in a tensor dimension
189// where the producer is unshared *and* the mesh dimension it wants to be
190// sharded over is not already sharded over by the producer, then we add that
191// sharding to the producer layout.
192StatusOr<Layout> MergeLayouts(
193 const absl::optional<Layout>& producer,
194 const mlir::DenseMap<mlir::OpOperand*, Layout>& consumers) {
195 if (consumers.empty()) return producer.value();
196
197 // Initialize the specs to those of the first consumer layout and merge
198 // consumers.
199 std::vector<std::string> proposed_specs =
200 consumers.begin()->second.sharding_spec_strs();
201 int layout_rank = proposed_specs.size();
202
203 // Verify consumer layout ranks match.
204 for (const auto& consumer : consumers) {
205 const Layout& consumer_layout = consumer.second;
206 if (consumer_layout.rank() != layout_rank)
207 return errors::InvalidArgument(
208 "found two consumer layout of different ranks: ",
209 consumer_layout.rank(), " and ", layout_rank);
210 }
211
212 // Merge consumer layouts.
213 for (const auto& consumer : consumers) {
214 const Layout& consumer_layout = consumer.second;
215
216 // Check every tensor dimension.
217 for (int j = 0; j < consumer_layout.rank(); ++j) {
218 const std::string& consumer_spec_j = consumer_layout.sharding_spec(j);
219 if (consumer_spec_j == Layout::kAny) continue;
220
221 // If the proposed spec is set as any, give priority to the consumer spec.
222 if (proposed_specs[j] == Layout::kAny) {
223 proposed_specs[j] = consumer_spec_j;
224 continue;
225 }
226
227 // If any consumer layout disagrees with the current merge, set the
228 // spec to not sharded.
229 if (proposed_specs[j] != consumer_spec_j)
230 proposed_specs[j] = Layout::kUnshardedDim;
231 }
232 }
233
234 // Filter over-sharded specs.
235 absl::flat_hash_map<std::string, int> counter;
236 for (const std::string& spec : proposed_specs) counter[spec] += 1;
237 for (std::string& spec : proposed_specs)
238 if (counter[spec] > 1) spec = Layout::kUnshardedDim;
239
240 // Return layout if there is no producer, else move into producer algorithm.
241 const Mesh mesh = consumers.begin()->second.mesh();
242 if (!producer) {
243 FilterkAnySpecs(proposed_specs);
244 return Layout::GetLayout(proposed_specs, mesh);
245 }
246
247 if (producer->rank() != layout_rank) {
248 return errors::InvalidArgument(
249 "producer and consumer layout have different ranks: ", producer->rank(),
250 " and ", layout_rank);
251 }
252
253 // For the producer merge, first we define mesh dims used by the producer to
254 // avoid creating a layout that shards twice over the same mesh dim.
255 llvm::DenseSet<llvm::StringRef> producer_dims;
256 for (int j = 0; j < producer->rank(); ++j) {
257 llvm::StringRef spec = producer->sharding_spec(j);
258 if (Layout::IsShardedDimension(spec.str())) producer_dims.insert(spec);
259 }
260 // Merge producer layout with existing layout.
261 for (int j = 0; j < producer->rank(); ++j) {
262 const std::string& producer_spec = producer->sharding_spec(j);
263
264 if (producer_spec == proposed_specs[j] || producer_spec == Layout::kAny)
265 continue;
266
267 if (proposed_specs[j] == Layout::kAny) {
268 proposed_specs[j] = producer_spec;
269 continue;
270 }
271 // If producer is unsharded and proposed_spec is sharded. Need to make sure
272 // mesh dim is not used elsewhere. If so, set to unsharded.
273 if (Layout::IsUnshardedDimension(producer_spec)) {
274 bool isMeshDimUsed = producer_dims.contains(proposed_specs[j]);
275 if (isMeshDimUsed) {
276 proposed_specs[j] = Layout::kUnshardedDim;
277 }
278 } else {
279 // If producer is sharded we can set layout to shard over same
280 // mesh dim.
281 //
282 // If mesh dim is already used in the layout elsewhere it will
283 // get unset by the case above.
284 proposed_specs[j] = producer_spec;
285 }
286 }
287 FilterkAnySpecs(proposed_specs);
288 return Layout::GetLayout(proposed_specs, mesh);
289}
290
291mlir::LogicalResult InsertLayoutsForDTensorLayout(
292 mlir::ModuleOp& module,
293 llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
294 llvm::DenseSet<mlir::Value>& is_updated,
295 llvm::DenseSet<mlir::Value>& is_locked) {
296 return mlir::failure(
297 module
298 .walk([&](mlir::TF::DTensorLayout op) -> mlir::WalkResult {
299 // Check there are no "Layout::kAny" or "kMatch" specs in the
300 // layouts.
301 for (const std::string& spec : op.layout().sharding_spec_strs())
302 if (spec == Layout::kAny || spec == Layout::kMatch)
303 return op->emitOpError()
304 << "found " << spec
305 << " as a sharding spec which is not allowed";
306 // Insert layout.
307 producer_request[op.input()].emplace(op.layout());
308 is_updated.insert(op.input());
309 is_locked.insert(op.input());
310 return mlir::WalkResult::advance();
311 })
312 .wasInterrupted());
313}
314
315// Runs ComputeLayout API on all ops inside graph **without** any consumer
316// requested layout/ operand layouts populated.
317mlir::LogicalResult InsertInitialLayoutsFromComputeLayout(
318 mlir::ModuleOp module,
319 const llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers,
320 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
321 llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
322 llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>&
323 consumer_requests,
324 llvm::DenseSet<mlir::Value>& is_updated) {
325 auto walk_result = module.walk([&](mlir::Operation* op) {
326 // We ignore ops that don't have either an OpResult in consumers or an
327 // OpOperand in producers. Note that if one operand is missing from
328 // producers then all operands should be missing as well as all op results
329 // from consumers and the opposite as well.
330
331 if (op->getNumOperands() > 0) {
332 if (producers.find(&op->getOpOperand(0)) == producers.end())
333 return mlir::WalkResult::advance();
334 } else if (op->getNumResults() > 0) {
335 if (consumers.find(op->getOpResult(0)) == consumers.end())
336 return mlir::WalkResult::advance();
337 } else {
338 // Note that this case should never happen (I.e. a TF ops should have
339 // either inputs or outputs, but that isn't technically guaranteed).
340 return mlir::WalkResult::advance();
341 }
342
343 auto* expander = SPMDExpanderRegistry::Global()->GetPropagateFnForOp(op);
344 if (expander == nullptr) {
345 op->emitOpError() << "does not implement layout propagation";
346 return mlir::WalkResult::interrupt();
347 }
348
349 // Invoke ComputeLayout on `cluster_op` with empty input/consumer layouts.
350 StatusOr<llvm::DenseMap<int, Layout>> forward_result =
351 expander->ComputeLayoutForward(
352 op, /*input_layouts=*/llvm::DenseMap<int, Layout>(),
353 /*output_layouts=*/llvm::DenseMap<int, Layout>());
354 if (!forward_result.ok()) {
355 op->emitOpError() << "ComputeLayoutForward error: "
356 << forward_result.status().error_message();
357 return mlir::WalkResult::interrupt();
358 }
359 StatusOr<llvm::DenseMap<int, Layout>> backward_result =
360 expander->ComputeLayoutBackward(
361 op, /*input_layouts=*/llvm::DenseMap<int, Layout>(),
362 /*output_layouts=*/llvm::DenseMap<int, Layout>());
363 if (!backward_result.ok()) {
364 op->emitOpError() << "ComputeLayoutBackward error: "
365 << backward_result.status().error_message();
366 return mlir::WalkResult::interrupt();
367 }
368
369 // If any operand layouts were returned, add the layout to consumer requests
370 // and set the value as updated.
371 for (auto const& op_idx_and_op_layout : *backward_result) {
372 auto const& op_idx = op_idx_and_op_layout.first;
373 auto const& op_layout = op_idx_and_op_layout.second;
374 auto& operand = op->getOpOperand(op_idx);
375 const auto& producer_values = producers.lookup(&operand);
376 for (mlir::Value producer_value : producer_values) {
377 if (!consumer_requests[producer_value].count(&operand))
378 consumer_requests[producer_value][&operand] = op_layout;
379
380 is_updated.insert(producer_value);
381 }
382 }
383
384 // If any output layouts were returned, add the layout to producer requests
385 // and set the value as updated.
386 for (auto const& out_idx_and_out_layout : *forward_result) {
387 auto const& out_idx = out_idx_and_out_layout.first;
388 auto const& out_layout = out_idx_and_out_layout.second;
389 mlir::Value output_value = op->getResult(out_idx);
390 producer_request.try_emplace(output_value, out_layout);
391 is_updated.insert(output_value);
392 }
393
394 return mlir::WalkResult::advance();
395 });
396 return mlir::failure(walk_result.wasInterrupted());
397}
398
399// Propagates mesh and inserts initial layouts for
400// * Any DTensorLayout ops (this handles function inputs and other ops with user
401// layouts.
402// * CopyToMesh
403// * ConstOp
404mlir::LogicalResult InsertInitialLayouts(
405 mlir::ModuleOp& module, mlir::func::FuncOp& main_func,
406 const llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers,
407 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
408 llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>&
409 consumer_request,
410 llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
411 llvm::DenseSet<mlir::Value>& is_updated,
412 llvm::DenseSet<mlir::Value>& is_locked) {
413 std::queue<mlir::Operation*> operations;
414
415 if (mlir::failed(InsertLayoutsForDTensorLayout(module, producer_request,
416 is_updated, is_locked)))
417 return mlir::failure();
418 return InsertInitialLayoutsFromComputeLayout(module, consumers, producers,
419 producer_request,
420 consumer_request, is_updated);
421}
422
423// Given a list of mlir::Values with updated producer or consumer layouts
424// update the merged_layouts list and track which layouts actually changed.
425mlir::LogicalResult MergeAndGetUpdatedLayouts(
426 const llvm::DenseSet<mlir::Value>& is_locked,
427 llvm::DenseSet<mlir::Value>& is_updated,
428 llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
429 llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>&
430 consumer_requests,
431 llvm::DenseMap<mlir::Value, Layout>& merged_layouts) {
432 llvm::DenseSet<mlir::Value> updated_merge;
433 for (auto& value : is_updated) {
434 auto& producer_layout = producer_request[value];
435 if (is_locked.find(value) != is_locked.end()) {
436 // Locked values must have a producer request. If the merged_layout is
437 // not already set, then this is the first pass, so we set it and mark
438 // then entry as updated.
439 if (merged_layouts.find(value) == merged_layouts.end()) {
440 if (!producer_layout)
441 return value.getDefiningOp()->emitError() << "missing locked layout";
442 merged_layouts[value] = producer_layout.value();
443 updated_merge.insert(value);
444 }
445 continue;
446 }
447 auto merged = MergeLayouts(producer_layout, consumer_requests[value]);
448 if (!merged.ok())
449 return value.getDefiningOp()->emitOpError()
450 << merged.status().error_message();
451
452 auto current_layout = merged_layouts.find(value);
453 if (current_layout == merged_layouts.end() ||
454 current_layout->second != merged.value()) {
455 updated_merge.insert(value);
456 merged_layouts[value] = merged.value();
457 }
458 }
459
460 is_updated = updated_merge;
461 return mlir::success();
462}
463
464// Finds the most sharded merged layout given `layouts`.
465mlir::LogicalResult GetMostShardedLayout(llvm::ArrayRef<Layout> layouts,
466 mlir::Location location,
467 absl::optional<Layout>* out) {
468 // If there are no layouts to merge, leave the output empty.
469 if (layouts.empty()) return mlir::success();
470
471 absl::optional<Layout> layout;
472 std::map<std::string, std::set<int>> layout_map;
473 for (const Layout& layout : layouts) {
474 for (int i = 0; i < layout.rank(); ++i) {
475 const std::string& mesh_dim = layout.dim(i).sharding_spec();
476 if (mesh_dim == Layout::kUnshardedDim) continue;
477
478 layout_map[mesh_dim].insert(i);
479 }
480 }
481
482 for (auto& it : layout_map)
483 if (it.second.size() > 1) it.second.clear();
484
485 std::map<int, std::set<std::string>> dim_to_layout_map;
486 for (const auto& it : layout_map) {
487 assert(it.second.size() <= 1);
488 if (it.second.empty()) continue;
489
490 const int tensor_dim_index = *it.second.begin();
491 dim_to_layout_map[tensor_dim_index].insert(it.first);
492 }
493
494 for (auto& it : dim_to_layout_map)
495 if (it.second.size() > 1) it.second.clear();
496
497 std::vector<std::string> merged_spec;
498 assert(!layouts.empty());
499 for (int i = 0; i < layouts[0].rank(); ++i) {
500 const auto it = dim_to_layout_map.find(i);
501 if (it != dim_to_layout_map.end() && !it->second.empty()) {
502 assert(it->second.size() == 1);
503 merged_spec.emplace_back(*it->second.begin());
504 } else {
505 merged_spec.emplace_back(Layout::kUnshardedDim);
506 }
507 }
508 const auto new_layout = Layout::GetLayout(merged_spec, layouts[0].mesh());
509 if (!new_layout.ok()) {
510 return mlir::emitError(
511 location, llvm::formatv("error in layout propagation while merging "
512 "producer layouts. {0}",
513 new_layout.status().error_message()));
514 }
515 out->emplace(*new_layout);
516 return mlir::success();
517}
518
519// Merge layouts of mlir::Value from multiple producers into a single final
520// layout. A mlir::Value can have multiple producers if the value is from a
521// tf.If/tf.IfRegion op. Given multiple producer layouts of the same
522// mlir::Value, the merging logic is as follows:
523// 1) If a dimension can be sharded, shard the dimension as much as possible.
524// 2) If mesh dimension is already used or two same mesh dimensions are used
525// in different dimensions, then leave the dimension as replicated.
526//
527// For example:
528// ("x", replicated) , (replicated, "y") will have ("x", "y") merged layout.
529// ("x", replicated) , (replicated, "x") will have (replicated, replicated)
530// merged layout.
531mlir::LogicalResult MergeProducerLayouts(
532 const llvm::DenseMap<mlir::Value, Layout>& merged_layouts,
533 const std::vector<mlir::Value>& producer_values, mlir::Location location,
534 absl::optional<Layout>* layout_out) {
535 // If there is a single producer for mlir::Value, then return the layout
536 // from the producer.
537 absl::optional<Layout> layout;
538 if (producer_values.size() == 1) {
539 const auto it = merged_layouts.find(producer_values[0]);
540 if (it != merged_layouts.end()) *layout_out = it->second;
541 return mlir::success();
542 }
543
544 // For the case with multiple producer, merge the layouts.
545 llvm::SmallVector<Layout, 4> candidate_layouts;
546 candidate_layouts.reserve(producer_values.size());
547 for (mlir::Value value : producer_values) {
548 auto it = merged_layouts.find(value);
549 if (it == merged_layouts.end()) continue;
550 candidate_layouts.emplace_back(it->second);
551 }
552
553 if (mlir::failed(GetMostShardedLayout(candidate_layouts, location, &layout)))
554 return mlir::failure();
555
556 if (layout) *layout_out = *layout;
557 return mlir::success();
558}
559
560// For an op, calls the corresponding ComputeLayouts function with the data from
561// the merged_layouts map. Records the result in the producer_request and
562// consumer_requests maps and notes if any layouts have changed.
563mlir::LogicalResult UpdateLayoutsForOp(
564 mlir::Operation* op,
565 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
566 const llvm::DenseMap<mlir::Value, Layout>& merged_layouts,
567 llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
568 llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>&
569 consumer_requests,
570 llvm::DenseSet<mlir::Value>& is_updated) {
571 auto* expander = SPMDExpanderRegistry::Global()->GetPropagateFnForOp(op);
572 if (expander == nullptr)
573 return op->emitOpError() << "does not implement layout propagation";
574
575 // Get input and output layouts for this op from the merged_layouts map.
576 llvm::DenseMap<int, Layout> input_layouts(op->getNumOperands());
577 llvm::DenseMap<int, Layout> output_layouts(op->getNumResults());
578
579 for (int i = 0; i < op->getNumOperands(); ++i) {
580 // For inputs, we need to find the producer's mlir::Value that eventually
581 // feeds into this op. This is in the producers map.
582 // Merge different layouts for multiples producers `values`.
583 auto producer_values = producers.find(&(op->getOpOperand(i)));
584 if (producer_values == producers.end())
585 return op->emitError() << "Unable to find producer for operand " << i;
586
587 absl::optional<Layout> layout;
588 if (mlir::failed(MergeProducerLayouts(merged_layouts,
589 producer_values->getSecond(),
590 op->getLoc(), &layout)))
591 return mlir::failure();
592
593 if (layout) input_layouts[i] = *layout;
594 }
595
596 for (int i = 0; i < op->getNumResults(); ++i) {
597 auto layout = merged_layouts.find(op->getOpResult(i));
598 if (layout != merged_layouts.end()) output_layouts[i] = layout->second;
599 }
600
601 auto forward_result =
602 expander->ComputeLayoutForward(op, input_layouts, output_layouts);
603 if (!forward_result.ok()) {
604 return op->emitOpError() << "ComputeLayoutForward error: "
605 << forward_result.status().error_message();
606 }
607 const auto new_output_layouts = *forward_result;
608 auto backward_result =
609 expander->ComputeLayoutBackward(op, input_layouts, output_layouts);
610 if (!backward_result.ok()) {
611 return op->emitOpError() << "ComputeLayoutBackward error: "
612 << backward_result.status().error_message();
613 }
614 const auto new_input_layouts = *backward_result;
615
616 // Update the consumer layouts for this op.
617 for (int i = 0; i < op->getNumOperands(); ++i) {
618 mlir::OpOperand* operand = &(op->getOpOperand(i));
619 // No need to check that this exists, we already did it above.
620 const auto& producer_values = producers.find(operand);
621 const auto input_layout = new_input_layouts.find(i);
622
623 for (mlir::Value value : producer_values->getSecond()) {
624 auto& consumer_request = consumer_requests[value];
625 const auto consumer_request_from_op_operand =
626 consumer_request.find(operand);
627
628 // Update the consumer_request for this OpOperand: we respect what compute
629 // layout returns and erase the a requested layout if no layout is
630 // returned.
631 // TODO(hongjunchoi, bfontain): Consider the case when op output type is
632 // resource type with subtype information.
633 if (input_layout != new_input_layouts.end() &&
634 (consumer_request_from_op_operand == consumer_request.end() ||
635 input_layout->second != consumer_request_from_op_operand->second)) {
636 // RestoreV2 op most likely would have unknown rank upon restoring, and
637 // we relax unknown rank check for the inputs that are produced from
638 // there.
639 const bool exempt_restore_unknown_rank =
640 ValueRank(value) == -1 && value.getDefiningOp() &&
641 llvm::isa<mlir::TF::RestoreV2Op>(value.getDefiningOp());
642 if (!exempt_restore_unknown_rank &&
643 input_layout->second.rank() != ValueRank(value))
644 return op->emitOpError()
645 << "Rank for input " << i << " layout is "
646 << input_layout->second.rank() << " but actual rank is "
647 << ValueRank(value);
648
649 // If there was a layout returned and either no previous request or the
650 // request changed, insert and mark as updated.
651 consumer_request[operand] = input_layout->second;
652 is_updated.insert(value);
653 } else if (input_layout == new_input_layouts.end() &&
654 consumer_request_from_op_operand != consumer_request.end()) {
655 // If no layout was returned and there is previous request, erase the
656 // old consumer request.
657 consumer_request.erase(operand);
658 is_updated.insert(value);
659 }
660 }
661 }
662
663 // Update the producer request for this op.
664 // If the output is different from what is in the request list, update the
665 // the request and mark the mlir::Value as having an updated Layout request.
666 for (int i = 0; i < op->getNumResults(); ++i) {
667 const auto output_layout = new_output_layouts.find(i);
668 if (output_layout == new_output_layouts.end()) continue;
669 const auto& result = op->getOpResult(i);
670 if (producer_request[result] != output_layout->second) {
671 if (output_layout->second.rank() != ValueRank(result))
672 return op->emitOpError() << "Rank for output " << i << " layout is "
673 << output_layout->second.rank()
674 << " but actual rank is " << ValueRank(result);
675 producer_request[result] = output_layout->second;
676 is_updated.insert(result);
677 }
678 }
679 return mlir::success();
680}
681
682mlir::LogicalResult InsertDTensorLayoutOps(
683 mlir::OpBuilder& builder,
684 const llvm::DenseMap<mlir::Value, Layout>& merged_layouts) {
685 for (const auto& merged_layout : merged_layouts) {
686 // merged_layout is a pair of mlir::Value and Layout.
687 // If there is only one user of the Value and that user is a DTensorLayout
688 // op, then we can skip creating the op as the layout is already there. Note
689 // that we specifically do not allow updating a layout in an already present
690 // DTensorLayout op as we have considered them to be 'locked' throughout
691 // the algorithm.
692 const auto& users = merged_layout.first.getUsers();
693 int num_users = std::distance(users.begin(), users.end());
694 if (num_users == 1 && mlir::isa<mlir::TF::DTensorLayout>(*users.begin()))
695 continue;
696 builder.setInsertionPointAfterValue(merged_layout.first);
697 // Handles resource and variant as the real shape is embedded in the
698 // resource type elements.
699 mlir::Type value_type = GetSubtypeOrSelf(merged_layout.first);
700
701 if (auto type = value_type.dyn_cast<mlir::TensorType>()) {
702 auto layout_op = builder.create<mlir::TF::DTensorLayout>(
703 merged_layout.first.getLoc(), merged_layout.first,
704 mlir::dtensor::LayoutAttr::get(builder.getContext(),
705 merged_layout.second),
706 mlir::TF::ShapeAttr::get(builder.getContext(), type));
707 llvm::SmallPtrSet<mlir::Operation*, 4> exception{layout_op};
708 merged_layout.first.replaceAllUsesExcept(layout_op.output(), exception);
709 } else {
710 mlir::emitError(merged_layout.first.getLoc())
711 << "value type is not TensorType as expected.";
712 }
713 }
714
715 return mlir::success();
716}
717
718void GetOperationsNeedingUpdate(
719 const llvm::DenseSet<mlir::Value>& is_updated,
720 const llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers,
721 llvm::DenseSet<mlir::Operation*>& operations) {
722 for (auto& value : is_updated) {
723 auto uses = consumers.find(value);
724 // Some values have no consumers (e.g. outputs of the main function).
725 if (uses != consumers.end())
726 for (auto* use : uses->second)
727 if (!mlir::isa<mlir::TF::CopyToMeshOp>(use->getOwner()))
728 operations.insert(use->getOwner());
729 // If this is an OpResult, also add the op that produces it.
730 if (value.isa<mlir::OpResult>() &&
731 !mlir::isa<mlir::TF::CopyToMeshOp>(value.getDefiningOp()))
732 operations.insert(value.getDefiningOp());
733 }
734}
735
736namespace {
737
738// Custom printing class which prints out layouts and ignores DTensorLayout
739// ops and also non registered attributes.
740class LayoutPrinter : public mlir::OpAsmPrinter {
741 public:
742 explicit LayoutPrinter(
743 llvm::raw_ostream& os,
744 const llvm::DenseMap<mlir::Value, Layout>& merged_layouts)
745 : indent_level_(0),
746 os_(os),
747 current_location_(0),
748 next_argument_(0),
749 merged_layouts_(merged_layouts) {}
750
751 llvm::raw_ostream& getStream() const override { return os_; }
752
753 void printRegionArgument(mlir::BlockArgument arg,
754 llvm::ArrayRef<mlir::NamedAttribute> argAttrs,
755 bool omitType) override {
756 printOperand(arg);
757 if (!omitType) {
758 os_ << ": ";
759 printType(arg.getType());
760 }
761 printOptionalAttrDict(argAttrs, llvm::None);
762 }
763
764 void printOperand(mlir::Value value) override { printOperand(value, os_); }
765
766 /// Print a newline and indent the printer to the start of the current
767 /// operation.
768 void printNewline() override {
769 os_ << "\n";
770 os_.indent(indent_level_);
771 }
772
773 // Note that we ignore the parameters printEntryBlockArgs and
774 // printBlockTerminators for simplicity.
775 void printRegion(mlir::Region& blocks, bool printEntryBlockArgs,
776 bool printBlockTerminators,
777 bool printEmptyBlock = false) override {
778 os_ << " {\n";
779 for (auto& b : blocks.getBlocks()) print(b);
780 os_.indent(indent_level_) << "}";
781 }
782
783 void print(mlir::Block& block) {
784 // Each nested block level increases the indent.
785 os_.indent(indent_level_) << "%bb(";
786 for (int i = 0; i < block.getNumArguments(); ++i) {
787 if (arguments_.find(block.getArgument(i)) == arguments_.end())
788 arguments_[block.getArgument(i)] = next_argument_++;
789 if (i > 0) os_ << ", ";
790 os_ << "%arg" << arguments_[block.getArgument(i)];
791 }
792 os_ << "):\n";
793 indent_level_ += 2;
794 for (auto& op : block.getOperations()) print(op);
795 indent_level_ -= 2;
796 }
797
798 // Prints the TF node name from `loc`.
799 void printLoc(mlir::Location loc) {
800 os_ << " [" << mlir::GetNameFromLoc(loc) << "]";
801 }
802
803 void print(mlir::Operation& op) {
804 // Don't print tf.DTensorLayout ops.
805 if (mlir::isa<mlir::TF::DTensorLayout>(op)) return;
806
807 // Don't print functions with empty bodies.
808 if (auto func_op = mlir::dyn_cast<mlir::func::FuncOp>(op))
809 if (func_op.empty()) return;
810
811 // Each operation is on its own line, so we start by indenting the
812 // the line.
813 os_.indent(indent_level_);
814
815 // Record a unique identifier for the op (this will be used for printing
816 // op results and operands).
817 location_[&op] = current_location_++;
818
819 // Print the outputs.
820 for (int i = 0; i < op.getNumResults(); ++i) {
821 if (i > 0) os_ << ", ";
822 printOperand(op.getOpResult(i), os_);
823 }
824
825 if (op.getNumResults() > 0) os_ << " = ";
826
827 // Some ops have a special printing method, call this if it exists.
828 if (auto opInfo = op.getRegisteredInfo()) {
829 opInfo->printAssembly(&op, *this, /*defaultDialect=*/"");
830 printLoc(op.getLoc());
831 os_ << "\n";
832 return;
833 }
834
835 // Otherwise we do a generic printing.
836 printGenericOp(&op, true);
837 printLoc(op.getLoc());
838
839 os_ << "\n";
840 }
841
842 // Print an operand, this could be both the OpResult or a BlockArgument.
843 // We also print the layout if it exists and the type.
844 void printOperand(mlir::Value value, llvm::raw_ostream& os) override {
845 if (auto result = value.dyn_cast<mlir::OpResult>()) {
846 // If DTensorLayout ops are already in the module, we need to skip them
847 // since we aren't printing them out.
848 if (mlir::isa<mlir::TF::DTensorLayout>(result.getDefiningOp())) {
849 printOperand(result.getDefiningOp()->getOperand(0));
850 return;
851 }
852
853 // OpResult are of the format %op_number:%result_number. We elide the
854 // result number if there is only one result (the case for most ops).
855 os << "%" << location_[result.getDefiningOp()];
856 if (result.getDefiningOp()->getNumResults() > 1)
857 os << ":" << result.getResultNumber();
858 } else if (auto argument = value.dyn_cast<mlir::BlockArgument>()) {
859 if (arguments_.find(argument) == arguments_.end())
860 arguments_[argument] = next_argument_++;
861 os << "%arg" << arguments_[argument];
862 }
863 auto layout = merged_layouts_.find(value);
864 if (layout != merged_layouts_.end()) {
865 os << " \"";
866 printLayout(layout->second, os);
867 os << "\"";
868 }
869 os << " ";
870 printType(value.getType());
871 }
872
873 void printLayout(const Layout& layout, llvm::raw_ostream& os) {
874 // Layouts are printed with * for an unsharded dim and the mesh dim for a
875 // sharded dim. This keeps the layout compact.
876 for (int i = 0; i < layout.rank(); ++i) {
877 if (i > 0) os << ",";
878 if (Layout::IsUnshardedDimension(layout.sharding_spec(i)))
879 os << "*";
880 else
881 os << layout.sharding_spec(i);
882 }
883 }
884
885 // A generic op consists of a name, and any of the following:
886 // * arguments,
887 // * attributes
888 // * regions
889 // These are printed out in that order.
890 void printGenericOp(mlir::Operation* op, bool printOpName) override {
891 if (printOpName) os_ << "\"" << op->getName().getStringRef() << "\"";
892 os_ << "(";
893 for (int i = 0; i < op->getNumOperands(); ++i) {
894 if (i > 0) os_ << ", ";
895 printOperand(op->getOperand(i), os_);
896 }
897 os_ << ")";
898
899 if (!op->getAttrs().empty()) {
900 std::vector<mlir::NamedAttribute> filtered;
901 for (auto attr : op->getAttrs())
902 if (*attr.getName().str().begin() != '_' &&
903 attr.getName().str() != "device")
904 filtered.emplace_back(attr);
905 if (!filtered.empty()) {
906 os_ << " {";
907 for (int i = 0; i < filtered.size(); ++i) {
908 if (i > 0) os_ << ", ";
909 printNamedAttribute(filtered[i]);
910 }
911 os_ << "}";
912 }
913 }
914
915 if (op->getNumRegions() > 0) {
916 os_ << " (";
917 for (auto& region : op->getRegions()) printRegion(region, false, false);
918 os_ << ")";
919 }
920 };
921
922 void printSymbolName(llvm::StringRef symbolRef) override {
923 os_ << symbolRef;
924 };
925
926 void printNamedAttribute(mlir::NamedAttribute attr) {
927 os_ << attr.getName().strref() << " = ";
928 printAttribute(attr.getValue());
929 }
930
931 void printAttribute(mlir::Attribute attr) override { attr.print(os_); }
932
933 void printType(mlir::Type type) override { type.print(os_); }
934
935 // The following functions are part of the printing interface but aren't
936 // needed for the compact printing form for Layout printing.
937 void printAttributeWithoutType(mlir::Attribute attr) override{};
938 void printSuccessor(mlir::Block* successor) override{};
939 void printSuccessorAndUseList(mlir::Block* successor,
940 mlir::ValueRange succOperands) override{};
941 void printOptionalAttrDict(
942 llvm::ArrayRef<mlir::NamedAttribute> attrs,
943 llvm::ArrayRef<llvm::StringRef> elidedAttrs) override{};
944 void printOptionalAttrDictWithKeyword(
945 llvm::ArrayRef<mlir::NamedAttribute> attrs,
946 llvm::ArrayRef<llvm::StringRef> elidedAttrs) override{};
947
948 void shadowRegionArgs(mlir::Region& region,
949 mlir::ValueRange namesToUse) override{};
950 void printAffineMapOfSSAIds(mlir::AffineMapAttr mapAttr,
951 mlir::ValueRange operands) override{};
952
953 void printAffineExprOfSSAIds(mlir::AffineExpr expr,
954 mlir::ValueRange dimOperands,
955 mlir::ValueRange symOperands) override{};
956
957 private:
958 int indent_level_;
959 llvm::raw_ostream& os_;
960 llvm::DenseMap<mlir::Operation*, int> location_;
961 int current_location_;
962 llvm::DenseMap<mlir::BlockArgument, int> arguments_;
963 int next_argument_;
964 const llvm::DenseMap<mlir::Value, Layout>& merged_layouts_;
965};
966
967// Log the current set of layouts to a file marked by the hash of the input
968// module and the stage.
969void LogLayoutsAndOps(const int stage, const uint64_t module_hash,
970 const llvm::DenseMap<mlir::Value, Layout>& merged_layouts,
971 mlir::ModuleOp& module) {
972 if (module->hasAttr(kDoNotLog) || ((ClientId() != 0) && !LogOnAllTasks()))
973 return;
974
975 std::string prefix = tensorflow::GetDumpDirFromEnvVar();
976 if (prefix.empty()) return;
977
978 auto* env = tensorflow::Env::Default();
979 auto status = env->RecursivelyCreateDir(prefix);
980 if (!status.ok()) {
981 LOG(WARNING) << "cannot create directory '" + prefix +
982 "': " + status.error_message();
983 return;
984 }
985
986 absl::StrAppend(&prefix, "/layout_propagation_v2_module_", module_hash,
987 "_stage_", stage, "_");
988 if (!tensorflow::Env::Default()->CreateUniqueFileName(&prefix, ".mlir")) {
989 LOG(WARNING) << "cannot create unique filename, won't dump MLIR module.";
990 return;
991 }
992
993 std::unique_ptr<WritableFile> file_writer;
994 status = env->NewWritableFile(prefix, &file_writer);
995 if (!status.ok()) {
996 LOG(WARNING) << "cannot open file '" + prefix +
997 "': " + status.error_message();
998 return;
999 }
1000
1001 // Print the module to a string before writing to the file.
1002 std::string txt_module;
1003 {
1004 llvm::raw_string_ostream os(txt_module);
1005 LayoutPrinter printer(os, merged_layouts);
1006 module.print(printer);
1007 }
1008
1009 status = file_writer->Append(txt_module);
1010 if (!status.ok()) {
1011 LOG(WARNING) << "error writing to file '" + prefix +
1012 "': " + status.error_message();
1013 return;
1014 }
1015 (void)file_writer->Close();
1016 LOG(INFO) << "Dumped MLIR module to " << prefix;
1017}
1018
1019// Canonicalizer and DCE transformation passes may removed ops in the graph and
1020// result in multiple consecutive DTensorLayout ops. Detect all such cases and
1021// replace unnecessary DTensorLayout ops with Identity ops.
1022mlir::LogicalResult ReplaceAuxiliaryDTensorLayoutOpsWithIdentity(
1023 mlir::ModuleOp module) {
1024 llvm::SmallVector<mlir::TF::DTensorLayout, 4> layout_ops;
1025 module.walk([&](mlir::TF::DTensorLayout op) { layout_ops.emplace_back(op); });
1026
1027 for (auto layout_op : llvm::reverse(layout_ops)) {
1028 auto input_op = layout_op.input().getDefiningOp();
1029 if (auto input_layout_op =
1030 llvm::dyn_cast_or_null<mlir::TF::DTensorLayout>(input_op)) {
1031 // Check that layout of input DTensorLayout op is equivalent to
1032 // the layout of its connected DTensorLayout op.
1033 if (layout_op.layout() != input_layout_op.layout())
1034 return layout_op.emitOpError(
1035 "Found inconsistent layout. This should never happen.");
1036
1037 // Replace DTensorLayout op with identity op.
1038 mlir::OpBuilder builder(layout_op);
1039 auto identity = builder.create<mlir::TF::IdentityOp>(
1040 layout_op->getLoc(), layout_op.getType(), layout_op.input());
1041 layout_op.output().replaceAllUsesWith(identity.output());
1042 layout_op.erase();
1043 }
1044 }
1045
1046 return mlir::success();
1047}
1048
1049// Inserts/changes DTensorLayout op after IfRegion op and results of then/else
1050// branches to ensure that the return values of IfRegion ops are consistent.
1051// After layout propagation, layouts of return value of tf.IfRegion op, and
1052// layouts of terminators of then/else branches of IfRegion op may be different.
1053// In that case, the layouts of returns values must be merged to a same layout
1054// as return values of IfRegion op and results of then/else branches are
1055// semantically equivalent.
1056mlir::LogicalResult InsertDTensorLayoutForIfRegionOp(
1057 const llvm::SmallVectorImpl<mlir::TF::IfRegionOp>& if_ops,
1058 mlir::MLIRContext* context) {
1059 for (mlir::TF::IfRegionOp if_op : if_ops) {
1060 for (mlir::OpResult if_result : if_op.getResults()) {
1061 const int result_index = if_result.getResultNumber();
1062 mlir::Value then_branch_result = if_op.then_branch()
1063 .front()
1064 .getTerminator()
1065 ->getOpOperand(result_index)
1066 .get();
1067 mlir::Value else_branch_result = if_op.else_branch()
1068 .front()
1069 .getTerminator()
1070 ->getOpOperand(result_index)
1071 .get();
1072
1073 auto if_result_layout =
1074 llvm::dyn_cast<mlir::TF::DTensorLayout>(*if_result.user_begin());
1075 auto then_result_layout = llvm::dyn_cast<mlir::TF::DTensorLayout>(
1076 *then_branch_result.getDefiningOp());
1077 auto else_result_layout = llvm::dyn_cast<mlir::TF::DTensorLayout>(
1078 *else_branch_result.getDefiningOp());
1079 llvm::SmallVector<Layout, 4> layouts{if_result_layout.layout(),
1080 then_result_layout.layout(),
1081 else_result_layout.layout()};
1082 std::set<Layout> layouts_set{layouts.begin(), layouts.end()};
1083 if (layouts_set.size() == 1) continue;
1084
1085 absl::optional<Layout> merged_layout;
1086 if (mlir::failed(
1087 GetMostShardedLayout(layouts, if_op.getLoc(), &merged_layout)))
1088 return mlir::failure();
1089 assert(merged_layout);
1090
1091 if_result_layout->setAttr(
1092 kQualifiedLayoutAttr,
1093 mlir::dtensor::LayoutAttr::get(context, *merged_layout));
1094 then_result_layout->setAttr(
1095 kQualifiedLayoutAttr,
1096 mlir::dtensor::LayoutAttr::get(context, *merged_layout));
1097 else_result_layout->setAttr(
1098 kQualifiedLayoutAttr,
1099 mlir::dtensor::LayoutAttr::get(context, *merged_layout));
1100 }
1101 }
1102 return mlir::success();
1103}
1104
1105// Inserts necessary DTensorRelayout ops so that the layouts for while loops
1106// are correct.
1107//
1108// Due to how while loop layout propagation is done, we may need to fix the
1109// layouts so that the second and beyond step of the loop receive a tensor with
1110// the correct layout.
1111// E.g.
1112// %b = tf.WhileRegion(%a) ({
1113// %bb0(%arg0): # Cond
1114// %c = tf.A(%arg0)
1115// tf.Yield(%c)
1116// }, {
1117// %bb0(%arg0): # Body
1118// %d = tf.B(%arg0)
1119// tf.Yield(%d)
1120// }
1121// }
1122// %e = tf.C(%b)
1123//
1124// Layout propagation treats the loop body as if it were an inlined function and
1125// does not have a condition which fixes the layout of %d, as return value, to
1126// match the layout of %arg0 (or %a).
1127//
1128// Towards this, we:
1129// 1) Check the layout of %arg0 and see if matches the layout of the input 0
1130// (%d) of tf.Yield.
1131// 2) If it doesn't match we update the we insert a DTensorRelayout op between
1132// %d and tf.Yield with the correct layout and insert a second
1133// DTensorRelayout op after the loop body.
1134//
1135// NOTE: that it is necessary in general to insert both DTensorRelayout ops,
1136// as opposed to just updating the layout of %d (which would in general be more
1137// efficient) since %d may still be used by other ops in the loop body.
1138//
1139// NOTE: this is not needed for the condition as the output of the condition is
1140// a scalar and therefore always replicated.
1141mlir::LogicalResult InsertRelayoutForWhileLoops(
1142 const llvm::SmallVectorImpl<mlir::TF::WhileRegionOp>& while_ops,
1143 mlir::OpBuilder& builder) {
1144 for (mlir::TF::WhileRegionOp op : while_ops) {
1145 // Get the terminator so we can check the output layouts of the loop body.
1146 mlir::Operation* yield_op = op.body().front().getTerminator();
1147 if (!mlir::isa<mlir::TF::YieldOp>(yield_op))
1148 return op->emitOpError() << "body terminator is not a Yield op.";
1149
1150 for (int i = 0; i < op.body().getNumArguments(); ++i) {
1151 // Inputs should only have one, a DTensorLayout op.
1152 mlir::Value argument = op.body().getArgument(i);
1153 if (!argument.hasOneUse())
1154 return op.emitOpError()
1155 << "body argument " << i << " doesn't have a single use.";
1156 mlir::Operation* input_layout_op = argument.getUses().begin().getUser();
1157 if (!mlir::isa<mlir::TF::DTensorLayout>(input_layout_op))
1158 return op.emitOpError() << "body argument " << i
1159 << " is not consumed by a DTensorLayout op.";
1160 const Layout input_layout =
1161 mlir::cast<mlir::TF::DTensorLayout>(input_layout_op).layout();
1162
1163 // Inputs to Yield should also be a DTensorLayout op.
1164 if (!yield_op->getOperand(i).isa<mlir::OpResult>() ||
1165 !mlir::isa<mlir::TF::DTensorLayout>(
1166 yield_op->getOperand(i).getDefiningOp()))
1167 return yield_op->emitOpError()
1168 << "argument " << i << " to is not a DTensorLayout op.";
1169 mlir::Operation* output_layout_op =
1170 yield_op->getOperand(i).getDefiningOp();
1171 const Layout output_layout =
1172 mlir::cast<mlir::TF::DTensorLayout>(output_layout_op).layout();
1173
1174 // If the layouts are equal we have nothing to do. Note that this caches
1175 // the case that that input and output are a resource, since the layout
1176 // of a resource is fixed.
1177 if (input_layout == output_layout) continue;
1178
1179 // Insert the first Relayout op (in the loop body).
1180 builder.setInsertionPointAfter(output_layout_op);
1181 if (!yield_op->getOperand(i).getType().isa<mlir::TensorType>())
1182 return yield_op->emitOpError()
1183 << "operand " << i << " does not have TensorType";
1184 mlir::TF::ShapeAttr global_shape = mlir::TF::ShapeAttr::get(
1185 builder.getContext(),
1186 yield_op->getOperand(i).getType().cast<mlir::TensorType>());
1187 mlir::TF::RelayoutOp first_relayout =
1188 builder.create<mlir::TF::RelayoutOp>(
1189 op.getLoc(), yield_op->getOperand(i).getType(),
1190 yield_op->getOperand(i), input_layout.ToString());
1191 mlir::TF::DTensorLayout first_layout_op =
1192 builder.create<mlir::TF::DTensorLayout>(
1193 op.getLoc(), first_relayout.output(),
1194 mlir::dtensor::LayoutAttr::get(builder.getContext(),
1195 input_layout),
1196 global_shape);
1197 yield_op->setOperand(i, first_layout_op.output());
1198
1199 // Insert the second relayout op after the loop itself.
1200 builder.setInsertionPointAfter(op);
1201 mlir::TF::DTensorLayout second_layout_op =
1202 builder.create<mlir::TF::DTensorLayout>(
1203 op.getLoc(), op->getResult(i),
1204 mlir::dtensor::LayoutAttr::get(builder.getContext(),
1205 input_layout),
1206 global_shape);
1207 mlir::TF::RelayoutOp second_relayout =
1208 builder.create<mlir::TF::RelayoutOp>(
1209 op.getLoc(), second_layout_op.output().getType(),
1210 second_layout_op.output(), output_layout.ToString());
1211 op->getResult(i).replaceAllUsesExcept(
1212 second_relayout.output(), llvm::SmallPtrSet<mlir::Operation*, 1>{
1213 second_layout_op.getOperation()});
1214 }
1215 }
1216 return mlir::success();
1217}
1218
1219// For all constants with multiple usages, clone the constants so that each
1220// constant operation has at most 1 usage.
1221void DuplicateConstants(mlir::ModuleOp module) {
1222 llvm::SmallVector<mlir::TF::ConstOp, 4> const_ops;
1223 module.walk(
1224 [&](mlir::TF::ConstOp const_op) { const_ops.emplace_back(const_op); });
1225
1226 for (mlir::TF::ConstOp const_op : const_ops) {
1227 mlir::OpBuilder builder(const_op);
1228 auto uses = const_op->getUses();
1229 if (uses.empty()) return;
1230
1231 llvm::SmallDenseMap<mlir::Operation*, mlir::OpOperand*> const_use_map;
1232 mlir::OpOperand& first_use = *uses.begin();
1233 for (mlir::OpOperand& use : uses) {
1234 if (&use == &first_use) continue;
1235
1236 mlir::Operation* new_const = builder.clone(*const_op);
1237 const_use_map.try_emplace(new_const, &use);
1238 }
1239
1240 for (const auto& it : const_use_map) it.second->set(it.first->getResult(0));
1241 }
1242}
1243
1244// Find the root(s) values of "current_value" within the cycle, and put it
1245// into "roots".
1246void FindRoot(
1247 const llvm::DenseSet<mlir::Value>& is_updated,
1248 const mlir::Value& current_value,
1249 llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
1250 llvm::DenseSet<mlir::Value>* roots) {
1251 // Standard BFS to find root values of current_value.
1252 std::deque<mlir::Value> to_process;
1253 to_process.push_back(current_value);
1254
1255 llvm::DenseSet<mlir::Value> visited;
1256 visited.insert(current_value);
1257
1258 while (!to_process.empty()) {
1259 int level_size = to_process.size();
1260 for (int UNUSED = 0; UNUSED < level_size; ++UNUSED) {
1261 mlir::Value cur_val = to_process.front();
1262 to_process.pop_front();
1263
1264 // Terminating condition, if there is no defining op then this is a root.
1265 mlir::Operation* defining_op = cur_val.getDefiningOp();
1266 if (defining_op == nullptr) {
1267 roots->insert(current_value);
1268 continue;
1269 }
1270
1271 // Expand out from 'cur_val' one step closer to roots. If there was
1272 // no-one one step closer to root, then this is a root.
1273 bool is_root = true;
1274 for (int i = 0; i < defining_op->getNumOperands(); ++i) {
1275 mlir::Value operand = defining_op->getOperand(i);
1276 if (operand != cur_val && is_updated.contains(operand)) {
1277 is_root = false;
1278
1279 if (!visited.contains(operand)) {
1280 visited.insert(operand);
1281 to_process.push_back(operand);
1282 }
1283 }
1284 }
1285
1286 if (is_root) roots->insert(cur_val);
1287 }
1288 }
1289}
1290
1291// Finds the root value(s) of the values that have layouts cycling back and
1292// forth in an infinite loop during layout propagation and prints the closest TF
1293// op that consumes those root value(s). This allows users and developers to
1294// debug the root cause of layouts that should be changed to prevent infinite
1295// layout propagation.
1296void FindRootsAndEmitError(
1297 mlir::ModuleOp& module,
1298 llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>> producers,
1299 const llvm::DenseSet<mlir::Value>& is_updated) {
1300 llvm::DenseSet<mlir::Value> roots;
1301 for (auto& value : is_updated) {
1302 FindRoot(is_updated, value, producers, &roots);
1303 }
1304 module.emitOpError()
1305 << "Maximum number of layout propagation steps reached. Unable to "
1306 "converge to a fixed layout. Please rerun with a higher limit in the "
1307 "DTENSOR_LAYOUT_PROPAGATION_MAX_STEPS environment variable.";
1308 for (auto& root : roots) {
1309 for (mlir::OpOperand& operand : root.getUses()) {
1310 llvm::DenseMap<llvm::StringRef, mlir::Operation*> func_to_caller;
1311 llvm::SmallVector<mlir::Value, 4> skipped_values;
1312
1313 // For each root value that may need a different layout, find the
1314 // closest TF op that consumes it and print it.
1315 llvm::SmallVector<mlir::OpOperand*, 4> consuming_operands =
1316 TraceUseToNextTFOp(&operand, func_to_caller, &skipped_values);
1317
1318 for (mlir::OpOperand* new_operand : consuming_operands) {
1319 mlir::Operation* operation = new_operand->getOwner();
1320 mlir::Location loc = operation->getLoc();
1321 operation->emitOpError() << '\n'
1322 << "The following op consumes tensors that "
1323 "may need a different layout. "
1324 "["
1325 << mlir::GetNameFromLoc(loc) << "]" << '\n';
1326 }
1327 }
1328 }
1329}
1330} // namespace
1331
1332// Runs an iteration of layout propagation, where we merge producer and consumer
1333// requests and then recompute recommended layouts on all operations that
1334// are connected to an updated layout.
1335Status RunOneIteration(
1336 llvm::DenseSet<mlir::Value>& is_locked,
1337 llvm::DenseSet<mlir::Value>& is_updated,
1338 llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
1339 llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>&
1340 consumer_requests,
1341 llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
1342 llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers,
1343 llvm::DenseMap<mlir::Value, Layout>& merged_layouts, mlir::ModuleOp& module,
1344 const uint64_t module_hash, int* stage) {
1345 if (is_updated.empty()) return OkStatus();
1346 // Merge any possibly updated layouts.
1347 if (mlir::failed(
1348 MergeAndGetUpdatedLayouts(is_locked, is_updated, producer_request,
1349 consumer_requests, merged_layouts)))
1350 return errors::Internal(
1351 "MergeAndGetUpdatedLayouts failed to merge layouts.");
1352
1353 // Compile a list of operations with updated inputs or outputs.
1354 llvm::DenseSet<mlir::Operation*> operations_needing_update;
1355 GetOperationsNeedingUpdate(is_updated, consumers, operations_needing_update);
1356 is_updated.clear();
1357
1358 if (VLOG_IS_ON(2)) {
1359 LogLayoutsAndOps(*stage, module_hash, merged_layouts, module);
1360 }
1361
1362 for (auto* op : operations_needing_update) {
1363 if (mlir::failed(UpdateLayoutsForOp(op, producers, merged_layouts,
1364 producer_request, consumer_requests,
1365 is_updated)))
1366 return errors::Internal("UpdateLayoutsForOp failed to update layouts.");
1367 }
1368 ++(*stage);
1369 return OkStatus();
1370}
1371
1372// Compares every value's layouts in `merged_a` with the ones in `merged_b`,
1373// and store the values that differ in `changed`.
1374Status CompareMergedLayouts(const llvm::DenseMap<mlir::Value, Layout>& merged_a,
1375 const llvm::DenseMap<mlir::Value, Layout>& merged_b,
1376 llvm::DenseSet<mlir::Value>& changed) {
1377 if (merged_a.size() != merged_b.size())
1378 return errors::Internal(
1379 "Both merged_layouts did not have the same number of set layouts.");
1380 for (const auto& value_and_layout : merged_a) {
1381 const mlir::Value value = value_and_layout.getFirst();
1382 const Layout& layout = value_and_layout.getSecond();
1383 auto value_and_layout_in_b = merged_b.find(value);
1384 if (value_and_layout_in_b == merged_b.end())
1385 return errors::Internal(
1386 "Comparing merged_layouts that contain different mlir::Value's.");
1387 if (value_and_layout_in_b->second != layout) {
1388 changed.insert(value);
1389 }
1390 }
1391 return OkStatus();
1392}
1393
1394// MLIR pass that propagates layout for all ops the module.
1395struct DLayoutPropagationPassV2
1396 : public impl::DTensorLayoutPropagationV2Base<DLayoutPropagationPassV2> {
1397 void getDependentDialects(mlir::DialectRegistry& registry) const override {
1398 registry.insert<mlir::dtensor::DTensorDialect>();
1399 }
1400
1401 void runOnOperation() override {
1402 mlir::MLIRContext& context = getContext();
1403 mlir::OpBuilder builder(&context);
1404
1405 auto module = getOperation();
1406
1407 if (mlir::failed(ReplaceAuxiliaryDTensorLayoutOpsWithIdentity(module)))
1408 return signalPassFailure();
1409
1410 // In order to ensure that constant operations with multiple usages with
1411 // different consumer layout requests does not lead to replicated constant
1412 // tensors, we duplicate all constants to have at most 1 usages.
1413 // After SPMD Expansion, these duplicated constants will be merged back
1414 // during SCCP pass.
1415 DuplicateConstants(module);
1416
1417 mlir::func::FuncOp main_func =
1418 module.lookupSymbol<mlir::func::FuncOp>("main");
1419 if (!main_func) return;
1420
1421 mlir::Dialect* tf_dialect =
1422 context.getLoadedDialect<mlir::TF::TensorFlowDialect>();
1423
1424 // This maps from OpResults to a list of OpOperands that consume this.
1425 // Note that this will pass over/through
1426 // (Stateful)PartitionedCall and other control flow, directly connecting
1427 // producing ops to their consumers in the function. I.e. it presents
1428 // flattened/inlined view of the flow of data.
1429 llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>> consumers;
1430 // Maintain a reverse mapping.
1431 llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>> producers;
1432 // For each mlir::Value this is what the producer would like to have the
1433 // layout be.
1434 llvm::DenseMap<mlir::Value, absl::optional<Layout>> producer_request;
1435 // For each mlir::Value this is what the consumers would like to have the
1436 // layout be. Note the map is in 'parallel' to the consumers map above.
1437 llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>
1438 consumer_requests;
1439 // Tracks if the layout was updated since last cycle.
1440 llvm::DenseSet<mlir::Value> is_updated;
1441 // Tracks if the layout is locked. In this case we don't pass consumer
1442 // layouts to MergeLayouts. Used for input layouts and user defined layouts.
1443 llvm::DenseSet<mlir::Value> is_locked;
1444
1445 // Create consumers and producers maps.
1446 if (mlir::failed(
1447 PopulateConsumersFromModule(&module, tf_dialect, consumers)))
1448 return signalPassFailure();
1449
1450 for (auto& consumer : consumers) {
1451 for (auto* operand : consumer.second) {
1452 if (producers.find(operand) == producers.end()) {
1453 producers[operand] = std::vector<mlir::Value>{consumer.first};
1454 } else {
1455 producers[operand].emplace_back(consumer.first);
1456 }
1457 }
1458 }
1459
1460 // Setup the initial starting conditions for the layout algorithm
1461 if (mlir::failed(InsertInitialLayouts(
1462 module, main_func, consumers, producers, consumer_requests,
1463 producer_request, is_updated, is_locked)))
1464 return signalPassFailure();
1465
1466 const auto module_hash = OpHash(module);
1467 int stage = 0;
1468
1469 llvm::DenseMap<mlir::Value, Layout> merged_layouts;
1470 Status status;
1471
1472 while (!is_updated.empty() && stage < kLayoutPropagationMaxStages) {
1473 ++stage;
1474 int steps = 0;
1475 // Step 1. Run the layout propagation v2 until convergence or max steps.
1476 while (!is_updated.empty() && steps < LayoutPropagationMaxSteps()) {
1477 Status status = RunOneIteration(
1478 is_locked, is_updated, producer_request, consumer_requests,
1479 producers, consumers, merged_layouts, module, module_hash, &steps);
1480 if (!status.ok()) {
1481 module.emitOpError() << "Failure running iteration.";
1482 return signalPassFailure();
1483 }
1484 }
1485 if (VLOG_IS_ON(2)) {
1486 LOG(INFO) << "Failed to converge in stage " << stage;
1487 }
1488 // Step 2. If we are here, then we failed to converge, and likely
1489 // there is an oscillation of layouts. Detect all the edges that are
1490 // changing layouts.
1491 llvm::DenseMap<mlir::Value, Layout> merged_layouts_at_max_steps =
1492 merged_layouts;
1493 llvm::DenseSet<mlir::Value> changed;
1494 int previous_change_size = -1;
1495
1496 while (changed.size() > previous_change_size) {
1497 if (!RunOneIteration(is_locked, is_updated, producer_request,
1498 consumer_requests, producers, consumers,
1499 merged_layouts, module, module_hash, &steps)
1500 .ok()) {
1501 module.emitOpError() << "Failure running iteration.";
1502 return signalPassFailure();
1503 }
1504 if (!CompareMergedLayouts(merged_layouts_at_max_steps, merged_layouts,
1505 changed)
1506 .ok()) {
1507 module.emitOpError() << "Failure comparing merged layouts.";
1508 return signalPassFailure();
1509 }
1510 previous_change_size = changed.size();
1511 }
1512
1513 // Step 3. Layouts that haven't changed means they're not part of the
1514 // cyclic problem, so freeze them.
1515 for (const auto& value_and_layout : merged_layouts) {
1516 const mlir::Value value = value_and_layout.getFirst();
1517 if (changed.find(value) == changed.end()) {
1518 is_locked.insert(value);
1519 }
1520 }
1521 // Step 4. Any information corresponding to the changed layouts
1522 // should be disinfected, we do this by clearing all information
1523 // regarding them.
1524 for (const mlir::Value changed_value : changed) {
1525 producer_request.erase(changed_value);
1526 consumer_requests.erase(changed_value);
1527 merged_layouts.erase(changed_value);
1528 }
1529
1530 // Step 5. ComputeLayout again on all the ops linked to the changed
1531 // layouts. The next iteration will take this information and merge again.
1532 llvm::DenseSet<mlir::Operation*> operations_needing_update;
1533 is_updated = changed;
1534 GetOperationsNeedingUpdate(is_updated, consumers,
1535 operations_needing_update);
1536 is_updated.clear();
1537
1538 for (auto* op : operations_needing_update) {
1539 if (mlir::failed(UpdateLayoutsForOp(op, producers, merged_layouts,
1540 producer_request, consumer_requests,
1541 is_updated))) {
1542 module.emitOpError() << "Failure in UpdateLayoutsForOp.";
1543 return signalPassFailure();
1544 }
1545 }
1546 }
1547
1548 if (stage >= kLayoutPropagationMaxStages) {
1549 FindRootsAndEmitError(module, producers, is_updated);
1550 return signalPassFailure();
1551 }
1552
1553 if (mlir::failed(
1554 CopyLayoutsForSkippedOps(module, tf_dialect, merged_layouts)))
1555 return signalPassFailure();
1556
1557 if (VLOG_IS_ON(2)) {
1558 LogLayoutsAndOps(stage, module_hash, merged_layouts, module);
1559 }
1560
1561 if (!AllOpResultsHaveLayouts(&module, tf_dialect, merged_layouts))
1562 return signalPassFailure();
1563
1564 if (mlir::failed(InsertDTensorLayoutOps(builder, merged_layouts)))
1565 return signalPassFailure();
1566
1567 // Handle layout of control flow operations.
1568 llvm::SmallVector<mlir::TF::IfRegionOp, 4> if_ops;
1569 llvm::SmallVector<mlir::TF::WhileRegionOp, 4> while_ops;
1570 module.walk([&](mlir::Operation* op) {
1571 if (auto if_op = llvm::dyn_cast<mlir::TF::IfRegionOp>(op))
1572 if_ops.emplace_back(if_op);
1573 else if (auto while_op = llvm::dyn_cast<mlir::TF::WhileRegionOp>(op))
1574 while_ops.emplace_back(while_op);
1575 });
1576
1577 if (mlir::failed(InsertRelayoutForWhileLoops(while_ops, builder)))
1578 return signalPassFailure();
1579
1580 if (mlir::failed(
1581 InsertDTensorLayoutForIfRegionOp(if_ops, builder.getContext())))
1582 return signalPassFailure();
1583 };
1584};
1585
1586} // namespace
1587
1588std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
1589CreateDTensorLayoutPropagationPassV2() {
1590 return std::make_unique<DLayoutPropagationPassV2>();
1591}
1592
1593} // namespace dtensor
1594} // namespace tensorflow
1595