1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <string> |
17 | #include <utility> |
18 | |
19 | #include "llvm/ADT/APInt.h" |
20 | #include "llvm/ADT/ArrayRef.h" |
21 | #include "llvm/ADT/DenseMap.h" |
22 | #include "llvm/ADT/STLExtras.h" |
23 | #include "llvm/ADT/SetVector.h" |
24 | #include "llvm/ADT/SmallVector.h" |
25 | #include "llvm/ADT/StringRef.h" |
26 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
27 | #include "mlir/IR/Attributes.h" // from @llvm-project |
28 | #include "mlir/IR/Builders.h" // from @llvm-project |
29 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
30 | #include "mlir/IR/Operation.h" // from @llvm-project |
31 | #include "mlir/IR/Types.h" // from @llvm-project |
32 | #include "mlir/IR/Visitors.h" // from @llvm-project |
33 | #include "mlir/Pass/Pass.h" // from @llvm-project |
34 | #include "mlir/Pass/PassManager.h" // from @llvm-project |
35 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
36 | #include "mlir/Transforms/Passes.h" // from @llvm-project |
37 | #include "mlir/Transforms/RegionUtils.h" // from @llvm-project |
38 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
39 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
40 | #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" |
41 | #include "tensorflow/compiler/xla/client/sharding_builder.h" |
42 | #include "tensorflow/dtensor/cc/constants.h" |
43 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
44 | #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h" |
45 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
46 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
47 | #include "tensorflow/dtensor/mlir/op_utils.h" |
48 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
49 | |
50 | namespace tensorflow { |
51 | namespace dtensor { |
52 | |
53 | namespace { |
54 | #define GEN_PASS_DEF_DTENSORTPUINTEGRATION |
55 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
56 | |
57 | // Adds metadata used in TPU Compilation to `cluster` as attributes. |
58 | void AddMetadataToTPUCluster(const Mesh& mesh_config, |
59 | mlir::tf_device::ClusterOp cluster, |
60 | mlir::OpBuilder* builder) { |
61 | cluster->setAttr("_tpu_replicate" , |
62 | builder->getStringAttr(mesh_config.ToString())); |
63 | cluster->setAttr("step_marker_location" , builder->getStringAttr("" )); |
64 | cluster->setAttr("padding_map" , builder->getArrayAttr({})); |
65 | cluster->setAttr("use_spmd_for_xla_partitioning" , |
66 | builder->getBoolAttr(false)); |
67 | cluster->setAttr(tensorflow::kTopologyAttr, builder->getStringAttr("" )); |
68 | cluster->setAttr(tensorflow::kDeviceAssignmentAttr, |
69 | builder->getArrayAttr({})); |
70 | cluster->setAttr(tensorflow::kNumCoresPerReplicaAttr, |
71 | builder->getI64IntegerAttr(1)); |
72 | } |
73 | |
74 | // TODO(hongjunchoi): Implement cluster inlining pass so that there are no |
75 | // nested tf_device.cluster ops with same mesh. |
76 | void IdentifyTPUFunctions( |
77 | mlir::ModuleOp module, llvm::SmallVectorImpl<Mesh>* tpu_meshs, |
78 | llvm::SmallVectorImpl<mlir::TF::StatefulPartitionedCallOp>* tpu_functions) { |
79 | auto main_func = module.lookupSymbol<mlir::func::FuncOp>("main" ); |
80 | if (!main_func) return; |
81 | |
82 | for (auto call : main_func.getOps<mlir::TF::StatefulPartitionedCallOp>()) { |
83 | auto mesh_or_status = Mesh::FromString(string(call.config())); |
84 | // Function calls created by end users instead of being converted from |
85 | // tf_device.cluster do not have a serialized mesh as a config attribute. We |
86 | // ignore the error returned from parsing in this case. |
87 | if (!mesh_or_status.ok()) return; |
88 | bool skip_xla_compilation = false; |
89 | if (call->hasAttr(kSkipXlaCompilation)) { |
90 | skip_xla_compilation = |
91 | call->getAttrOfType<mlir::BoolAttr>(kSkipXlaCompilation).getValue(); |
92 | } |
93 | if (mesh_or_status->is_tpu_mesh() && !skip_xla_compilation) { |
94 | tpu_functions->emplace_back(call); |
95 | tpu_meshs->emplace_back(std::move(mesh_or_status.value())); |
96 | } |
97 | } |
98 | } |
99 | |
100 | mlir::LogicalResult CreateTPUCluster( |
101 | mlir::TF::StatefulPartitionedCallOp tpu_call, mlir::OpBuilder* builder, |
102 | mlir::tf_device::ClusterOp* newly_created_cluster) { |
103 | auto function = MaybeFindFunction(tpu_call); |
104 | if (!function) |
105 | return tpu_call.emitOpError( |
106 | "failed during TPU Integration as Func op TPU mesh was not found" ); |
107 | |
108 | auto& function_block = function->getCallableRegion()->front(); |
109 | builder->setInsertionPointToStart(&function_block); |
110 | |
111 | auto cluster = builder->create<mlir::tf_device::ClusterOp>( |
112 | tpu_call.getLoc(), function->getCallableResults()); |
113 | cluster.getBody().push_back(new mlir::Block); |
114 | |
115 | auto& function_body = function_block.getOperations(); |
116 | cluster.GetBody().getOperations().splice( |
117 | cluster.GetBody().getOperations().begin(), function_body, |
118 | std::next(function_body.begin()), std::prev(function_body.end())); |
119 | |
120 | builder->setInsertionPointToEnd(&cluster.GetBody()); |
121 | mlir::Operation* function_block_terminator = function_block.getTerminator(); |
122 | builder->create<mlir::tf_device::ReturnOp>( |
123 | tpu_call.getLoc(), function_block_terminator->getOperands()); |
124 | |
125 | function_block_terminator->setOperands(cluster.getResults()); |
126 | |
127 | *newly_created_cluster = cluster; |
128 | return mlir::success(); |
129 | } |
130 | |
131 | struct DTensorTPUIntegration |
132 | : public impl::DTensorTPUIntegrationBase<DTensorTPUIntegration> { |
133 | void getDependentDialects(mlir::DialectRegistry& registry) const override { |
134 | registry.insert<mlir::dtensor::DTensorDialect>(); |
135 | registry.insert<mlir::tf_device::TensorFlowDeviceDialect>(); |
136 | } |
137 | |
138 | void runOnOperation() override { |
139 | mlir::MLIRContext& context = getContext(); |
140 | mlir::OpBuilder op_builder(&context); |
141 | auto module = getOperation(); |
142 | llvm::SmallVector<mlir::TF::StatefulPartitionedCallOp, 4> tpu_functions; |
143 | llvm::SmallVector<Mesh, 4> tpu_meshes; |
144 | IdentifyTPUFunctions(module, &tpu_meshes, &tpu_functions); |
145 | |
146 | for (auto tpu_function_and_mesh : llvm::zip(tpu_meshes, tpu_functions)) { |
147 | mlir::tf_device::ClusterOp cluster; |
148 | |
149 | if (mlir::failed(CreateTPUCluster(std::get<1>(tpu_function_and_mesh), |
150 | &op_builder, &cluster))) |
151 | return signalPassFailure(); |
152 | |
153 | AddMetadataToTPUCluster(std::get<0>(tpu_function_and_mesh), cluster, |
154 | &op_builder); |
155 | } |
156 | }; |
157 | }; |
158 | |
159 | } // namespace |
160 | |
161 | std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> |
162 | CreateDTensorTPUIntegration() { |
163 | return std::make_unique<DTensorTPUIntegration>(); |
164 | } |
165 | |
166 | } // namespace dtensor |
167 | } // namespace tensorflow |
168 | |