1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/dtensor/mlir/spmd_expander.h" |
17 | |
18 | #include <climits> |
19 | #include <cstdint> |
20 | #include <iterator> |
21 | #include <memory> |
22 | #include <string> |
23 | |
24 | #include "absl/container/flat_hash_map.h" |
25 | #include "absl/strings/str_cat.h" |
26 | #include "absl/types/optional.h" |
27 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
28 | #include "mlir/IR/OperationSupport.h" // from @llvm-project |
29 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
30 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
31 | #include "tensorflow/core/platform/errors.h" |
32 | #include "tensorflow/dtensor/cc/constants.h" |
33 | #include "tensorflow/dtensor/cc/dstatus.h" |
34 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
35 | #include "tensorflow/dtensor/mlir/collectives.h" |
36 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
37 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
38 | #include "tensorflow/dtensor/mlir/op_utils.h" |
39 | #include "tensorflow/dtensor/mlir/shape_utils.h" |
40 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
41 | #include "tensorflow/dtensor/proto/layout.pb.h" |
42 | |
43 | namespace tensorflow { |
44 | namespace dtensor { |
45 | |
46 | // static |
47 | SPMDExpanderRegistry* SPMDExpanderRegistry::Global() { |
48 | static SPMDExpanderRegistry* registry = new SPMDExpanderRegistry(); |
49 | return registry; |
50 | } |
51 | |
52 | SPMDExpanderBase* SPMDExpanderRegistry::GetPropagateFnForOp( |
53 | mlir::Operation* op) { |
54 | auto key = OpName(op); |
55 | auto fn = op_to_propagate_fn_map_.find(key); |
56 | if (fn == op_to_propagate_fn_map_.end()) return nullptr; |
57 | return fn->second.get(); |
58 | } |
59 | |
60 | InitOnStartupMarker SPMDExpanderRegistry::RegisterPropagateFn( |
61 | std::string opName, std::unique_ptr<SPMDExpanderBase> prop) { |
62 | CHECK(op_to_propagate_fn_map_ // Crash ok |
63 | .insert_or_assign(opName, std::move(prop)) |
64 | .second); |
65 | return {}; |
66 | } |
67 | |
68 | Status SPMDExpanderBase::ExpandOpAndSetLayout(mlir::Operation* op, |
69 | mlir::Operation** output) { |
70 | TF_ASSIGN_OR_RETURN(std::vector<absl::optional<Layout>> computed_layout, |
71 | ExtractLayoutFromOp(op)); |
72 | |
73 | if (computed_layout.empty() && op->getNumResults() != 0) { |
74 | return errors::InvalidArgument( |
75 | absl::StrCat("No attachced layout found for op : " , OpName(op), |
76 | " This might be due to an error in layout propagation." ) |
77 | .c_str()); |
78 | } |
79 | |
80 | // `op` may be removed/replaced from the graph during SPMD expansion, so |
81 | // extract the global output shape before expansion. |
82 | llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> global_output_shapes; |
83 | global_output_shapes.reserve(op->getNumResults()); |
84 | for (auto output_value : op->getResults()) { |
85 | auto maybe_ranked = |
86 | output_value.getType().dyn_cast<mlir::RankedTensorType>(); |
87 | // Do not extract global shape if the shape isn't statically known. |
88 | // |
89 | // This is a bit subtle and relies on the check of static shape of output |
90 | // value below when extracting local_shape. We probably should consider a |
91 | // placeholder for unknown shapes to avoid surprises in the future. |
92 | // |
93 | // Given the nature of RestoreV2 op and its output ranks, we only special |
94 | // case for RestoreV2 for now. |
95 | if (llvm::isa<mlir::TF::RestoreV2Op, mlir::TF::DTensorRestoreV2Op>(op) && |
96 | (!maybe_ranked || !maybe_ranked.hasStaticShape())) |
97 | continue; |
98 | TF_ASSIGN_OR_RETURN(auto global_shape, |
99 | ExtractGlobalOutputShape(output_value)); |
100 | global_output_shapes.emplace_back(llvm::SmallVector<int64_t, 4>{ |
101 | global_shape.begin(), global_shape.end()}); |
102 | } |
103 | |
104 | TF_ASSIGN_OR_RETURN(*output, this->ExpandOp(op)); |
105 | |
106 | // TODO(hthu): Use ToString() instead. |
107 | SetLayoutOnOp(*output, absl::Span<absl::optional<Layout>>( |
108 | computed_layout.data(), computed_layout.size())); |
109 | |
110 | // Verify the local shape of the expanded operation matches the shape expected |
111 | // from the layout. Note that this does **not** catch all errors. When tensor |
112 | // dimension is sharded in a wrong mesh with the same device cardinality as |
113 | // the correct/expected mesh, this check will still pass. |
114 | for (const auto& output_layout_and_index : |
115 | llvm::enumerate(llvm::zip((*output)->getResults(), computed_layout))) { |
116 | const int index = output_layout_and_index.index(); |
117 | const auto& output_and_layout = output_layout_and_index.value(); |
118 | |
119 | auto output_value = std::get<0>(output_and_layout); |
120 | // Extract the static shape of `output_value` if possible, otherwise ignore |
121 | // this output. |
122 | auto local_expanded_shape_or_status = GetShapeOfValue(output_value); |
123 | if (!local_expanded_shape_or_status.ok()) continue; |
124 | |
125 | const auto local_expanded_shape = local_expanded_shape_or_status.value(); |
126 | const auto& layout = std::get<1>(output_and_layout); |
127 | const auto expected_global_shape = |
128 | layout->GlobalShapeFromLocalShape(local_expanded_shape); |
129 | |
130 | for (const auto& expanded_and_true_global_shape : |
131 | llvm::zip(global_output_shapes[index], expected_global_shape)) { |
132 | const auto expanded_shape = std::get<0>(expanded_and_true_global_shape); |
133 | const auto expected_shape = std::get<1>(expanded_and_true_global_shape); |
134 | // If any of the shape has unknown dimension, do not check/validate the |
135 | // shape. |
136 | if (expanded_shape <= 0 || expected_shape <= 0) continue; |
137 | |
138 | if (expanded_shape != expected_shape) { |
139 | return errors::Internal( |
140 | "SPMD expansion resulted in op output inconsistent with the " |
141 | "provided layout." ); |
142 | } |
143 | } |
144 | } |
145 | |
146 | return OkStatus(); |
147 | } |
148 | |
149 | StatusOr<llvm::DenseMap<int, Layout>> SPMDExpanderBase::ComputeLayoutForward( |
150 | mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) { |
151 | return errors::Unimplemented( |
152 | "ComputeLayoutForward API must be implemented via the subclass." ); |
153 | } |
154 | |
155 | StatusOr<llvm::DenseMap<int, Layout>> SPMDExpanderBase::ComputeLayoutForward( |
156 | mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts, |
157 | const llvm::DenseMap<int, Layout>& output_layouts) { |
158 | return ComputeLayoutForward(op, input_layouts); |
159 | } |
160 | |
161 | StatusOr<llvm::DenseMap<int, Layout>> SPMDExpanderBase::ComputeLayoutBackward( |
162 | mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) { |
163 | return errors::Unimplemented( |
164 | "ComputeLayoutBackward API must be implemented via the subclass." ); |
165 | } |
166 | |
167 | StatusOr<llvm::DenseMap<int, Layout>> SPMDExpanderBase::ComputeLayoutBackward( |
168 | mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts, |
169 | const llvm::DenseMap<int, Layout>& output_layouts) { |
170 | return ComputeLayoutBackward(op, output_layouts); |
171 | } |
172 | |
173 | Status RunSPMDExpansion(mlir::Operation* op, mlir::Operation** output) { |
174 | SPMDExpanderBase* expander = |
175 | SPMDExpanderRegistry::Global()->GetPropagateFnForOp(op); |
176 | if (expander != nullptr) { |
177 | return expander->ExpandOpAndSetLayout(op, output); |
178 | } else { |
179 | VLOG(1) << "No expansion found for " << OpName(op) << "\n" ; |
180 | *output = op; |
181 | } |
182 | return OkStatus(); |
183 | } |
184 | |
185 | } // namespace dtensor |
186 | } // namespace tensorflow |
187 | |