1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <map>
17#include <memory>
18#include <string>
19#include <utility>
20
21#include "llvm/ADT/DenseMap.h"
22#include "llvm/Support/FormatVariadic.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
24#include "mlir/IR/Attributes.h" // from @llvm-project
25#include "mlir/IR/Builders.h" // from @llvm-project
26#include "mlir/IR/BuiltinOps.h" // from @llvm-project
27#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
28#include "mlir/IR/OpDefinition.h" // from @llvm-project
29#include "mlir/IR/Operation.h" // from @llvm-project
30#include "mlir/IR/Value.h" // from @llvm-project
31#include "mlir/Pass/Pass.h" // from @llvm-project
32#include "mlir/Support/LogicalResult.h" // from @llvm-project
33#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
34#include "tensorflow/dtensor/cc/tensor_layout.h"
35#include "tensorflow/dtensor/mlir/device_utils.h"
36#include "tensorflow/dtensor/mlir/layout_parsing.h"
37#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
38#include "tensorflow/dtensor/mlir/value_utils.h"
39
40namespace tensorflow {
41namespace dtensor {
42
43namespace {
44#define GEN_PASS_DEF_DTENSORMOVECOMPILATIONTOHOST
45#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
46
47// Prefix for send/recv key used for transferring compilation program key.
48constexpr char kSendRecvKeyPrefix[] = "compilation_send_recv_key_";
49
50// Identifies all StatefulPartitionedCallOps for executing computation for
51// each mesh cluster and validate that at most one TPU computation exists.
52mlir::LogicalResult IdentifyAndValidateMeshComputations(
53 mlir::func::FuncOp function,
54 std::map<Mesh, mlir::TF::StatefulPartitionedCallOp>* function_map) {
55 for (auto dtensor_function :
56 function.getOps<mlir::TF::StatefulPartitionedCallOp>()) {
57 auto mesh_or = ExtractDeviceMeshFromOp(dtensor_function);
58 if (!mesh_or.ok() || !mesh_or->has_value())
59 return dtensor_function.emitOpError(
60 "StatefulPartitionCall op must have `_mesh` attribute specified.");
61
62 const Mesh& computation_mesh = mesh_or->value();
63 if (function_map->count(computation_mesh))
64 return dtensor_function.emitOpError(
65 "Found DTensor function with duplicate mesh specification. There "
66 "should be exactly 1 function for each mesh in computation cluster.");
67
68 (*function_map)[computation_mesh] = dtensor_function;
69 }
70
71 int num_xla_meshes = 0;
72 for (const auto& it : *function_map) {
73 if (it.first.is_tpu_mesh()) num_xla_meshes += 1;
74 }
75
76 if (num_xla_meshes > 1)
77 return function.emitOpError(
78 "Multiple XLA computation clusters found. Only 1 XLA cluster for "
79 "DTensor computation is supported for now.");
80
81 return mlir::success();
82}
83
84// Creates Send/Recv ops to transfer TPUCompile program key from host
85// computation to XLA computation.
86mlir::LogicalResult CreateSendRecvOpsToTransferProgramKey(
87 const Mesh& mesh, mlir::ModuleOp module, mlir::func::FuncOp function,
88 mlir::OpBuilder::InsertPoint insertpoint,
89 mlir::TF::_TPUCompileMlirOp compile_op,
90 mlir::tf_device::LaunchOp compile_op_launch, int* num_send_recv,
91 mlir::Value* program_key_output) {
92 mlir::OpBuilder builder(module.getContext());
93 mlir::Value compilation_key = *compile_op.program().begin();
94 absl::Span<const std::string> local_devices = mesh.local_devices();
95
96 // Create tensor name mapping for each send/recv pair.
97 llvm::SmallDenseMap<int, std::string> device_key_map;
98 const int num_tpu_devices = local_devices.size();
99 device_key_map.reserve(num_tpu_devices);
100 for (int i = 0; i < num_tpu_devices; ++i) {
101 std::string tensor_name = absl::StrCat(kSendRecvKeyPrefix, *num_send_recv);
102 *num_send_recv += 1;
103 device_key_map.try_emplace(i, std::move(tensor_name));
104 }
105
106 // Create send op to send TPU program key from host computation to XLA
107 // computation.
108 builder.setInsertionPointAfter(compile_op);
109 for (int i = 0; i < num_tpu_devices; ++i) {
110 const std::string& tensor_name = device_key_map[i];
111 auto send = builder.create<mlir::TF::_HostSendOp>(
112 compile_op->getLoc(), compilation_key, tensor_name,
113 compile_op_launch.getDevice(),
114 /*send_device_incarnation=*/0, local_devices[i]);
115 send->setAttr("device", compile_op_launch.getDeviceAttr());
116 }
117
118 // Create Recv ops to receive program key from host to each xla device
119 // computation.
120 llvm::SmallVector<mlir::func::FuncOp, 4> compilation_key_functions;
121 compilation_key_functions.reserve(num_tpu_devices);
122 mlir::SymbolTable symbol_table(module);
123
124 // For receiving TPU program key from host, `recv_device` attribute depends
125 // on `device_id` argument and therefore cannot be known statically.
126 // Therefore, we use tf.Case op to select correct receive op depending on
127 // the device id value.
128 for (int i = 0; i < num_tpu_devices; ++i) {
129 auto func_type = mlir::FunctionType::get(
130 builder.getContext(), llvm::ArrayRef<mlir::Type>{},
131 llvm::ArrayRef<mlir::Type>{compilation_key.getType()});
132
133 mlir::func::FuncOp recv_select_fn = mlir::func::FuncOp::create(
134 compile_op.getLoc(),
135 llvm::formatv("recv_compile_key_{0}_{1}", i, *num_send_recv).str(),
136 func_type, llvm::ArrayRef<mlir::NamedAttribute>{});
137 symbol_table.insert(recv_select_fn);
138 *num_send_recv += 1;
139
140 mlir::Block* fn_block = recv_select_fn.addEntryBlock();
141 mlir::OpBuilder fn_builder = mlir::OpBuilder::atBlockEnd(fn_block);
142 auto recv = fn_builder.create<mlir::TF::_HostRecvOp>(
143 compile_op->getLoc(),
144 compilation_key.getType().cast<mlir::TensorType>(), device_key_map[i],
145 compile_op_launch.getDevice(), /*send_device_incarnation=*/0,
146 local_devices[i]);
147 recv->setAttr("device", builder.getStringAttr(local_devices[i]));
148
149 fn_builder.create<mlir::func::ReturnOp>(recv_select_fn.getLoc(),
150 recv.tensor());
151
152 compilation_key_functions.emplace_back(recv_select_fn);
153 }
154
155 // Create logic that receives program key.
156 builder.restoreInsertionPoint(insertpoint);
157 auto device_id = GetDeviceOrdinal(mesh, function.getLoc(), function, &builder,
158 /*return_int64_type=*/false);
159 if (!device_id.ok()) return function->emitOpError("Cannot get device id");
160
161 llvm::SmallVector<mlir::Attribute, 4> symbols;
162 for (auto& func : compilation_key_functions)
163 symbols.push_back(mlir::SymbolRefAttr::get(func));
164
165 // Create a TF::Case op that selects `values` based on `id`.
166 auto program_key = builder.create<mlir::TF::CaseOp>(
167 compile_op.getLoc(),
168 /*output=*/llvm::SmallVector<mlir::Type, 4>{compilation_key.getType()},
169 /*branch_index=*/*device_id,
170 /*input=*/llvm::ArrayRef<mlir::Value>{},
171 /*branches=*/builder.getArrayAttr(symbols),
172 /*is_stateless=*/builder.getBoolAttr(false));
173 *program_key_output = program_key.getResult(0);
174 return mlir::success();
175}
176
177struct CompilationKeyRecvInfo {
178 const Mesh& receiving_function_mesh;
179 mlir::func::FuncOp receiving_function;
180 mlir::OpBuilder::InsertPoint recv_insertion_point;
181 mlir::Value program_key;
182};
183
184// Broadcasts compilation key across meshes specified by `recv_info`. The
185// broadcasted compilation key is added to `program_key` of each vector
186// element of `recv_info`.
187mlir::LogicalResult SendRecvCompilationKey(
188 const Mesh& host_mesh, mlir::ModuleOp module,
189 mlir::TF::_TPUCompileMlirOp compile_op,
190 mlir::tf_device::LaunchOp compile_launch_op,
191 mlir::Operation* compilation_move_before, int* num_send_recv,
192 llvm::SmallVectorImpl<CompilationKeyRecvInfo>* recv_info) {
193 for (int i = 0; i < recv_info->size(); ++i) {
194 CompilationKeyRecvInfo& info = (*recv_info)[i];
195 // Create send/recv ops to transfer compilation key from receiving meshes.
196 mlir::Value program_key;
197 if (mlir::failed(CreateSendRecvOpsToTransferProgramKey(
198 info.receiving_function_mesh, module, info.receiving_function,
199 info.recv_insertion_point, compile_op, compile_launch_op,
200 num_send_recv, &program_key)))
201 return mlir::failure();
202
203 info.program_key = program_key;
204 }
205
206 return mlir::success();
207}
208
209mlir::LogicalResult HandleCompilationOps(
210 const llvm::SmallVectorImpl<
211 mlir::TF::_TPUCompileMlirPlaceholderProgramKeyOp>& compilation_key_ops,
212 std::map<Mesh, mlir::TF::StatefulPartitionedCallOp>& computation_map,
213 mlir::ModuleOp module, int* num_send_recv) {
214 // Identity XLA function and corresponding CPU functions to move compilation.
215 const auto xla_mesh = llvm::find_if(
216 computation_map, [](const auto& it) { return it.first.is_tpu_mesh(); });
217
218 if (xla_mesh == computation_map.end()) {
219 return module.emitOpError(
220 "Found TPUCompilationKey op but XLA computation does not exist.");
221 }
222
223 mlir::func::FuncOp tpu_function = xla_mesh->second.func();
224 mlir::func::FuncOp host_function;
225 Mesh host_mesh;
226 for (auto compilation_key : compilation_key_ops) {
227 auto parent_function =
228 compilation_key->getParentOfType<mlir::func::FuncOp>();
229
230 if (!host_function) {
231 host_function = parent_function;
232 auto mesh_it = llvm::find_if(computation_map, [&](auto& it) {
233 return it.second.f() == host_function.getSymName();
234 });
235 if (mesh_it == computation_map.end())
236 return compilation_key.emitOpError(
237 "cannot find host mesh for TPU computation.");
238
239 host_mesh = mesh_it->first;
240
241 } else {
242 // TODO(hongjunchoi): Handle the case when CopyToMesh is used with
243 // special topology approach. In this case there will be 2 host
244 // meshes/functions.
245 if (host_function != parent_function)
246 return compilation_key.emitOpError(
247 "Found multiple TPU host mesh functions. There must be at most one "
248 "TPU host function.");
249 }
250 }
251
252 // Identify TPUCompileOp to host side mesh.
253 llvm::SmallVector<mlir::TF::_TPUCompileMlirOp, 4> compile_ops;
254 tpu_function.walk(
255 [&](mlir::TF::_TPUCompileMlirOp op) { compile_ops.emplace_back(op); });
256
257 const int num_compilations = compile_ops.size();
258 if (num_compilations != 1)
259 return tpu_function.emitOpError(llvm::formatv(
260 "Expected exactly 1 compilation op for TPU computation. Found {0}",
261 num_compilations));
262
263 mlir::TF::_TPUCompileMlirOp compile_op = *compile_ops.begin();
264 mlir::Operation& first_host_op = host_function.getBody().front().front();
265 mlir::OpBuilder builder(&first_host_op);
266 mlir::OpBuilder::InsertPoint host_insertion_point =
267 builder.saveInsertionPoint();
268 mlir::Operation* compilation_move_before = &first_host_op;
269
270 // If host mesh has multiple local devices only conduct compilation for the
271 // first host device by creating If Op to only compile for host with device
272 // ordinal 0.
273 if (host_mesh.local_device_ids().size() > 1) {
274 auto device_ordinal_host = GetDeviceOrdinal(
275 host_mesh, compile_op.getLoc(),
276 first_host_op.getParentOfType<mlir::func::FuncOp>(), &builder);
277 if (!device_ordinal_host.ok())
278 return compile_op.emitOpError(
279 llvm::formatv("error while creating TPU compilation logic. {0}",
280 device_ordinal_host.status().error_message()));
281
282 mlir::Value predicate_host = builder.create<mlir::TF::EqualOp>(
283 compile_op.getLoc(), *device_ordinal_host,
284 CreateIntScalarConst(0, builder, compile_op.getLoc()),
285 /*incompatible_shape_error=*/builder.getBoolAttr(true));
286
287 // If op here contains send/recv and TPUCompile op that should not be pruned
288 // away. Therefore, we explicitly set the op to be stateful.
289 auto if_host = builder.create<mlir::TF::IfRegionOp>(
290 compile_op.getLoc(), llvm::SmallVector<mlir::Type, 4>{}, predicate_host,
291 /*is_stateless=*/builder.getBoolAttr(false),
292 GetUniqueControlflowFnName("compilation_host_then", builder),
293 GetUniqueControlflowFnName("compilation_host_else", builder));
294
295 // Create empty else branch region.
296 auto& host_else_branch = if_host.else_branch();
297 host_else_branch.push_back(new mlir::Block);
298 builder.setInsertionPointToEnd(&host_else_branch.front());
299 builder.create<mlir::TF::YieldOp>(
300 compile_op.getLoc(),
301 /*operands=*/llvm::ArrayRef<mlir::Value>{});
302
303 // Create then branch region with logic to compile TPU program and send
304 // program key to all TPU devices.
305 auto& host_then_branch = if_host.then_branch();
306 host_then_branch.push_back(new mlir::Block);
307 builder.setInsertionPointToEnd(&host_then_branch.front());
308 auto yield = builder.create<mlir::TF::YieldOp>(
309 compile_op.getLoc(),
310 /*operands=*/llvm::ArrayRef<mlir::Value>{});
311 compilation_move_before = yield;
312
313 builder.setInsertionPointAfter(if_host);
314 host_insertion_point = builder.saveInsertionPoint();
315 }
316
317 auto compile_launch_op =
318 compile_op->getParentOfType<mlir::tf_device::LaunchOp>();
319
320 // Move Compile op and compile succeeded assert ops to host function.
321 compile_launch_op->moveBefore(compilation_move_before);
322
323 for (mlir::Operation* user : compile_launch_op.getResult(0).getUsers())
324 user->getParentOfType<mlir::tf_device::LaunchOp>()->moveBefore(
325 compilation_move_before);
326
327 // Send and receive compilation key across meshes.
328 llvm::SmallVector<CompilationKeyRecvInfo, 4> compilation_key_recv_info;
329 builder.setInsertionPointToStart(&tpu_function.front());
330 auto device_insertion_point = builder.saveInsertionPoint();
331 compilation_key_recv_info.emplace_back(CompilationKeyRecvInfo{
332 xla_mesh->first, tpu_function, device_insertion_point, nullptr});
333
334 compilation_key_recv_info.emplace_back(CompilationKeyRecvInfo{
335 host_mesh, host_function, host_insertion_point, nullptr});
336
337 if (mlir::failed(SendRecvCompilationKey(
338 host_mesh, module, compile_op, compile_launch_op,
339 compilation_move_before, num_send_recv, &compilation_key_recv_info)))
340 return mlir::failure();
341
342 // Replace usages of TPU program key in host and device meshes.
343 mlir::Value device_program_key = compilation_key_recv_info[0].program_key;
344 tpu_function.walk([&](mlir::Operation* op) {
345 if (llvm::isa<mlir::TF::TPUExecuteOp,
346 mlir::TF::TPUExecuteAndUpdateVariablesOp>(op))
347 op->setOperand(op->getNumOperands() - 1, device_program_key);
348 });
349
350 // Remove placeholder CompilationKey ops and replace it's usages with output
351 // of TPUCompile op.
352 mlir::Value host_program_key = compilation_key_recv_info[1].program_key;
353 for (auto compilation_key_op : compilation_key_ops) {
354 compilation_key_op.replaceAllUsesWith(host_program_key);
355 compilation_key_op.erase();
356 }
357 return mlir::success();
358}
359
360// Pass to move TPUCompile/TPUCompileSucceededAssert op to host mesh computation
361// and add necessary send/recv ops to transfer TPU program key to TPU device
362// computation.
363struct DTensorMoveCompilationToHost
364 : public impl::DTensorMoveCompilationToHostBase<
365 DTensorMoveCompilationToHost> {
366 void runOnOperation() override {
367 mlir::MLIRContext& context = getContext();
368 mlir::OpBuilder builder(&context);
369 auto module = getOperation();
370
371 llvm::SmallVector<mlir::TF::_TPUCompileMlirPlaceholderProgramKeyOp, 4>
372 compilation_key_ops;
373 module.walk([&](mlir::TF::_TPUCompileMlirPlaceholderProgramKeyOp op) {
374 compilation_key_ops.emplace_back(op);
375 });
376
377 if (compilation_key_ops.empty()) return;
378
379 mlir::func::FuncOp main_func =
380 module.lookupSymbol<mlir::func::FuncOp>("main");
381 if (!main_func) return;
382
383 std::map<Mesh, mlir::TF::StatefulPartitionedCallOp> computation_map;
384 if (mlir::failed(
385 IdentifyAndValidateMeshComputations(main_func, &computation_map)))
386 return signalPassFailure();
387
388 int num_send_recv = 0;
389 if (mlir::failed(HandleCompilationOps(compilation_key_ops, computation_map,
390 module, &num_send_recv)))
391 return signalPassFailure();
392 };
393};
394
395} // namespace
396
397std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
398CreateDTensorMoveCompilationToHost() {
399 return std::make_unique<DTensorMoveCompilationToHost>();
400}
401
402} // namespace dtensor
403} // namespace tensorflow
404