1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <memory>
17
18#include "llvm/Support/FormatVariadic.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
20#include "mlir/IR/BuiltinOps.h" // from @llvm-project
21#include "mlir/Pass/Pass.h" // from @llvm-project
22#include "mlir/Support/LogicalResult.h" // from @llvm-project
23#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
24#include "tensorflow/dtensor/mlir/dtensor_send_recv.h"
25#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
26#include "tensorflow/dtensor/mlir/shape_utils.h"
27#include "tensorflow/dtensor/mlir/value_utils.h"
28
29namespace tensorflow {
30namespace dtensor {
31
32namespace {
33#define GEN_PASS_DEF_DTENSORINFERSHAPESFORRESTOREV2OP
34#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
35
36// From the Operation that produces `value`, set the result type to `type`.
37//
38// Recursively set the result type to `type` going backward toward
39// the tf.RestoreV2Op that produced the unknown shape associated with `value`.
40mlir::LogicalResult BackwardShapeInferenceToRestoreOp(mlir::ModuleOp module,
41 mlir::OpBuilder* builder,
42 mlir::Value value,
43 mlir::Type type) {
44 mlir::Operation* op = value.getDefiningOp();
45 if (op == nullptr) return mlir::success();
46 if (!llvm::isa<mlir::TF::IdentityOp, mlir::TF::DTensorRecv,
47 mlir::TF::RestoreV2Op>(op)) {
48 return op->emitOpError(llvm::formatv(
49 "Expected an Identity, DTensorRecv, or RestoreV2 op, but got: {0}",
50 op->getName().getStringRef()));
51 }
52
53 builder->setInsertionPointAfter(op);
54
55 // Base case: If we got to the RestoreV2Op, then we got to the root
56 // of the unknown shape result. Set the type to `type` of the result index
57 // from `value`.
58 if (auto restore_op = llvm::dyn_cast_or_null<mlir::TF::RestoreV2Op>(op)) {
59 // This is usually a dangerous operation, but since we are backward
60 // propagating shapes and correctly setting the shapes backwards,
61 // we can modify the value itself here instead of creating a new
62 // RestoreV2 op.
63 //
64 // Creating a new RestoreV2 op and replacing all uses will make this
65 // algorithm run in O(N^2) where N = number of outputs of RestoreV2.
66 //
67 // Using setType(type) modifies in place and makes this algorithm run in
68 // O(N).
69 value.setType(type);
70 } else if (auto identity_op =
71 llvm::dyn_cast_or_null<mlir::TF::IdentityOp>(op)) {
72 auto new_identity_op = builder->create<mlir::TF::IdentityOp>(
73 identity_op.getLoc(), type, identity_op.input());
74 identity_op.output().replaceAllUsesWith(new_identity_op.output());
75 identity_op.erase();
76
77 // Recursively shape inference to the input of the identity op.
78 return BackwardShapeInferenceToRestoreOp(module, builder,
79 new_identity_op.input(), type);
80 } else if (auto recv_op = llvm::dyn_cast_or_null<mlir::TF::DTensorRecv>(op)) {
81 // If we have a DTensorRecv, then there is cross mesh action and the
82 // RestoreV2Op we want to fix is on the mesh of the corresponding
83 // DTensorSend. Set shape of this DTensorRecv first and go to the
84 // corresponding DTensorSend.
85 auto new_recv_op = builder->create<mlir::TF::DTensorRecv>(
86 recv_op.getLoc(), type, builder->getStringAttr(recv_op.key()),
87 mlir::TF::ShapeAttr::get(builder->getContext(),
88 type.dyn_cast<mlir::TensorType>()),
89 mlir::dtensor::LayoutAttr::get(builder->getContext(),
90 recv_op.layout()));
91
92 recv_op.replaceAllUsesWith(new_recv_op.output());
93 recv_op.erase();
94
95 auto send_op = GetCorrespondingDTensorSendRecvOp<mlir::TF::DTensorRecv>(
96 module, new_recv_op);
97
98 if (!send_op.ok())
99 return recv_op.emitOpError(send_op.status().error_message());
100
101 // Recursively shape inference to the input of the send op.
102 return BackwardShapeInferenceToRestoreOp(
103 module, builder, send_op.value()->getOperand(0), type);
104 }
105 return mlir::success();
106}
107
108// From every AssignVariableOp, if the value X that we are assigning to the
109// resource tensor has unknown shape information, then value X might be
110// from the result of a tf.RestoreV2 op.
111//
112// We can infer the unknown shape of the result of a tf.RestoreV2 op through
113// the resource tensors of AssignVariableOps that consume the results.
114//
115// Thus, we propagate the underlying resource tensor shape and dtype backwards
116// leading up to the tf.RestoreV2 op.
117mlir::LogicalResult PropagateShapeInformationFromAssignVariableOp(
118 mlir::ModuleOp module) {
119 module.walk([&](mlir::TF::AssignVariableOp assign_op) {
120 // Check that the `value` has an unknown shape.
121 if (ValueRank(assign_op.value()) == -1) {
122 StatusOr<llvm::ArrayRef<int64_t>> shape =
123 GetShapeOfValue(assign_op.resource());
124 if (!shape.ok()) {
125 assign_op->emitOpError(
126 "Resource tensor was expected to have shape information but was "
127 "missing it during CheckpointShapeInference.");
128 return mlir::WalkResult::interrupt();
129 }
130 // Propagete shape backwards to all the ops that use or produce
131 // the value with missing shape.
132 mlir::OpBuilder builder(assign_op);
133 mlir::Type known_type = GetSubtypeOrSelf(assign_op.resource());
134 if (mlir::failed(BackwardShapeInferenceToRestoreOp(
135 module, &builder, assign_op.value(), known_type))) {
136 assign_op->emitOpError(
137 "Error doing Backward shape inference from AssignVariableOp during "
138 "CheckpointShapeInference.");
139 return mlir::WalkResult::interrupt();
140 }
141 }
142 return mlir::WalkResult::advance();
143 });
144
145 return mlir::success();
146}
147
148struct DTensorInferShapesForRestoreV2Op
149 : public impl::DTensorInferShapesForRestoreV2OpBase<
150 DTensorInferShapesForRestoreV2Op> {
151 void runOnOperation() override {
152 auto module = getOperation();
153 if (failed(PropagateShapeInformationFromAssignVariableOp(module)))
154 return signalPassFailure();
155 };
156};
157
158} // namespace
159
160std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
161CreateDTensorInferShapesForRestoreV2Op() {
162 return std::make_unique<DTensorInferShapesForRestoreV2Op>();
163}
164
165} // namespace dtensor
166} // namespace tensorflow
167