1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <utility> |
17 | |
18 | #include "llvm/ADT/STLExtras.h" |
19 | #include "llvm/ADT/SmallVector.h" |
20 | #include "llvm/Support/FormatVariadic.h" |
21 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
22 | #include "mlir/IR/Attributes.h" // from @llvm-project |
23 | #include "mlir/IR/Builders.h" // from @llvm-project |
24 | #include "mlir/IR/Diagnostics.h" // from @llvm-project |
25 | #include "mlir/IR/Operation.h" // from @llvm-project |
26 | #include "mlir/IR/Value.h" // from @llvm-project |
27 | #include "mlir/Pass/Pass.h" // from @llvm-project |
28 | #include "mlir/Pass/PassManager.h" // from @llvm-project |
29 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
30 | #include "mlir/Transforms/Passes.h" // from @llvm-project |
31 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
32 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
33 | #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" |
34 | #include "tensorflow/dtensor/cc/constants.h" |
35 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
36 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
37 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
38 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
39 | #include "tensorflow/dtensor/mlir/op_utils.h" |
40 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
41 | |
42 | namespace tensorflow { |
43 | namespace dtensor { |
44 | |
45 | namespace { |
46 | #define GEN_PASS_DEF_DTENSORCLUSTERFUNCTIONCONVERSION |
47 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
48 | |
49 | // Attach layouts for all the returned values so that custom device could get |
50 | // layouts for the handles. |
51 | mlir::LogicalResult AttachRetvalLayouts( |
52 | mlir::OpBuilder* builder, mlir::TF::StatefulPartitionedCallOp sp_call_op) { |
53 | // Find the FuncOp that the StatefulPartitionedCallOp is invoking. |
54 | mlir::SymbolRefAttr sym = |
55 | sp_call_op.getCallableForCallee().dyn_cast<mlir::SymbolRefAttr>(); |
56 | if (!sym) |
57 | return sp_call_op.emitOpError( |
58 | "has no symbolRef for given StatefulPartitionedCallOp" ); |
59 | |
60 | auto func = mlir::dyn_cast<mlir::func::FuncOp>( |
61 | mlir::SymbolTable::lookupNearestSymbolFrom(sp_call_op, sym)); |
62 | if (!func) |
63 | return sp_call_op.emitOpError() << "found no FuncOp for symbol " << sym; |
64 | |
65 | llvm::SmallVector<absl::optional<Layout>, 8> retvals_layouts; |
66 | retvals_layouts.reserve(func.getNumResults()); |
67 | for (auto operand : func.front().getTerminator()->getOperands()) { |
68 | auto result_layout_or_status = ExtractLayoutFromOperand(operand); |
69 | if (!result_layout_or_status.ok()) { |
70 | return func.emitOpError("error while parsing result layout for function" ); |
71 | } |
72 | |
73 | auto result_layout = result_layout_or_status.value(); |
74 | |
75 | // When function returns its arguments directly, layout information for the |
76 | // return value of `func` may be only obtainable by looking at it's callsite |
77 | // operations. In that case, query the input layouts for function callsite |
78 | // operations for layout information. |
79 | if (!result_layout) { |
80 | if (auto block_arg = operand.dyn_cast<mlir::BlockArgument>()) { |
81 | auto layout_or_status = ExtractLayoutFromOperand( |
82 | sp_call_op.getOperand(block_arg.getArgNumber())); |
83 | if (!layout_or_status.ok()) |
84 | return func.emitOpError( |
85 | "error while parsing result layout for function" ); |
86 | result_layout = std::move(layout_or_status.value()); |
87 | } |
88 | |
89 | if (!result_layout) |
90 | return func.emitOpError( |
91 | llvm::formatv("missing result layout attribute for function. All " |
92 | "DTensor functions " |
93 | "must have layouts for its results." )); |
94 | } |
95 | retvals_layouts.emplace_back(result_layout.value()); |
96 | } |
97 | |
98 | // Note that we set this unconditionally - retvals_layout could be empty, but |
99 | // that is fine and we will have an empty _layout for the |
100 | // StatefulPartitionedCallOp. This is fine as for op without return values, |
101 | // all we need is a placeholder layout so that no special case is needed in |
102 | // dtensor_device. |
103 | SetLayoutOnOp(sp_call_op, |
104 | absl::Span<const absl::optional<Layout>>( |
105 | retvals_layouts.data(), retvals_layouts.size())); |
106 | |
107 | return mlir::success(); |
108 | } |
109 | |
110 | // Add an anotation to skip xla compilation for VarHandleOp and |
111 | // DestroyResourceOp. |
112 | void MaybeSkipXlaCompilation(mlir::OpBuilder* builder, |
113 | mlir::Operation* call_op) { |
114 | auto function = MaybeFindFunction(call_op); |
115 | const auto& body_ops = function->getBody().front().without_terminator(); |
116 | // VarHandleOp and DestroyResourceOp run on op-by-op mode, so there is only |
117 | // one op in the function body. |
118 | if (std::distance(std::begin(body_ops), std::end(body_ops)) == 1 && |
119 | llvm::isa<mlir::TF::VarHandleOp, mlir::TF::DestroyResourceOp>( |
120 | body_ops.begin())) { |
121 | call_op->setAttr(kSkipXlaCompilation, builder->getBoolAttr(true)); |
122 | } |
123 | } |
124 | |
125 | mlir::LogicalResult ReplaceClusterWithPartitionCallOp( |
126 | mlir::OpBuilder* builder, mlir::tf_device::ClusterFuncOp cluster_func) { |
127 | auto mesh_attr = cluster_func->getAttrOfType<mlir::StringAttr>(kMeshAttr); |
128 | if (!mesh_attr) |
129 | return cluster_func.emitOpError() |
130 | << "requires " << llvm::StringRef(kMeshAttr) << " attribute" ; |
131 | |
132 | llvm::SmallVector<mlir::Type, 8> output_types{ |
133 | cluster_func.getResultTypes().begin(), |
134 | cluster_func.getResultTypes().end()}; |
135 | |
136 | llvm::StringRef function_name = cluster_func.func(); |
137 | |
138 | builder->setInsertionPoint(cluster_func); |
139 | auto call_op = builder->create<mlir::TF::StatefulPartitionedCallOp>( |
140 | cluster_func.getLoc(), output_types, cluster_func.getOperands(), |
141 | function_name, mesh_attr, /*config_proto=*/builder->getStringAttr("" ), |
142 | /*executor_type=*/builder->getStringAttr("" )); |
143 | |
144 | MaybeSkipXlaCompilation(builder, call_op); |
145 | |
146 | if (mlir::failed(ValidateMetadataAttributes(cluster_func))) |
147 | return mlir::failure(); |
148 | |
149 | // All attributes beginning with `_` is validate, perform copy. |
150 | mlir::TF::CopyUnderscoredAttributes(cluster_func, call_op); |
151 | |
152 | cluster_func.replaceAllUsesWith(call_op.getResults()); |
153 | cluster_func.erase(); |
154 | |
155 | return AttachRetvalLayouts(builder, call_op); |
156 | } |
157 | |
158 | // MLIR pass that converts tf_device.cluster_func to TF partitioned call |
159 | // op with device mesh config added to `config` attribute. |
160 | struct DTensorClusterFunctionConversion |
161 | : public impl::DTensorClusterFunctionConversionBase< |
162 | DTensorClusterFunctionConversion> { |
163 | void runOnOperation() override { |
164 | mlir::MLIRContext& context = getContext(); |
165 | |
166 | // Find all tf_device.ClusterFunc ops and visit them in post order. This |
167 | // order guarantees that ops in function definition is visited before |
168 | // function call site operations. When python graph includes tf.functions |
169 | // this leads to nested tf_device.ClusterFunc ops. As we infer the layout |
170 | // of function call operations with layout attached to return values in the |
171 | // function definition, ClusterFunc op in nested/inner functions must be |
172 | // visited before ClusterFunc op in outer functions. |
173 | llvm::SmallVector<mlir::tf_device::ClusterFuncOp, 8> clusters; |
174 | getOperation().walk([&](mlir::tf_device::ClusterFuncOp cluster_func) { |
175 | clusters.emplace_back(cluster_func); |
176 | }); |
177 | |
178 | mlir::OpBuilder op_builder(&context); |
179 | for (auto cluster_func : llvm::reverse(clusters)) { |
180 | if (mlir::failed( |
181 | ReplaceClusterWithPartitionCallOp(&op_builder, cluster_func))) { |
182 | return signalPassFailure(); |
183 | } |
184 | } |
185 | }; |
186 | }; |
187 | |
188 | } // namespace |
189 | |
190 | std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> |
191 | CreateDTensorClusterFunctionConversion() { |
192 | return std::make_unique<DTensorClusterFunctionConversion>(); |
193 | } |
194 | |
195 | } // namespace dtensor |
196 | } // namespace tensorflow |
197 | |