1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <string> |
17 | |
18 | #include "llvm/ADT/STLExtras.h" |
19 | #include "llvm/Support/Casting.h" |
20 | #include "llvm/Support/FormatVariadic.h" |
21 | #include "mlir/IR/Block.h" // from @llvm-project |
22 | #include "mlir/IR/Builders.h" // from @llvm-project |
23 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
24 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
25 | #include "mlir/IR/Operation.h" // from @llvm-project |
26 | #include "mlir/IR/UseDefLists.h" // from @llvm-project |
27 | #include "mlir/IR/Value.h" // from @llvm-project |
28 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
29 | #include "mlir/Transforms/Passes.h" // from @llvm-project |
30 | #include "mlir/Transforms/RegionUtils.h" // from @llvm-project |
31 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" |
32 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
33 | #include "tensorflow/dtensor/cc/constants.h" |
34 | #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h" |
35 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
36 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
37 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
38 | |
39 | namespace tensorflow { |
40 | namespace dtensor { |
41 | |
42 | namespace { |
43 | #define GEN_PASS_DEF_DTENSORHANDLECROSSCLUSTERDEPENDENCIES |
44 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
45 | |
46 | constexpr char kMissingMeshErrorMsg[] = |
47 | "Failed to extract mesh for DTensorHandleCrossClusterDependencies pass. " |
48 | "All clusters must have specified mesh." ; |
49 | |
50 | constexpr char kInvalidTensorTransferErrorMsg[] = |
51 | "CopyToMeshOp must be used to send data across mesh." ; |
52 | |
53 | constexpr char kInvalidLayoutMsg[] = |
54 | "found CopyToMesh with invalid layout. Found layout {0}." ; |
55 | |
56 | // Extracts mesh from `cluster`. |
57 | mlir::LogicalResult (mlir::tf_device::ClusterOp cluster, |
58 | Mesh* mesh_output) { |
59 | auto mesh_or_status = ExtractDeviceMeshFromOp(cluster); |
60 | if (!mesh_or_status.ok()) return cluster.emitOpError(kMissingMeshErrorMsg); |
61 | |
62 | const auto& mesh_or_null = mesh_or_status.value(); |
63 | if (!mesh_or_null.has_value()) |
64 | return cluster.emitOpError(kMissingMeshErrorMsg); |
65 | |
66 | *mesh_output = mesh_or_null.value(); |
67 | return mlir::success(); |
68 | } |
69 | |
70 | // Returns const op if `op` is a const op or DTensorLayoutOp with Const op as |
71 | // input. |
72 | mlir::Operation* GetConstOp(mlir::Operation* op) { |
73 | if (llvm::isa<mlir::TF::ConstOp>(op)) return op; |
74 | |
75 | if (auto layout = llvm::dyn_cast<mlir::TF::DTensorLayout>(op)) { |
76 | mlir::Operation* input_op = layout.input().getDefiningOp(); |
77 | if (input_op && llvm::isa<mlir::TF::ConstOp>(input_op)) return input_op; |
78 | } |
79 | return nullptr; |
80 | } |
81 | |
82 | // Creates a clone of `const_op` at the beginning of `cluster` body region and |
83 | // set the output value of cloned op replace output of CopyToMesh op within |
84 | // `cluster`. |
85 | mlir::LogicalResult CloneOpToCluster(mlir::Operation* const_op, |
86 | mlir::tf_device::ClusterOp cluster, |
87 | mlir::OpOperand* operand) { |
88 | auto copy_to_mesh = |
89 | llvm::dyn_cast<mlir::TF::CopyToMeshOp>(operand->getOwner()); |
90 | assert(copy_to_mesh); |
91 | const std::string layout_attr = copy_to_mesh.layout().str(); |
92 | StatusOr<Layout> layout = Layout::FromString(layout_attr); |
93 | if (!layout.ok()) |
94 | return copy_to_mesh.emitOpError( |
95 | llvm::formatv(kInvalidLayoutMsg, layout_attr)); |
96 | |
97 | mlir::OpBuilder builder(&cluster.GetBody().front()); |
98 | mlir::Operation* cloned_op = builder.clone(*const_op); |
99 | mlir::TensorType type = |
100 | cloned_op->getResult(0).getType().cast<mlir::TensorType>(); |
101 | auto layout_op = builder.create<mlir::TF::DTensorLayout>( |
102 | const_op->getLoc(), cloned_op->getResult(0), |
103 | mlir::dtensor::LayoutAttr::get(builder.getContext(), *layout), |
104 | mlir::TF::ShapeAttr::get(builder.getContext(), type)); |
105 | |
106 | copy_to_mesh.output().replaceUsesWithIf( |
107 | layout_op.output(), [&](mlir::OpOperand& operand) { |
108 | return cluster.getOperation()->isProperAncestor(operand.getOwner()); |
109 | }); |
110 | |
111 | if (copy_to_mesh->getUsers().empty()) copy_to_mesh.erase(); |
112 | |
113 | return mlir::success(); |
114 | } |
115 | |
116 | mlir::LogicalResult GetInputProducingValue(mlir::OpOperand& operand, |
117 | mlir::Value* val_output) { |
118 | auto input_value = operand.get().dyn_cast<mlir::OpResult>(); |
119 | if (!input_value) return mlir::success(); |
120 | |
121 | auto input_cluster = |
122 | llvm::dyn_cast<mlir::tf_device::ClusterOp>(input_value.getOwner()); |
123 | if (input_cluster) { |
124 | // If value is from another tf_device.cluster output, then query into |
125 | // the terminator of the input cluster to get mlir::Value from Tensorflow |
126 | // operation that is producing the value. |
127 | *val_output = input_cluster.GetBody().getTerminator()->getOperand( |
128 | input_value.getResultNumber()); |
129 | } else { |
130 | *val_output = input_value; |
131 | } |
132 | return mlir::success(); |
133 | } |
134 | |
135 | // Copies constant operation to mesh clusters if there are multiple usages of |
136 | // constants across multiple mesh computations. This is needed for 2 reasons. |
137 | // a) Cloning constants across mesh can reduce send/recvs during execution. |
138 | // b) DTensor SPMD Expansion for some ops (like tf.reduce_sum) requires inputs |
139 | // to computation to be constants. |
140 | mlir::LogicalResult CloneConstantsAcrossMesh( |
141 | mlir::tf_device::ClusterOp cluster) { |
142 | auto& body_region = cluster.getBody(); |
143 | Mesh mesh; |
144 | if (mlir::failed(ExtractMeshFromCluster(cluster, &mesh))) |
145 | return mlir::failure(); |
146 | |
147 | mlir::LogicalResult result(mlir::success()); |
148 | mlir::visitUsedValuesDefinedAbove( |
149 | body_region, body_region, [&](mlir::OpOperand* operand) { |
150 | if (mlir::failed(result)) return; |
151 | |
152 | mlir::Value input_value; |
153 | result = GetInputProducingValue(*operand, &input_value); |
154 | if (mlir::failed(result) || !input_value) return; |
155 | |
156 | auto input_cluster = |
157 | input_value.getDefiningOp() |
158 | ->getParentOfType<mlir::tf_device::ClusterOp>(); |
159 | Mesh input_mesh; |
160 | if (mlir::failed(ExtractMeshFromCluster(input_cluster, &input_mesh))) { |
161 | result = mlir::failure(); |
162 | return; |
163 | } |
164 | |
165 | if (input_mesh == mesh) return; |
166 | if (!llvm::isa<mlir::TF::CopyToMeshOp>(operand->getOwner())) { |
167 | result = |
168 | operand->getOwner()->emitOpError(kInvalidTensorTransferErrorMsg); |
169 | return; |
170 | } |
171 | |
172 | mlir::Operation* const_op = GetConstOp(input_value.getDefiningOp()); |
173 | if (const_op) result = CloneOpToCluster(const_op, cluster, operand); |
174 | }); |
175 | |
176 | return result; |
177 | } |
178 | |
179 | // Transforms CopyToMesh op to a pair of DTensorSend/DTensorRecv operations. |
180 | mlir::LogicalResult LowerToSendRecv(mlir::TF::CopyToMeshOp copy_to_mesh, |
181 | mlir::MLIRContext* context, |
182 | int* send_recv_counter) { |
183 | const mlir::OpResult copied_value = |
184 | copy_to_mesh.input().cast<mlir::OpResult>(); |
185 | const int result_index = copied_value.getResultNumber(); |
186 | auto src_cluster = |
187 | llvm::cast<mlir::tf_device::ClusterOp>(copied_value.getDefiningOp()); |
188 | mlir::Value value_to_send = |
189 | src_cluster.GetBody().getTerminator()->getOperand(result_index); |
190 | |
191 | // Create DTensorSend op that sends `value_to_send` across mesh cluster. |
192 | mlir::OpBuilder builder(value_to_send.getParentBlock()->getTerminator()); |
193 | |
194 | const std::string op_key = |
195 | llvm::formatv("communication_key_{0}_{1}" , copy_to_mesh.layout(), |
196 | *send_recv_counter) |
197 | .str(); |
198 | const std::string layout_attr = copy_to_mesh.layout().str(); |
199 | auto layout_or_status = Layout::FromString(layout_attr); |
200 | if (!layout_or_status.ok()) |
201 | return copy_to_mesh.emitOpError( |
202 | llvm::formatv(kInvalidLayoutMsg, layout_attr)); |
203 | |
204 | // Create send op that sends data from input cluster to target cluster. |
205 | const Layout& target_layout = layout_or_status.value(); |
206 | builder.create<mlir::TF::DTensorSend>( |
207 | copy_to_mesh.getLoc(), value_to_send, builder.getStringAttr(op_key), |
208 | mlir::dtensor::LayoutAttr::get(context, target_layout)); |
209 | |
210 | // Create recv op that recvs data from send op. |
211 | auto tensor_type = value_to_send.getType().dyn_cast<mlir::TensorType>(); |
212 | if (!tensor_type) |
213 | return copy_to_mesh.emitOpError( |
214 | "found CopyToMesh sending value with unknown shape. Inputs to " |
215 | "CopyToMesh op must have static shape." ); |
216 | |
217 | builder.setInsertionPoint(copy_to_mesh); |
218 | auto recv_op = builder.create<mlir::TF::DTensorRecv>( |
219 | copy_to_mesh.getLoc(), value_to_send.getType(), |
220 | builder.getStringAttr(op_key), |
221 | mlir::TF::ShapeAttr::get(context, tensor_type), |
222 | mlir::dtensor::LayoutAttr::get(context, target_layout)); |
223 | |
224 | // Replace value for recv ops for all usages of `copy_to_mesh` op. |
225 | copy_to_mesh.replaceAllUsesWith(recv_op.output()); |
226 | |
227 | // Remove copy to mesh op. |
228 | copy_to_mesh.erase(); |
229 | |
230 | *send_recv_counter += 1; |
231 | |
232 | return mlir::success(); |
233 | } |
234 | |
235 | // Lowers tf.CopyToMesh to a pair of DTensorSend/DTensorRecv operations. |
236 | // |
237 | // For example: |
238 | // %0 = "tf_device.cluster"() ({ |
239 | // %1 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32> |
240 | // tf_device.return %1 : tensor<i32> |
241 | // }) {_mesh="mesh:CPU,x=2,y=2"} : () -> (tensor<i32>) |
242 | // |
243 | // %2 = "tf_device.cluster"() ({ |
244 | // %3 = "tf.CopyToMesh"(%0) |
245 | // { layout ="mesh:TPU,x=2,y=2 layout:x,replicated" } : |
246 | // (tensor<i32>) -> (tensor<i32>) |
247 | // %4 = "tf.Neg"(%3) : (tensor<i32>) -> tensor<i32> |
248 | // tf_device.return %4 : tensor<i32> |
249 | // }) {_mesh="mesh:TPU,x=2,y=2"} : () -> (tensor<i32>) |
250 | // return |
251 | // } |
252 | // |
253 | // Is transformed to: |
254 | // |
255 | // %0 = "tf_device.cluster"() ({ |
256 | // %1 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32> |
257 | // "tf.DTensorSend"(%1) {...} : (tensor<i32>) -> () |
258 | // tf_device.return %1 : tensor<i32> |
259 | // }) {_mesh="mesh:CPU,x=2,y=2"} : () -> (tensor<i32>) |
260 | // |
261 | // %2 = "tf_device.cluster"() ({ |
262 | // %3 = "tf.DTensorRecv"() {...} : () -> (tensor<i32>) |
263 | // %4 = "tf.Neg"(%3) : (tensor<i32>) -> tensor<i32> |
264 | // tf_device.return %4 : tensor<i32> |
265 | // }) {_mesh="mesh:TPU,x=2,y=2"} : () -> (tensor<i32>) |
266 | // return |
267 | // } |
268 | mlir::LogicalResult ReplaceCopyToMeshWithVirtualSendRecv( |
269 | mlir::tf_device::ClusterOp cluster, mlir::MLIRContext* context, |
270 | int* send_recv_counter) { |
271 | Mesh current_mesh; |
272 | if (mlir::failed(ExtractMeshFromCluster(cluster, ¤t_mesh))) |
273 | return mlir::failure(); |
274 | |
275 | mlir::Region& cluster_region = cluster.getBody(); |
276 | mlir::LogicalResult result = mlir::success(); |
277 | |
278 | mlir::visitUsedValuesDefinedAbove( |
279 | cluster_region, cluster_region, [&](mlir::OpOperand* operand) { |
280 | mlir::Value input_value; |
281 | if (mlir::failed(GetInputProducingValue(*operand, &input_value))) { |
282 | result = mlir::failure(); |
283 | return; |
284 | } |
285 | if (!input_value) return; |
286 | |
287 | auto input_cluster = |
288 | input_value.getDefiningOp() |
289 | ->getParentOfType<mlir::tf_device::ClusterOp>(); |
290 | Mesh input_mesh; |
291 | if (mlir::failed(ExtractMeshFromCluster(input_cluster, &input_mesh))) { |
292 | result = mlir::failure(); |
293 | return; |
294 | } |
295 | |
296 | if (current_mesh == input_mesh) return; |
297 | |
298 | // Check that values that cross mesh boundaries go through CopyToMesh |
299 | // op. |
300 | mlir::Operation* input_op = operand->getOwner(); |
301 | mlir::TF::CopyToMeshOp copy_to_mesh = |
302 | llvm::dyn_cast<mlir::TF::CopyToMeshOp>(input_op); |
303 | if (!copy_to_mesh) { |
304 | result = |
305 | operand->getOwner()->emitOpError(kInvalidTensorTransferErrorMsg); |
306 | return; |
307 | } |
308 | |
309 | // Lower CopyToMesh op to a pair of virtual Send/Recv op. |
310 | if (mlir::failed( |
311 | LowerToSendRecv(copy_to_mesh, context, send_recv_counter))) { |
312 | result = mlir::failure(); |
313 | return; |
314 | } |
315 | }); |
316 | return result; |
317 | } |
318 | |
319 | struct DTensorHandleCrossClusterDependencies |
320 | : public impl::DTensorHandleCrossClusterDependenciesBase< |
321 | DTensorHandleCrossClusterDependencies> { |
322 | void getDependentDialects(mlir::DialectRegistry& registry) const override { |
323 | registry.insert<mlir::dtensor::DTensorDialect>(); |
324 | } |
325 | |
326 | void runOnOperation() override { |
327 | mlir::MLIRContext& context = getContext(); |
328 | mlir::ModuleOp module = getOperation(); |
329 | llvm::SmallVector<mlir::tf_device::ClusterOp, 4> clusters; |
330 | module.walk([&](mlir::tf_device::ClusterOp cluster) { |
331 | clusters.emplace_back(cluster); |
332 | }); |
333 | |
334 | int send_recv_counter = 0; |
335 | for (auto cluster : clusters) { |
336 | if (mlir::failed(CloneConstantsAcrossMesh(cluster))) |
337 | return signalPassFailure(); |
338 | |
339 | if (mlir::failed(ReplaceCopyToMeshWithVirtualSendRecv( |
340 | cluster, &context, &send_recv_counter))) |
341 | return signalPassFailure(); |
342 | } |
343 | |
344 | // Once CopyToMesh has been lowered to DTensorSend/Recv operations, |
345 | // tf_device.Cluster may now have dangling/unused result values. Remove all |
346 | // such return values. |
347 | for (auto cluster : clusters) RemoveUnusedClusterResults(cluster); |
348 | } |
349 | }; |
350 | |
351 | } // namespace |
352 | |
353 | std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> |
354 | CreateDTensorHandleCrossClusterDependencies() { |
355 | return std::make_unique<DTensorHandleCrossClusterDependencies>(); |
356 | } |
357 | |
358 | } // namespace dtensor |
359 | } // namespace tensorflow |
360 | |