1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/mlir/spmd_expander.h"
17
18#include <climits>
19#include <cstdint>
20#include <iterator>
21#include <memory>
22#include <string>
23
24#include "absl/container/flat_hash_map.h"
25#include "absl/strings/str_cat.h"
26#include "absl/types/optional.h"
27#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
28#include "mlir/IR/OperationSupport.h" // from @llvm-project
29#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
30#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31#include "tensorflow/core/platform/errors.h"
32#include "tensorflow/dtensor/cc/constants.h"
33#include "tensorflow/dtensor/cc/dstatus.h"
34#include "tensorflow/dtensor/cc/tensor_layout.h"
35#include "tensorflow/dtensor/mlir/collectives.h"
36#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
37#include "tensorflow/dtensor/mlir/layout_parsing.h"
38#include "tensorflow/dtensor/mlir/op_utils.h"
39#include "tensorflow/dtensor/mlir/shape_utils.h"
40#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
41#include "tensorflow/dtensor/proto/layout.pb.h"
42
43namespace tensorflow {
44namespace dtensor {
45
46// static
47SPMDExpanderRegistry* SPMDExpanderRegistry::Global() {
48 static SPMDExpanderRegistry* registry = new SPMDExpanderRegistry();
49 return registry;
50}
51
52SPMDExpanderBase* SPMDExpanderRegistry::GetPropagateFnForOp(
53 mlir::Operation* op) {
54 auto key = OpName(op);
55 auto fn = op_to_propagate_fn_map_.find(key);
56 if (fn == op_to_propagate_fn_map_.end()) return nullptr;
57 return fn->second.get();
58}
59
60InitOnStartupMarker SPMDExpanderRegistry::RegisterPropagateFn(
61 std::string opName, std::unique_ptr<SPMDExpanderBase> prop) {
62 CHECK(op_to_propagate_fn_map_ // Crash ok
63 .insert_or_assign(opName, std::move(prop))
64 .second);
65 return {};
66}
67
68Status SPMDExpanderBase::ExpandOpAndSetLayout(mlir::Operation* op,
69 mlir::Operation** output) {
70 TF_ASSIGN_OR_RETURN(std::vector<absl::optional<Layout>> computed_layout,
71 ExtractLayoutFromOp(op));
72
73 if (computed_layout.empty() && op->getNumResults() != 0) {
74 return errors::InvalidArgument(
75 absl::StrCat("No attachced layout found for op : ", OpName(op),
76 " This might be due to an error in layout propagation.")
77 .c_str());
78 }
79
80 // `op` may be removed/replaced from the graph during SPMD expansion, so
81 // extract the global output shape before expansion.
82 llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> global_output_shapes;
83 global_output_shapes.reserve(op->getNumResults());
84 for (auto output_value : op->getResults()) {
85 auto maybe_ranked =
86 output_value.getType().dyn_cast<mlir::RankedTensorType>();
87 // Do not extract global shape if the shape isn't statically known.
88 //
89 // This is a bit subtle and relies on the check of static shape of output
90 // value below when extracting local_shape. We probably should consider a
91 // placeholder for unknown shapes to avoid surprises in the future.
92 //
93 // Given the nature of RestoreV2 op and its output ranks, we only special
94 // case for RestoreV2 for now.
95 if (llvm::isa<mlir::TF::RestoreV2Op, mlir::TF::DTensorRestoreV2Op>(op) &&
96 (!maybe_ranked || !maybe_ranked.hasStaticShape()))
97 continue;
98 TF_ASSIGN_OR_RETURN(auto global_shape,
99 ExtractGlobalOutputShape(output_value));
100 global_output_shapes.emplace_back(llvm::SmallVector<int64_t, 4>{
101 global_shape.begin(), global_shape.end()});
102 }
103
104 TF_ASSIGN_OR_RETURN(*output, this->ExpandOp(op));
105
106 // TODO(hthu): Use ToString() instead.
107 SetLayoutOnOp(*output, absl::Span<absl::optional<Layout>>(
108 computed_layout.data(), computed_layout.size()));
109
110 // Verify the local shape of the expanded operation matches the shape expected
111 // from the layout. Note that this does **not** catch all errors. When tensor
112 // dimension is sharded in a wrong mesh with the same device cardinality as
113 // the correct/expected mesh, this check will still pass.
114 for (const auto& output_layout_and_index :
115 llvm::enumerate(llvm::zip((*output)->getResults(), computed_layout))) {
116 const int index = output_layout_and_index.index();
117 const auto& output_and_layout = output_layout_and_index.value();
118
119 auto output_value = std::get<0>(output_and_layout);
120 // Extract the static shape of `output_value` if possible, otherwise ignore
121 // this output.
122 auto local_expanded_shape_or_status = GetShapeOfValue(output_value);
123 if (!local_expanded_shape_or_status.ok()) continue;
124
125 const auto local_expanded_shape = local_expanded_shape_or_status.value();
126 const auto& layout = std::get<1>(output_and_layout);
127 const auto expected_global_shape =
128 layout->GlobalShapeFromLocalShape(local_expanded_shape);
129
130 for (const auto& expanded_and_true_global_shape :
131 llvm::zip(global_output_shapes[index], expected_global_shape)) {
132 const auto expanded_shape = std::get<0>(expanded_and_true_global_shape);
133 const auto expected_shape = std::get<1>(expanded_and_true_global_shape);
134 // If any of the shape has unknown dimension, do not check/validate the
135 // shape.
136 if (expanded_shape <= 0 || expected_shape <= 0) continue;
137
138 if (expanded_shape != expected_shape) {
139 return errors::Internal(
140 "SPMD expansion resulted in op output inconsistent with the "
141 "provided layout.");
142 }
143 }
144 }
145
146 return OkStatus();
147}
148
149StatusOr<llvm::DenseMap<int, Layout>> SPMDExpanderBase::ComputeLayoutForward(
150 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
151 return errors::Unimplemented(
152 "ComputeLayoutForward API must be implemented via the subclass.");
153}
154
155StatusOr<llvm::DenseMap<int, Layout>> SPMDExpanderBase::ComputeLayoutForward(
156 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts,
157 const llvm::DenseMap<int, Layout>& output_layouts) {
158 return ComputeLayoutForward(op, input_layouts);
159}
160
161StatusOr<llvm::DenseMap<int, Layout>> SPMDExpanderBase::ComputeLayoutBackward(
162 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
163 return errors::Unimplemented(
164 "ComputeLayoutBackward API must be implemented via the subclass.");
165}
166
167StatusOr<llvm::DenseMap<int, Layout>> SPMDExpanderBase::ComputeLayoutBackward(
168 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts,
169 const llvm::DenseMap<int, Layout>& output_layouts) {
170 return ComputeLayoutBackward(op, output_layouts);
171}
172
173Status RunSPMDExpansion(mlir::Operation* op, mlir::Operation** output) {
174 SPMDExpanderBase* expander =
175 SPMDExpanderRegistry::Global()->GetPropagateFnForOp(op);
176 if (expander != nullptr) {
177 return expander->ExpandOpAndSetLayout(op, output);
178 } else {
179 VLOG(1) << "No expansion found for " << OpName(op) << "\n";
180 *output = op;
181 }
182 return OkStatus();
183}
184
185} // namespace dtensor
186} // namespace tensorflow
187