1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "llvm/ADT/StringRef.h"
17#include "llvm/Support/FormatVariadic.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
19#include "mlir/IR/Attributes.h" // from @llvm-project
20#include "mlir/IR/Builders.h" // from @llvm-project
21#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
22#include "mlir/IR/Diagnostics.h" // from @llvm-project
23#include "mlir/IR/Operation.h" // from @llvm-project
24#include "mlir/IR/Types.h" // from @llvm-project
25#include "mlir/IR/Value.h" // from @llvm-project
26#include "mlir/Pass/Pass.h" // from @llvm-project
27#include "mlir/Pass/PassManager.h" // from @llvm-project
28#include "mlir/Support/LogicalResult.h" // from @llvm-project
29#include "mlir/Transforms/Passes.h" // from @llvm-project
30#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
33#include "tensorflow/dtensor/mlir/layout_parsing.h"
34
35namespace tensorflow {
36namespace dtensor {
37
38namespace {
39#define GEN_PASS_DEF_DTENSORTPUADDRESOURCEDEVICEATTRIBUTE
40#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
41
42constexpr char kFuncDeviceAttr[] = "tf.device";
43
44// Returns whether `val` is of resource type.
45bool IsResourceType(mlir::Value val) {
46 return val.isa<mlir::BlockArgument>() && val.getType()
47 .cast<mlir::TensorType>()
48 .getElementType()
49 .isa<mlir::TF::ResourceType>();
50}
51
52// Adds device attribute to `arg` with the device placement of `execute_op`
53void AddPlaceholderDeviceAttributeToResource(
54 mlir::BlockArgument arg, mlir::TF::TPUExecuteOp execute_op) {
55 // TPUExecute op is wrapped inside tf_device.Launch op for device assignment.
56 auto tpu_execute_device_launch =
57 execute_op->getParentOfType<mlir::tf_device::LaunchOp>();
58 mlir::StringRef tpu_device_attr = tpu_execute_device_launch.getDevice();
59
60 auto function = execute_op->getParentOfType<mlir::func::FuncOp>();
61 mlir::OpBuilder builder(execute_op);
62 function.setArgAttr(arg.getArgNumber(), kFuncDeviceAttr,
63 builder.getStringAttr(tpu_device_attr));
64}
65
66// Returns AssignVariableOp that consumes output of `val`. `val` is a output
67// from TPUExecute op which is wrapped inside a single tf_device.Launch
68// operation. As so, output of parent launch op is queried to identify connected
69// AssignVariable op.
70mlir::Operation* IdentifyConnectedAssignVariableOp(mlir::Value val) {
71 for (mlir::OpOperand& use : val.getUses()) {
72 auto return_op = llvm::dyn_cast<mlir::tf_device::ReturnOp>(use.getOwner());
73 if (!return_op) continue;
74
75 auto parent_launch =
76 val.getDefiningOp()->getParentOfType<mlir::tf_device::LaunchOp>();
77 mlir::Value launch_output = parent_launch.getResult(use.getOperandNumber());
78 for (mlir::Operation* user : launch_output.getUsers()) {
79 auto assign_variable = llvm::dyn_cast<mlir::TF::AssignVariableOp>(user);
80 if (!assign_variable) continue;
81
82 return assign_variable;
83 }
84 }
85 return nullptr;
86}
87
88struct DTensorTpuAddResourceDeviceAttribute
89 : public impl::DTensorTpuAddResourceDeviceAttributeBase<
90 DTensorTpuAddResourceDeviceAttribute> {
91 void runOnOperation() override {
92 mlir::MLIRContext& context = getContext();
93 mlir::OpBuilder op_builder(&context);
94 mlir::ModuleOp module = getOperation();
95 // For each resource value that is input or that is consumed by TPUExecute
96 // op, add placeholder device attribute to the resource argument.
97 mlir::WalkResult walk_result =
98 module.walk([](mlir::TF::TPUExecuteOp tpu_execute) {
99 for (mlir::Value tpu_input : tpu_execute.getOperands()) {
100 if (IsResourceType(tpu_input))
101 AddPlaceholderDeviceAttributeToResource(
102 tpu_input.cast<mlir::BlockArgument>(), tpu_execute);
103
104 mlir::Operation* input_op = tpu_input.getDefiningOp();
105 auto read_variable_op =
106 llvm::dyn_cast_or_null<mlir::TF::ReadVariableOp>(input_op);
107 if (!read_variable_op) continue;
108
109 AddPlaceholderDeviceAttributeToResource(
110 read_variable_op.resource().cast<mlir::BlockArgument>(),
111 tpu_execute);
112 }
113
114 for (mlir::Value result : tpu_execute.getResults()) {
115 mlir::Operation* assign_variable =
116 IdentifyConnectedAssignVariableOp(result);
117 if (assign_variable == nullptr) continue;
118
119 AddPlaceholderDeviceAttributeToResource(
120 llvm::cast<mlir::TF::AssignVariableOp>(assign_variable)
121 .resource()
122 .cast<mlir::BlockArgument>(),
123 tpu_execute);
124 }
125
126 return mlir::WalkResult::advance();
127 });
128
129 if (walk_result.wasInterrupted()) return signalPassFailure();
130 };
131};
132
133} // namespace
134
135// Adds placeholder device attributes to resource arguments of TPU functions.
136// Device attribute added is consistent with device placement of TPUExecute op.
137// This is required for enabling CreateTPUMergeVariablesWithExecutePass as the
138// pass checks that all resources must have consistent device placement with
139// TPUExecute op in order to enable buffer aliasing.
140std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
141CreateDTensorTpuAddResourceDeviceAttribute() {
142 return std::make_unique<DTensorTpuAddResourceDeviceAttribute>();
143}
144
145} // namespace dtensor
146} // namespace tensorflow
147