1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "llvm/ADT/StringRef.h" |
17 | #include "llvm/Support/FormatVariadic.h" |
18 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
19 | #include "mlir/IR/Attributes.h" // from @llvm-project |
20 | #include "mlir/IR/Builders.h" // from @llvm-project |
21 | #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
22 | #include "mlir/IR/Diagnostics.h" // from @llvm-project |
23 | #include "mlir/IR/Operation.h" // from @llvm-project |
24 | #include "mlir/IR/Types.h" // from @llvm-project |
25 | #include "mlir/IR/Value.h" // from @llvm-project |
26 | #include "mlir/Pass/Pass.h" // from @llvm-project |
27 | #include "mlir/Pass/PassManager.h" // from @llvm-project |
28 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
29 | #include "mlir/Transforms/Passes.h" // from @llvm-project |
30 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
31 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
32 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
33 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
34 | |
35 | namespace tensorflow { |
36 | namespace dtensor { |
37 | |
38 | namespace { |
39 | #define GEN_PASS_DEF_DTENSORTPUADDRESOURCEDEVICEATTRIBUTE |
40 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
41 | |
42 | constexpr char kFuncDeviceAttr[] = "tf.device" ; |
43 | |
44 | // Returns whether `val` is of resource type. |
45 | bool IsResourceType(mlir::Value val) { |
46 | return val.isa<mlir::BlockArgument>() && val.getType() |
47 | .cast<mlir::TensorType>() |
48 | .getElementType() |
49 | .isa<mlir::TF::ResourceType>(); |
50 | } |
51 | |
52 | // Adds device attribute to `arg` with the device placement of `execute_op` |
53 | void AddPlaceholderDeviceAttributeToResource( |
54 | mlir::BlockArgument arg, mlir::TF::TPUExecuteOp execute_op) { |
55 | // TPUExecute op is wrapped inside tf_device.Launch op for device assignment. |
56 | auto tpu_execute_device_launch = |
57 | execute_op->getParentOfType<mlir::tf_device::LaunchOp>(); |
58 | mlir::StringRef tpu_device_attr = tpu_execute_device_launch.getDevice(); |
59 | |
60 | auto function = execute_op->getParentOfType<mlir::func::FuncOp>(); |
61 | mlir::OpBuilder builder(execute_op); |
62 | function.setArgAttr(arg.getArgNumber(), kFuncDeviceAttr, |
63 | builder.getStringAttr(tpu_device_attr)); |
64 | } |
65 | |
66 | // Returns AssignVariableOp that consumes output of `val`. `val` is a output |
67 | // from TPUExecute op which is wrapped inside a single tf_device.Launch |
68 | // operation. As so, output of parent launch op is queried to identify connected |
69 | // AssignVariable op. |
70 | mlir::Operation* IdentifyConnectedAssignVariableOp(mlir::Value val) { |
71 | for (mlir::OpOperand& use : val.getUses()) { |
72 | auto return_op = llvm::dyn_cast<mlir::tf_device::ReturnOp>(use.getOwner()); |
73 | if (!return_op) continue; |
74 | |
75 | auto parent_launch = |
76 | val.getDefiningOp()->getParentOfType<mlir::tf_device::LaunchOp>(); |
77 | mlir::Value launch_output = parent_launch.getResult(use.getOperandNumber()); |
78 | for (mlir::Operation* user : launch_output.getUsers()) { |
79 | auto assign_variable = llvm::dyn_cast<mlir::TF::AssignVariableOp>(user); |
80 | if (!assign_variable) continue; |
81 | |
82 | return assign_variable; |
83 | } |
84 | } |
85 | return nullptr; |
86 | } |
87 | |
88 | struct DTensorTpuAddResourceDeviceAttribute |
89 | : public impl::DTensorTpuAddResourceDeviceAttributeBase< |
90 | DTensorTpuAddResourceDeviceAttribute> { |
91 | void runOnOperation() override { |
92 | mlir::MLIRContext& context = getContext(); |
93 | mlir::OpBuilder op_builder(&context); |
94 | mlir::ModuleOp module = getOperation(); |
95 | // For each resource value that is input or that is consumed by TPUExecute |
96 | // op, add placeholder device attribute to the resource argument. |
97 | mlir::WalkResult walk_result = |
98 | module.walk([](mlir::TF::TPUExecuteOp tpu_execute) { |
99 | for (mlir::Value tpu_input : tpu_execute.getOperands()) { |
100 | if (IsResourceType(tpu_input)) |
101 | AddPlaceholderDeviceAttributeToResource( |
102 | tpu_input.cast<mlir::BlockArgument>(), tpu_execute); |
103 | |
104 | mlir::Operation* input_op = tpu_input.getDefiningOp(); |
105 | auto read_variable_op = |
106 | llvm::dyn_cast_or_null<mlir::TF::ReadVariableOp>(input_op); |
107 | if (!read_variable_op) continue; |
108 | |
109 | AddPlaceholderDeviceAttributeToResource( |
110 | read_variable_op.resource().cast<mlir::BlockArgument>(), |
111 | tpu_execute); |
112 | } |
113 | |
114 | for (mlir::Value result : tpu_execute.getResults()) { |
115 | mlir::Operation* assign_variable = |
116 | IdentifyConnectedAssignVariableOp(result); |
117 | if (assign_variable == nullptr) continue; |
118 | |
119 | AddPlaceholderDeviceAttributeToResource( |
120 | llvm::cast<mlir::TF::AssignVariableOp>(assign_variable) |
121 | .resource() |
122 | .cast<mlir::BlockArgument>(), |
123 | tpu_execute); |
124 | } |
125 | |
126 | return mlir::WalkResult::advance(); |
127 | }); |
128 | |
129 | if (walk_result.wasInterrupted()) return signalPassFailure(); |
130 | }; |
131 | }; |
132 | |
133 | } // namespace |
134 | |
135 | // Adds placeholder device attributes to resource arguments of TPU functions. |
136 | // Device attribute added is consistent with device placement of TPUExecute op. |
137 | // This is required for enabling CreateTPUMergeVariablesWithExecutePass as the |
138 | // pass checks that all resources must have consistent device placement with |
139 | // TPUExecute op in order to enable buffer aliasing. |
140 | std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> |
141 | CreateDTensorTpuAddResourceDeviceAttribute() { |
142 | return std::make_unique<DTensorTpuAddResourceDeviceAttribute>(); |
143 | } |
144 | |
145 | } // namespace dtensor |
146 | } // namespace tensorflow |
147 | |