1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <map> |
17 | #include <memory> |
18 | #include <string> |
19 | #include <utility> |
20 | |
21 | #include "llvm/ADT/DenseMap.h" |
22 | #include "llvm/Support/FormatVariadic.h" |
23 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
24 | #include "mlir/IR/Attributes.h" // from @llvm-project |
25 | #include "mlir/IR/Builders.h" // from @llvm-project |
26 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
27 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
28 | #include "mlir/IR/OpDefinition.h" // from @llvm-project |
29 | #include "mlir/IR/Operation.h" // from @llvm-project |
30 | #include "mlir/IR/Value.h" // from @llvm-project |
31 | #include "mlir/Pass/Pass.h" // from @llvm-project |
32 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
33 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
34 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
35 | #include "tensorflow/dtensor/mlir/device_utils.h" |
36 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
37 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
38 | #include "tensorflow/dtensor/mlir/value_utils.h" |
39 | |
40 | namespace tensorflow { |
41 | namespace dtensor { |
42 | |
43 | namespace { |
44 | #define GEN_PASS_DEF_DTENSORMOVECOMPILATIONTOHOST |
45 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
46 | |
47 | // Prefix for send/recv key used for transferring compilation program key. |
48 | constexpr char kSendRecvKeyPrefix[] = "compilation_send_recv_key_" ; |
49 | |
50 | // Identifies all StatefulPartitionedCallOps for executing computation for |
51 | // each mesh cluster and validate that at most one TPU computation exists. |
52 | mlir::LogicalResult IdentifyAndValidateMeshComputations( |
53 | mlir::func::FuncOp function, |
54 | std::map<Mesh, mlir::TF::StatefulPartitionedCallOp>* function_map) { |
55 | for (auto dtensor_function : |
56 | function.getOps<mlir::TF::StatefulPartitionedCallOp>()) { |
57 | auto mesh_or = ExtractDeviceMeshFromOp(dtensor_function); |
58 | if (!mesh_or.ok() || !mesh_or->has_value()) |
59 | return dtensor_function.emitOpError( |
60 | "StatefulPartitionCall op must have `_mesh` attribute specified." ); |
61 | |
62 | const Mesh& computation_mesh = mesh_or->value(); |
63 | if (function_map->count(computation_mesh)) |
64 | return dtensor_function.emitOpError( |
65 | "Found DTensor function with duplicate mesh specification. There " |
66 | "should be exactly 1 function for each mesh in computation cluster." ); |
67 | |
68 | (*function_map)[computation_mesh] = dtensor_function; |
69 | } |
70 | |
71 | int num_xla_meshes = 0; |
72 | for (const auto& it : *function_map) { |
73 | if (it.first.is_tpu_mesh()) num_xla_meshes += 1; |
74 | } |
75 | |
76 | if (num_xla_meshes > 1) |
77 | return function.emitOpError( |
78 | "Multiple XLA computation clusters found. Only 1 XLA cluster for " |
79 | "DTensor computation is supported for now." ); |
80 | |
81 | return mlir::success(); |
82 | } |
83 | |
84 | // Creates Send/Recv ops to transfer TPUCompile program key from host |
85 | // computation to XLA computation. |
86 | mlir::LogicalResult CreateSendRecvOpsToTransferProgramKey( |
87 | const Mesh& mesh, mlir::ModuleOp module, mlir::func::FuncOp function, |
88 | mlir::OpBuilder::InsertPoint insertpoint, |
89 | mlir::TF::_TPUCompileMlirOp compile_op, |
90 | mlir::tf_device::LaunchOp compile_op_launch, int* num_send_recv, |
91 | mlir::Value* program_key_output) { |
92 | mlir::OpBuilder builder(module.getContext()); |
93 | mlir::Value compilation_key = *compile_op.program().begin(); |
94 | absl::Span<const std::string> local_devices = mesh.local_devices(); |
95 | |
96 | // Create tensor name mapping for each send/recv pair. |
97 | llvm::SmallDenseMap<int, std::string> device_key_map; |
98 | const int num_tpu_devices = local_devices.size(); |
99 | device_key_map.reserve(num_tpu_devices); |
100 | for (int i = 0; i < num_tpu_devices; ++i) { |
101 | std::string tensor_name = absl::StrCat(kSendRecvKeyPrefix, *num_send_recv); |
102 | *num_send_recv += 1; |
103 | device_key_map.try_emplace(i, std::move(tensor_name)); |
104 | } |
105 | |
106 | // Create send op to send TPU program key from host computation to XLA |
107 | // computation. |
108 | builder.setInsertionPointAfter(compile_op); |
109 | for (int i = 0; i < num_tpu_devices; ++i) { |
110 | const std::string& tensor_name = device_key_map[i]; |
111 | auto send = builder.create<mlir::TF::_HostSendOp>( |
112 | compile_op->getLoc(), compilation_key, tensor_name, |
113 | compile_op_launch.getDevice(), |
114 | /*send_device_incarnation=*/0, local_devices[i]); |
115 | send->setAttr("device" , compile_op_launch.getDeviceAttr()); |
116 | } |
117 | |
118 | // Create Recv ops to receive program key from host to each xla device |
119 | // computation. |
120 | llvm::SmallVector<mlir::func::FuncOp, 4> compilation_key_functions; |
121 | compilation_key_functions.reserve(num_tpu_devices); |
122 | mlir::SymbolTable symbol_table(module); |
123 | |
124 | // For receiving TPU program key from host, `recv_device` attribute depends |
125 | // on `device_id` argument and therefore cannot be known statically. |
126 | // Therefore, we use tf.Case op to select correct receive op depending on |
127 | // the device id value. |
128 | for (int i = 0; i < num_tpu_devices; ++i) { |
129 | auto func_type = mlir::FunctionType::get( |
130 | builder.getContext(), llvm::ArrayRef<mlir::Type>{}, |
131 | llvm::ArrayRef<mlir::Type>{compilation_key.getType()}); |
132 | |
133 | mlir::func::FuncOp recv_select_fn = mlir::func::FuncOp::create( |
134 | compile_op.getLoc(), |
135 | llvm::formatv("recv_compile_key_{0}_{1}" , i, *num_send_recv).str(), |
136 | func_type, llvm::ArrayRef<mlir::NamedAttribute>{}); |
137 | symbol_table.insert(recv_select_fn); |
138 | *num_send_recv += 1; |
139 | |
140 | mlir::Block* fn_block = recv_select_fn.addEntryBlock(); |
141 | mlir::OpBuilder fn_builder = mlir::OpBuilder::atBlockEnd(fn_block); |
142 | auto recv = fn_builder.create<mlir::TF::_HostRecvOp>( |
143 | compile_op->getLoc(), |
144 | compilation_key.getType().cast<mlir::TensorType>(), device_key_map[i], |
145 | compile_op_launch.getDevice(), /*send_device_incarnation=*/0, |
146 | local_devices[i]); |
147 | recv->setAttr("device" , builder.getStringAttr(local_devices[i])); |
148 | |
149 | fn_builder.create<mlir::func::ReturnOp>(recv_select_fn.getLoc(), |
150 | recv.tensor()); |
151 | |
152 | compilation_key_functions.emplace_back(recv_select_fn); |
153 | } |
154 | |
155 | // Create logic that receives program key. |
156 | builder.restoreInsertionPoint(insertpoint); |
157 | auto device_id = GetDeviceOrdinal(mesh, function.getLoc(), function, &builder, |
158 | /*return_int64_type=*/false); |
159 | if (!device_id.ok()) return function->emitOpError("Cannot get device id" ); |
160 | |
161 | llvm::SmallVector<mlir::Attribute, 4> symbols; |
162 | for (auto& func : compilation_key_functions) |
163 | symbols.push_back(mlir::SymbolRefAttr::get(func)); |
164 | |
165 | // Create a TF::Case op that selects `values` based on `id`. |
166 | auto program_key = builder.create<mlir::TF::CaseOp>( |
167 | compile_op.getLoc(), |
168 | /*output=*/llvm::SmallVector<mlir::Type, 4>{compilation_key.getType()}, |
169 | /*branch_index=*/*device_id, |
170 | /*input=*/llvm::ArrayRef<mlir::Value>{}, |
171 | /*branches=*/builder.getArrayAttr(symbols), |
172 | /*is_stateless=*/builder.getBoolAttr(false)); |
173 | *program_key_output = program_key.getResult(0); |
174 | return mlir::success(); |
175 | } |
176 | |
177 | struct CompilationKeyRecvInfo { |
178 | const Mesh& receiving_function_mesh; |
179 | mlir::func::FuncOp receiving_function; |
180 | mlir::OpBuilder::InsertPoint recv_insertion_point; |
181 | mlir::Value program_key; |
182 | }; |
183 | |
184 | // Broadcasts compilation key across meshes specified by `recv_info`. The |
185 | // broadcasted compilation key is added to `program_key` of each vector |
186 | // element of `recv_info`. |
187 | mlir::LogicalResult SendRecvCompilationKey( |
188 | const Mesh& host_mesh, mlir::ModuleOp module, |
189 | mlir::TF::_TPUCompileMlirOp compile_op, |
190 | mlir::tf_device::LaunchOp compile_launch_op, |
191 | mlir::Operation* compilation_move_before, int* num_send_recv, |
192 | llvm::SmallVectorImpl<CompilationKeyRecvInfo>* recv_info) { |
193 | for (int i = 0; i < recv_info->size(); ++i) { |
194 | CompilationKeyRecvInfo& info = (*recv_info)[i]; |
195 | // Create send/recv ops to transfer compilation key from receiving meshes. |
196 | mlir::Value program_key; |
197 | if (mlir::failed(CreateSendRecvOpsToTransferProgramKey( |
198 | info.receiving_function_mesh, module, info.receiving_function, |
199 | info.recv_insertion_point, compile_op, compile_launch_op, |
200 | num_send_recv, &program_key))) |
201 | return mlir::failure(); |
202 | |
203 | info.program_key = program_key; |
204 | } |
205 | |
206 | return mlir::success(); |
207 | } |
208 | |
209 | mlir::LogicalResult HandleCompilationOps( |
210 | const llvm::SmallVectorImpl< |
211 | mlir::TF::_TPUCompileMlirPlaceholderProgramKeyOp>& compilation_key_ops, |
212 | std::map<Mesh, mlir::TF::StatefulPartitionedCallOp>& computation_map, |
213 | mlir::ModuleOp module, int* num_send_recv) { |
214 | // Identity XLA function and corresponding CPU functions to move compilation. |
215 | const auto xla_mesh = llvm::find_if( |
216 | computation_map, [](const auto& it) { return it.first.is_tpu_mesh(); }); |
217 | |
218 | if (xla_mesh == computation_map.end()) { |
219 | return module.emitOpError( |
220 | "Found TPUCompilationKey op but XLA computation does not exist." ); |
221 | } |
222 | |
223 | mlir::func::FuncOp tpu_function = xla_mesh->second.func(); |
224 | mlir::func::FuncOp host_function; |
225 | Mesh host_mesh; |
226 | for (auto compilation_key : compilation_key_ops) { |
227 | auto parent_function = |
228 | compilation_key->getParentOfType<mlir::func::FuncOp>(); |
229 | |
230 | if (!host_function) { |
231 | host_function = parent_function; |
232 | auto mesh_it = llvm::find_if(computation_map, [&](auto& it) { |
233 | return it.second.f() == host_function.getSymName(); |
234 | }); |
235 | if (mesh_it == computation_map.end()) |
236 | return compilation_key.emitOpError( |
237 | "cannot find host mesh for TPU computation." ); |
238 | |
239 | host_mesh = mesh_it->first; |
240 | |
241 | } else { |
242 | // TODO(hongjunchoi): Handle the case when CopyToMesh is used with |
243 | // special topology approach. In this case there will be 2 host |
244 | // meshes/functions. |
245 | if (host_function != parent_function) |
246 | return compilation_key.emitOpError( |
247 | "Found multiple TPU host mesh functions. There must be at most one " |
248 | "TPU host function." ); |
249 | } |
250 | } |
251 | |
252 | // Identify TPUCompileOp to host side mesh. |
253 | llvm::SmallVector<mlir::TF::_TPUCompileMlirOp, 4> compile_ops; |
254 | tpu_function.walk( |
255 | [&](mlir::TF::_TPUCompileMlirOp op) { compile_ops.emplace_back(op); }); |
256 | |
257 | const int num_compilations = compile_ops.size(); |
258 | if (num_compilations != 1) |
259 | return tpu_function.emitOpError(llvm::formatv( |
260 | "Expected exactly 1 compilation op for TPU computation. Found {0}" , |
261 | num_compilations)); |
262 | |
263 | mlir::TF::_TPUCompileMlirOp compile_op = *compile_ops.begin(); |
264 | mlir::Operation& first_host_op = host_function.getBody().front().front(); |
265 | mlir::OpBuilder builder(&first_host_op); |
266 | mlir::OpBuilder::InsertPoint host_insertion_point = |
267 | builder.saveInsertionPoint(); |
268 | mlir::Operation* compilation_move_before = &first_host_op; |
269 | |
270 | // If host mesh has multiple local devices only conduct compilation for the |
271 | // first host device by creating If Op to only compile for host with device |
272 | // ordinal 0. |
273 | if (host_mesh.local_device_ids().size() > 1) { |
274 | auto device_ordinal_host = GetDeviceOrdinal( |
275 | host_mesh, compile_op.getLoc(), |
276 | first_host_op.getParentOfType<mlir::func::FuncOp>(), &builder); |
277 | if (!device_ordinal_host.ok()) |
278 | return compile_op.emitOpError( |
279 | llvm::formatv("error while creating TPU compilation logic. {0}" , |
280 | device_ordinal_host.status().error_message())); |
281 | |
282 | mlir::Value predicate_host = builder.create<mlir::TF::EqualOp>( |
283 | compile_op.getLoc(), *device_ordinal_host, |
284 | CreateIntScalarConst(0, builder, compile_op.getLoc()), |
285 | /*incompatible_shape_error=*/builder.getBoolAttr(true)); |
286 | |
287 | // If op here contains send/recv and TPUCompile op that should not be pruned |
288 | // away. Therefore, we explicitly set the op to be stateful. |
289 | auto if_host = builder.create<mlir::TF::IfRegionOp>( |
290 | compile_op.getLoc(), llvm::SmallVector<mlir::Type, 4>{}, predicate_host, |
291 | /*is_stateless=*/builder.getBoolAttr(false), |
292 | GetUniqueControlflowFnName("compilation_host_then" , builder), |
293 | GetUniqueControlflowFnName("compilation_host_else" , builder)); |
294 | |
295 | // Create empty else branch region. |
296 | auto& host_else_branch = if_host.else_branch(); |
297 | host_else_branch.push_back(new mlir::Block); |
298 | builder.setInsertionPointToEnd(&host_else_branch.front()); |
299 | builder.create<mlir::TF::YieldOp>( |
300 | compile_op.getLoc(), |
301 | /*operands=*/llvm::ArrayRef<mlir::Value>{}); |
302 | |
303 | // Create then branch region with logic to compile TPU program and send |
304 | // program key to all TPU devices. |
305 | auto& host_then_branch = if_host.then_branch(); |
306 | host_then_branch.push_back(new mlir::Block); |
307 | builder.setInsertionPointToEnd(&host_then_branch.front()); |
308 | auto yield = builder.create<mlir::TF::YieldOp>( |
309 | compile_op.getLoc(), |
310 | /*operands=*/llvm::ArrayRef<mlir::Value>{}); |
311 | compilation_move_before = yield; |
312 | |
313 | builder.setInsertionPointAfter(if_host); |
314 | host_insertion_point = builder.saveInsertionPoint(); |
315 | } |
316 | |
317 | auto compile_launch_op = |
318 | compile_op->getParentOfType<mlir::tf_device::LaunchOp>(); |
319 | |
320 | // Move Compile op and compile succeeded assert ops to host function. |
321 | compile_launch_op->moveBefore(compilation_move_before); |
322 | |
323 | for (mlir::Operation* user : compile_launch_op.getResult(0).getUsers()) |
324 | user->getParentOfType<mlir::tf_device::LaunchOp>()->moveBefore( |
325 | compilation_move_before); |
326 | |
327 | // Send and receive compilation key across meshes. |
328 | llvm::SmallVector<CompilationKeyRecvInfo, 4> compilation_key_recv_info; |
329 | builder.setInsertionPointToStart(&tpu_function.front()); |
330 | auto device_insertion_point = builder.saveInsertionPoint(); |
331 | compilation_key_recv_info.emplace_back(CompilationKeyRecvInfo{ |
332 | xla_mesh->first, tpu_function, device_insertion_point, nullptr}); |
333 | |
334 | compilation_key_recv_info.emplace_back(CompilationKeyRecvInfo{ |
335 | host_mesh, host_function, host_insertion_point, nullptr}); |
336 | |
337 | if (mlir::failed(SendRecvCompilationKey( |
338 | host_mesh, module, compile_op, compile_launch_op, |
339 | compilation_move_before, num_send_recv, &compilation_key_recv_info))) |
340 | return mlir::failure(); |
341 | |
342 | // Replace usages of TPU program key in host and device meshes. |
343 | mlir::Value device_program_key = compilation_key_recv_info[0].program_key; |
344 | tpu_function.walk([&](mlir::Operation* op) { |
345 | if (llvm::isa<mlir::TF::TPUExecuteOp, |
346 | mlir::TF::TPUExecuteAndUpdateVariablesOp>(op)) |
347 | op->setOperand(op->getNumOperands() - 1, device_program_key); |
348 | }); |
349 | |
350 | // Remove placeholder CompilationKey ops and replace it's usages with output |
351 | // of TPUCompile op. |
352 | mlir::Value host_program_key = compilation_key_recv_info[1].program_key; |
353 | for (auto compilation_key_op : compilation_key_ops) { |
354 | compilation_key_op.replaceAllUsesWith(host_program_key); |
355 | compilation_key_op.erase(); |
356 | } |
357 | return mlir::success(); |
358 | } |
359 | |
360 | // Pass to move TPUCompile/TPUCompileSucceededAssert op to host mesh computation |
361 | // and add necessary send/recv ops to transfer TPU program key to TPU device |
362 | // computation. |
363 | struct DTensorMoveCompilationToHost |
364 | : public impl::DTensorMoveCompilationToHostBase< |
365 | DTensorMoveCompilationToHost> { |
366 | void runOnOperation() override { |
367 | mlir::MLIRContext& context = getContext(); |
368 | mlir::OpBuilder builder(&context); |
369 | auto module = getOperation(); |
370 | |
371 | llvm::SmallVector<mlir::TF::_TPUCompileMlirPlaceholderProgramKeyOp, 4> |
372 | compilation_key_ops; |
373 | module.walk([&](mlir::TF::_TPUCompileMlirPlaceholderProgramKeyOp op) { |
374 | compilation_key_ops.emplace_back(op); |
375 | }); |
376 | |
377 | if (compilation_key_ops.empty()) return; |
378 | |
379 | mlir::func::FuncOp main_func = |
380 | module.lookupSymbol<mlir::func::FuncOp>("main" ); |
381 | if (!main_func) return; |
382 | |
383 | std::map<Mesh, mlir::TF::StatefulPartitionedCallOp> computation_map; |
384 | if (mlir::failed( |
385 | IdentifyAndValidateMeshComputations(main_func, &computation_map))) |
386 | return signalPassFailure(); |
387 | |
388 | int num_send_recv = 0; |
389 | if (mlir::failed(HandleCompilationOps(compilation_key_ops, computation_map, |
390 | module, &num_send_recv))) |
391 | return signalPassFailure(); |
392 | }; |
393 | }; |
394 | |
395 | } // namespace |
396 | |
397 | std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> |
398 | CreateDTensorMoveCompilationToHost() { |
399 | return std::make_unique<DTensorMoveCompilationToHost>(); |
400 | } |
401 | |
402 | } // namespace dtensor |
403 | } // namespace tensorflow |
404 | |