1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_H_
17#define TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_H_
18
19#include <string>
20
21#include "absl/types/optional.h"
22#include "mlir/IR/Builders.h" // from @llvm-project
23#include "mlir/IR/Operation.h" // from @llvm-project
24#include "mlir/IR/UseDefLists.h" // from @llvm-project
25#include "tensorflow/core/framework/registration/registration.h"
26#include "tensorflow/dtensor/cc/dstatus.h"
27#include "tensorflow/dtensor/cc/tensor_layout.h"
28#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
29
30namespace tensorflow {
31namespace dtensor {
32
33// Base class for handling SPMD expansion of a MLIR TF Operation.
34class SPMDExpanderBase {
35 public:
36 virtual ~SPMDExpanderBase() {}
37
38 // Converts `op` to a SPMD expanded form. SPMD expansion logic is
39 // a function of op type, op output's layout, and layout of op's
40 // inputs. Must return the `op` that is expanded as the final return value.
41 virtual StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) = 0;
42
43 // Layout propagation functions.
44 //
45 // During the layout algorithm, for each op output we compute a layout by
46 // merging the current layout request from the op producing the output and the
47 // layout requests from the ops consuming the output. These merged layouts
48 // represent the current state of layouts over the entire mlir module.
49 //
50 // For an op, if any of the merged layouts for the inputs or output are
51 // updated, the ComputeLayoutForward and ComputeLayoutBackward functions will
52 // be called with all the updated layout maps populated.
53 //
54 // ComputeLayoutForward should take the input layouts and determine which
55 // output layout these inputs would produce. Likewise, ComputeLayoutBackward
56 // should take the output layouts and determine the what layouts to propagate
57 // to the inputs.
58 //
59 // In both cases the functions should choose layouts that reduce the amount of
60 // cross device communication for the op.
61 //
62 // ComputeLayoutForward should not take into account the current output
63 // layout(s) when computing the new ones. The merge algorithm will decide what
64 // to do. There are only a very few cases where the current output layout may
65 // need to propagated again, in which case those ops can override the
66 // expanded ComputeLayout* functions. This similarly applies to
67 // ComputeLayoutBackward.
68 //
69 // Note that for some ops, where the input layout does not determine output
70 // layout (and visa versa), it is acceptable to either return a replicated
71 // layout. E.g. for tf.Fill, ComputeLayoutForward can return a replicated
72 // output layout and if a consumer requests a more sharded layout, then the
73 // layout algorithm will merge the requests, resulting in the more sharded
74 // layout.
75
76 // Computes output layout(s) of `op` based on the current `input_layouts`
77 // inferred from inputs of `op`. The `input_layouts` parameter maps input
78 // indices to the corresponding layouts. It may be empty if the op has no
79 // operands or if no input layouts have been inferred yet.
80 virtual StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward(
81 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts);
82
83 // Computes output layout(s) of `op` based on the current `input_layouts` and
84 // `output_layouts` inferred from the inputs and outputs of `op`. Both
85 // parameters maps input/output indices to the corresponding layouts. Either
86 // may be empty.
87 //
88 // NOTE: The other ComputeLayoutForward function should be preferred since in
89 // most cases the output layouts are only computed based on the input layouts.
90 virtual StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward(
91 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts,
92 const llvm::DenseMap<int, Layout>& output_layouts);
93
94 // Computes input layout(s) of `op` based on the current `output_layouts`
95 // inferred from outputs of `op`. The `output_layouts` parameter maps output
96 // indices to the corresponding layouts. It may be empty if the op has no
97 // outputs or if no output layouts have been inferred yet.
98 virtual StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward(
99 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts);
100
101 // Computes input layout(s) of `op` based on the current `output_layouts` and
102 // `input_layouts` inferred from the outputs and inputs of `op`. Both
103 // parameters maps input/output indices to the corresponding layouts. Either
104 // may be empty.
105 //
106 // NOTE: The other ComputeLayoutBackward function should be preferred since in
107 // most cases the input layouts are only computed based on the output layouts.
108 virtual StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward(
109 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts,
110 const llvm::DenseMap<int, Layout>& output_layouts);
111
112 // Run ExpandOp() and set layout from the computed layout from original op.
113 // Returns the expanded op in output.
114 Status ExpandOpAndSetLayout(mlir::Operation* op, mlir::Operation** output);
115};
116
117// Computes the SPMD expansion for `op`.
118//
119// Prior to this call, all inputs to `op` have been lowered to local operations
120// & shapes. The lowered op must emit a type compatible with the local shape.
121Status RunSPMDExpansion(mlir::Operation* op, mlir::Operation** output);
122
123// A registry of SPMD expanders. This map is statically stored and initialized
124// with all the registered SPMD expanders.
125class SPMDExpanderRegistry {
126 public:
127 ~SPMDExpanderRegistry() = default;
128
129 // A singleton available at startup.
130 static SPMDExpanderRegistry* Global();
131
132 // Returns the expansion for the given operation (or nullptr if no expansion
133 // has been registered).
134 SPMDExpanderBase* GetPropagateFnForOp(mlir::Operation* op);
135
136 // Registers an expander for the provided opName.
137 InitOnStartupMarker RegisterPropagateFn(
138 std::string opName, std::unique_ptr<SPMDExpanderBase> prop);
139
140 private:
141 absl::flat_hash_map<std::string, std::unique_ptr<SPMDExpanderBase>>
142 op_to_propagate_fn_map_;
143};
144
145#define REGISTER_SPMD(name, op, prop, ...) \
146 static ::tensorflow::InitOnStartupMarker const spmd_##name = \
147 InitOnStartupMarker{} \
148 << SPMDExpanderRegistry::Global()->RegisterPropagateFn( \
149 mlir::op ::getOperationName().str(), \
150 std::make_unique<prop>(__VA_ARGS__))
151
152} // namespace dtensor
153} // namespace tensorflow
154
155#endif // TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_H_
156