1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <algorithm> |
17 | #include <iterator> |
18 | #include <queue> |
19 | #include <string> |
20 | |
21 | #include "absl/container/flat_hash_set.h" |
22 | #include "absl/types/optional.h" |
23 | #include "llvm/ADT/DenseMap.h" |
24 | #include "llvm/ADT/DenseSet.h" |
25 | #include "llvm/ADT/STLExtras.h" |
26 | #include "llvm/ADT/SmallVector.h" |
27 | #include "llvm/ADT/StringRef.h" |
28 | #include "llvm/Support/FormatVariadic.h" |
29 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
30 | #include "mlir/IR/Attributes.h" // from @llvm-project |
31 | #include "mlir/IR/Builders.h" // from @llvm-project |
32 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
33 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
34 | #include "mlir/IR/MLIRContext.h" // from @llvm-project |
35 | #include "mlir/IR/OpImplementation.h" // from @llvm-project |
36 | #include "mlir/IR/Operation.h" // from @llvm-project |
37 | #include "mlir/IR/TypeUtilities.h" // from @llvm-project |
38 | #include "mlir/IR/Types.h" // from @llvm-project |
39 | #include "mlir/IR/Value.h" // from @llvm-project |
40 | #include "mlir/IR/Visitors.h" // from @llvm-project |
41 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
42 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
43 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
44 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" |
45 | #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" |
46 | #include "tensorflow/compiler/mlir/utils/name_utils.h" |
47 | #include "tensorflow/dtensor/cc/constants.h" |
48 | #include "tensorflow/dtensor/cc/dtensor_utils.h" |
49 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
50 | #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h" |
51 | #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dtensor_attributes.h" |
52 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
53 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
54 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
55 | #include "tensorflow/dtensor/mlir/op_utils.h" |
56 | #include "tensorflow/dtensor/mlir/spmd_expander.h" |
57 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
58 | #include "tensorflow/dtensor/mlir/value_utils.h" |
59 | |
60 | namespace tensorflow { |
61 | namespace dtensor { |
62 | |
63 | namespace { |
64 | #define GEN_PASS_DEF_DTENSORLAYOUTPROPAGATIONV2 |
65 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
66 | |
67 | // This value dictates how many times during layout propagation we allow |
68 | // fixing of oscillatory behaviors. |
69 | constexpr int kLayoutPropagationMaxStages = 3; |
70 | |
71 | bool AllOpResultsHaveLayouts( |
72 | mlir::ModuleOp* module, mlir::Dialect* tf_dialect, |
73 | const llvm::DenseMap<mlir::Value, Layout>& layouts) { |
74 | const auto& result = module->walk([&](mlir::Operation* op) { |
75 | if (op->getDialect() != tf_dialect || |
76 | mlir::isa<mlir::TF::DTensorLayout>(op)) |
77 | return mlir::WalkResult::advance(); |
78 | for (const auto& result : op->getOpResults()) { |
79 | if (layouts.find(result) == layouts.end()) { |
80 | op->emitOpError() << "missing layout for result " |
81 | << result.getResultNumber(); |
82 | return mlir::WalkResult::interrupt(); |
83 | } |
84 | } |
85 | return mlir::WalkResult::advance(); |
86 | }); |
87 | return !result.wasInterrupted(); |
88 | } |
89 | |
90 | void UpdateLayoutForSkippedOps( |
91 | mlir::OpOperand& operand, |
92 | const llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller, |
93 | const Layout& layout_to_copy, |
94 | llvm::DenseMap<mlir::Value, Layout>& layouts) { |
95 | llvm::SmallVector<mlir::Value, 4> skipped_values; |
96 | TraceUseToNextTFOp(&operand, func_to_caller, &skipped_values); |
97 | for (const mlir::Value& skipped_value : skipped_values) |
98 | if ((!skipped_value.isa<mlir::OpResult>() || |
99 | !mlir::isa<mlir::TF::DTensorLayout, mlir::tf_device::ClusterOp>( |
100 | skipped_value.getDefiningOp())) && |
101 | layouts.find(skipped_value) == layouts.end()) |
102 | // TraceUseToNextTFOp's contract is that it only skips over ops that |
103 | // act like the identity (such as function calls, returns, yields, |
104 | // controlflow, DTensorLayouts, etc). This means that operand layout |
105 | // that we came from is the layout we want for this value. |
106 | layouts[skipped_value] = layout_to_copy; |
107 | } |
108 | |
109 | // Some ops, which are skipped by TraceUseToNextTFOp, will not have layouts |
110 | // for their mlir::OpResults. |
111 | // E.g. during the creation of the consumers map, we skip the input and output |
112 | // of the WhileRegion op. In particular if we have: |
113 | // |
114 | // %b = tf.WhileRegion(%a) ({ |
115 | // %bb0(%arg0): # Cond |
116 | // %c = tf.A(%arg0) |
117 | // tf.Yield(%c) |
118 | // }, { |
119 | // %bb0(%arg0): # Body |
120 | // %d = tf.B(%arg0) |
121 | // tf.Yield(%d) |
122 | // } |
123 | // } |
124 | // %e = tf.C(%b) |
125 | // |
126 | // Then the consumers map would directly connect the mlir::Value %a to input 0 |
127 | // of tf.A and tf.B, bypassing the WhileRegion and the mlir::Value of %arg1. |
128 | // Similarly it would connect the mlir::Value of %d directly to input 0 of |
129 | // tf.C bypassing the mlir::Value of %b. |
130 | // This means that at the end of layout propagation the skipped values would not |
131 | // have an assigned layout. But this layout can be derived by taking the known |
132 | // layout of %a and propagating to each mlir::Value that was skipped while |
133 | // connecting %a to the input 0 of tf.A and tf.B. Similarly we derive the layout |
134 | // for %b from %d. |
135 | // |
136 | // To get layouts we |
137 | // 1) Iterate over all ops that have layouts for their OpResults and call |
138 | // TraceUseToNextTFOp to get the skipped mlir::Values. |
139 | // 2) If any skipped mlir::Value doesn't have a layout set, then we set the |
140 | // layout. |
141 | mlir::LogicalResult CopyLayoutsForSkippedOps( |
142 | mlir::ModuleOp module, mlir::Dialect* tf_dialect, |
143 | llvm::DenseMap<mlir::Value, Layout>& layouts) { |
144 | llvm::DenseMap<llvm::StringRef, mlir::Operation*> func_to_caller; |
145 | |
146 | if (mlir::failed(GetFuncToCaller(module, func_to_caller))) |
147 | return mlir::failure(); |
148 | |
149 | // Update layouts derived from ops. |
150 | module->walk([&](mlir::Operation* op) { |
151 | for (mlir::OpOperand& operand : op->getOpOperands()) { |
152 | if (layouts.find(operand.get()) == layouts.end()) continue; |
153 | const Layout layout = layouts[operand.get()]; |
154 | UpdateLayoutForSkippedOps(operand, func_to_caller, layout, layouts); |
155 | } |
156 | }); |
157 | |
158 | // Update layouts derived from inputs |
159 | mlir::func::FuncOp main_func = |
160 | module.lookupSymbol<mlir::func::FuncOp>("main" ); |
161 | if (!main_func) return mlir::success(); |
162 | |
163 | for (auto& value : main_func.getArguments()) { |
164 | if (layouts.find(value) == layouts.end()) continue; |
165 | const Layout layout = layouts[value]; |
166 | |
167 | for (mlir::OpOperand& operand : value.getUses()) |
168 | UpdateLayoutForSkippedOps(operand, func_to_caller, layout, layouts); |
169 | } |
170 | |
171 | return mlir::success(); |
172 | } |
173 | |
174 | namespace { |
175 | void FilterkAnySpecs(std::vector<std::string>& proposed_specs) { |
176 | for (auto& spec : proposed_specs) { |
177 | if (spec == Layout::kAny) spec = Layout::kUnshardedDim; |
178 | } |
179 | } |
180 | } // namespace |
181 | |
182 | // Merges the producer and consumer layouts into a single layout. |
183 | // Assumes that all layouts are of the same rank. |
184 | // Consumers are first merged together so that we have the layout which is |
185 | // sharded in a tensor dim if and only if all consumers are sharded in the same |
186 | // sharding_spec. |
187 | // If producer layout is present, we merge the consumer layouts into the layout |
188 | // of the producer: if the consumer wants a sharded layout in a tensor dimension |
189 | // where the producer is unshared *and* the mesh dimension it wants to be |
190 | // sharded over is not already sharded over by the producer, then we add that |
191 | // sharding to the producer layout. |
192 | StatusOr<Layout> MergeLayouts( |
193 | const absl::optional<Layout>& producer, |
194 | const mlir::DenseMap<mlir::OpOperand*, Layout>& consumers) { |
195 | if (consumers.empty()) return producer.value(); |
196 | |
197 | // Initialize the specs to those of the first consumer layout and merge |
198 | // consumers. |
199 | std::vector<std::string> proposed_specs = |
200 | consumers.begin()->second.sharding_spec_strs(); |
201 | int layout_rank = proposed_specs.size(); |
202 | |
203 | // Verify consumer layout ranks match. |
204 | for (const auto& consumer : consumers) { |
205 | const Layout& consumer_layout = consumer.second; |
206 | if (consumer_layout.rank() != layout_rank) |
207 | return errors::InvalidArgument( |
208 | "found two consumer layout of different ranks: " , |
209 | consumer_layout.rank(), " and " , layout_rank); |
210 | } |
211 | |
212 | // Merge consumer layouts. |
213 | for (const auto& consumer : consumers) { |
214 | const Layout& consumer_layout = consumer.second; |
215 | |
216 | // Check every tensor dimension. |
217 | for (int j = 0; j < consumer_layout.rank(); ++j) { |
218 | const std::string& consumer_spec_j = consumer_layout.sharding_spec(j); |
219 | if (consumer_spec_j == Layout::kAny) continue; |
220 | |
221 | // If the proposed spec is set as any, give priority to the consumer spec. |
222 | if (proposed_specs[j] == Layout::kAny) { |
223 | proposed_specs[j] = consumer_spec_j; |
224 | continue; |
225 | } |
226 | |
227 | // If any consumer layout disagrees with the current merge, set the |
228 | // spec to not sharded. |
229 | if (proposed_specs[j] != consumer_spec_j) |
230 | proposed_specs[j] = Layout::kUnshardedDim; |
231 | } |
232 | } |
233 | |
234 | // Filter over-sharded specs. |
235 | absl::flat_hash_map<std::string, int> counter; |
236 | for (const std::string& spec : proposed_specs) counter[spec] += 1; |
237 | for (std::string& spec : proposed_specs) |
238 | if (counter[spec] > 1) spec = Layout::kUnshardedDim; |
239 | |
240 | // Return layout if there is no producer, else move into producer algorithm. |
241 | const Mesh mesh = consumers.begin()->second.mesh(); |
242 | if (!producer) { |
243 | FilterkAnySpecs(proposed_specs); |
244 | return Layout::GetLayout(proposed_specs, mesh); |
245 | } |
246 | |
247 | if (producer->rank() != layout_rank) { |
248 | return errors::InvalidArgument( |
249 | "producer and consumer layout have different ranks: " , producer->rank(), |
250 | " and " , layout_rank); |
251 | } |
252 | |
253 | // For the producer merge, first we define mesh dims used by the producer to |
254 | // avoid creating a layout that shards twice over the same mesh dim. |
255 | llvm::DenseSet<llvm::StringRef> producer_dims; |
256 | for (int j = 0; j < producer->rank(); ++j) { |
257 | llvm::StringRef spec = producer->sharding_spec(j); |
258 | if (Layout::IsShardedDimension(spec.str())) producer_dims.insert(spec); |
259 | } |
260 | // Merge producer layout with existing layout. |
261 | for (int j = 0; j < producer->rank(); ++j) { |
262 | const std::string& producer_spec = producer->sharding_spec(j); |
263 | |
264 | if (producer_spec == proposed_specs[j] || producer_spec == Layout::kAny) |
265 | continue; |
266 | |
267 | if (proposed_specs[j] == Layout::kAny) { |
268 | proposed_specs[j] = producer_spec; |
269 | continue; |
270 | } |
271 | // If producer is unsharded and proposed_spec is sharded. Need to make sure |
272 | // mesh dim is not used elsewhere. If so, set to unsharded. |
273 | if (Layout::IsUnshardedDimension(producer_spec)) { |
274 | bool isMeshDimUsed = producer_dims.contains(proposed_specs[j]); |
275 | if (isMeshDimUsed) { |
276 | proposed_specs[j] = Layout::kUnshardedDim; |
277 | } |
278 | } else { |
279 | // If producer is sharded we can set layout to shard over same |
280 | // mesh dim. |
281 | // |
282 | // If mesh dim is already used in the layout elsewhere it will |
283 | // get unset by the case above. |
284 | proposed_specs[j] = producer_spec; |
285 | } |
286 | } |
287 | FilterkAnySpecs(proposed_specs); |
288 | return Layout::GetLayout(proposed_specs, mesh); |
289 | } |
290 | |
291 | mlir::LogicalResult InsertLayoutsForDTensorLayout( |
292 | mlir::ModuleOp& module, |
293 | llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request, |
294 | llvm::DenseSet<mlir::Value>& is_updated, |
295 | llvm::DenseSet<mlir::Value>& is_locked) { |
296 | return mlir::failure( |
297 | module |
298 | .walk([&](mlir::TF::DTensorLayout op) -> mlir::WalkResult { |
299 | // Check there are no "Layout::kAny" or "kMatch" specs in the |
300 | // layouts. |
301 | for (const std::string& spec : op.layout().sharding_spec_strs()) |
302 | if (spec == Layout::kAny || spec == Layout::kMatch) |
303 | return op->emitOpError() |
304 | << "found " << spec |
305 | << " as a sharding spec which is not allowed" ; |
306 | // Insert layout. |
307 | producer_request[op.input()].emplace(op.layout()); |
308 | is_updated.insert(op.input()); |
309 | is_locked.insert(op.input()); |
310 | return mlir::WalkResult::advance(); |
311 | }) |
312 | .wasInterrupted()); |
313 | } |
314 | |
315 | // Runs ComputeLayout API on all ops inside graph **without** any consumer |
316 | // requested layout/ operand layouts populated. |
317 | mlir::LogicalResult InsertInitialLayoutsFromComputeLayout( |
318 | mlir::ModuleOp module, |
319 | const llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers, |
320 | const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers, |
321 | llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request, |
322 | llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>& |
323 | consumer_requests, |
324 | llvm::DenseSet<mlir::Value>& is_updated) { |
325 | auto walk_result = module.walk([&](mlir::Operation* op) { |
326 | // We ignore ops that don't have either an OpResult in consumers or an |
327 | // OpOperand in producers. Note that if one operand is missing from |
328 | // producers then all operands should be missing as well as all op results |
329 | // from consumers and the opposite as well. |
330 | |
331 | if (op->getNumOperands() > 0) { |
332 | if (producers.find(&op->getOpOperand(0)) == producers.end()) |
333 | return mlir::WalkResult::advance(); |
334 | } else if (op->getNumResults() > 0) { |
335 | if (consumers.find(op->getOpResult(0)) == consumers.end()) |
336 | return mlir::WalkResult::advance(); |
337 | } else { |
338 | // Note that this case should never happen (I.e. a TF ops should have |
339 | // either inputs or outputs, but that isn't technically guaranteed). |
340 | return mlir::WalkResult::advance(); |
341 | } |
342 | |
343 | auto* expander = SPMDExpanderRegistry::Global()->GetPropagateFnForOp(op); |
344 | if (expander == nullptr) { |
345 | op->emitOpError() << "does not implement layout propagation" ; |
346 | return mlir::WalkResult::interrupt(); |
347 | } |
348 | |
349 | // Invoke ComputeLayout on `cluster_op` with empty input/consumer layouts. |
350 | StatusOr<llvm::DenseMap<int, Layout>> forward_result = |
351 | expander->ComputeLayoutForward( |
352 | op, /*input_layouts=*/llvm::DenseMap<int, Layout>(), |
353 | /*output_layouts=*/llvm::DenseMap<int, Layout>()); |
354 | if (!forward_result.ok()) { |
355 | op->emitOpError() << "ComputeLayoutForward error: " |
356 | << forward_result.status().error_message(); |
357 | return mlir::WalkResult::interrupt(); |
358 | } |
359 | StatusOr<llvm::DenseMap<int, Layout>> backward_result = |
360 | expander->ComputeLayoutBackward( |
361 | op, /*input_layouts=*/llvm::DenseMap<int, Layout>(), |
362 | /*output_layouts=*/llvm::DenseMap<int, Layout>()); |
363 | if (!backward_result.ok()) { |
364 | op->emitOpError() << "ComputeLayoutBackward error: " |
365 | << backward_result.status().error_message(); |
366 | return mlir::WalkResult::interrupt(); |
367 | } |
368 | |
369 | // If any operand layouts were returned, add the layout to consumer requests |
370 | // and set the value as updated. |
371 | for (auto const& op_idx_and_op_layout : *backward_result) { |
372 | auto const& op_idx = op_idx_and_op_layout.first; |
373 | auto const& op_layout = op_idx_and_op_layout.second; |
374 | auto& operand = op->getOpOperand(op_idx); |
375 | const auto& producer_values = producers.lookup(&operand); |
376 | for (mlir::Value producer_value : producer_values) { |
377 | if (!consumer_requests[producer_value].count(&operand)) |
378 | consumer_requests[producer_value][&operand] = op_layout; |
379 | |
380 | is_updated.insert(producer_value); |
381 | } |
382 | } |
383 | |
384 | // If any output layouts were returned, add the layout to producer requests |
385 | // and set the value as updated. |
386 | for (auto const& out_idx_and_out_layout : *forward_result) { |
387 | auto const& out_idx = out_idx_and_out_layout.first; |
388 | auto const& out_layout = out_idx_and_out_layout.second; |
389 | mlir::Value output_value = op->getResult(out_idx); |
390 | producer_request.try_emplace(output_value, out_layout); |
391 | is_updated.insert(output_value); |
392 | } |
393 | |
394 | return mlir::WalkResult::advance(); |
395 | }); |
396 | return mlir::failure(walk_result.wasInterrupted()); |
397 | } |
398 | |
399 | // Propagates mesh and inserts initial layouts for |
400 | // * Any DTensorLayout ops (this handles function inputs and other ops with user |
401 | // layouts. |
402 | // * CopyToMesh |
403 | // * ConstOp |
404 | mlir::LogicalResult InsertInitialLayouts( |
405 | mlir::ModuleOp& module, mlir::func::FuncOp& main_func, |
406 | const llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers, |
407 | const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers, |
408 | llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>& |
409 | consumer_request, |
410 | llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request, |
411 | llvm::DenseSet<mlir::Value>& is_updated, |
412 | llvm::DenseSet<mlir::Value>& is_locked) { |
413 | std::queue<mlir::Operation*> operations; |
414 | |
415 | if (mlir::failed(InsertLayoutsForDTensorLayout(module, producer_request, |
416 | is_updated, is_locked))) |
417 | return mlir::failure(); |
418 | return InsertInitialLayoutsFromComputeLayout(module, consumers, producers, |
419 | producer_request, |
420 | consumer_request, is_updated); |
421 | } |
422 | |
423 | // Given a list of mlir::Values with updated producer or consumer layouts |
424 | // update the merged_layouts list and track which layouts actually changed. |
425 | mlir::LogicalResult MergeAndGetUpdatedLayouts( |
426 | const llvm::DenseSet<mlir::Value>& is_locked, |
427 | llvm::DenseSet<mlir::Value>& is_updated, |
428 | llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request, |
429 | llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>& |
430 | consumer_requests, |
431 | llvm::DenseMap<mlir::Value, Layout>& merged_layouts) { |
432 | llvm::DenseSet<mlir::Value> updated_merge; |
433 | for (auto& value : is_updated) { |
434 | auto& producer_layout = producer_request[value]; |
435 | if (is_locked.find(value) != is_locked.end()) { |
436 | // Locked values must have a producer request. If the merged_layout is |
437 | // not already set, then this is the first pass, so we set it and mark |
438 | // then entry as updated. |
439 | if (merged_layouts.find(value) == merged_layouts.end()) { |
440 | if (!producer_layout) |
441 | return value.getDefiningOp()->emitError() << "missing locked layout" ; |
442 | merged_layouts[value] = producer_layout.value(); |
443 | updated_merge.insert(value); |
444 | } |
445 | continue; |
446 | } |
447 | auto merged = MergeLayouts(producer_layout, consumer_requests[value]); |
448 | if (!merged.ok()) |
449 | return value.getDefiningOp()->emitOpError() |
450 | << merged.status().error_message(); |
451 | |
452 | auto current_layout = merged_layouts.find(value); |
453 | if (current_layout == merged_layouts.end() || |
454 | current_layout->second != merged.value()) { |
455 | updated_merge.insert(value); |
456 | merged_layouts[value] = merged.value(); |
457 | } |
458 | } |
459 | |
460 | is_updated = updated_merge; |
461 | return mlir::success(); |
462 | } |
463 | |
464 | // Finds the most sharded merged layout given `layouts`. |
465 | mlir::LogicalResult GetMostShardedLayout(llvm::ArrayRef<Layout> layouts, |
466 | mlir::Location location, |
467 | absl::optional<Layout>* out) { |
468 | // If there are no layouts to merge, leave the output empty. |
469 | if (layouts.empty()) return mlir::success(); |
470 | |
471 | absl::optional<Layout> layout; |
472 | std::map<std::string, std::set<int>> layout_map; |
473 | for (const Layout& layout : layouts) { |
474 | for (int i = 0; i < layout.rank(); ++i) { |
475 | const std::string& mesh_dim = layout.dim(i).sharding_spec(); |
476 | if (mesh_dim == Layout::kUnshardedDim) continue; |
477 | |
478 | layout_map[mesh_dim].insert(i); |
479 | } |
480 | } |
481 | |
482 | for (auto& it : layout_map) |
483 | if (it.second.size() > 1) it.second.clear(); |
484 | |
485 | std::map<int, std::set<std::string>> dim_to_layout_map; |
486 | for (const auto& it : layout_map) { |
487 | assert(it.second.size() <= 1); |
488 | if (it.second.empty()) continue; |
489 | |
490 | const int tensor_dim_index = *it.second.begin(); |
491 | dim_to_layout_map[tensor_dim_index].insert(it.first); |
492 | } |
493 | |
494 | for (auto& it : dim_to_layout_map) |
495 | if (it.second.size() > 1) it.second.clear(); |
496 | |
497 | std::vector<std::string> merged_spec; |
498 | assert(!layouts.empty()); |
499 | for (int i = 0; i < layouts[0].rank(); ++i) { |
500 | const auto it = dim_to_layout_map.find(i); |
501 | if (it != dim_to_layout_map.end() && !it->second.empty()) { |
502 | assert(it->second.size() == 1); |
503 | merged_spec.emplace_back(*it->second.begin()); |
504 | } else { |
505 | merged_spec.emplace_back(Layout::kUnshardedDim); |
506 | } |
507 | } |
508 | const auto new_layout = Layout::GetLayout(merged_spec, layouts[0].mesh()); |
509 | if (!new_layout.ok()) { |
510 | return mlir::emitError( |
511 | location, llvm::formatv("error in layout propagation while merging " |
512 | "producer layouts. {0}" , |
513 | new_layout.status().error_message())); |
514 | } |
515 | out->emplace(*new_layout); |
516 | return mlir::success(); |
517 | } |
518 | |
519 | // Merge layouts of mlir::Value from multiple producers into a single final |
520 | // layout. A mlir::Value can have multiple producers if the value is from a |
521 | // tf.If/tf.IfRegion op. Given multiple producer layouts of the same |
522 | // mlir::Value, the merging logic is as follows: |
523 | // 1) If a dimension can be sharded, shard the dimension as much as possible. |
524 | // 2) If mesh dimension is already used or two same mesh dimensions are used |
525 | // in different dimensions, then leave the dimension as replicated. |
526 | // |
527 | // For example: |
528 | // ("x", replicated) , (replicated, "y") will have ("x", "y") merged layout. |
529 | // ("x", replicated) , (replicated, "x") will have (replicated, replicated) |
530 | // merged layout. |
531 | mlir::LogicalResult MergeProducerLayouts( |
532 | const llvm::DenseMap<mlir::Value, Layout>& merged_layouts, |
533 | const std::vector<mlir::Value>& producer_values, mlir::Location location, |
534 | absl::optional<Layout>* layout_out) { |
535 | // If there is a single producer for mlir::Value, then return the layout |
536 | // from the producer. |
537 | absl::optional<Layout> layout; |
538 | if (producer_values.size() == 1) { |
539 | const auto it = merged_layouts.find(producer_values[0]); |
540 | if (it != merged_layouts.end()) *layout_out = it->second; |
541 | return mlir::success(); |
542 | } |
543 | |
544 | // For the case with multiple producer, merge the layouts. |
545 | llvm::SmallVector<Layout, 4> candidate_layouts; |
546 | candidate_layouts.reserve(producer_values.size()); |
547 | for (mlir::Value value : producer_values) { |
548 | auto it = merged_layouts.find(value); |
549 | if (it == merged_layouts.end()) continue; |
550 | candidate_layouts.emplace_back(it->second); |
551 | } |
552 | |
553 | if (mlir::failed(GetMostShardedLayout(candidate_layouts, location, &layout))) |
554 | return mlir::failure(); |
555 | |
556 | if (layout) *layout_out = *layout; |
557 | return mlir::success(); |
558 | } |
559 | |
560 | // For an op, calls the corresponding ComputeLayouts function with the data from |
561 | // the merged_layouts map. Records the result in the producer_request and |
562 | // consumer_requests maps and notes if any layouts have changed. |
563 | mlir::LogicalResult UpdateLayoutsForOp( |
564 | mlir::Operation* op, |
565 | const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers, |
566 | const llvm::DenseMap<mlir::Value, Layout>& merged_layouts, |
567 | llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request, |
568 | llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>& |
569 | consumer_requests, |
570 | llvm::DenseSet<mlir::Value>& is_updated) { |
571 | auto* expander = SPMDExpanderRegistry::Global()->GetPropagateFnForOp(op); |
572 | if (expander == nullptr) |
573 | return op->emitOpError() << "does not implement layout propagation" ; |
574 | |
575 | // Get input and output layouts for this op from the merged_layouts map. |
576 | llvm::DenseMap<int, Layout> input_layouts(op->getNumOperands()); |
577 | llvm::DenseMap<int, Layout> output_layouts(op->getNumResults()); |
578 | |
579 | for (int i = 0; i < op->getNumOperands(); ++i) { |
580 | // For inputs, we need to find the producer's mlir::Value that eventually |
581 | // feeds into this op. This is in the producers map. |
582 | // Merge different layouts for multiples producers `values`. |
583 | auto producer_values = producers.find(&(op->getOpOperand(i))); |
584 | if (producer_values == producers.end()) |
585 | return op->emitError() << "Unable to find producer for operand " << i; |
586 | |
587 | absl::optional<Layout> layout; |
588 | if (mlir::failed(MergeProducerLayouts(merged_layouts, |
589 | producer_values->getSecond(), |
590 | op->getLoc(), &layout))) |
591 | return mlir::failure(); |
592 | |
593 | if (layout) input_layouts[i] = *layout; |
594 | } |
595 | |
596 | for (int i = 0; i < op->getNumResults(); ++i) { |
597 | auto layout = merged_layouts.find(op->getOpResult(i)); |
598 | if (layout != merged_layouts.end()) output_layouts[i] = layout->second; |
599 | } |
600 | |
601 | auto forward_result = |
602 | expander->ComputeLayoutForward(op, input_layouts, output_layouts); |
603 | if (!forward_result.ok()) { |
604 | return op->emitOpError() << "ComputeLayoutForward error: " |
605 | << forward_result.status().error_message(); |
606 | } |
607 | const auto new_output_layouts = *forward_result; |
608 | auto backward_result = |
609 | expander->ComputeLayoutBackward(op, input_layouts, output_layouts); |
610 | if (!backward_result.ok()) { |
611 | return op->emitOpError() << "ComputeLayoutBackward error: " |
612 | << backward_result.status().error_message(); |
613 | } |
614 | const auto new_input_layouts = *backward_result; |
615 | |
616 | // Update the consumer layouts for this op. |
617 | for (int i = 0; i < op->getNumOperands(); ++i) { |
618 | mlir::OpOperand* operand = &(op->getOpOperand(i)); |
619 | // No need to check that this exists, we already did it above. |
620 | const auto& producer_values = producers.find(operand); |
621 | const auto input_layout = new_input_layouts.find(i); |
622 | |
623 | for (mlir::Value value : producer_values->getSecond()) { |
624 | auto& consumer_request = consumer_requests[value]; |
625 | const auto consumer_request_from_op_operand = |
626 | consumer_request.find(operand); |
627 | |
628 | // Update the consumer_request for this OpOperand: we respect what compute |
629 | // layout returns and erase the a requested layout if no layout is |
630 | // returned. |
631 | // TODO(hongjunchoi, bfontain): Consider the case when op output type is |
632 | // resource type with subtype information. |
633 | if (input_layout != new_input_layouts.end() && |
634 | (consumer_request_from_op_operand == consumer_request.end() || |
635 | input_layout->second != consumer_request_from_op_operand->second)) { |
636 | // RestoreV2 op most likely would have unknown rank upon restoring, and |
637 | // we relax unknown rank check for the inputs that are produced from |
638 | // there. |
639 | const bool exempt_restore_unknown_rank = |
640 | ValueRank(value) == -1 && value.getDefiningOp() && |
641 | llvm::isa<mlir::TF::RestoreV2Op>(value.getDefiningOp()); |
642 | if (!exempt_restore_unknown_rank && |
643 | input_layout->second.rank() != ValueRank(value)) |
644 | return op->emitOpError() |
645 | << "Rank for input " << i << " layout is " |
646 | << input_layout->second.rank() << " but actual rank is " |
647 | << ValueRank(value); |
648 | |
649 | // If there was a layout returned and either no previous request or the |
650 | // request changed, insert and mark as updated. |
651 | consumer_request[operand] = input_layout->second; |
652 | is_updated.insert(value); |
653 | } else if (input_layout == new_input_layouts.end() && |
654 | consumer_request_from_op_operand != consumer_request.end()) { |
655 | // If no layout was returned and there is previous request, erase the |
656 | // old consumer request. |
657 | consumer_request.erase(operand); |
658 | is_updated.insert(value); |
659 | } |
660 | } |
661 | } |
662 | |
663 | // Update the producer request for this op. |
664 | // If the output is different from what is in the request list, update the |
665 | // the request and mark the mlir::Value as having an updated Layout request. |
666 | for (int i = 0; i < op->getNumResults(); ++i) { |
667 | const auto output_layout = new_output_layouts.find(i); |
668 | if (output_layout == new_output_layouts.end()) continue; |
669 | const auto& result = op->getOpResult(i); |
670 | if (producer_request[result] != output_layout->second) { |
671 | if (output_layout->second.rank() != ValueRank(result)) |
672 | return op->emitOpError() << "Rank for output " << i << " layout is " |
673 | << output_layout->second.rank() |
674 | << " but actual rank is " << ValueRank(result); |
675 | producer_request[result] = output_layout->second; |
676 | is_updated.insert(result); |
677 | } |
678 | } |
679 | return mlir::success(); |
680 | } |
681 | |
682 | mlir::LogicalResult InsertDTensorLayoutOps( |
683 | mlir::OpBuilder& builder, |
684 | const llvm::DenseMap<mlir::Value, Layout>& merged_layouts) { |
685 | for (const auto& merged_layout : merged_layouts) { |
686 | // merged_layout is a pair of mlir::Value and Layout. |
687 | // If there is only one user of the Value and that user is a DTensorLayout |
688 | // op, then we can skip creating the op as the layout is already there. Note |
689 | // that we specifically do not allow updating a layout in an already present |
690 | // DTensorLayout op as we have considered them to be 'locked' throughout |
691 | // the algorithm. |
692 | const auto& users = merged_layout.first.getUsers(); |
693 | int num_users = std::distance(users.begin(), users.end()); |
694 | if (num_users == 1 && mlir::isa<mlir::TF::DTensorLayout>(*users.begin())) |
695 | continue; |
696 | builder.setInsertionPointAfterValue(merged_layout.first); |
697 | // Handles resource and variant as the real shape is embedded in the |
698 | // resource type elements. |
699 | mlir::Type value_type = GetSubtypeOrSelf(merged_layout.first); |
700 | |
701 | if (auto type = value_type.dyn_cast<mlir::TensorType>()) { |
702 | auto layout_op = builder.create<mlir::TF::DTensorLayout>( |
703 | merged_layout.first.getLoc(), merged_layout.first, |
704 | mlir::dtensor::LayoutAttr::get(builder.getContext(), |
705 | merged_layout.second), |
706 | mlir::TF::ShapeAttr::get(builder.getContext(), type)); |
707 | llvm::SmallPtrSet<mlir::Operation*, 4> exception{layout_op}; |
708 | merged_layout.first.replaceAllUsesExcept(layout_op.output(), exception); |
709 | } else { |
710 | mlir::emitError(merged_layout.first.getLoc()) |
711 | << "value type is not TensorType as expected." ; |
712 | } |
713 | } |
714 | |
715 | return mlir::success(); |
716 | } |
717 | |
718 | void GetOperationsNeedingUpdate( |
719 | const llvm::DenseSet<mlir::Value>& is_updated, |
720 | const llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers, |
721 | llvm::DenseSet<mlir::Operation*>& operations) { |
722 | for (auto& value : is_updated) { |
723 | auto uses = consumers.find(value); |
724 | // Some values have no consumers (e.g. outputs of the main function). |
725 | if (uses != consumers.end()) |
726 | for (auto* use : uses->second) |
727 | if (!mlir::isa<mlir::TF::CopyToMeshOp>(use->getOwner())) |
728 | operations.insert(use->getOwner()); |
729 | // If this is an OpResult, also add the op that produces it. |
730 | if (value.isa<mlir::OpResult>() && |
731 | !mlir::isa<mlir::TF::CopyToMeshOp>(value.getDefiningOp())) |
732 | operations.insert(value.getDefiningOp()); |
733 | } |
734 | } |
735 | |
736 | namespace { |
737 | |
738 | // Custom printing class which prints out layouts and ignores DTensorLayout |
739 | // ops and also non registered attributes. |
740 | class LayoutPrinter : public mlir::OpAsmPrinter { |
741 | public: |
742 | explicit LayoutPrinter( |
743 | llvm::raw_ostream& os, |
744 | const llvm::DenseMap<mlir::Value, Layout>& merged_layouts) |
745 | : indent_level_(0), |
746 | os_(os), |
747 | current_location_(0), |
748 | next_argument_(0), |
749 | merged_layouts_(merged_layouts) {} |
750 | |
751 | llvm::raw_ostream& getStream() const override { return os_; } |
752 | |
753 | void printRegionArgument(mlir::BlockArgument arg, |
754 | llvm::ArrayRef<mlir::NamedAttribute> argAttrs, |
755 | bool omitType) override { |
756 | printOperand(arg); |
757 | if (!omitType) { |
758 | os_ << ": " ; |
759 | printType(arg.getType()); |
760 | } |
761 | printOptionalAttrDict(argAttrs, llvm::None); |
762 | } |
763 | |
764 | void printOperand(mlir::Value value) override { printOperand(value, os_); } |
765 | |
766 | /// Print a newline and indent the printer to the start of the current |
767 | /// operation. |
768 | void printNewline() override { |
769 | os_ << "\n" ; |
770 | os_.indent(indent_level_); |
771 | } |
772 | |
773 | // Note that we ignore the parameters printEntryBlockArgs and |
774 | // printBlockTerminators for simplicity. |
775 | void printRegion(mlir::Region& blocks, bool printEntryBlockArgs, |
776 | bool printBlockTerminators, |
777 | bool printEmptyBlock = false) override { |
778 | os_ << " {\n" ; |
779 | for (auto& b : blocks.getBlocks()) print(b); |
780 | os_.indent(indent_level_) << "}" ; |
781 | } |
782 | |
783 | void print(mlir::Block& block) { |
784 | // Each nested block level increases the indent. |
785 | os_.indent(indent_level_) << "%bb(" ; |
786 | for (int i = 0; i < block.getNumArguments(); ++i) { |
787 | if (arguments_.find(block.getArgument(i)) == arguments_.end()) |
788 | arguments_[block.getArgument(i)] = next_argument_++; |
789 | if (i > 0) os_ << ", " ; |
790 | os_ << "%arg" << arguments_[block.getArgument(i)]; |
791 | } |
792 | os_ << "):\n" ; |
793 | indent_level_ += 2; |
794 | for (auto& op : block.getOperations()) print(op); |
795 | indent_level_ -= 2; |
796 | } |
797 | |
798 | // Prints the TF node name from `loc`. |
799 | void printLoc(mlir::Location loc) { |
800 | os_ << " [" << mlir::GetNameFromLoc(loc) << "]" ; |
801 | } |
802 | |
803 | void print(mlir::Operation& op) { |
804 | // Don't print tf.DTensorLayout ops. |
805 | if (mlir::isa<mlir::TF::DTensorLayout>(op)) return; |
806 | |
807 | // Don't print functions with empty bodies. |
808 | if (auto func_op = mlir::dyn_cast<mlir::func::FuncOp>(op)) |
809 | if (func_op.empty()) return; |
810 | |
811 | // Each operation is on its own line, so we start by indenting the |
812 | // the line. |
813 | os_.indent(indent_level_); |
814 | |
815 | // Record a unique identifier for the op (this will be used for printing |
816 | // op results and operands). |
817 | location_[&op] = current_location_++; |
818 | |
819 | // Print the outputs. |
820 | for (int i = 0; i < op.getNumResults(); ++i) { |
821 | if (i > 0) os_ << ", " ; |
822 | printOperand(op.getOpResult(i), os_); |
823 | } |
824 | |
825 | if (op.getNumResults() > 0) os_ << " = " ; |
826 | |
827 | // Some ops have a special printing method, call this if it exists. |
828 | if (auto opInfo = op.getRegisteredInfo()) { |
829 | opInfo->printAssembly(&op, *this, /*defaultDialect=*/"" ); |
830 | printLoc(op.getLoc()); |
831 | os_ << "\n" ; |
832 | return; |
833 | } |
834 | |
835 | // Otherwise we do a generic printing. |
836 | printGenericOp(&op, true); |
837 | printLoc(op.getLoc()); |
838 | |
839 | os_ << "\n" ; |
840 | } |
841 | |
842 | // Print an operand, this could be both the OpResult or a BlockArgument. |
843 | // We also print the layout if it exists and the type. |
844 | void printOperand(mlir::Value value, llvm::raw_ostream& os) override { |
845 | if (auto result = value.dyn_cast<mlir::OpResult>()) { |
846 | // If DTensorLayout ops are already in the module, we need to skip them |
847 | // since we aren't printing them out. |
848 | if (mlir::isa<mlir::TF::DTensorLayout>(result.getDefiningOp())) { |
849 | printOperand(result.getDefiningOp()->getOperand(0)); |
850 | return; |
851 | } |
852 | |
853 | // OpResult are of the format %op_number:%result_number. We elide the |
854 | // result number if there is only one result (the case for most ops). |
855 | os << "%" << location_[result.getDefiningOp()]; |
856 | if (result.getDefiningOp()->getNumResults() > 1) |
857 | os << ":" << result.getResultNumber(); |
858 | } else if (auto argument = value.dyn_cast<mlir::BlockArgument>()) { |
859 | if (arguments_.find(argument) == arguments_.end()) |
860 | arguments_[argument] = next_argument_++; |
861 | os << "%arg" << arguments_[argument]; |
862 | } |
863 | auto layout = merged_layouts_.find(value); |
864 | if (layout != merged_layouts_.end()) { |
865 | os << " \"" ; |
866 | printLayout(layout->second, os); |
867 | os << "\"" ; |
868 | } |
869 | os << " " ; |
870 | printType(value.getType()); |
871 | } |
872 | |
873 | void printLayout(const Layout& layout, llvm::raw_ostream& os) { |
874 | // Layouts are printed with * for an unsharded dim and the mesh dim for a |
875 | // sharded dim. This keeps the layout compact. |
876 | for (int i = 0; i < layout.rank(); ++i) { |
877 | if (i > 0) os << "," ; |
878 | if (Layout::IsUnshardedDimension(layout.sharding_spec(i))) |
879 | os << "*" ; |
880 | else |
881 | os << layout.sharding_spec(i); |
882 | } |
883 | } |
884 | |
885 | // A generic op consists of a name, and any of the following: |
886 | // * arguments, |
887 | // * attributes |
888 | // * regions |
889 | // These are printed out in that order. |
890 | void printGenericOp(mlir::Operation* op, bool printOpName) override { |
891 | if (printOpName) os_ << "\"" << op->getName().getStringRef() << "\"" ; |
892 | os_ << "(" ; |
893 | for (int i = 0; i < op->getNumOperands(); ++i) { |
894 | if (i > 0) os_ << ", " ; |
895 | printOperand(op->getOperand(i), os_); |
896 | } |
897 | os_ << ")" ; |
898 | |
899 | if (!op->getAttrs().empty()) { |
900 | std::vector<mlir::NamedAttribute> filtered; |
901 | for (auto attr : op->getAttrs()) |
902 | if (*attr.getName().str().begin() != '_' && |
903 | attr.getName().str() != "device" ) |
904 | filtered.emplace_back(attr); |
905 | if (!filtered.empty()) { |
906 | os_ << " {" ; |
907 | for (int i = 0; i < filtered.size(); ++i) { |
908 | if (i > 0) os_ << ", " ; |
909 | printNamedAttribute(filtered[i]); |
910 | } |
911 | os_ << "}" ; |
912 | } |
913 | } |
914 | |
915 | if (op->getNumRegions() > 0) { |
916 | os_ << " (" ; |
917 | for (auto& region : op->getRegions()) printRegion(region, false, false); |
918 | os_ << ")" ; |
919 | } |
920 | }; |
921 | |
922 | void printSymbolName(llvm::StringRef symbolRef) override { |
923 | os_ << symbolRef; |
924 | }; |
925 | |
926 | void printNamedAttribute(mlir::NamedAttribute attr) { |
927 | os_ << attr.getName().strref() << " = " ; |
928 | printAttribute(attr.getValue()); |
929 | } |
930 | |
931 | void printAttribute(mlir::Attribute attr) override { attr.print(os_); } |
932 | |
933 | void printType(mlir::Type type) override { type.print(os_); } |
934 | |
935 | // The following functions are part of the printing interface but aren't |
936 | // needed for the compact printing form for Layout printing. |
937 | void printAttributeWithoutType(mlir::Attribute attr) override{}; |
938 | void printSuccessor(mlir::Block* successor) override{}; |
939 | void printSuccessorAndUseList(mlir::Block* successor, |
940 | mlir::ValueRange succOperands) override{}; |
941 | void printOptionalAttrDict( |
942 | llvm::ArrayRef<mlir::NamedAttribute> attrs, |
943 | llvm::ArrayRef<llvm::StringRef> elidedAttrs) override{}; |
944 | void printOptionalAttrDictWithKeyword( |
945 | llvm::ArrayRef<mlir::NamedAttribute> attrs, |
946 | llvm::ArrayRef<llvm::StringRef> elidedAttrs) override{}; |
947 | |
948 | void shadowRegionArgs(mlir::Region& region, |
949 | mlir::ValueRange namesToUse) override{}; |
950 | void printAffineMapOfSSAIds(mlir::AffineMapAttr mapAttr, |
951 | mlir::ValueRange operands) override{}; |
952 | |
953 | void printAffineExprOfSSAIds(mlir::AffineExpr expr, |
954 | mlir::ValueRange dimOperands, |
955 | mlir::ValueRange symOperands) override{}; |
956 | |
957 | private: |
958 | int indent_level_; |
959 | llvm::raw_ostream& os_; |
960 | llvm::DenseMap<mlir::Operation*, int> location_; |
961 | int current_location_; |
962 | llvm::DenseMap<mlir::BlockArgument, int> arguments_; |
963 | int next_argument_; |
964 | const llvm::DenseMap<mlir::Value, Layout>& merged_layouts_; |
965 | }; |
966 | |
967 | // Log the current set of layouts to a file marked by the hash of the input |
968 | // module and the stage. |
969 | void LogLayoutsAndOps(const int stage, const uint64_t module_hash, |
970 | const llvm::DenseMap<mlir::Value, Layout>& merged_layouts, |
971 | mlir::ModuleOp& module) { |
972 | if (module->hasAttr(kDoNotLog) || ((ClientId() != 0) && !LogOnAllTasks())) |
973 | return; |
974 | |
975 | std::string prefix = tensorflow::GetDumpDirFromEnvVar(); |
976 | if (prefix.empty()) return; |
977 | |
978 | auto* env = tensorflow::Env::Default(); |
979 | auto status = env->RecursivelyCreateDir(prefix); |
980 | if (!status.ok()) { |
981 | LOG(WARNING) << "cannot create directory '" + prefix + |
982 | "': " + status.error_message(); |
983 | return; |
984 | } |
985 | |
986 | absl::StrAppend(&prefix, "/layout_propagation_v2_module_" , module_hash, |
987 | "_stage_" , stage, "_" ); |
988 | if (!tensorflow::Env::Default()->CreateUniqueFileName(&prefix, ".mlir" )) { |
989 | LOG(WARNING) << "cannot create unique filename, won't dump MLIR module." ; |
990 | return; |
991 | } |
992 | |
993 | std::unique_ptr<WritableFile> file_writer; |
994 | status = env->NewWritableFile(prefix, &file_writer); |
995 | if (!status.ok()) { |
996 | LOG(WARNING) << "cannot open file '" + prefix + |
997 | "': " + status.error_message(); |
998 | return; |
999 | } |
1000 | |
1001 | // Print the module to a string before writing to the file. |
1002 | std::string txt_module; |
1003 | { |
1004 | llvm::raw_string_ostream os(txt_module); |
1005 | LayoutPrinter printer(os, merged_layouts); |
1006 | module.print(printer); |
1007 | } |
1008 | |
1009 | status = file_writer->Append(txt_module); |
1010 | if (!status.ok()) { |
1011 | LOG(WARNING) << "error writing to file '" + prefix + |
1012 | "': " + status.error_message(); |
1013 | return; |
1014 | } |
1015 | (void)file_writer->Close(); |
1016 | LOG(INFO) << "Dumped MLIR module to " << prefix; |
1017 | } |
1018 | |
1019 | // Canonicalizer and DCE transformation passes may removed ops in the graph and |
1020 | // result in multiple consecutive DTensorLayout ops. Detect all such cases and |
1021 | // replace unnecessary DTensorLayout ops with Identity ops. |
1022 | mlir::LogicalResult ReplaceAuxiliaryDTensorLayoutOpsWithIdentity( |
1023 | mlir::ModuleOp module) { |
1024 | llvm::SmallVector<mlir::TF::DTensorLayout, 4> layout_ops; |
1025 | module.walk([&](mlir::TF::DTensorLayout op) { layout_ops.emplace_back(op); }); |
1026 | |
1027 | for (auto layout_op : llvm::reverse(layout_ops)) { |
1028 | auto input_op = layout_op.input().getDefiningOp(); |
1029 | if (auto input_layout_op = |
1030 | llvm::dyn_cast_or_null<mlir::TF::DTensorLayout>(input_op)) { |
1031 | // Check that layout of input DTensorLayout op is equivalent to |
1032 | // the layout of its connected DTensorLayout op. |
1033 | if (layout_op.layout() != input_layout_op.layout()) |
1034 | return layout_op.emitOpError( |
1035 | "Found inconsistent layout. This should never happen." ); |
1036 | |
1037 | // Replace DTensorLayout op with identity op. |
1038 | mlir::OpBuilder builder(layout_op); |
1039 | auto identity = builder.create<mlir::TF::IdentityOp>( |
1040 | layout_op->getLoc(), layout_op.getType(), layout_op.input()); |
1041 | layout_op.output().replaceAllUsesWith(identity.output()); |
1042 | layout_op.erase(); |
1043 | } |
1044 | } |
1045 | |
1046 | return mlir::success(); |
1047 | } |
1048 | |
1049 | // Inserts/changes DTensorLayout op after IfRegion op and results of then/else |
1050 | // branches to ensure that the return values of IfRegion ops are consistent. |
1051 | // After layout propagation, layouts of return value of tf.IfRegion op, and |
1052 | // layouts of terminators of then/else branches of IfRegion op may be different. |
1053 | // In that case, the layouts of returns values must be merged to a same layout |
1054 | // as return values of IfRegion op and results of then/else branches are |
1055 | // semantically equivalent. |
1056 | mlir::LogicalResult InsertDTensorLayoutForIfRegionOp( |
1057 | const llvm::SmallVectorImpl<mlir::TF::IfRegionOp>& if_ops, |
1058 | mlir::MLIRContext* context) { |
1059 | for (mlir::TF::IfRegionOp if_op : if_ops) { |
1060 | for (mlir::OpResult if_result : if_op.getResults()) { |
1061 | const int result_index = if_result.getResultNumber(); |
1062 | mlir::Value then_branch_result = if_op.then_branch() |
1063 | .front() |
1064 | .getTerminator() |
1065 | ->getOpOperand(result_index) |
1066 | .get(); |
1067 | mlir::Value else_branch_result = if_op.else_branch() |
1068 | .front() |
1069 | .getTerminator() |
1070 | ->getOpOperand(result_index) |
1071 | .get(); |
1072 | |
1073 | auto if_result_layout = |
1074 | llvm::dyn_cast<mlir::TF::DTensorLayout>(*if_result.user_begin()); |
1075 | auto then_result_layout = llvm::dyn_cast<mlir::TF::DTensorLayout>( |
1076 | *then_branch_result.getDefiningOp()); |
1077 | auto else_result_layout = llvm::dyn_cast<mlir::TF::DTensorLayout>( |
1078 | *else_branch_result.getDefiningOp()); |
1079 | llvm::SmallVector<Layout, 4> layouts{if_result_layout.layout(), |
1080 | then_result_layout.layout(), |
1081 | else_result_layout.layout()}; |
1082 | std::set<Layout> layouts_set{layouts.begin(), layouts.end()}; |
1083 | if (layouts_set.size() == 1) continue; |
1084 | |
1085 | absl::optional<Layout> merged_layout; |
1086 | if (mlir::failed( |
1087 | GetMostShardedLayout(layouts, if_op.getLoc(), &merged_layout))) |
1088 | return mlir::failure(); |
1089 | assert(merged_layout); |
1090 | |
1091 | if_result_layout->setAttr( |
1092 | kQualifiedLayoutAttr, |
1093 | mlir::dtensor::LayoutAttr::get(context, *merged_layout)); |
1094 | then_result_layout->setAttr( |
1095 | kQualifiedLayoutAttr, |
1096 | mlir::dtensor::LayoutAttr::get(context, *merged_layout)); |
1097 | else_result_layout->setAttr( |
1098 | kQualifiedLayoutAttr, |
1099 | mlir::dtensor::LayoutAttr::get(context, *merged_layout)); |
1100 | } |
1101 | } |
1102 | return mlir::success(); |
1103 | } |
1104 | |
1105 | // Inserts necessary DTensorRelayout ops so that the layouts for while loops |
1106 | // are correct. |
1107 | // |
1108 | // Due to how while loop layout propagation is done, we may need to fix the |
1109 | // layouts so that the second and beyond step of the loop receive a tensor with |
1110 | // the correct layout. |
1111 | // E.g. |
1112 | // %b = tf.WhileRegion(%a) ({ |
1113 | // %bb0(%arg0): # Cond |
1114 | // %c = tf.A(%arg0) |
1115 | // tf.Yield(%c) |
1116 | // }, { |
1117 | // %bb0(%arg0): # Body |
1118 | // %d = tf.B(%arg0) |
1119 | // tf.Yield(%d) |
1120 | // } |
1121 | // } |
1122 | // %e = tf.C(%b) |
1123 | // |
1124 | // Layout propagation treats the loop body as if it were an inlined function and |
1125 | // does not have a condition which fixes the layout of %d, as return value, to |
1126 | // match the layout of %arg0 (or %a). |
1127 | // |
1128 | // Towards this, we: |
1129 | // 1) Check the layout of %arg0 and see if matches the layout of the input 0 |
1130 | // (%d) of tf.Yield. |
1131 | // 2) If it doesn't match we update the we insert a DTensorRelayout op between |
1132 | // %d and tf.Yield with the correct layout and insert a second |
1133 | // DTensorRelayout op after the loop body. |
1134 | // |
1135 | // NOTE: that it is necessary in general to insert both DTensorRelayout ops, |
1136 | // as opposed to just updating the layout of %d (which would in general be more |
1137 | // efficient) since %d may still be used by other ops in the loop body. |
1138 | // |
1139 | // NOTE: this is not needed for the condition as the output of the condition is |
1140 | // a scalar and therefore always replicated. |
1141 | mlir::LogicalResult InsertRelayoutForWhileLoops( |
1142 | const llvm::SmallVectorImpl<mlir::TF::WhileRegionOp>& while_ops, |
1143 | mlir::OpBuilder& builder) { |
1144 | for (mlir::TF::WhileRegionOp op : while_ops) { |
1145 | // Get the terminator so we can check the output layouts of the loop body. |
1146 | mlir::Operation* yield_op = op.body().front().getTerminator(); |
1147 | if (!mlir::isa<mlir::TF::YieldOp>(yield_op)) |
1148 | return op->emitOpError() << "body terminator is not a Yield op." ; |
1149 | |
1150 | for (int i = 0; i < op.body().getNumArguments(); ++i) { |
1151 | // Inputs should only have one, a DTensorLayout op. |
1152 | mlir::Value argument = op.body().getArgument(i); |
1153 | if (!argument.hasOneUse()) |
1154 | return op.emitOpError() |
1155 | << "body argument " << i << " doesn't have a single use." ; |
1156 | mlir::Operation* input_layout_op = argument.getUses().begin().getUser(); |
1157 | if (!mlir::isa<mlir::TF::DTensorLayout>(input_layout_op)) |
1158 | return op.emitOpError() << "body argument " << i |
1159 | << " is not consumed by a DTensorLayout op." ; |
1160 | const Layout input_layout = |
1161 | mlir::cast<mlir::TF::DTensorLayout>(input_layout_op).layout(); |
1162 | |
1163 | // Inputs to Yield should also be a DTensorLayout op. |
1164 | if (!yield_op->getOperand(i).isa<mlir::OpResult>() || |
1165 | !mlir::isa<mlir::TF::DTensorLayout>( |
1166 | yield_op->getOperand(i).getDefiningOp())) |
1167 | return yield_op->emitOpError() |
1168 | << "argument " << i << " to is not a DTensorLayout op." ; |
1169 | mlir::Operation* output_layout_op = |
1170 | yield_op->getOperand(i).getDefiningOp(); |
1171 | const Layout output_layout = |
1172 | mlir::cast<mlir::TF::DTensorLayout>(output_layout_op).layout(); |
1173 | |
1174 | // If the layouts are equal we have nothing to do. Note that this caches |
1175 | // the case that that input and output are a resource, since the layout |
1176 | // of a resource is fixed. |
1177 | if (input_layout == output_layout) continue; |
1178 | |
1179 | // Insert the first Relayout op (in the loop body). |
1180 | builder.setInsertionPointAfter(output_layout_op); |
1181 | if (!yield_op->getOperand(i).getType().isa<mlir::TensorType>()) |
1182 | return yield_op->emitOpError() |
1183 | << "operand " << i << " does not have TensorType" ; |
1184 | mlir::TF::ShapeAttr global_shape = mlir::TF::ShapeAttr::get( |
1185 | builder.getContext(), |
1186 | yield_op->getOperand(i).getType().cast<mlir::TensorType>()); |
1187 | mlir::TF::RelayoutOp first_relayout = |
1188 | builder.create<mlir::TF::RelayoutOp>( |
1189 | op.getLoc(), yield_op->getOperand(i).getType(), |
1190 | yield_op->getOperand(i), input_layout.ToString()); |
1191 | mlir::TF::DTensorLayout first_layout_op = |
1192 | builder.create<mlir::TF::DTensorLayout>( |
1193 | op.getLoc(), first_relayout.output(), |
1194 | mlir::dtensor::LayoutAttr::get(builder.getContext(), |
1195 | input_layout), |
1196 | global_shape); |
1197 | yield_op->setOperand(i, first_layout_op.output()); |
1198 | |
1199 | // Insert the second relayout op after the loop itself. |
1200 | builder.setInsertionPointAfter(op); |
1201 | mlir::TF::DTensorLayout second_layout_op = |
1202 | builder.create<mlir::TF::DTensorLayout>( |
1203 | op.getLoc(), op->getResult(i), |
1204 | mlir::dtensor::LayoutAttr::get(builder.getContext(), |
1205 | input_layout), |
1206 | global_shape); |
1207 | mlir::TF::RelayoutOp second_relayout = |
1208 | builder.create<mlir::TF::RelayoutOp>( |
1209 | op.getLoc(), second_layout_op.output().getType(), |
1210 | second_layout_op.output(), output_layout.ToString()); |
1211 | op->getResult(i).replaceAllUsesExcept( |
1212 | second_relayout.output(), llvm::SmallPtrSet<mlir::Operation*, 1>{ |
1213 | second_layout_op.getOperation()}); |
1214 | } |
1215 | } |
1216 | return mlir::success(); |
1217 | } |
1218 | |
1219 | // For all constants with multiple usages, clone the constants so that each |
1220 | // constant operation has at most 1 usage. |
1221 | void DuplicateConstants(mlir::ModuleOp module) { |
1222 | llvm::SmallVector<mlir::TF::ConstOp, 4> const_ops; |
1223 | module.walk( |
1224 | [&](mlir::TF::ConstOp const_op) { const_ops.emplace_back(const_op); }); |
1225 | |
1226 | for (mlir::TF::ConstOp const_op : const_ops) { |
1227 | mlir::OpBuilder builder(const_op); |
1228 | auto uses = const_op->getUses(); |
1229 | if (uses.empty()) return; |
1230 | |
1231 | llvm::SmallDenseMap<mlir::Operation*, mlir::OpOperand*> const_use_map; |
1232 | mlir::OpOperand& first_use = *uses.begin(); |
1233 | for (mlir::OpOperand& use : uses) { |
1234 | if (&use == &first_use) continue; |
1235 | |
1236 | mlir::Operation* new_const = builder.clone(*const_op); |
1237 | const_use_map.try_emplace(new_const, &use); |
1238 | } |
1239 | |
1240 | for (const auto& it : const_use_map) it.second->set(it.first->getResult(0)); |
1241 | } |
1242 | } |
1243 | |
1244 | // Find the root(s) values of "current_value" within the cycle, and put it |
1245 | // into "roots". |
1246 | void FindRoot( |
1247 | const llvm::DenseSet<mlir::Value>& is_updated, |
1248 | const mlir::Value& current_value, |
1249 | llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers, |
1250 | llvm::DenseSet<mlir::Value>* roots) { |
1251 | // Standard BFS to find root values of current_value. |
1252 | std::deque<mlir::Value> to_process; |
1253 | to_process.push_back(current_value); |
1254 | |
1255 | llvm::DenseSet<mlir::Value> visited; |
1256 | visited.insert(current_value); |
1257 | |
1258 | while (!to_process.empty()) { |
1259 | int level_size = to_process.size(); |
1260 | for (int UNUSED = 0; UNUSED < level_size; ++UNUSED) { |
1261 | mlir::Value cur_val = to_process.front(); |
1262 | to_process.pop_front(); |
1263 | |
1264 | // Terminating condition, if there is no defining op then this is a root. |
1265 | mlir::Operation* defining_op = cur_val.getDefiningOp(); |
1266 | if (defining_op == nullptr) { |
1267 | roots->insert(current_value); |
1268 | continue; |
1269 | } |
1270 | |
1271 | // Expand out from 'cur_val' one step closer to roots. If there was |
1272 | // no-one one step closer to root, then this is a root. |
1273 | bool is_root = true; |
1274 | for (int i = 0; i < defining_op->getNumOperands(); ++i) { |
1275 | mlir::Value operand = defining_op->getOperand(i); |
1276 | if (operand != cur_val && is_updated.contains(operand)) { |
1277 | is_root = false; |
1278 | |
1279 | if (!visited.contains(operand)) { |
1280 | visited.insert(operand); |
1281 | to_process.push_back(operand); |
1282 | } |
1283 | } |
1284 | } |
1285 | |
1286 | if (is_root) roots->insert(cur_val); |
1287 | } |
1288 | } |
1289 | } |
1290 | |
1291 | // Finds the root value(s) of the values that have layouts cycling back and |
1292 | // forth in an infinite loop during layout propagation and prints the closest TF |
1293 | // op that consumes those root value(s). This allows users and developers to |
1294 | // debug the root cause of layouts that should be changed to prevent infinite |
1295 | // layout propagation. |
1296 | void FindRootsAndEmitError( |
1297 | mlir::ModuleOp& module, |
1298 | llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>> producers, |
1299 | const llvm::DenseSet<mlir::Value>& is_updated) { |
1300 | llvm::DenseSet<mlir::Value> roots; |
1301 | for (auto& value : is_updated) { |
1302 | FindRoot(is_updated, value, producers, &roots); |
1303 | } |
1304 | module.emitOpError() |
1305 | << "Maximum number of layout propagation steps reached. Unable to " |
1306 | "converge to a fixed layout. Please rerun with a higher limit in the " |
1307 | "DTENSOR_LAYOUT_PROPAGATION_MAX_STEPS environment variable." ; |
1308 | for (auto& root : roots) { |
1309 | for (mlir::OpOperand& operand : root.getUses()) { |
1310 | llvm::DenseMap<llvm::StringRef, mlir::Operation*> func_to_caller; |
1311 | llvm::SmallVector<mlir::Value, 4> skipped_values; |
1312 | |
1313 | // For each root value that may need a different layout, find the |
1314 | // closest TF op that consumes it and print it. |
1315 | llvm::SmallVector<mlir::OpOperand*, 4> consuming_operands = |
1316 | TraceUseToNextTFOp(&operand, func_to_caller, &skipped_values); |
1317 | |
1318 | for (mlir::OpOperand* new_operand : consuming_operands) { |
1319 | mlir::Operation* operation = new_operand->getOwner(); |
1320 | mlir::Location loc = operation->getLoc(); |
1321 | operation->emitOpError() << '\n' |
1322 | << "The following op consumes tensors that " |
1323 | "may need a different layout. " |
1324 | "[" |
1325 | << mlir::GetNameFromLoc(loc) << "]" << '\n'; |
1326 | } |
1327 | } |
1328 | } |
1329 | } |
1330 | } // namespace |
1331 | |
1332 | // Runs an iteration of layout propagation, where we merge producer and consumer |
1333 | // requests and then recompute recommended layouts on all operations that |
1334 | // are connected to an updated layout. |
1335 | Status RunOneIteration( |
1336 | llvm::DenseSet<mlir::Value>& is_locked, |
1337 | llvm::DenseSet<mlir::Value>& is_updated, |
1338 | llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request, |
1339 | llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>& |
1340 | consumer_requests, |
1341 | llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers, |
1342 | llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers, |
1343 | llvm::DenseMap<mlir::Value, Layout>& merged_layouts, mlir::ModuleOp& module, |
1344 | const uint64_t module_hash, int* stage) { |
1345 | if (is_updated.empty()) return OkStatus(); |
1346 | // Merge any possibly updated layouts. |
1347 | if (mlir::failed( |
1348 | MergeAndGetUpdatedLayouts(is_locked, is_updated, producer_request, |
1349 | consumer_requests, merged_layouts))) |
1350 | return errors::Internal( |
1351 | "MergeAndGetUpdatedLayouts failed to merge layouts." ); |
1352 | |
1353 | // Compile a list of operations with updated inputs or outputs. |
1354 | llvm::DenseSet<mlir::Operation*> operations_needing_update; |
1355 | GetOperationsNeedingUpdate(is_updated, consumers, operations_needing_update); |
1356 | is_updated.clear(); |
1357 | |
1358 | if (VLOG_IS_ON(2)) { |
1359 | LogLayoutsAndOps(*stage, module_hash, merged_layouts, module); |
1360 | } |
1361 | |
1362 | for (auto* op : operations_needing_update) { |
1363 | if (mlir::failed(UpdateLayoutsForOp(op, producers, merged_layouts, |
1364 | producer_request, consumer_requests, |
1365 | is_updated))) |
1366 | return errors::Internal("UpdateLayoutsForOp failed to update layouts." ); |
1367 | } |
1368 | ++(*stage); |
1369 | return OkStatus(); |
1370 | } |
1371 | |
1372 | // Compares every value's layouts in `merged_a` with the ones in `merged_b`, |
1373 | // and store the values that differ in `changed`. |
1374 | Status CompareMergedLayouts(const llvm::DenseMap<mlir::Value, Layout>& merged_a, |
1375 | const llvm::DenseMap<mlir::Value, Layout>& merged_b, |
1376 | llvm::DenseSet<mlir::Value>& changed) { |
1377 | if (merged_a.size() != merged_b.size()) |
1378 | return errors::Internal( |
1379 | "Both merged_layouts did not have the same number of set layouts." ); |
1380 | for (const auto& value_and_layout : merged_a) { |
1381 | const mlir::Value value = value_and_layout.getFirst(); |
1382 | const Layout& layout = value_and_layout.getSecond(); |
1383 | auto value_and_layout_in_b = merged_b.find(value); |
1384 | if (value_and_layout_in_b == merged_b.end()) |
1385 | return errors::Internal( |
1386 | "Comparing merged_layouts that contain different mlir::Value's." ); |
1387 | if (value_and_layout_in_b->second != layout) { |
1388 | changed.insert(value); |
1389 | } |
1390 | } |
1391 | return OkStatus(); |
1392 | } |
1393 | |
1394 | // MLIR pass that propagates layout for all ops the module. |
1395 | struct DLayoutPropagationPassV2 |
1396 | : public impl::DTensorLayoutPropagationV2Base<DLayoutPropagationPassV2> { |
1397 | void getDependentDialects(mlir::DialectRegistry& registry) const override { |
1398 | registry.insert<mlir::dtensor::DTensorDialect>(); |
1399 | } |
1400 | |
1401 | void runOnOperation() override { |
1402 | mlir::MLIRContext& context = getContext(); |
1403 | mlir::OpBuilder builder(&context); |
1404 | |
1405 | auto module = getOperation(); |
1406 | |
1407 | if (mlir::failed(ReplaceAuxiliaryDTensorLayoutOpsWithIdentity(module))) |
1408 | return signalPassFailure(); |
1409 | |
1410 | // In order to ensure that constant operations with multiple usages with |
1411 | // different consumer layout requests does not lead to replicated constant |
1412 | // tensors, we duplicate all constants to have at most 1 usages. |
1413 | // After SPMD Expansion, these duplicated constants will be merged back |
1414 | // during SCCP pass. |
1415 | DuplicateConstants(module); |
1416 | |
1417 | mlir::func::FuncOp main_func = |
1418 | module.lookupSymbol<mlir::func::FuncOp>("main" ); |
1419 | if (!main_func) return; |
1420 | |
1421 | mlir::Dialect* tf_dialect = |
1422 | context.getLoadedDialect<mlir::TF::TensorFlowDialect>(); |
1423 | |
1424 | // This maps from OpResults to a list of OpOperands that consume this. |
1425 | // Note that this will pass over/through |
1426 | // (Stateful)PartitionedCall and other control flow, directly connecting |
1427 | // producing ops to their consumers in the function. I.e. it presents |
1428 | // flattened/inlined view of the flow of data. |
1429 | llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>> consumers; |
1430 | // Maintain a reverse mapping. |
1431 | llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>> producers; |
1432 | // For each mlir::Value this is what the producer would like to have the |
1433 | // layout be. |
1434 | llvm::DenseMap<mlir::Value, absl::optional<Layout>> producer_request; |
1435 | // For each mlir::Value this is what the consumers would like to have the |
1436 | // layout be. Note the map is in 'parallel' to the consumers map above. |
1437 | llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>> |
1438 | consumer_requests; |
1439 | // Tracks if the layout was updated since last cycle. |
1440 | llvm::DenseSet<mlir::Value> is_updated; |
1441 | // Tracks if the layout is locked. In this case we don't pass consumer |
1442 | // layouts to MergeLayouts. Used for input layouts and user defined layouts. |
1443 | llvm::DenseSet<mlir::Value> is_locked; |
1444 | |
1445 | // Create consumers and producers maps. |
1446 | if (mlir::failed( |
1447 | PopulateConsumersFromModule(&module, tf_dialect, consumers))) |
1448 | return signalPassFailure(); |
1449 | |
1450 | for (auto& consumer : consumers) { |
1451 | for (auto* operand : consumer.second) { |
1452 | if (producers.find(operand) == producers.end()) { |
1453 | producers[operand] = std::vector<mlir::Value>{consumer.first}; |
1454 | } else { |
1455 | producers[operand].emplace_back(consumer.first); |
1456 | } |
1457 | } |
1458 | } |
1459 | |
1460 | // Setup the initial starting conditions for the layout algorithm |
1461 | if (mlir::failed(InsertInitialLayouts( |
1462 | module, main_func, consumers, producers, consumer_requests, |
1463 | producer_request, is_updated, is_locked))) |
1464 | return signalPassFailure(); |
1465 | |
1466 | const auto module_hash = OpHash(module); |
1467 | int stage = 0; |
1468 | |
1469 | llvm::DenseMap<mlir::Value, Layout> merged_layouts; |
1470 | Status status; |
1471 | |
1472 | while (!is_updated.empty() && stage < kLayoutPropagationMaxStages) { |
1473 | ++stage; |
1474 | int steps = 0; |
1475 | // Step 1. Run the layout propagation v2 until convergence or max steps. |
1476 | while (!is_updated.empty() && steps < LayoutPropagationMaxSteps()) { |
1477 | Status status = RunOneIteration( |
1478 | is_locked, is_updated, producer_request, consumer_requests, |
1479 | producers, consumers, merged_layouts, module, module_hash, &steps); |
1480 | if (!status.ok()) { |
1481 | module.emitOpError() << "Failure running iteration." ; |
1482 | return signalPassFailure(); |
1483 | } |
1484 | } |
1485 | if (VLOG_IS_ON(2)) { |
1486 | LOG(INFO) << "Failed to converge in stage " << stage; |
1487 | } |
1488 | // Step 2. If we are here, then we failed to converge, and likely |
1489 | // there is an oscillation of layouts. Detect all the edges that are |
1490 | // changing layouts. |
1491 | llvm::DenseMap<mlir::Value, Layout> merged_layouts_at_max_steps = |
1492 | merged_layouts; |
1493 | llvm::DenseSet<mlir::Value> changed; |
1494 | int previous_change_size = -1; |
1495 | |
1496 | while (changed.size() > previous_change_size) { |
1497 | if (!RunOneIteration(is_locked, is_updated, producer_request, |
1498 | consumer_requests, producers, consumers, |
1499 | merged_layouts, module, module_hash, &steps) |
1500 | .ok()) { |
1501 | module.emitOpError() << "Failure running iteration." ; |
1502 | return signalPassFailure(); |
1503 | } |
1504 | if (!CompareMergedLayouts(merged_layouts_at_max_steps, merged_layouts, |
1505 | changed) |
1506 | .ok()) { |
1507 | module.emitOpError() << "Failure comparing merged layouts." ; |
1508 | return signalPassFailure(); |
1509 | } |
1510 | previous_change_size = changed.size(); |
1511 | } |
1512 | |
1513 | // Step 3. Layouts that haven't changed means they're not part of the |
1514 | // cyclic problem, so freeze them. |
1515 | for (const auto& value_and_layout : merged_layouts) { |
1516 | const mlir::Value value = value_and_layout.getFirst(); |
1517 | if (changed.find(value) == changed.end()) { |
1518 | is_locked.insert(value); |
1519 | } |
1520 | } |
1521 | // Step 4. Any information corresponding to the changed layouts |
1522 | // should be disinfected, we do this by clearing all information |
1523 | // regarding them. |
1524 | for (const mlir::Value changed_value : changed) { |
1525 | producer_request.erase(changed_value); |
1526 | consumer_requests.erase(changed_value); |
1527 | merged_layouts.erase(changed_value); |
1528 | } |
1529 | |
1530 | // Step 5. ComputeLayout again on all the ops linked to the changed |
1531 | // layouts. The next iteration will take this information and merge again. |
1532 | llvm::DenseSet<mlir::Operation*> operations_needing_update; |
1533 | is_updated = changed; |
1534 | GetOperationsNeedingUpdate(is_updated, consumers, |
1535 | operations_needing_update); |
1536 | is_updated.clear(); |
1537 | |
1538 | for (auto* op : operations_needing_update) { |
1539 | if (mlir::failed(UpdateLayoutsForOp(op, producers, merged_layouts, |
1540 | producer_request, consumer_requests, |
1541 | is_updated))) { |
1542 | module.emitOpError() << "Failure in UpdateLayoutsForOp." ; |
1543 | return signalPassFailure(); |
1544 | } |
1545 | } |
1546 | } |
1547 | |
1548 | if (stage >= kLayoutPropagationMaxStages) { |
1549 | FindRootsAndEmitError(module, producers, is_updated); |
1550 | return signalPassFailure(); |
1551 | } |
1552 | |
1553 | if (mlir::failed( |
1554 | CopyLayoutsForSkippedOps(module, tf_dialect, merged_layouts))) |
1555 | return signalPassFailure(); |
1556 | |
1557 | if (VLOG_IS_ON(2)) { |
1558 | LogLayoutsAndOps(stage, module_hash, merged_layouts, module); |
1559 | } |
1560 | |
1561 | if (!AllOpResultsHaveLayouts(&module, tf_dialect, merged_layouts)) |
1562 | return signalPassFailure(); |
1563 | |
1564 | if (mlir::failed(InsertDTensorLayoutOps(builder, merged_layouts))) |
1565 | return signalPassFailure(); |
1566 | |
1567 | // Handle layout of control flow operations. |
1568 | llvm::SmallVector<mlir::TF::IfRegionOp, 4> if_ops; |
1569 | llvm::SmallVector<mlir::TF::WhileRegionOp, 4> while_ops; |
1570 | module.walk([&](mlir::Operation* op) { |
1571 | if (auto if_op = llvm::dyn_cast<mlir::TF::IfRegionOp>(op)) |
1572 | if_ops.emplace_back(if_op); |
1573 | else if (auto while_op = llvm::dyn_cast<mlir::TF::WhileRegionOp>(op)) |
1574 | while_ops.emplace_back(while_op); |
1575 | }); |
1576 | |
1577 | if (mlir::failed(InsertRelayoutForWhileLoops(while_ops, builder))) |
1578 | return signalPassFailure(); |
1579 | |
1580 | if (mlir::failed( |
1581 | InsertDTensorLayoutForIfRegionOp(if_ops, builder.getContext()))) |
1582 | return signalPassFailure(); |
1583 | }; |
1584 | }; |
1585 | |
1586 | } // namespace |
1587 | |
1588 | std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> |
1589 | CreateDTensorLayoutPropagationPassV2() { |
1590 | return std::make_unique<DLayoutPropagationPassV2>(); |
1591 | } |
1592 | |
1593 | } // namespace dtensor |
1594 | } // namespace tensorflow |
1595 | |