1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <string> |
17 | |
18 | #include "llvm/Support/Casting.h" |
19 | #include "llvm/Support/FormatVariadic.h" |
20 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
21 | #include "mlir/IR/Attributes.h" // from @llvm-project |
22 | #include "mlir/IR/Builders.h" // from @llvm-project |
23 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
24 | #include "mlir/IR/Operation.h" // from @llvm-project |
25 | #include "mlir/IR/Visitors.h" // from @llvm-project |
26 | #include "mlir/Pass/Pass.h" // from @llvm-project |
27 | #include "mlir/Pass/PassManager.h" // from @llvm-project |
28 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
29 | #include "mlir/Transforms/Passes.h" // from @llvm-project |
30 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
31 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
32 | #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" |
33 | #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" |
34 | #include "tensorflow/dtensor/cc/constants.h" |
35 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
36 | #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h" |
37 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
38 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
39 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
40 | |
41 | namespace tensorflow { |
42 | namespace dtensor { |
43 | |
44 | namespace { |
45 | #define GEN_PASS_DEF_DTENSOROPTODEVICECLUSTER |
46 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
47 | |
48 | // Extracts mesh config from the Op. |
49 | // We currently hard extract mesh information from all the args and assume they |
50 | // are the same. This should not be the case when we have multiple functions. |
51 | mlir::LogicalResult WrapDeviceCluster(mlir::OpBuilder *builder, |
52 | mlir::Operation *op) { |
53 | // Create new tf_device.cluster op wrapping a single operation. |
54 | builder->setInsertionPoint(op); |
55 | auto cluster = builder->create<mlir::tf_device::ClusterOp>( |
56 | op->getLoc(), op->getResultTypes()); |
57 | if (auto layout_op = llvm::dyn_cast<mlir::TF::DTensorLayout>(op)) { |
58 | cluster->setAttr(kMeshAttr, builder->getStringAttr( |
59 | layout_op.layout().mesh().ToString())); |
60 | } else if (auto copy_to_mesh = llvm::dyn_cast<mlir::TF::CopyToMeshOp>(op)) { |
61 | const std::string layout_string = copy_to_mesh.layout().str(); |
62 | auto layout_or = Layout::FromString(layout_string); |
63 | if (!layout_or.ok()) |
64 | return op->emitOpError( |
65 | llvm::formatv("Found tf.CopyToMesh Op with unparsable layout : {0}" , |
66 | layout_string)); |
67 | |
68 | cluster->setAttr(kMeshAttr, |
69 | builder->getStringAttr(layout_or->mesh().ToString())); |
70 | } else { |
71 | // If mesh configuration can be inferred from the op directly, use the mesh |
72 | // information from op attribute directly. If op is not annotated with mesh |
73 | // information, then mesh will be inferred in following |
74 | // DTensorMeshPropagation pass and will be inferred from consumers or |
75 | // operands. |
76 | auto status_or_mesh = ExtractDeviceMeshFromOp(op); |
77 | |
78 | if (!status_or_mesh.ok()) |
79 | return op->emitOpError( |
80 | llvm::formatv("failed to wrap to device cluster. {0}" , |
81 | status_or_mesh.status().error_message())); |
82 | |
83 | const auto mesh_config = status_or_mesh.value(); |
84 | if (mesh_config) |
85 | cluster->setAttr(kMeshAttr, |
86 | builder->getStringAttr(mesh_config->ToString())); |
87 | } |
88 | |
89 | op->replaceAllUsesWith(cluster); |
90 | |
91 | cluster.getBody().push_back(new mlir::Block); |
92 | |
93 | builder->setInsertionPointToEnd(&cluster.GetBody()); |
94 | builder->create<mlir::tf_device::ReturnOp>(op->getLoc(), op->getResults()); |
95 | |
96 | // Move `op` inside newly created `ClusterOp`. |
97 | op->moveBefore(cluster.GetBody().getTerminator()); |
98 | |
99 | return mlir::success(); |
100 | } |
101 | |
102 | // MLIR pass that wraps tf_device.cluster op to every TF op. |
103 | struct DTensorOpToDeviceClusterPass |
104 | : public impl::DTensorOpToDeviceClusterBase<DTensorOpToDeviceClusterPass> { |
105 | void getDependentDialects(mlir::DialectRegistry ®istry) const override { |
106 | registry.insert<mlir::dtensor::DTensorDialect>(); |
107 | registry.insert<mlir::tf_device::TensorFlowDeviceDialect>(); |
108 | } |
109 | |
110 | void runOnOperation() override { |
111 | mlir::MLIRContext &context = getContext(); |
112 | mlir::OpBuilder op_builder(&context); |
113 | mlir::Dialect *tf = |
114 | getContext().getLoadedDialect<mlir::TF::TensorFlowDialect>(); |
115 | |
116 | auto walk_result = getOperation().walk([&](mlir::Operation *operation) { |
117 | const auto op_dialect = operation->getDialect(); |
118 | // Only TF dialects are supported for layout propagation. |
119 | if (op_dialect != tf) return mlir::WalkResult::advance(); |
120 | |
121 | // For control flow operations, tf.yield ops exists and should not be |
122 | // wrapped to tf_device.cluster as the op does not need to be transformed |
123 | // in SPMD expansion and tf.If/tf.While op require all ops to terminate |
124 | // with tf.Yield op. Wrapping yield op in tf_device.cluster invalidates |
125 | // this invariant. |
126 | if (llvm::isa<mlir::TF::YieldOp>(operation)) |
127 | return mlir::WalkResult::advance(); |
128 | |
129 | if (mlir::failed(WrapDeviceCluster(&op_builder, operation))) |
130 | return mlir::WalkResult::interrupt(); |
131 | return mlir::WalkResult::advance(); |
132 | }); |
133 | |
134 | if (walk_result.wasInterrupted()) signalPassFailure(); |
135 | } |
136 | }; |
137 | |
138 | } // namespace |
139 | |
140 | std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> |
141 | CreateDTensorOpToDeviceClusterPass() { |
142 | return std::make_unique<DTensorOpToDeviceClusterPass>(); |
143 | } |
144 | |
145 | } // namespace dtensor |
146 | } // namespace tensorflow |
147 | |