1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_COMMON_H_
17#define TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_COMMON_H_
18
19#include <string>
20#include <vector>
21
22#include "absl/container/flat_hash_map.h"
23#include "absl/strings/string_view.h"
24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/SmallPtrSet.h"
26#include "llvm/ADT/SmallVector.h"
27#include "llvm/Support/Casting.h"
28#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
29#include "mlir/IR/Builders.h" // from @llvm-project
30#include "mlir/IR/BuiltinOps.h" // from @llvm-project
31#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32#include "mlir/IR/MLIRContext.h" // from @llvm-project
33#include "mlir/IR/Value.h" // from @llvm-project
34#include "mlir/IR/Visitors.h" // from @llvm-project
35#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
36#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
37#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
38#include "tensorflow/dtensor/cc/dstatus.h"
39#include "tensorflow/dtensor/cc/tensor_layout.h"
40#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
41
42namespace tensorflow {
43namespace dtensor {
44
45constexpr absl::string_view kReduceOpAdd = "Add";
46constexpr absl::string_view kReduceOpAll = "All";
47constexpr absl::string_view kReduceOpAny = "Any";
48constexpr absl::string_view kReduceOpMax = "Max";
49constexpr absl::string_view kReduceOpMin = "Min";
50constexpr absl::string_view kReduceOpMul = "Mul";
51// Mean is not a valid combinator function on its own. It is handled specially
52// by the reduce expansion.
53constexpr absl::string_view kReduceOpMean = "Mean";
54
55// Returns true if all layouts are replicated.
56bool AllReplicated(const std::vector<Layout>& layouts);
57
58// Takes a global type and converts it to a local type. Fails if the number of
59// shards does not divide the size of the dimension (if not dynamic).
60StatusOr<mlir::TensorType> LocalTypeFromGlobalType(
61 const Layout& layout, const mlir::TensorType& original_type);
62
63// Takes a global type and converts it to a local type.
64StatusOr<mlir::TensorType> GlobalTypeFromLocalType(
65 const Layout& layout, const mlir::TensorType& original_type);
66
67// Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways
68// in 'split_dimension' dimension and returns the split values.
69Status CreateSplitOp(const int num_split, const int split_dimension,
70 const mlir::Location location, mlir::Value src_input,
71 mlir::OpBuilder* builder, mlir::TF::SplitOp* split_op);
72
73// Given layouts + shapes, determines if the two are broadcast compatible.
74// See source file for more documentation.
75StatusOr<Layout> GetBroadcastLayoutForElementWise(
76 const Layout& layout_a, const Layout& layout_b,
77 mlir::ArrayRef<int64_t> shape_a, mlir::ArrayRef<int64_t> shape_b,
78 int64_t dims_to_ignore, std::vector<std::string>& to_split_a,
79 std::vector<std::string>& to_split_b);
80
81// Returns a merged layout using `GetBroadcastLayoutForElementwise()` function
82// given a list of operand layouts.
83StatusOr<absl::optional<Layout>> GetMergedOperandLayout(
84 const llvm::DenseMap<int, Layout>& operand_layouts, mlir::Operation* op);
85
86// Returns the forwarded input value of DTensorLayout op for which `value` is
87// the output. This must be used after layout propagation and before SPMD
88// expansion when all mlir::Value's of tf ops are followed by DTensorLayout op
89// to specify output layout.
90// To make the implementation safe for Layout Propagation V1 algorithm, if the
91// defining op of `value` is not DTensorLayout op (only the case for V1),
92// returns `value` directly.
93// TODO(b/172936130): Remove special casing for v1 Layout Propagation
94// algorithm.
95mlir::Value GetForwardedDTensorLayoutInput(mlir::Value value);
96
97// Goal of this function is to connect 'mlir::Value's (read 'mlir::OpResult's)
98// to the 'mlir::OpOperand's which use them, crossing function call
99// boundaries. The only keys in consumers which will not actually be
100// 'mlir::OpResult's will be the 'mlir::Value's representing the inputs of the
101// main function. The rest will be direct output of operations -- i.e.
102// mlir::OpResult. Note that 'mlir::Value's that are not used by any op or are
103// simply returned from the main functiuon will not be in this list. In these
104// cases, there are no conditions on the layouts for these 'mlir::Value's.
105//
106// A list of current assumptions in this code:
107// * Functions are only called once.
108// * Functions that are not reachable from main have been trimmed.
109// * Input to CopyToMesh can always be traced back to function inputs.
110mlir::LogicalResult PopulateConsumersFromModule(
111 mlir::ModuleOp* module, mlir::Dialect* tf_dialect,
112 llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers);
113
114// From device id, return an mlir::Value for a tensor of shape [1,
115// mesh.rank()] whose entries are the mesh coordinates of the device. The mesh
116// used, is the mesh for the given cluster.
117StatusOr<mlir::Value> GetMeshCoordinatesFromCluster(
118 mlir::tf_device::ClusterOp cluster);
119
120// Checks that optional metadata attributes of `op` are valid if they
121// exist. More specifically, output layouts of tf.Shape op and layouts of
122// resources inferred from AssignVariable op is added as metadata.
123mlir::LogicalResult ValidateMetadataAttributes(mlir::Operation* op);
124
125// Creates a map from function to ops which calls the function.
126mlir::LogicalResult GetFuncToCaller(
127 mlir::ModuleOp module,
128 llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller);
129
130// Takes an operand and traces its use across function call and
131// tf_device.cluster boundaries. Note that this may turn one operand into
132// many.
133llvm::SmallVector<mlir::OpOperand*, 4> TraceUseToNextTFOp(
134 mlir::OpOperand* operand,
135 const llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller,
136 llvm::SmallVector<mlir::Value, 4>* skipped_values = nullptr);
137
138// Replaces `cluster` with a new tf_device.cluster without return values
139// if result values are not used by any other ops.
140//
141// For example:
142//
143// %unused_value = "tf_device.cluster"() ({
144// %1 = "tf.Const"() {value = dense<10> : tensor<i32>} : () ->
145// tensor<i32> %2 = "tf.Neg"(%1) : (tensor<i32>) -> tensor<i32>
146// tf_device.return %2 : tensor<i32>
147// }) {_mesh="mesh:CPU,x=2,y=2"} : () -> (tensor<i32>)
148//
149// Will be transformed to:
150//
151// "tf_device.cluster"() ({
152// %1 = "tf.Const"() {value = dense<10> : tensor<i32>} : () ->
153// tensor<i32> %2 = "tf.Neg"(%1) : (tensor<i32>) -> tensor<i32>
154// tf_device.return
155// }) {_mesh="mesh:CPU,x=2,y=2"} : () -> ()
156void RemoveUnusedClusterResults(mlir::tf_device::ClusterOp cluster);
157
158mlir::StringAttr GetUniqueControlflowFnName(const std::string& prefix,
159 mlir::OpBuilder& builder);
160
161// Sets the builder insertion point to after value. If value is a block
162// argument, this checks that all users of the value are in the same cluster.
163// If not it errors out. If they are then it sets the inserition point to the
164// top of the cluster.
165Status SetBuilderInsertionAfterValue(mlir::Value value,
166 mlir::OpBuilder& builder);
167
168// Inserts a StringFormat and Print op, should only be used for debugging
169// on CPU.
170Status PrintTensor(mlir::Value value, const std::string& format_string);
171
172// Extract a vector of string from mlir value.
173Status ExtractConstStringVectorFromValue(
174 mlir::Value value, llvm::SmallVectorImpl<std::string>& out_vector);
175
176StatusOr<std::string> ExtractConstScalarStringFromValue(mlir::Value value);
177
178// A general Iterator that visits a FuncOp's body in topological order. Note
179// that this does not visit the given FuncOp itself. Function ops are visited
180// exactly once if functions are used in multiple call sites.
181//
182// An example usage of this Iterator is for SPMD Expansion or Sparse
183// Expansion, where we expand ops in topological order starting from the
184// `main` FuncOp, only visiting function ops once so that we don't expand
185// multiple times.
186class TopologicalIterator {
187 public:
188 explicit TopologicalIterator(mlir::func::FuncOp main_func);
189
190 // Returns whether there is any further ops to visit.
191 bool hasNext();
192
193 // Returns the next op to visit in the topological ordering. Returns
194 // a nullptr if there is no next op to visit.
195 mlir::Operation* next();
196
197 private:
198 // Stack to keep track of ops to visit.
199 llvm::SmallVector<mlir::Operation*, 4> ops_to_visit_;
200
201 // Keep track of functions we are walking, this is needed to avoid recursive
202 // function calls.
203 llvm::SmallDenseSet<mlir::StringRef, 4> funcs_visited_in_call_stack_;
204
205 // Keep track of all visit functions. This is to guarantee that
206 // functions are visited exactly once if functions are used in multiple
207 // callsites.
208 llvm::SmallDenseSet<mlir::StringRef, 4> funcs_visited_;
209};
210} // namespace dtensor
211} // namespace tensorflow
212
213#endif // TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_COMMON_H_
214