1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <string>
17#include <utility>
18
19#include "llvm/ADT/APInt.h"
20#include "llvm/ADT/ArrayRef.h"
21#include "llvm/ADT/DenseMap.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/SetVector.h"
24#include "llvm/ADT/SmallVector.h"
25#include "llvm/ADT/StringRef.h"
26#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27#include "mlir/IR/Attributes.h" // from @llvm-project
28#include "mlir/IR/Builders.h" // from @llvm-project
29#include "mlir/IR/BuiltinOps.h" // from @llvm-project
30#include "mlir/IR/Operation.h" // from @llvm-project
31#include "mlir/IR/Types.h" // from @llvm-project
32#include "mlir/IR/Visitors.h" // from @llvm-project
33#include "mlir/Pass/Pass.h" // from @llvm-project
34#include "mlir/Pass/PassManager.h" // from @llvm-project
35#include "mlir/Support/LogicalResult.h" // from @llvm-project
36#include "mlir/Transforms/Passes.h" // from @llvm-project
37#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
38#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
39#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
41#include "tensorflow/compiler/xla/client/sharding_builder.h"
42#include "tensorflow/dtensor/cc/constants.h"
43#include "tensorflow/dtensor/cc/tensor_layout.h"
44#include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
45#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
46#include "tensorflow/dtensor/mlir/layout_parsing.h"
47#include "tensorflow/dtensor/mlir/op_utils.h"
48#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
49
50namespace tensorflow {
51namespace dtensor {
52
53namespace {
54#define GEN_PASS_DEF_DTENSORTPUINTEGRATION
55#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
56
57// Adds metadata used in TPU Compilation to `cluster` as attributes.
58void AddMetadataToTPUCluster(const Mesh& mesh_config,
59 mlir::tf_device::ClusterOp cluster,
60 mlir::OpBuilder* builder) {
61 cluster->setAttr("_tpu_replicate",
62 builder->getStringAttr(mesh_config.ToString()));
63 cluster->setAttr("step_marker_location", builder->getStringAttr(""));
64 cluster->setAttr("padding_map", builder->getArrayAttr({}));
65 cluster->setAttr("use_spmd_for_xla_partitioning",
66 builder->getBoolAttr(false));
67 cluster->setAttr(tensorflow::kTopologyAttr, builder->getStringAttr(""));
68 cluster->setAttr(tensorflow::kDeviceAssignmentAttr,
69 builder->getArrayAttr({}));
70 cluster->setAttr(tensorflow::kNumCoresPerReplicaAttr,
71 builder->getI64IntegerAttr(1));
72}
73
74// TODO(hongjunchoi): Implement cluster inlining pass so that there are no
75// nested tf_device.cluster ops with same mesh.
76void IdentifyTPUFunctions(
77 mlir::ModuleOp module, llvm::SmallVectorImpl<Mesh>* tpu_meshs,
78 llvm::SmallVectorImpl<mlir::TF::StatefulPartitionedCallOp>* tpu_functions) {
79 auto main_func = module.lookupSymbol<mlir::func::FuncOp>("main");
80 if (!main_func) return;
81
82 for (auto call : main_func.getOps<mlir::TF::StatefulPartitionedCallOp>()) {
83 auto mesh_or_status = Mesh::FromString(string(call.config()));
84 // Function calls created by end users instead of being converted from
85 // tf_device.cluster do not have a serialized mesh as a config attribute. We
86 // ignore the error returned from parsing in this case.
87 if (!mesh_or_status.ok()) return;
88 bool skip_xla_compilation = false;
89 if (call->hasAttr(kSkipXlaCompilation)) {
90 skip_xla_compilation =
91 call->getAttrOfType<mlir::BoolAttr>(kSkipXlaCompilation).getValue();
92 }
93 if (mesh_or_status->is_tpu_mesh() && !skip_xla_compilation) {
94 tpu_functions->emplace_back(call);
95 tpu_meshs->emplace_back(std::move(mesh_or_status.value()));
96 }
97 }
98}
99
100mlir::LogicalResult CreateTPUCluster(
101 mlir::TF::StatefulPartitionedCallOp tpu_call, mlir::OpBuilder* builder,
102 mlir::tf_device::ClusterOp* newly_created_cluster) {
103 auto function = MaybeFindFunction(tpu_call);
104 if (!function)
105 return tpu_call.emitOpError(
106 "failed during TPU Integration as Func op TPU mesh was not found");
107
108 auto& function_block = function->getCallableRegion()->front();
109 builder->setInsertionPointToStart(&function_block);
110
111 auto cluster = builder->create<mlir::tf_device::ClusterOp>(
112 tpu_call.getLoc(), function->getCallableResults());
113 cluster.getBody().push_back(new mlir::Block);
114
115 auto& function_body = function_block.getOperations();
116 cluster.GetBody().getOperations().splice(
117 cluster.GetBody().getOperations().begin(), function_body,
118 std::next(function_body.begin()), std::prev(function_body.end()));
119
120 builder->setInsertionPointToEnd(&cluster.GetBody());
121 mlir::Operation* function_block_terminator = function_block.getTerminator();
122 builder->create<mlir::tf_device::ReturnOp>(
123 tpu_call.getLoc(), function_block_terminator->getOperands());
124
125 function_block_terminator->setOperands(cluster.getResults());
126
127 *newly_created_cluster = cluster;
128 return mlir::success();
129}
130
131struct DTensorTPUIntegration
132 : public impl::DTensorTPUIntegrationBase<DTensorTPUIntegration> {
133 void getDependentDialects(mlir::DialectRegistry& registry) const override {
134 registry.insert<mlir::dtensor::DTensorDialect>();
135 registry.insert<mlir::tf_device::TensorFlowDeviceDialect>();
136 }
137
138 void runOnOperation() override {
139 mlir::MLIRContext& context = getContext();
140 mlir::OpBuilder op_builder(&context);
141 auto module = getOperation();
142 llvm::SmallVector<mlir::TF::StatefulPartitionedCallOp, 4> tpu_functions;
143 llvm::SmallVector<Mesh, 4> tpu_meshes;
144 IdentifyTPUFunctions(module, &tpu_meshes, &tpu_functions);
145
146 for (auto tpu_function_and_mesh : llvm::zip(tpu_meshes, tpu_functions)) {
147 mlir::tf_device::ClusterOp cluster;
148
149 if (mlir::failed(CreateTPUCluster(std::get<1>(tpu_function_and_mesh),
150 &op_builder, &cluster)))
151 return signalPassFailure();
152
153 AddMetadataToTPUCluster(std::get<0>(tpu_function_and_mesh), cluster,
154 &op_builder);
155 }
156 };
157};
158
159} // namespace
160
161std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
162CreateDTensorTPUIntegration() {
163 return std::make_unique<DTensorTPUIntegration>();
164}
165
166} // namespace dtensor
167} // namespace tensorflow
168