1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "llvm/ADT/STLExtras.h" |
17 | #include "llvm/ADT/SmallVector.h" |
18 | #include "llvm/ADT/StringRef.h" |
19 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
20 | #include "mlir/IR/Attributes.h" // from @llvm-project |
21 | #include "mlir/IR/Builders.h" // from @llvm-project |
22 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
23 | #include "mlir/IR/Diagnostics.h" // from @llvm-project |
24 | #include "mlir/IR/Operation.h" // from @llvm-project |
25 | #include "mlir/IR/Types.h" // from @llvm-project |
26 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
27 | #include "mlir/Transforms/Passes.h" // from @llvm-project |
28 | #include "mlir/Transforms/RegionUtils.h" // from @llvm-project |
29 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
30 | #include "tensorflow/dtensor/cc/constants.h" |
31 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
32 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
33 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
34 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
35 | |
36 | namespace tensorflow { |
37 | namespace dtensor { |
38 | |
39 | namespace { |
40 | #define GEN_PASS_DEF_DTENSORDEVICEMESHCLUSTERCOARSENING |
41 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
42 | |
43 | constexpr char kMissingMeshAttributeErrorMessage[] = |
44 | "failed to merge mesh cluster as cluster does not have mesh attribute. " |
45 | "This is likely due to problem in mesh propagation." ; |
46 | |
47 | // Determines whether two adjoining clusters should be merged. |
48 | mlir::LogicalResult ShouldMergeClusters(mlir::tf_device::ClusterOp cluster_a, |
49 | mlir::tf_device::ClusterOp cluster_b, |
50 | bool* should_merge) { |
51 | if (cluster_a->getParentRegion() != cluster_b->getParentRegion()) { |
52 | *should_merge = false; |
53 | return mlir::success(); |
54 | } |
55 | |
56 | auto mesh_a_or_status = ExtractDeviceMeshFromOp(cluster_a.getOperation()); |
57 | if (!mesh_a_or_status.ok()) |
58 | return cluster_a.emitOpError(mesh_a_or_status.status().error_message()); |
59 | |
60 | auto mesh_b_or_status = ExtractDeviceMeshFromOp(cluster_b.getOperation()); |
61 | if (!mesh_b_or_status.ok()) |
62 | return cluster_b.emitOpError(mesh_b_or_status.status().error_message()); |
63 | |
64 | auto mesh_a = mesh_a_or_status.value(); |
65 | auto mesh_b = mesh_b_or_status.value(); |
66 | if (!mesh_a || !mesh_b) { |
67 | return !mesh_a ? cluster_a.emitOpError(kMissingMeshAttributeErrorMessage) |
68 | : cluster_b.emitOpError(kMissingMeshAttributeErrorMessage); |
69 | } |
70 | |
71 | *should_merge = mesh_a == mesh_b; |
72 | return mlir::success(); |
73 | } |
74 | |
75 | // Moves all ops (except tf_device.return op) inside `src_cluster` to |
76 | // block inside `target_cluster`. Ops are moved before the `exit_op` |
77 | // inside the `target_cluster`. |
78 | void MoveOpsInsideCluster(mlir::tf_device::ClusterOp src_cluster, |
79 | mlir::tf_device::ClusterOp target_cluster, |
80 | mlir::Operation* exit_op) { |
81 | auto& cluster_body = src_cluster.GetBody().getOperations(); |
82 | target_cluster.GetBody().getOperations().splice( |
83 | exit_op->getIterator(), cluster_body, cluster_body.begin(), |
84 | std::prev(cluster_body.end())); |
85 | } |
86 | |
87 | // Returns a list of pair of mlir Values that represent <return values of ops |
88 | // inside the merged_cluster, output values of merged cluster>. |
89 | // |
90 | // If outputs of `current_cluster` is used as operands to ops in |
91 | // `merging_cluster`, then make sure to replace operands such that |
92 | // results values from the inner ops of `current_cluster` is used instead. |
93 | // |
94 | // For example, |
95 | // %0 = "tf_device.cluster"() ({ |
96 | // %1 = "tf.A"() : () -> tensor<i32> |
97 | // "tf_device.return"(%1) : (tensor<i32>) -> () |
98 | // }) { mesh = "mesh_config: cpu[1, 1]"} : () -> (tensor<i32>) |
99 | // |
100 | // %2 = "tf_device.cluster"() ({ |
101 | // %3 = "tf.B"(%0) : (tenosr<i32>) -> tensor<f32> |
102 | // "tf_device.return"(%3) : (tensor<f32>) -> () |
103 | // }) { mesh = "mesh_config: cpu[1, 1]"} : () -> (tensor<f32>) |
104 | // |
105 | // will be: |
106 | // %0 = "tf_device.cluster"() ({ |
107 | // %1 = "tf.A"() : () -> tensor<i32> |
108 | // |
109 | // # NOTE: `tf.B` op now takes operand directly from |
110 | // # `tf.A` instead of `tf_dtensor.cluster op. |
111 | // %2 = "tf.B"(%1) : (tenosr<i32>) -> tensor<f32> |
112 | // "tf_device.return"(%1, %2) : (tensor<i32>, tensor<f32>)) -> () |
113 | // }) {mesh = "mesh_config: cpu[1, 1]"} : () -> (tensor<i32>, tensor<f32>) |
114 | llvm::SmallVector<std::pair<mlir::Value, mlir::Value>, 8> |
115 | GetMergedMeshClusterResults(mlir::tf_device::ClusterOp current_cluster, |
116 | mlir::tf_device::ClusterOp merging_cluster) { |
117 | llvm::SmallVector<std::pair<mlir::Value, mlir::Value>, 8> |
118 | merged_cluster_results; |
119 | merged_cluster_results.reserve(current_cluster.getNumResults() + |
120 | merging_cluster.getNumResults()); |
121 | |
122 | auto current_cluster_return_op = current_cluster.GetBody().getTerminator(); |
123 | for (auto result : llvm::zip(current_cluster_return_op->getOpOperands(), |
124 | current_cluster.getResults())) { |
125 | mlir::Value inner_op_result = std::get<0>(result).get(); |
126 | mlir::Value outer_op_result = std::get<1>(result); |
127 | |
128 | // If the output value of `current_cluster` is only used by ops |
129 | // inside the `merged_cluster`, do not add the value as a return |
130 | // value for newly created tf_device.cluster op. |
131 | bool result_only_used_by_merging_cluster = true; |
132 | for (auto& use : llvm::make_early_inc_range(outer_op_result.getUses())) { |
133 | if (merging_cluster.GetBody().findAncestorOpInBlock(*use.getOwner())) { |
134 | use.set(inner_op_result); |
135 | } else { |
136 | result_only_used_by_merging_cluster = false; |
137 | } |
138 | } |
139 | |
140 | if (!result_only_used_by_merging_cluster) { |
141 | merged_cluster_results.emplace_back(inner_op_result, outer_op_result); |
142 | } |
143 | } |
144 | |
145 | auto merging_cluster_return_op = merging_cluster.GetBody().getTerminator(); |
146 | for (auto result : llvm::zip(merging_cluster_return_op->getOpOperands(), |
147 | merging_cluster.getResults())) { |
148 | mlir::Value inner_op_result = std::get<0>(result).get(); |
149 | mlir::Value outer_op_result = std::get<1>(result); |
150 | |
151 | if (!outer_op_result.getUses().empty()) |
152 | merged_cluster_results.emplace_back(inner_op_result, outer_op_result); |
153 | } |
154 | |
155 | return merged_cluster_results; |
156 | } |
157 | |
158 | // Updates the users of `merging_cluster` so that they use values |
159 | // from `merged_cluster` instead. |
160 | void ReplaceOperandUsagesWithMergedClusterOutputs( |
161 | const llvm::SmallVectorImpl<mlir::Value>& values_to_replace, |
162 | mlir::tf_device::ClusterOp merged_cluster) { |
163 | for (auto result : |
164 | llvm::zip(values_to_replace, merged_cluster.getResults())) { |
165 | std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); |
166 | } |
167 | } |
168 | |
169 | // Creates a new tf_device.cluster op that merges |
170 | // `current_cluster` and `merging_cluster`. |
171 | mlir::LogicalResult CreateMergedMeshCluster( |
172 | mlir::OpBuilder* builder, mlir::tf_device::ClusterOp current_cluster, |
173 | mlir::tf_device::ClusterOp merging_cluster, |
174 | mlir::tf_device::ClusterOp* merged_cluster) { |
175 | auto return_values = |
176 | GetMergedMeshClusterResults(current_cluster, merging_cluster); |
177 | |
178 | llvm::SmallVector<mlir::Type, 8> merged_cluster_output_types; |
179 | llvm::SmallVector<mlir::Value, 8> merged_cluster_output_values; |
180 | llvm::SmallVector<mlir::Value, 8> output_values_to_replace; |
181 | merged_cluster_output_types.reserve(return_values.size()); |
182 | merged_cluster_output_values.reserve(return_values.size()); |
183 | output_values_to_replace.reserve(return_values.size()); |
184 | for (auto cluster_return_value : return_values) { |
185 | auto inner_op_return_value = std::get<0>(cluster_return_value); |
186 | merged_cluster_output_types.emplace_back(inner_op_return_value.getType()); |
187 | merged_cluster_output_values.emplace_back(inner_op_return_value); |
188 | output_values_to_replace.emplace_back(std::get<1>(cluster_return_value)); |
189 | } |
190 | |
191 | *merged_cluster = builder->create<mlir::tf_device::ClusterOp>( |
192 | current_cluster.getLoc(), merged_cluster_output_types); |
193 | auto mesh_attr = current_cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr); |
194 | if (!mesh_attr) |
195 | return current_cluster.emitOpError(kMissingMeshAttributeErrorMessage); |
196 | |
197 | (*merged_cluster)->setAttr(kMeshAttr, mesh_attr); |
198 | |
199 | // Create a terminator op that returns all return values from |
200 | // `current_cluster` and `merging_cluster`. |
201 | merged_cluster->getBody().push_back(new mlir::Block); |
202 | builder->setInsertionPointToEnd(&merged_cluster->GetBody()); |
203 | builder->create<mlir::tf_device::ReturnOp>(merged_cluster->getLoc(), |
204 | merged_cluster_output_values); |
205 | |
206 | // Make sure to replace usages of tf_device.cluster ops to be merged-away with |
207 | // newly created tf_device.cluster op. |
208 | ReplaceOperandUsagesWithMergedClusterOutputs(output_values_to_replace, |
209 | *merged_cluster); |
210 | |
211 | return mlir::success(); |
212 | } |
213 | |
214 | // Merges `current_cluster` and `merging_cluster` and returns a new merged |
215 | // tf_device.cluster. |
216 | mlir::LogicalResult MergeClusters(mlir::OpBuilder* builder, |
217 | mlir::tf_device::ClusterOp current_cluster, |
218 | mlir::tf_device::ClusterOp merging_cluster, |
219 | mlir::tf_device::ClusterOp* merged_cluster) { |
220 | builder->setInsertionPoint(current_cluster); |
221 | |
222 | // Create new tf_device.cluster op that outputs results of both |
223 | // `current_cluster` and `merging_cluster`. |
224 | if (mlir::failed(CreateMergedMeshCluster(builder, current_cluster, |
225 | merging_cluster, merged_cluster))) |
226 | return mlir::failure(); |
227 | |
228 | // Move all ops to newly created merged cluster. |
229 | auto exit_op = merged_cluster->GetBody().getTerminator(); |
230 | MoveOpsInsideCluster(current_cluster, *merged_cluster, exit_op); |
231 | MoveOpsInsideCluster(merging_cluster, *merged_cluster, exit_op); |
232 | |
233 | // Remove mesh clusters as they are now merged to a new cluster. |
234 | current_cluster.erase(); |
235 | merging_cluster.erase(); |
236 | return mlir::success(); |
237 | } |
238 | |
239 | // Loops through tf_device.Cluster ops and merge clusters with same execution |
240 | // device set. |
241 | mlir::LogicalResult ClusterDeviceClusterOpsInBlock(mlir::OpBuilder* builder, |
242 | mlir::Block* block) { |
243 | llvm::SmallVector<mlir::tf_device::ClusterOp, 4> block_ops; |
244 | block->walk([&](mlir::Operation* op) { |
245 | if (auto cluster = llvm::dyn_cast<mlir::tf_device::ClusterOp>(op)) |
246 | block_ops.emplace_back(cluster); |
247 | }); |
248 | |
249 | llvm::Optional<mlir::tf_device::ClusterOp> current_cluster; |
250 | for (mlir::tf_device::ClusterOp cluster : |
251 | llvm::make_early_inc_range(block_ops)) { |
252 | if (!current_cluster.has_value()) { |
253 | current_cluster = cluster; |
254 | continue; |
255 | } |
256 | bool should_merge; |
257 | if (failed(ShouldMergeClusters(*current_cluster, cluster, &should_merge))) |
258 | return mlir::failure(); |
259 | |
260 | if (should_merge) { |
261 | mlir::tf_device::ClusterOp new_cluster; |
262 | if (mlir::failed( |
263 | MergeClusters(builder, *current_cluster, cluster, &new_cluster))) |
264 | return mlir::failure(); |
265 | |
266 | current_cluster.emplace(new_cluster); |
267 | } else { |
268 | current_cluster.emplace(cluster); |
269 | } |
270 | } |
271 | return mlir::success(); |
272 | } |
273 | |
274 | } // namespace |
275 | |
276 | // MLIR pass that merges cluster ops with the same mesh attribute. |
277 | struct DTensorDeviceMeshClusterCoarsening |
278 | : public impl::DTensorDeviceMeshClusterCoarseningBase< |
279 | DTensorDeviceMeshClusterCoarsening> { |
280 | void runOnOperation() override { |
281 | mlir::MLIRContext& context = getContext(); |
282 | mlir::OpBuilder builder(&context); |
283 | for (mlir::Block& block : getOperation()) |
284 | if (mlir::failed(ClusterDeviceClusterOpsInBlock(&builder, &block))) |
285 | return signalPassFailure(); |
286 | } |
287 | }; |
288 | |
289 | std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> |
290 | CreateDTensorDeviceMeshClusterCoarsening() { |
291 | return std::make_unique<DTensorDeviceMeshClusterCoarsening>(); |
292 | } |
293 | |
294 | } // namespace dtensor |
295 | } // namespace tensorflow |
296 | |