1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <string> |
17 | #include <utility> |
18 | |
19 | #include "absl/types/optional.h" |
20 | #include "llvm/ADT/SetVector.h" |
21 | #include "llvm/ADT/SmallVector.h" |
22 | #include "llvm/Support/FormatVariadic.h" |
23 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
24 | #include "mlir/IR/Attributes.h" // from @llvm-project |
25 | #include "mlir/IR/Builders.h" // from @llvm-project |
26 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
27 | #include "mlir/IR/Diagnostics.h" // from @llvm-project |
28 | #include "mlir/IR/Operation.h" // from @llvm-project |
29 | #include "mlir/IR/Value.h" // from @llvm-project |
30 | #include "mlir/IR/Visitors.h" // from @llvm-project |
31 | #include "mlir/Pass/Pass.h" // from @llvm-project |
32 | #include "mlir/Pass/PassManager.h" // from @llvm-project |
33 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
34 | #include "mlir/Transforms/Passes.h" // from @llvm-project |
35 | #include "mlir/Transforms/RegionUtils.h" // from @llvm-project |
36 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
37 | #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" |
38 | #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" |
39 | #include "tensorflow/dtensor/cc/constants.h" |
40 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
41 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
42 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
43 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
44 | #include "tensorflow/dtensor/mlir/op_utils.h" |
45 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
46 | |
47 | namespace tensorflow { |
48 | namespace dtensor { |
49 | |
50 | namespace { |
51 | #define GEN_PASS_DEF_DTENSORMESHPROPAGATION |
52 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
53 | |
54 | // Extracts mesh of `block_arg` by parsing function argument attributes of it's |
55 | // enclosing function. Mesh is inferred either using `tf._layout` or `tf._mesh` |
56 | // attributes. |
57 | mlir::LogicalResult (mlir::BlockArgument block_arg, |
58 | absl::optional<Mesh>* out) { |
59 | auto func_op = mlir::dyn_cast_or_null<mlir::func::FuncOp>( |
60 | block_arg.getOwner()->getParentOp()); |
61 | if (!func_op) { |
62 | return block_arg.getOwner()->getParentOp()->emitOpError( |
63 | "must be enclosed by a function" ); |
64 | } |
65 | auto layout_or_status = ExtractLayoutFromOperand(block_arg); |
66 | if (!layout_or_status.ok()) |
67 | return func_op.emitOpError(layout_or_status.status().error_message()); |
68 | |
69 | if (layout_or_status->has_value()) { |
70 | out->emplace(layout_or_status->value().mesh()); |
71 | return mlir::success(); |
72 | } |
73 | |
74 | auto mesh_attr = func_op.getArgAttrOfType<mlir::StringAttr>( |
75 | block_arg.getArgNumber(), kCustomDeviceMeshAttr); |
76 | if (!mesh_attr) return mlir::success(); |
77 | |
78 | auto mesh_from_block_arg_or_status = |
79 | Mesh::FromString(mesh_attr.getValue().str()); |
80 | if (!mesh_from_block_arg_or_status.ok()) { |
81 | return func_op.emitOpError( |
82 | "Failed during mesh propagation. Op operand has invalid serialized " |
83 | "mesh" ); |
84 | } |
85 | |
86 | out->emplace(mesh_from_block_arg_or_status.value()); |
87 | return mlir::success(); |
88 | } |
89 | |
90 | // Extracts mesh of operation that produces `value`. |
91 | mlir::LogicalResult (mlir::Value value, |
92 | absl::optional<Mesh>* out) { |
93 | auto input_op = value.getDefiningOp(); |
94 | if (!input_op) return mlir::success(); |
95 | |
96 | auto operand_cluster = |
97 | llvm::dyn_cast<mlir::tf_device::ClusterOp>(value.getDefiningOp()); |
98 | if (!operand_cluster) { |
99 | return mlir::emitError(value.getLoc()) |
100 | << "operand must be from different device cluster." ; |
101 | } |
102 | |
103 | auto mesh_or_status = ExtractDeviceMeshFromOp(operand_cluster); |
104 | if (!mesh_or_status.ok()) |
105 | return operand_cluster.emitOpError( |
106 | llvm::formatv("Failed during mesh propagation. {0}" , |
107 | mesh_or_status.status().error_message())); |
108 | |
109 | auto = mesh_or_status.value(); |
110 | if (extracted_mesh) *out = extracted_mesh.value(); |
111 | return mlir::success(); |
112 | } |
113 | |
114 | // Extracts mesh configuration from `operand`. If operand is a function |
115 | // argument, then mesh config is extracted from "tf._mesh" arg attribute of the |
116 | // corresponding func op. If operand is from a preceding op, then mesh |
117 | // configuration is extracted from the enclosing tf_device.Cluster op. |
118 | mlir::LogicalResult ExtractMeshFromOperand( |
119 | const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers, |
120 | mlir::OpOperand* operand, absl::optional<Mesh>* out) { |
121 | mlir::Value operand_value = operand->get(); |
122 | |
123 | const auto check_and_assign_mesh = |
124 | [](mlir::Location loc, absl::optional<Mesh>& mesh, |
125 | absl::optional<Mesh>& operand_mesh) -> mlir::LogicalResult { |
126 | if (mesh && !operand_mesh) { |
127 | operand_mesh.swap(mesh); |
128 | } else if (mesh && operand_mesh && mesh != operand_mesh) { |
129 | return mlir::emitError( |
130 | loc, |
131 | "Error during mesh propagation. Found inconsistent mesh " |
132 | "while inferring mesh from operands." ); |
133 | } |
134 | return mlir::success(); |
135 | }; |
136 | |
137 | // If `operand` is a block argument then extract mesh from `tf._mesh` |
138 | // attribute of the corresponding function argument. |
139 | if (auto block_arg = operand_value.dyn_cast<mlir::BlockArgument>()) { |
140 | if (mlir::failed(ExtractMeshFromBlockArgument(block_arg, out))) |
141 | return mlir::failure(); |
142 | |
143 | if (!out->has_value()) { |
144 | auto it = producers.find(operand); |
145 | if (it != producers.end()) { |
146 | auto producer_values = it->getSecond(); |
147 | absl::optional<Mesh> operand_mesh; |
148 | for (mlir::Value producer_value : producer_values) { |
149 | if (auto arg = producer_value.dyn_cast<mlir::BlockArgument>()) { |
150 | absl::optional<Mesh> mesh; |
151 | if (mlir::failed(ExtractMeshFromBlockArgument(arg, &mesh))) |
152 | return mlir::failure(); |
153 | |
154 | if (mlir::failed(check_and_assign_mesh( |
155 | operand->getOwner()->getLoc(), mesh, operand_mesh))) |
156 | return mlir::failure(); |
157 | } else { |
158 | auto input_cluster = |
159 | producer_value.getDefiningOp() |
160 | ->getParentOfType<mlir::tf_device::ClusterOp>(); |
161 | auto output_from_producing_op = input_cluster.getResult( |
162 | producer_value.cast<mlir::OpResult>().getResultNumber()); |
163 | |
164 | absl::optional<Mesh> mesh; |
165 | if (mlir::failed( |
166 | ExtractMeshFromOpOutput(output_from_producing_op, &mesh))) |
167 | return mlir::failure(); |
168 | |
169 | if (mlir::failed(check_and_assign_mesh( |
170 | operand->getOwner()->getLoc(), mesh, operand_mesh))) |
171 | return mlir::failure(); |
172 | } |
173 | } |
174 | *out = operand_mesh; |
175 | } |
176 | } |
177 | return mlir::success(); |
178 | } |
179 | |
180 | // If `operand` is from another operation, extract mesh from enclosing |
181 | // tf_device.cluster op of the input operation. |
182 | if (mlir::failed(ExtractMeshFromOpOutput(operand_value, out))) |
183 | return mlir::failure(); |
184 | |
185 | return mlir::success(); |
186 | } |
187 | |
188 | // Infers mesh of `cluster` from it's operands. If mesh can be inferred, all |
189 | // operands must have same mesh. |
190 | mlir::LogicalResult InferMeshFromInputs( |
191 | const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers, |
192 | mlir::tf_device::ClusterOp cluster, absl::optional<Mesh>* mesh, |
193 | llvm::SmallVector<mlir::OpOperand*, 8>* inputs_with_inferred_mesh) { |
194 | auto result = mlir::success(); |
195 | |
196 | // If `cluster` wraps a `tf.CopyToMesh` op, do not infer mesh from it's |
197 | // inputs. `tf.CopyToMesh` specifies that all operations following the |
198 | // operation is executed on target device mesh cluster specified by |
199 | // `tf.CopyToMesh`. |
200 | if (llvm::isa<mlir::TF::CopyToMeshOp>(&cluster.GetBody().front())) |
201 | return result; |
202 | |
203 | mlir::visitUsedValuesDefinedAbove( |
204 | cluster.getBody(), cluster.getBody(), [&](mlir::OpOperand* operand) { |
205 | if (mlir::failed(result)) return; |
206 | absl::optional<Mesh> ; |
207 | |
208 | // If inputs to mesh is from DTensorLayout op, then use the mesh |
209 | // extracted from the DTensorLayout op to infer the mesh of the cluster. |
210 | if (auto layout_op = |
211 | llvm::dyn_cast<mlir::TF::DTensorLayout>(operand->getOwner())) { |
212 | auto mesh = layout_op.layout().mesh(); |
213 | extracted_config.emplace(mesh); |
214 | } else { |
215 | auto = |
216 | ExtractMeshFromOperand(producers, operand, &extracted_config); |
217 | if (mlir::failed(extract_result)) { |
218 | result = extract_result; |
219 | return; |
220 | } |
221 | } |
222 | |
223 | // DTensorDevice may create a graph with resource arguments with an |
224 | // empty layout. These layouts of the resource values will be added |
225 | // after layout is inferred from resource update ops. Therefore, ignore |
226 | // DTensorLayout ops will empty layouts. |
227 | if (!extracted_config || extracted_config->IsEmpty()) return; |
228 | |
229 | inputs_with_inferred_mesh->emplace_back(operand); |
230 | if (mesh->has_value() && extracted_config != mesh->value()) { |
231 | result = cluster.emitOpError( |
232 | "failed during mesh propagation. All inputs to " |
233 | "`tf_device.Cluster` must have same mesh configuration." ); |
234 | } |
235 | |
236 | if (!mesh->has_value()) mesh->emplace(extracted_config.value()); |
237 | }); |
238 | |
239 | return result; |
240 | } |
241 | |
242 | // Extracts mesh from function return attributes. If `tf._default_layout` |
243 | // attribute exists, mesh from the default layout is used. If not, mesh from |
244 | // `tf._mesh` attribute is used. |
245 | StatusOr<absl::optional<Mesh>> ( |
246 | const int output_index, mlir::func::FuncOp function) { |
247 | absl::optional<Mesh> function_mesh; |
248 | auto terminator = llvm::cast<mlir::func::ReturnOp>( |
249 | function.getBody().front().getTerminator()); |
250 | TF_ASSIGN_OR_RETURN(auto layout, ExtractLayoutFromFunctionReturnAttr( |
251 | terminator, output_index)); |
252 | |
253 | if (layout) { |
254 | function_mesh.emplace(layout->mesh()); |
255 | return function_mesh; |
256 | } |
257 | |
258 | auto output_mesh_attr = function.getResultAttrOfType<mlir::StringAttr>( |
259 | output_index, kCustomDeviceMeshAttr); |
260 | if (output_mesh_attr) { |
261 | TF_ASSIGN_OR_RETURN(auto mesh, |
262 | Mesh::FromString(output_mesh_attr.getValue().str())); |
263 | function_mesh.emplace(std::move(mesh)); |
264 | } |
265 | return function_mesh; |
266 | } |
267 | |
268 | // Infers mesh from users of `cluster` and records the usages that were used to |
269 | // infer mesh configuration in `consumers_with_mesh`. |
270 | mlir::LogicalResult InferMeshFromConsumers( |
271 | mlir::tf_device::ClusterOp cluster, absl::optional<Mesh>* mesh, |
272 | llvm::SmallVector<mlir::OpOperand*, 8>* consumers_with_mesh) { |
273 | for (auto& use_value : cluster.getOperation()->getUses()) { |
274 | mlir::Operation* consumer = use_value.getOwner(); |
275 | |
276 | // `tf.CopyToMesh` specifies that all operations following the |
277 | // operation are executed on target device mesh cluster specified by |
278 | // `tf.CopyToMesh`. Therefore, if `consumer` operation is `tf.CopyToMesh` |
279 | // do not propagate mesh backwards to `cluster`. |
280 | if (llvm::isa<mlir::TF::CopyToMeshOp>(consumer)) continue; |
281 | |
282 | Mesh ; |
283 | |
284 | // If `cluster` output is output value of a function, then infer mesh using |
285 | // function return value attribute, if it exists. |
286 | if (auto return_op = llvm::dyn_cast<mlir::func::ReturnOp>(consumer)) { |
287 | auto status_or_mesh = ExtractMeshFromFuctionOutput( |
288 | use_value.getOperandNumber(), |
289 | return_op->getParentOfType<mlir::func::FuncOp>()); |
290 | if (!status_or_mesh.ok()) |
291 | return cluster.emitOpError(status_or_mesh.status().ToString()); |
292 | |
293 | auto mesh = status_or_mesh.value(); |
294 | if (mesh) extracted_mesh = *mesh; |
295 | } else { |
296 | // If `cluster` output is input to another cluster/op then infer mesh from |
297 | // the consumer operation. |
298 | auto consumer_cluster = |
299 | consumer->getParentOfType<mlir::tf_device::ClusterOp>(); |
300 | if (!consumer_cluster) { |
301 | return cluster.emitOpError( |
302 | "failed to propagate mesh information. All operations must be " |
303 | "enclosed inside a tf_device.cluster op." ); |
304 | } |
305 | |
306 | auto mesh_or_status = ExtractDeviceMeshFromOp(consumer_cluster); |
307 | if (!mesh_or_status.ok()) |
308 | return cluster.emitOpError(mesh_or_status.status().error_message()); |
309 | |
310 | auto consumer_mesh = mesh_or_status.value(); |
311 | if (!consumer_mesh) continue; |
312 | |
313 | extracted_mesh = consumer_mesh.value(); |
314 | } |
315 | |
316 | if (extracted_mesh.IsEmpty()) continue; |
317 | |
318 | if (mesh->has_value() && extracted_mesh != mesh->value()) { |
319 | return cluster.emitOpError( |
320 | "failed to propagate mesh information. Mesh for op is ambiguous as " |
321 | "consumers have different mesh attributes" ); |
322 | } |
323 | |
324 | consumers_with_mesh->emplace_back(&use_value); |
325 | if (!mesh->has_value()) mesh->emplace(std::move(extracted_mesh)); |
326 | } |
327 | return mlir::success(); |
328 | } |
329 | |
330 | // Infers default mesh of function given it's inputs and outputs. Function has a |
331 | // default mesh if all its inputs/outputs have valus assigned to the same mesh. |
332 | mlir::LogicalResult InferFunctionDefaultMesh( |
333 | const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers, |
334 | mlir::func::FuncOp function, mlir::OpBuilder* builder, |
335 | absl::optional<mlir::StringAttr>* inferred_default_mesh) { |
336 | auto terminator = function.getCallableRegion()->front().getTerminator(); |
337 | for (auto& result_value : terminator->getOpOperands()) { |
338 | auto result_defining_op = result_value.get().getDefiningOp(); |
339 | if (!result_defining_op) continue; |
340 | |
341 | auto result_cluster = |
342 | llvm::cast<mlir::tf_device::ClusterOp>(result_defining_op); |
343 | auto result_mesh = |
344 | result_cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr); |
345 | if (!result_mesh) continue; |
346 | |
347 | if (inferred_default_mesh->has_value() && |
348 | inferred_default_mesh->value() != result_mesh) { |
349 | inferred_default_mesh->reset(); |
350 | return mlir::success(); |
351 | } |
352 | inferred_default_mesh->emplace(result_mesh); |
353 | } |
354 | |
355 | absl::optional<Mesh> inferred_mesh_from_args; |
356 | for (auto function_arg : function.getArguments()) { |
357 | auto uses = function_arg.getUses(); |
358 | if (uses.empty()) { |
359 | if (mlir::failed(ExtractMeshFromBlockArgument(function_arg, |
360 | &inferred_mesh_from_args))) |
361 | return mlir::failure(); |
362 | } else { |
363 | auto operand = uses.begin().getOperand(); |
364 | if (mlir::failed(ExtractMeshFromOperand(producers, operand, |
365 | &inferred_mesh_from_args))) |
366 | return mlir::failure(); |
367 | } |
368 | if (!inferred_mesh_from_args) continue; |
369 | |
370 | std::string mesh_str = inferred_mesh_from_args->ToString(); |
371 | if (inferred_default_mesh->has_value() && |
372 | inferred_default_mesh->value().getValue().str() != mesh_str) { |
373 | inferred_default_mesh->reset(); |
374 | return mlir::success(); |
375 | } |
376 | |
377 | inferred_default_mesh->emplace(builder->getStringAttr(std::move(mesh_str))); |
378 | } |
379 | return mlir::success(); |
380 | } |
381 | |
382 | // Annotates `tf._mesh` attribute to argument of `function` with |
383 | // string of `mesh`. |
384 | void AnnotateFunctionArgumentsWithMeshInformation( |
385 | const Mesh& mesh, |
386 | const llvm::SmallVector<mlir::OpOperand*, 8>& input_values_from_mesh, |
387 | mlir::func::FuncOp function, mlir::OpBuilder* builder) { |
388 | for (auto value : input_values_from_mesh) { |
389 | function.setArgAttr(value->getOperandNumber(), kCustomDeviceMeshAttr, |
390 | builder->getStringAttr(mesh.ToString())); |
391 | } |
392 | } |
393 | |
394 | // Annotates return value attributes of `function_to_annotate` with mesh |
395 | // information parsed from usages of the function. `callsite_operation` is |
396 | // callable op whose function definition is `function_to_annotate`. |
397 | mlir::LogicalResult AnnotateFunctionReturnValuesWithMeshInformation( |
398 | const llvm::SmallVector<mlir::OpOperand*, 8>& return_values_from_mesh, |
399 | mlir::Operation* callsite_operation, |
400 | mlir::func::FuncOp function_to_annotate, mlir::OpBuilder* builder) { |
401 | for (auto value : return_values_from_mesh) { |
402 | absl::optional<mlir::StringAttr> result_mesh_attribute; |
403 | if (llvm::isa<mlir::func::ReturnOp>(value->getOwner())) { |
404 | auto parent_function = |
405 | callsite_operation->getParentOfType<mlir::func::FuncOp>(); |
406 | auto function_result_layout = |
407 | parent_function.getResultAttrOfType<mlir::StringAttr>( |
408 | value->getOperandNumber(), kCustomDefaultLayoutAttr); |
409 | if (function_result_layout) { |
410 | auto layout_or_status = |
411 | Layout::FromString(function_result_layout.getValue().str()); |
412 | if (!layout_or_status.ok()) |
413 | return parent_function.emitOpError( |
414 | layout_or_status.status().error_message()); |
415 | |
416 | result_mesh_attribute.emplace( |
417 | builder->getStringAttr(layout_or_status->mesh().ToString())); |
418 | } else { |
419 | auto function_result_mesh = |
420 | parent_function.getResultAttrOfType<mlir::StringAttr>( |
421 | value->getOperandNumber(), kCustomDeviceMeshAttr); |
422 | if (function_result_mesh) |
423 | result_mesh_attribute.emplace(function_result_mesh); |
424 | } |
425 | } else { |
426 | auto op_mesh = |
427 | value->getOwner()->getAttrOfType<mlir::StringAttr>(kMeshAttr); |
428 | if (op_mesh) result_mesh_attribute.emplace(std::move(op_mesh)); |
429 | } |
430 | |
431 | if (result_mesh_attribute) |
432 | function_to_annotate.setResultAttr( |
433 | value->get().cast<mlir::OpResult>().getResultNumber(), |
434 | kCustomDeviceMeshAttr, result_mesh_attribute.value()); |
435 | } |
436 | return mlir::success(); |
437 | } |
438 | |
439 | // MLIR pass that propagates mesh information to tf_device.Cluster ops. |
440 | struct DTensorMeshPropagation |
441 | : public impl::DTensorMeshPropagationBase<DTensorMeshPropagation> { |
442 | void runOnOperation() override { |
443 | mlir::MLIRContext& context = getContext(); |
444 | mlir::OpBuilder builder(&context); |
445 | auto module = getOperation(); |
446 | mlir::func::FuncOp main_func = |
447 | module.lookupSymbol<mlir::func::FuncOp>("main" ); |
448 | if (!main_func) return; |
449 | |
450 | mlir::Dialect* tf_dialect = |
451 | context.getLoadedDialect<mlir::TF::TensorFlowDialect>(); |
452 | |
453 | // This maps from OpResults to a list of OpOperands that consume this. |
454 | // Note that this will pass over/through |
455 | // (Stateful)PartitionedCall and other control flow, directly connecting |
456 | // producing ops to their consumers in the function. I.e. it presents |
457 | // flattened/inlined view of the flow of data. |
458 | llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>> consumers; |
459 | // Maintain a reverse mapping. Note that for controlflow operations like |
460 | // tf.If op, there may be multiple producers for a mlir::Value. |
461 | llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>> producers; |
462 | |
463 | // Create consumers and producers maps. |
464 | if (mlir::failed( |
465 | PopulateConsumersFromModule(&module, tf_dialect, consumers))) |
466 | return signalPassFailure(); |
467 | |
468 | for (auto& consumer : consumers) { |
469 | for (auto* operand : consumer.second) { |
470 | producers[operand].emplace_back(consumer.first); |
471 | } |
472 | } |
473 | |
474 | bool mesh_changed = true; |
475 | while (mesh_changed) { |
476 | mesh_changed = false; |
477 | if (mlir::failed( |
478 | PropagateMesh(producers, main_func, &builder, &mesh_changed))) |
479 | return signalPassFailure(); |
480 | } |
481 | } |
482 | |
483 | // Propagates and sets `_mesh` attributes to all clusters inside `function` if |
484 | // possible. |
485 | mlir::LogicalResult PropagateMesh( |
486 | const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& |
487 | producers, |
488 | mlir::func::FuncOp, mlir::OpBuilder* builder, bool* mesh_changed); |
489 | |
490 | // Infers mesh of `cluster` from its input operations. |
491 | mlir::LogicalResult PropagateMeshFromInputs( |
492 | const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& |
493 | producers, |
494 | mlir::tf_device::ClusterOp cluster, mlir::OpBuilder* builder, |
495 | bool* mesh_changed); |
496 | |
497 | // Infers mesh of `cluster` from its consuming operations. |
498 | mlir::LogicalResult PropagateMeshFromConsumers( |
499 | const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& |
500 | producers, |
501 | mlir::tf_device::ClusterOp cluster, mlir::OpBuilder* builder, |
502 | bool* mesh_changed); |
503 | |
504 | // Assigns function default mesh to clusters with no mesh specified. Note that |
505 | // function has default mesh if all its dtensor inputs/outputs are assigned to |
506 | // a single mesh. |
507 | mlir::LogicalResult PropagateDefaultMeshToUnAssignedClusters( |
508 | const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& |
509 | producers, |
510 | mlir::func::FuncOp, mlir::OpBuilder* builder, bool* mesh_changed); |
511 | }; |
512 | |
513 | mlir::LogicalResult |
514 | DTensorMeshPropagation::PropagateDefaultMeshToUnAssignedClusters( |
515 | const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers, |
516 | mlir::func::FuncOp function, mlir::OpBuilder* builder, bool* mesh_changed) { |
517 | absl::optional<mlir::StringAttr> mesh; |
518 | if (mlir::failed( |
519 | InferFunctionDefaultMesh(producers, function, builder, &mesh))) |
520 | return mlir::failure(); |
521 | |
522 | llvm::SmallVector<mlir::tf_device::ClusterOp, 4> clusters_without_mesh; |
523 | auto walk_result = function.walk([&](mlir::tf_device::ClusterOp cluster) { |
524 | auto mesh_or_status = ExtractDeviceMeshFromOp(cluster); |
525 | if (!mesh_or_status.ok()) { |
526 | cluster.GetBody().front().emitOpError( |
527 | mesh_or_status.status().error_message()); |
528 | return mlir::WalkResult::interrupt(); |
529 | } |
530 | |
531 | const auto& mesh = mesh_or_status.value(); |
532 | if (mesh.has_value()) return mlir::WalkResult::advance(); |
533 | |
534 | clusters_without_mesh.emplace_back(cluster); |
535 | return mlir::WalkResult::advance(); |
536 | }); |
537 | |
538 | if (walk_result.wasInterrupted()) return mlir::failure(); |
539 | |
540 | if (!mesh.has_value()) return mlir::success(); |
541 | |
542 | // Set function default mesh to cluster with unspecified mesh. |
543 | for (auto cluster_without_mesh : clusters_without_mesh) { |
544 | *mesh_changed = true; |
545 | cluster_without_mesh->setAttr(kMeshAttr, mesh.value()); |
546 | } |
547 | |
548 | return mlir::success(); |
549 | } |
550 | |
551 | mlir::LogicalResult DTensorMeshPropagation::PropagateMeshFromInputs( |
552 | const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers, |
553 | mlir::tf_device::ClusterOp cluster, mlir::OpBuilder* builder, |
554 | bool* mesh_changed) { |
555 | // If operation inside a mesh cluster is not a callable operation and |
556 | // mesh is already specified on a cluster, do nothing. |
557 | auto inner_func = MaybeFindFunction(&cluster.GetBody().front()); |
558 | auto cluster_mesh = cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr); |
559 | if (!inner_func && cluster_mesh) return mlir::success(); |
560 | |
561 | // If mesh of `cluster` is not specified, infer mesh using inputs of mesh |
562 | // cluster. |
563 | absl::optional<Mesh> ; |
564 | llvm::SmallVector<mlir::OpOperand*, 8> inputs_with_inferred_mesh; |
565 | if (failed(InferMeshFromInputs(producers, cluster, &extracted_mesh, |
566 | &inputs_with_inferred_mesh))) { |
567 | return mlir::failure(); |
568 | } |
569 | |
570 | // If operation include 'cluster` is a function call, annotate input and |
571 | // output mesh of `cluster` using function argument and return value |
572 | // attributes, then recursively propagate mesh of the function definition. |
573 | if (inner_func) { |
574 | // All inputs to cluster must be from the same mesh. If input mesh to |
575 | // callable operation is inferred, then annotated the input mesh to |
576 | // function argument attribute so that this information can be used to |
577 | // infer mesh of ops inside `inner_func`. |
578 | if (extracted_mesh.has_value()) { |
579 | AnnotateFunctionArgumentsWithMeshInformation(extracted_mesh.value(), |
580 | inputs_with_inferred_mesh, |
581 | inner_func.value(), builder); |
582 | } |
583 | |
584 | // Recursively propagate mesh to clusters in function definition of |
585 | // `inner_func`. |
586 | if (mlir::failed(PropagateMesh(producers, inner_func.value(), builder, |
587 | mesh_changed))) |
588 | return mlir::failure(); |
589 | |
590 | // Once all clusters inside `inner_func` callable has been set, now we can |
591 | // infer mesh of `cluster`. That is, mesh of call site operation is equal |
592 | // to mesh of return values of the function. |
593 | absl::optional<mlir::StringAttr> function_mesh; |
594 | if (mlir::failed(InferFunctionDefaultMesh(producers, inner_func.value(), |
595 | builder, &function_mesh))) |
596 | return mlir::failure(); |
597 | |
598 | if (function_mesh && !cluster_mesh) { |
599 | *mesh_changed = true; |
600 | cluster->setAttr(kMeshAttr, function_mesh.value()); |
601 | } |
602 | } else if (!cluster_mesh && extracted_mesh.has_value()) { |
603 | *mesh_changed = true; |
604 | cluster->setAttr(kMeshAttr, |
605 | builder->getStringAttr(extracted_mesh->ToString())); |
606 | } |
607 | return mlir::success(); |
608 | } |
609 | |
610 | // Set mesh of `cluster`, inferring mesh from consumer operations of `cluster`. |
611 | mlir::LogicalResult DTensorMeshPropagation::PropagateMeshFromConsumers( |
612 | const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers, |
613 | mlir::tf_device::ClusterOp cluster, mlir::OpBuilder* builder, |
614 | bool* mesh_changed) { |
615 | mlir::Operation* op_inside_cluster = &cluster.GetBody().front(); |
616 | auto inner_func = MaybeFindFunction(op_inside_cluster); |
617 | auto cluster_mesh = cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr); |
618 | // If mesh is already set, then do nothing. |
619 | if (!inner_func && cluster_mesh) return mlir::success(); |
620 | |
621 | // Infer mesh of `cluster` from its output usages. |
622 | absl::optional<Mesh> ; |
623 | llvm::SmallVector<mlir::OpOperand*, 8> consumers_with_mesh_information; |
624 | if (failed(InferMeshFromConsumers(cluster, &extracted_mesh_from_consumers, |
625 | &consumers_with_mesh_information))) |
626 | return mlir::failure(); |
627 | |
628 | // If operation inside mesh cluster is a function callsite operation, |
629 | // then propagate mesh of the function recursively. |
630 | if (inner_func) { |
631 | if (mlir::failed(AnnotateFunctionReturnValuesWithMeshInformation( |
632 | consumers_with_mesh_information, op_inside_cluster, |
633 | inner_func.value(), builder))) |
634 | return mlir::failure(); |
635 | |
636 | if (mlir::failed(PropagateMesh(producers, inner_func.value(), builder, |
637 | mesh_changed))) |
638 | return mlir::failure(); |
639 | |
640 | absl::optional<mlir::StringAttr> function_mesh; |
641 | if (mlir::failed(InferFunctionDefaultMesh(producers, inner_func.value(), |
642 | builder, &function_mesh))) |
643 | return mlir::failure(); |
644 | |
645 | if (function_mesh && !cluster_mesh) { |
646 | *mesh_changed = true; |
647 | cluster->setAttr(kMeshAttr, function_mesh.value()); |
648 | } |
649 | } else if (extracted_mesh_from_consumers && !cluster_mesh) { |
650 | *mesh_changed = true; |
651 | cluster->setAttr(kMeshAttr, builder->getStringAttr( |
652 | extracted_mesh_from_consumers->ToString())); |
653 | } |
654 | return mlir::success(); |
655 | } |
656 | |
657 | // Propagates mesh information to all `tf_device.Cluster` ops in `function`. If |
658 | // `function` includes callable ops, then recursively traverse the function |
659 | // definition to propagate mesh information using input operands and consuming |
660 | // result ops. Note that at current stage of graph optimization, |
661 | // tf_device.cluster ops are enclosing a single operation. |
662 | mlir::LogicalResult DTensorMeshPropagation::PropagateMesh( |
663 | const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers, |
664 | mlir::func::FuncOp function, mlir::OpBuilder* builder, bool* mesh_changed) { |
665 | // Iterate clusters in topological order propagating mesh from operations' |
666 | // inputs. |
667 | llvm::SmallVector<mlir::tf_device::ClusterOp, 8> cluster_ops; |
668 | for (auto cluster : function.getOps<mlir::tf_device::ClusterOp>()) { |
669 | cluster_ops.emplace_back(cluster); |
670 | |
671 | if (mlir::failed( |
672 | PropagateMeshFromInputs(producers, cluster, builder, mesh_changed))) |
673 | return mlir::failure(); |
674 | } |
675 | |
676 | // Iterate clusters in reverse topological order and propagate mesh from |
677 | // consumers. |
678 | for (auto cluster : llvm::reverse(cluster_ops)) { |
679 | if (mlir::failed(PropagateMeshFromConsumers(producers, cluster, builder, |
680 | mesh_changed))) |
681 | return mlir::failure(); |
682 | } |
683 | |
684 | if (mlir::failed(PropagateDefaultMeshToUnAssignedClusters( |
685 | producers, function, builder, mesh_changed))) |
686 | return mlir::failure(); |
687 | |
688 | return mlir::success(); |
689 | } |
690 | |
691 | } // namespace |
692 | |
693 | std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> |
694 | CreateDTensorMeshPropagationPass() { |
695 | return std::make_unique<DTensorMeshPropagation>(); |
696 | } |
697 | |
698 | } // namespace dtensor |
699 | } // namespace tensorflow |
700 | |