1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <utility>
17
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22#include "mlir/IR/Attributes.h" // from @llvm-project
23#include "mlir/IR/Builders.h" // from @llvm-project
24#include "mlir/IR/Diagnostics.h" // from @llvm-project
25#include "mlir/IR/Operation.h" // from @llvm-project
26#include "mlir/IR/Value.h" // from @llvm-project
27#include "mlir/Pass/Pass.h" // from @llvm-project
28#include "mlir/Pass/PassManager.h" // from @llvm-project
29#include "mlir/Support/LogicalResult.h" // from @llvm-project
30#include "mlir/Transforms/Passes.h" // from @llvm-project
31#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
32#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
33#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
34#include "tensorflow/dtensor/cc/constants.h"
35#include "tensorflow/dtensor/cc/tensor_layout.h"
36#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
37#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
38#include "tensorflow/dtensor/mlir/layout_parsing.h"
39#include "tensorflow/dtensor/mlir/op_utils.h"
40#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
41
42namespace tensorflow {
43namespace dtensor {
44
45namespace {
46#define GEN_PASS_DEF_DTENSORCLUSTERFUNCTIONCONVERSION
47#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
48
49// Attach layouts for all the returned values so that custom device could get
50// layouts for the handles.
51mlir::LogicalResult AttachRetvalLayouts(
52 mlir::OpBuilder* builder, mlir::TF::StatefulPartitionedCallOp sp_call_op) {
53 // Find the FuncOp that the StatefulPartitionedCallOp is invoking.
54 mlir::SymbolRefAttr sym =
55 sp_call_op.getCallableForCallee().dyn_cast<mlir::SymbolRefAttr>();
56 if (!sym)
57 return sp_call_op.emitOpError(
58 "has no symbolRef for given StatefulPartitionedCallOp");
59
60 auto func = mlir::dyn_cast<mlir::func::FuncOp>(
61 mlir::SymbolTable::lookupNearestSymbolFrom(sp_call_op, sym));
62 if (!func)
63 return sp_call_op.emitOpError() << "found no FuncOp for symbol " << sym;
64
65 llvm::SmallVector<absl::optional<Layout>, 8> retvals_layouts;
66 retvals_layouts.reserve(func.getNumResults());
67 for (auto operand : func.front().getTerminator()->getOperands()) {
68 auto result_layout_or_status = ExtractLayoutFromOperand(operand);
69 if (!result_layout_or_status.ok()) {
70 return func.emitOpError("error while parsing result layout for function");
71 }
72
73 auto result_layout = result_layout_or_status.value();
74
75 // When function returns its arguments directly, layout information for the
76 // return value of `func` may be only obtainable by looking at it's callsite
77 // operations. In that case, query the input layouts for function callsite
78 // operations for layout information.
79 if (!result_layout) {
80 if (auto block_arg = operand.dyn_cast<mlir::BlockArgument>()) {
81 auto layout_or_status = ExtractLayoutFromOperand(
82 sp_call_op.getOperand(block_arg.getArgNumber()));
83 if (!layout_or_status.ok())
84 return func.emitOpError(
85 "error while parsing result layout for function");
86 result_layout = std::move(layout_or_status.value());
87 }
88
89 if (!result_layout)
90 return func.emitOpError(
91 llvm::formatv("missing result layout attribute for function. All "
92 "DTensor functions "
93 "must have layouts for its results."));
94 }
95 retvals_layouts.emplace_back(result_layout.value());
96 }
97
98 // Note that we set this unconditionally - retvals_layout could be empty, but
99 // that is fine and we will have an empty _layout for the
100 // StatefulPartitionedCallOp. This is fine as for op without return values,
101 // all we need is a placeholder layout so that no special case is needed in
102 // dtensor_device.
103 SetLayoutOnOp(sp_call_op,
104 absl::Span<const absl::optional<Layout>>(
105 retvals_layouts.data(), retvals_layouts.size()));
106
107 return mlir::success();
108}
109
110// Add an anotation to skip xla compilation for VarHandleOp and
111// DestroyResourceOp.
112void MaybeSkipXlaCompilation(mlir::OpBuilder* builder,
113 mlir::Operation* call_op) {
114 auto function = MaybeFindFunction(call_op);
115 const auto& body_ops = function->getBody().front().without_terminator();
116 // VarHandleOp and DestroyResourceOp run on op-by-op mode, so there is only
117 // one op in the function body.
118 if (std::distance(std::begin(body_ops), std::end(body_ops)) == 1 &&
119 llvm::isa<mlir::TF::VarHandleOp, mlir::TF::DestroyResourceOp>(
120 body_ops.begin())) {
121 call_op->setAttr(kSkipXlaCompilation, builder->getBoolAttr(true));
122 }
123}
124
125mlir::LogicalResult ReplaceClusterWithPartitionCallOp(
126 mlir::OpBuilder* builder, mlir::tf_device::ClusterFuncOp cluster_func) {
127 auto mesh_attr = cluster_func->getAttrOfType<mlir::StringAttr>(kMeshAttr);
128 if (!mesh_attr)
129 return cluster_func.emitOpError()
130 << "requires " << llvm::StringRef(kMeshAttr) << " attribute";
131
132 llvm::SmallVector<mlir::Type, 8> output_types{
133 cluster_func.getResultTypes().begin(),
134 cluster_func.getResultTypes().end()};
135
136 llvm::StringRef function_name = cluster_func.func();
137
138 builder->setInsertionPoint(cluster_func);
139 auto call_op = builder->create<mlir::TF::StatefulPartitionedCallOp>(
140 cluster_func.getLoc(), output_types, cluster_func.getOperands(),
141 function_name, mesh_attr, /*config_proto=*/builder->getStringAttr(""),
142 /*executor_type=*/builder->getStringAttr(""));
143
144 MaybeSkipXlaCompilation(builder, call_op);
145
146 if (mlir::failed(ValidateMetadataAttributes(cluster_func)))
147 return mlir::failure();
148
149 // All attributes beginning with `_` is validate, perform copy.
150 mlir::TF::CopyUnderscoredAttributes(cluster_func, call_op);
151
152 cluster_func.replaceAllUsesWith(call_op.getResults());
153 cluster_func.erase();
154
155 return AttachRetvalLayouts(builder, call_op);
156}
157
158// MLIR pass that converts tf_device.cluster_func to TF partitioned call
159// op with device mesh config added to `config` attribute.
160struct DTensorClusterFunctionConversion
161 : public impl::DTensorClusterFunctionConversionBase<
162 DTensorClusterFunctionConversion> {
163 void runOnOperation() override {
164 mlir::MLIRContext& context = getContext();
165
166 // Find all tf_device.ClusterFunc ops and visit them in post order. This
167 // order guarantees that ops in function definition is visited before
168 // function call site operations. When python graph includes tf.functions
169 // this leads to nested tf_device.ClusterFunc ops. As we infer the layout
170 // of function call operations with layout attached to return values in the
171 // function definition, ClusterFunc op in nested/inner functions must be
172 // visited before ClusterFunc op in outer functions.
173 llvm::SmallVector<mlir::tf_device::ClusterFuncOp, 8> clusters;
174 getOperation().walk([&](mlir::tf_device::ClusterFuncOp cluster_func) {
175 clusters.emplace_back(cluster_func);
176 });
177
178 mlir::OpBuilder op_builder(&context);
179 for (auto cluster_func : llvm::reverse(clusters)) {
180 if (mlir::failed(
181 ReplaceClusterWithPartitionCallOp(&op_builder, cluster_func))) {
182 return signalPassFailure();
183 }
184 }
185 };
186};
187
188} // namespace
189
190std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
191CreateDTensorClusterFunctionConversion() {
192 return std::make_unique<DTensorClusterFunctionConversion>();
193}
194
195} // namespace dtensor
196} // namespace tensorflow
197