1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_H_ |
17 | #define TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_H_ |
18 | |
19 | #include <string> |
20 | |
21 | #include "absl/types/optional.h" |
22 | #include "mlir/IR/Builders.h" // from @llvm-project |
23 | #include "mlir/IR/Operation.h" // from @llvm-project |
24 | #include "mlir/IR/UseDefLists.h" // from @llvm-project |
25 | #include "tensorflow/core/framework/registration/registration.h" |
26 | #include "tensorflow/dtensor/cc/dstatus.h" |
27 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
28 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
29 | |
30 | namespace tensorflow { |
31 | namespace dtensor { |
32 | |
33 | // Base class for handling SPMD expansion of a MLIR TF Operation. |
34 | class SPMDExpanderBase { |
35 | public: |
36 | virtual ~SPMDExpanderBase() {} |
37 | |
38 | // Converts `op` to a SPMD expanded form. SPMD expansion logic is |
39 | // a function of op type, op output's layout, and layout of op's |
40 | // inputs. Must return the `op` that is expanded as the final return value. |
41 | virtual StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) = 0; |
42 | |
43 | // Layout propagation functions. |
44 | // |
45 | // During the layout algorithm, for each op output we compute a layout by |
46 | // merging the current layout request from the op producing the output and the |
47 | // layout requests from the ops consuming the output. These merged layouts |
48 | // represent the current state of layouts over the entire mlir module. |
49 | // |
50 | // For an op, if any of the merged layouts for the inputs or output are |
51 | // updated, the ComputeLayoutForward and ComputeLayoutBackward functions will |
52 | // be called with all the updated layout maps populated. |
53 | // |
54 | // ComputeLayoutForward should take the input layouts and determine which |
55 | // output layout these inputs would produce. Likewise, ComputeLayoutBackward |
56 | // should take the output layouts and determine the what layouts to propagate |
57 | // to the inputs. |
58 | // |
59 | // In both cases the functions should choose layouts that reduce the amount of |
60 | // cross device communication for the op. |
61 | // |
62 | // ComputeLayoutForward should not take into account the current output |
63 | // layout(s) when computing the new ones. The merge algorithm will decide what |
64 | // to do. There are only a very few cases where the current output layout may |
65 | // need to propagated again, in which case those ops can override the |
66 | // expanded ComputeLayout* functions. This similarly applies to |
67 | // ComputeLayoutBackward. |
68 | // |
69 | // Note that for some ops, where the input layout does not determine output |
70 | // layout (and visa versa), it is acceptable to either return a replicated |
71 | // layout. E.g. for tf.Fill, ComputeLayoutForward can return a replicated |
72 | // output layout and if a consumer requests a more sharded layout, then the |
73 | // layout algorithm will merge the requests, resulting in the more sharded |
74 | // layout. |
75 | |
76 | // Computes output layout(s) of `op` based on the current `input_layouts` |
77 | // inferred from inputs of `op`. The `input_layouts` parameter maps input |
78 | // indices to the corresponding layouts. It may be empty if the op has no |
79 | // operands or if no input layouts have been inferred yet. |
80 | virtual StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward( |
81 | mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts); |
82 | |
83 | // Computes output layout(s) of `op` based on the current `input_layouts` and |
84 | // `output_layouts` inferred from the inputs and outputs of `op`. Both |
85 | // parameters maps input/output indices to the corresponding layouts. Either |
86 | // may be empty. |
87 | // |
88 | // NOTE: The other ComputeLayoutForward function should be preferred since in |
89 | // most cases the output layouts are only computed based on the input layouts. |
90 | virtual StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward( |
91 | mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts, |
92 | const llvm::DenseMap<int, Layout>& output_layouts); |
93 | |
94 | // Computes input layout(s) of `op` based on the current `output_layouts` |
95 | // inferred from outputs of `op`. The `output_layouts` parameter maps output |
96 | // indices to the corresponding layouts. It may be empty if the op has no |
97 | // outputs or if no output layouts have been inferred yet. |
98 | virtual StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward( |
99 | mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts); |
100 | |
101 | // Computes input layout(s) of `op` based on the current `output_layouts` and |
102 | // `input_layouts` inferred from the outputs and inputs of `op`. Both |
103 | // parameters maps input/output indices to the corresponding layouts. Either |
104 | // may be empty. |
105 | // |
106 | // NOTE: The other ComputeLayoutBackward function should be preferred since in |
107 | // most cases the input layouts are only computed based on the output layouts. |
108 | virtual StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward( |
109 | mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts, |
110 | const llvm::DenseMap<int, Layout>& output_layouts); |
111 | |
112 | // Run ExpandOp() and set layout from the computed layout from original op. |
113 | // Returns the expanded op in output. |
114 | Status ExpandOpAndSetLayout(mlir::Operation* op, mlir::Operation** output); |
115 | }; |
116 | |
117 | // Computes the SPMD expansion for `op`. |
118 | // |
119 | // Prior to this call, all inputs to `op` have been lowered to local operations |
120 | // & shapes. The lowered op must emit a type compatible with the local shape. |
121 | Status RunSPMDExpansion(mlir::Operation* op, mlir::Operation** output); |
122 | |
123 | // A registry of SPMD expanders. This map is statically stored and initialized |
124 | // with all the registered SPMD expanders. |
125 | class SPMDExpanderRegistry { |
126 | public: |
127 | ~SPMDExpanderRegistry() = default; |
128 | |
129 | // A singleton available at startup. |
130 | static SPMDExpanderRegistry* Global(); |
131 | |
132 | // Returns the expansion for the given operation (or nullptr if no expansion |
133 | // has been registered). |
134 | SPMDExpanderBase* GetPropagateFnForOp(mlir::Operation* op); |
135 | |
136 | // Registers an expander for the provided opName. |
137 | InitOnStartupMarker RegisterPropagateFn( |
138 | std::string opName, std::unique_ptr<SPMDExpanderBase> prop); |
139 | |
140 | private: |
141 | absl::flat_hash_map<std::string, std::unique_ptr<SPMDExpanderBase>> |
142 | op_to_propagate_fn_map_; |
143 | }; |
144 | |
145 | #define REGISTER_SPMD(name, op, prop, ...) \ |
146 | static ::tensorflow::InitOnStartupMarker const spmd_##name = \ |
147 | InitOnStartupMarker{} \ |
148 | << SPMDExpanderRegistry::Global()->RegisterPropagateFn( \ |
149 | mlir::op ::getOperationName().str(), \ |
150 | std::make_unique<prop>(__VA_ARGS__)) |
151 | |
152 | } // namespace dtensor |
153 | } // namespace tensorflow |
154 | |
155 | #endif // TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_H_ |
156 | |