1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "llvm/ADT/SmallVector.h"
17#include "llvm/ADT/StringRef.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
19#include "mlir/IR/Attributes.h" // from @llvm-project
20#include "mlir/IR/Builders.h" // from @llvm-project
21#include "mlir/IR/BuiltinOps.h" // from @llvm-project
22#include "mlir/IR/Operation.h" // from @llvm-project
23#include "mlir/IR/SymbolTable.h" // from @llvm-project
24#include "mlir/IR/Visitors.h" // from @llvm-project
25#include "mlir/Pass/Pass.h" // from @llvm-project
26#include "mlir/Pass/PassManager.h" // from @llvm-project
27#include "mlir/Support/LogicalResult.h" // from @llvm-project
28#include "mlir/Transforms/Passes.h" // from @llvm-project
29#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
31#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
32#include "tensorflow/dtensor/mlir/device_utils.h"
33#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
34#include "tensorflow/dtensor/mlir/op_utils.h"
35#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
36
37namespace tensorflow {
38namespace dtensor {
39
40namespace {
41#define GEN_PASS_DEF_DTENSORPROPAGATEDEVICEIDTOFUNCTIONARGS
42#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
43
44// Holds information on functions to rewrite. `function` is the function
45// definition or function that needs to be updated and `callsite_ops` holds a
46// list of ops that calls the `function`.
47struct FunctionToChangeInfo {
48 mlir::func::FuncOp function;
49 llvm::SmallVector<mlir::Operation*, 4> callsite_ops;
50};
51
52// Finds all functions in graph that is not a public functions and retrieves
53// their callsite operations.
54llvm::SmallVector<FunctionToChangeInfo, 4> FindFunctionsToRewrite(
55 mlir::ModuleOp module) {
56 llvm::SmallVector<FunctionToChangeInfo, 4> functions_to_change;
57 module.walk([&](mlir::Operation* op) {
58 if (!llvm::isa<mlir::TF::StatefulPartitionedCallOp,
59 mlir::TF::PartitionedCallOp>(op))
60 return;
61
62 // Extract function symbol from PartitionedCall or StatefulPartitionedCall
63 // op.
64 llvm::StringRef symbol;
65 if (auto call_op =
66 llvm::dyn_cast<mlir::TF::StatefulPartitionedCallOp>(op)) {
67 symbol = call_op.f();
68 } else {
69 auto symbol_ref = llvm::dyn_cast<mlir::TF::PartitionedCallOp>(op).f();
70 if (!symbol_ref.isa<mlir::FlatSymbolRefAttr>()) return;
71 symbol = symbol_ref.getRootReference().getValue();
72 }
73
74 // If function definition could be found, then extract all function usages.
75 auto function = MaybeFindFunction(op);
76 if (!function || function->isPublic()) return;
77
78 auto function_uses = mlir::SymbolTable::getSymbolUses(
79 mlir::StringAttr::get(module.getContext(), symbol),
80 &module.getBodyRegion());
81 if (!function_uses) return;
82
83 llvm::SmallVector<mlir::Operation*, 4> function_use_ops;
84 for (auto function_use : *function_uses)
85 function_use_ops.emplace_back(function_use.getUser());
86
87 functions_to_change.emplace_back(
88 FunctionToChangeInfo{function.value(), function_use_ops});
89 });
90
91 return functions_to_change;
92}
93
94// Rewrites function such that 0th argument of type `type` is added to
95// `function`.
96void PrependArgumentToFunction(mlir::func::FuncOp function, mlir::Type type,
97 mlir::OpBuilder* builder) {
98 auto& function_body = function.front();
99 function_body.insertArgument(static_cast<unsigned>(0), type,
100 function.getLoc());
101 auto new_argument_types =
102 llvm::to_vector<4>(function_body.getArgumentTypes());
103 function.setType(
104 mlir::FunctionType::get(builder->getContext(), new_argument_types,
105 function.getFunctionType().getResults()));
106}
107
108// Rewrites function callsites ops. As function signatures are already updated,
109// simply add 0th argument of the parent function to 0th operand of the callsite
110// operation.
111mlir::LogicalResult PrependDeviceIdToCallsites(mlir::OpBuilder* builder,
112 mlir::Operation* op) {
113 auto device_id_or_status = DeviceId(op);
114 if (!device_id_or_status.ok())
115 return op->emitOpError(
116 "Failed during PropagateDeviceIdToFunctionArgs pass. All functions "
117 "must have device id as 0th argument.");
118
119 auto new_operands = llvm::to_vector<4>(op->getOperands());
120 new_operands.insert(new_operands.begin(), device_id_or_status.value());
121
122 builder->setInsertionPoint(op);
123 mlir::Operation* new_call = nullptr;
124 if (auto stateful_partitioned_call =
125 llvm::dyn_cast<mlir::TF::StatefulPartitionedCallOp>(op)) {
126 new_call = builder->create<mlir::TF::StatefulPartitionedCallOp>(
127 op->getLoc(), op->getResultTypes(), new_operands,
128 stateful_partitioned_call.f(), stateful_partitioned_call.config(),
129 stateful_partitioned_call.config_proto(),
130 stateful_partitioned_call.executor_type());
131 } else {
132 auto partitioned_call = llvm::cast<mlir::TF::PartitionedCallOp>(op);
133 new_call = builder->create<mlir::TF::PartitionedCallOp>(
134 op->getLoc(), op->getResultTypes(), new_operands, partitioned_call.f(),
135 partitioned_call.config(), partitioned_call.config_proto(),
136 partitioned_call.executor_type());
137 }
138
139 for (auto results : llvm::zip(op->getResults(), new_call->getResults()))
140 std::get<0>(results).replaceAllUsesWith(std::get<1>(results));
141
142 op->erase();
143
144 return mlir::success();
145}
146
147// Pass that rewrites the functions in graph so that 0th argument of the main
148// function (i.e. device_id) is present on all functions in the graph.
149struct DTensorPropagateDeviceIdToFunctionArgs
150 : public impl::DTensorPropagateDeviceIdToFunctionArgsBase<
151 DTensorPropagateDeviceIdToFunctionArgs> {
152 void runOnOperation() override {
153 mlir::MLIRContext& context = getContext();
154 auto module = getOperation();
155 mlir::OpBuilder builder(&context);
156
157 // Extracts device id argument from main function.
158 mlir::func::FuncOp main_func =
159 module.lookupSymbol<mlir::func::FuncOp>("main");
160 auto device_id_or_status = DeviceId(&main_func.getBody().front().front());
161 if (!device_id_or_status.ok()) {
162 main_func.emitOpError(
163 "Error in PropagateDeviceIdToFunctionArgs pass. Main function must "
164 "have device id as 0th function argument.");
165 return signalPassFailure();
166 }
167 auto device_id_from_main_function = device_id_or_status.value();
168 // First iterate through all functions to rewrite and update the signatures
169 // first.
170 const auto functions_to_update = FindFunctionsToRewrite(module);
171 for (const auto& function_to_update : functions_to_update)
172 PrependArgumentToFunction(function_to_update.function,
173 device_id_from_main_function.getType(),
174 &builder);
175
176 // Once all function signatures are updated, rewrite the callsite ops.
177 for (const auto& function_to_update : functions_to_update) {
178 for (auto call_site_op : function_to_update.callsite_ops) {
179 if (mlir::failed(PrependDeviceIdToCallsites(&builder, call_site_op)))
180 return signalPassFailure();
181 }
182 }
183 };
184};
185
186} // namespace
187
188std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
189CreateDTensorPropagateDeviceIdToFunctionArgs() {
190 return std::make_unique<DTensorPropagateDeviceIdToFunctionArgs>();
191}
192
193} // namespace dtensor
194} // namespace tensorflow
195