1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/dtensor/mlir/collectives.h" |
17 | |
18 | #include <cstdint> |
19 | #include <string> |
20 | |
21 | #include "absl/container/flat_hash_set.h" |
22 | #include "absl/strings/string_view.h" |
23 | #include "llvm/ADT/SmallVector.h" |
24 | #include "llvm/Support/FormatVariadic.h" |
25 | #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
26 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
27 | #include "mlir/IR/Value.h" // from @llvm-project |
28 | #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" |
29 | #include "tensorflow/core/platform/errors.h" |
30 | #include "tensorflow/dtensor/cc/dstatus.h" |
31 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
32 | #include "tensorflow/dtensor/mlir/collectives_common.h" |
33 | #include "tensorflow/dtensor/mlir/dtensor_location.h" |
34 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
35 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
36 | #include "tensorflow/dtensor/mlir/shape_utils.h" |
37 | #include "tensorflow/dtensor/mlir/sparse_expander_common.h" |
38 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
39 | #include "tensorflow/dtensor/mlir/value_utils.h" |
40 | |
41 | namespace tensorflow { |
42 | namespace dtensor { |
43 | |
44 | namespace { |
45 | |
46 | namespace ops_util = ::mlir::TF::collection_ops_util; |
47 | |
48 | } // namespace |
49 | |
50 | StatusOr<mlir::Value> EmitAllGather( |
51 | mlir::OpBuilder& builder, mlir::Value input, |
52 | const dtensor::Layout& src_layout, const dtensor::Layout& tgt_layout, |
53 | llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops) { |
54 | if (src_layout.IsEquivalent(tgt_layout)) return input; |
55 | |
56 | if (src_layout.rank() != tgt_layout.rank()) { |
57 | return errors::InvalidArgument( |
58 | "Expected source and target layout to have the same rank, got " , |
59 | src_layout.rank(), " vs " , tgt_layout.rank()); |
60 | } |
61 | |
62 | // Check that the tgt_layout is less sharded then src_layout. |
63 | for (int i = 0; i < src_layout.rank(); ++i) { |
64 | if (src_layout.sharding_spec(i) != tgt_layout.sharding_spec(i) && |
65 | Layout::IsShardedDimension(tgt_layout.sharding_spec(i))) { |
66 | return errors::InvalidArgument("source layout (" , src_layout.ToString(), |
67 | ") for all gather is not less sharded " |
68 | "than the target layout (" , |
69 | tgt_layout.ToString()); |
70 | } |
71 | } |
72 | |
73 | // For convenience, operate on explicit input shapes. This isn't necessary, |
74 | // as we could instead generate operations on top of the dynamic shape. |
75 | const mlir::TensorType input_type = |
76 | input.getType().dyn_cast<mlir::TensorType>(); |
77 | if (!input_type) { |
78 | return errors::Internal( |
79 | llvm::formatv( |
80 | "Cannot cast input_type : {0} to TensorType. Shape must be " |
81 | " statically known before emitting AllGather. This should not " |
82 | "happen as we already cast it when getting its shape." , |
83 | input.getType()) |
84 | .str()); |
85 | } |
86 | |
87 | TF_ASSIGN_OR_RETURN(mlir::TensorType global_type, |
88 | GlobalTypeFromLocalType(src_layout, input_type)); |
89 | TF_ASSIGN_OR_RETURN(mlir::TensorType output_type, |
90 | LocalTypeFromGlobalType(tgt_layout, global_type)); |
91 | |
92 | mlir::Location loc = DT_LOC2(input.getLoc(), "DTensorAllGatherOp" ); |
93 | mlir::TF::DTensorAllGatherOp all_gather = |
94 | builder.create<mlir::TF::DTensorAllGatherOp>( |
95 | loc, output_type, input, |
96 | mlir::dtensor::LayoutAttr::get(builder.getContext(), src_layout), |
97 | mlir::dtensor::LayoutAttr::get(builder.getContext(), tgt_layout)); |
98 | SetSingleLayoutOnOp(all_gather, tgt_layout); |
99 | |
100 | if (newly_created_ops != nullptr) newly_created_ops->insert(all_gather); |
101 | |
102 | return all_gather.output(); |
103 | } |
104 | |
105 | StatusOr<const mlir::Value> EmitAllScatter( |
106 | mlir::OpBuilder& builder, const mlir::Value& original_value, |
107 | const Layout& original_layout, const Layout& desired_layout, |
108 | llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops) { |
109 | if (original_layout.IsEquivalent(desired_layout)) return original_value; |
110 | |
111 | // Have an early return if desired layout is not more sharded then the |
112 | // original_layout. |
113 | assert(original_layout.rank() == desired_layout.rank()); |
114 | for (int i = 0; i < original_layout.rank(); ++i) { |
115 | if (original_layout.sharding_spec(i) != desired_layout.sharding_spec(i) && |
116 | Layout::IsShardedDimension(original_layout.sharding_spec(i))) { |
117 | return errors::InvalidArgument( |
118 | "EmitAllScatter was passed a desired_layout " , |
119 | desired_layout.ToString(), |
120 | " which was not more sharded than the original_layout " , |
121 | original_layout.ToString()); |
122 | } |
123 | } |
124 | |
125 | const mlir::TensorType input_type = |
126 | original_value.getType().dyn_cast<mlir::TensorType>(); |
127 | if (!input_type) |
128 | return errors::InvalidArgument( |
129 | "input to EmitAllScatter does not have a TensorType" ); |
130 | |
131 | TF_ASSIGN_OR_RETURN(const mlir::TensorType global_type, |
132 | GlobalTypeFromLocalType(original_layout, input_type)); |
133 | TF_ASSIGN_OR_RETURN(const mlir::TensorType output_type, |
134 | LocalTypeFromGlobalType(desired_layout, global_type)); |
135 | |
136 | mlir::Location loc = DT_LOC2(original_value.getLoc(), "DTensorAllScatterOp" ); |
137 | mlir::TF::DTensorAllScatterOp all_scatter = |
138 | builder.create<mlir::TF::DTensorAllScatterOp>( |
139 | loc, output_type, original_value, |
140 | mlir::dtensor::LayoutAttr::get(builder.getContext(), original_layout), |
141 | mlir::dtensor::LayoutAttr::get(builder.getContext(), desired_layout)); |
142 | SetSingleLayoutOnOp(all_scatter, desired_layout); |
143 | |
144 | if (newly_created_ops != nullptr) newly_created_ops->insert(all_scatter); |
145 | |
146 | return all_scatter.output(); |
147 | } |
148 | |
149 | StatusOr<mlir::Value> EmitDenseToSparseToDense( |
150 | mlir::OpBuilder& builder, mlir::Value input, |
151 | llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops) { |
152 | // First create a Dense To Sparse Op. Since there is no DenseToSparseOp, |
153 | // we do it manually by creating the indices, values, and shapes tensor |
154 | // through various ops. |
155 | // |
156 | // indices tensor = tf.where(tf.not_equal(input, tf.zeros_like(tensor))) |
157 | // values tensor = tf.gather_nd(input, indices) |
158 | // shape tensor = tf.shape(input) |
159 | mlir::TF::ZerosLikeOp zeros_like = |
160 | builder.create<mlir::TF::ZerosLikeOp>(input.getLoc(), input); |
161 | mlir::TF::NotEqualOp not_equal = builder.create<mlir::TF::NotEqualOp>( |
162 | zeros_like.getLoc(), input, zeros_like, builder.getBoolAttr(false)); |
163 | |
164 | mlir::TF::WhereOp indices = builder.create<mlir::TF::WhereOp>( |
165 | not_equal.getLoc(), |
166 | mlir::RankedTensorType::get(GetShapeOfValue(not_equal).value(), |
167 | builder.getI64Type()), |
168 | not_equal); |
169 | |
170 | mlir::TF::GatherNdOp values = builder.create<mlir::TF::GatherNdOp>( |
171 | input.getLoc(), input.getType(), input, indices); |
172 | auto shape = builder.create<mlir::TF::ShapeOp>(input.getLoc(), input, |
173 | builder.getBoolAttr(false)); |
174 | |
175 | // Emit a SparseToDenseOp and replace the SparseTensor with the result of |
176 | // this new op. |
177 | auto zero_scalar = CreateZeroScalarConst( |
178 | builder, input.getLoc(), |
179 | input.getType().cast<mlir::TensorType>().getElementType()); |
180 | if (!zero_scalar.has_value()) |
181 | return errors::Internal("Failure in creating a zero scalar const" ); |
182 | |
183 | auto dense = builder.create<mlir::TF::SparseToDenseOp>( |
184 | input.getLoc(), input.getType(), |
185 | mlir::ValueRange({indices, shape, values, zero_scalar.value()})); |
186 | |
187 | if (newly_created_ops != nullptr) { |
188 | for (auto new_op : {dense.getOperation(), shape.getOperation(), |
189 | values.getOperation(), indices.getOperation(), |
190 | not_equal.getOperation(), zeros_like.getOperation()}) { |
191 | newly_created_ops->insert(new_op); |
192 | } |
193 | } |
194 | |
195 | return dense.getResult(); |
196 | } |
197 | |
198 | StatusOr<mlir::Value> EmitRelayout( |
199 | mlir::Value input, const dtensor::Layout& src_layout, |
200 | const dtensor::Layout& tgt_layout, |
201 | llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops) { |
202 | // EmitRelayout is performed by doing a split, an AllGather and another split. |
203 | // The first split oppertunistically splits input tensor dimension i on mesh |
204 | // mesh axis x if: |
205 | // 1. tgt_layout contains x at position i |
206 | // 2. src_layout is unsharded at position i. |
207 | // 3. src_layout does not contain mesh axis x. |
208 | // This produces intermediate layout 1. |
209 | // Next an all concat is performed on any axis in the intermediate layout 1 |
210 | // that does not agree with the sharding on the output axis. |
211 | // This produces intermediate layout 2. |
212 | // A split is performed from intermediate layout 2 to the tgt layout. |
213 | |
214 | if (src_layout.IsEquivalent(tgt_layout)) return input; |
215 | |
216 | // Save whether the input is from a SparseToDenseOp. If it is, then we will |
217 | // emit a DenseToSparse and a SparseToDense op. |
218 | bool is_sparse = IsSparseValue(input); |
219 | if (!input.getType().isa<mlir::RankedTensorType>()) |
220 | return errors::Internal( |
221 | "attempting to relayout a tensor that does not " |
222 | "have a rank" ); |
223 | |
224 | if (src_layout.mesh() != tgt_layout.mesh()) { |
225 | return errors::Internal("Attempted to relayout to a different mesh." ); |
226 | } |
227 | if (src_layout.rank() != tgt_layout.rank()) { |
228 | return errors::Internal( |
229 | "Attempted to relayout to a different global shape." ); |
230 | } |
231 | |
232 | absl::flat_hash_set<std::string> src_sharding_dims; |
233 | for (int i = 0; i < src_layout.rank(); ++i) |
234 | src_sharding_dims.emplace(src_layout.sharding_spec(i)); |
235 | |
236 | std::vector<ShardingSpec> intermediate_specs_1(src_layout.rank()); |
237 | for (int i = 0; i < src_layout.rank(); ++i) { |
238 | if (Layout::IsShardedSpec(tgt_layout.dim(i)) && |
239 | !Layout::IsShardedSpec(src_layout.dim(i)) && |
240 | !src_sharding_dims.contains(tgt_layout.sharding_spec(i))) |
241 | intermediate_specs_1[i] = tgt_layout.dim(i); |
242 | else |
243 | intermediate_specs_1[i] = src_layout.dim(i); |
244 | } |
245 | TF_ASSIGN_OR_RETURN( |
246 | Layout intermediate_layout_1, |
247 | Layout::GetLayout(intermediate_specs_1, src_layout.mesh())); |
248 | |
249 | mlir::OpBuilder builder(input.getContext()); |
250 | TF_RETURN_IF_ERROR(SetBuilderInsertionAfterValue(input, builder)); |
251 | |
252 | llvm::SmallPtrSet<mlir::Operation*, 4> local_newly_created_ops; |
253 | TF_ASSIGN_OR_RETURN(mlir::Value split_result, |
254 | EmitAllScatter(builder, input, src_layout, |
255 | intermediate_layout_1, newly_created_ops)); |
256 | |
257 | std::vector<ShardingSpec> intermediate_specs_2(src_layout.rank()); |
258 | for (int i = 0; i < src_layout.rank(); ++i) { |
259 | if (Layout::IsShardedSpec(intermediate_specs_1[i]) && |
260 | intermediate_specs_1[i].sharding_spec() != tgt_layout.sharding_spec(i)) |
261 | intermediate_specs_2[i].set_sharding_spec(Layout::kUnshardedDim); |
262 | else |
263 | intermediate_specs_2[i] = intermediate_specs_1[i]; |
264 | } |
265 | TF_ASSIGN_OR_RETURN( |
266 | Layout intermediate_layout_2, |
267 | Layout::GetLayout(intermediate_specs_2, src_layout.mesh())); |
268 | |
269 | TF_ASSIGN_OR_RETURN( |
270 | mlir::Value concat_result, |
271 | EmitAllGather(builder, split_result, intermediate_layout_1, |
272 | intermediate_layout_2, newly_created_ops)); |
273 | |
274 | auto all_scatter = |
275 | EmitAllScatter(builder, concat_result, intermediate_layout_2, tgt_layout, |
276 | newly_created_ops); |
277 | |
278 | if (!is_sparse) return all_scatter; |
279 | if (!all_scatter.ok()) return all_scatter; |
280 | return EmitDenseToSparseToDense(builder, all_scatter.value(), |
281 | newly_created_ops); |
282 | } |
283 | |
284 | StatusOr<mlir::Operation*> EmitAllReduce( |
285 | mlir::OpBuilder& builder, const dtensor::Layout& output_layout, |
286 | const absl::flat_hash_set<std::string>& reduced_dims, |
287 | mlir::Operation* input, absl::string_view reduce_op) { |
288 | TF_ASSIGN_OR_RETURN(auto partitions, GetAllReducePartitionsFromReducedDims( |
289 | output_layout, reduced_dims)); |
290 | const int32 num_partitions = partitions.size(); |
291 | |
292 | // If every device lives in its own partition, we don't need to emit a |
293 | // collective. |
294 | if (num_partitions == output_layout.num_devices()) { |
295 | return InferSPMDExpandedLocalShape(input); |
296 | } |
297 | |
298 | // Construct a flattened list of reduce partitions. This will be converted |
299 | // into a 2-D const tensor for the DTensorAllReduce op. |
300 | std::vector<int32> partitions_flat; |
301 | for (auto& p : partitions) { |
302 | if (p.second.size() != partitions.begin()->second.size()) { |
303 | return errors::InvalidArgument( |
304 | "AllReduce partitions had different sizes -- this is not supported " |
305 | "in MLIR." ); |
306 | } |
307 | partitions_flat.insert(partitions_flat.end(), p.second.begin(), |
308 | p.second.end()); |
309 | } |
310 | |
311 | int32 partition_size = partitions.begin()->second.size(); |
312 | auto shaped_type = mlir::RankedTensorType::get( |
313 | {num_partitions, partition_size}, |
314 | mlir::IntegerType::get(builder.getContext(), 32)); |
315 | auto group_assignment = |
316 | mlir::DenseIntElementsAttr::get(shaped_type, partitions_flat); |
317 | |
318 | TF_ASSIGN_OR_RETURN(std::string device_type, |
319 | DeviceTypeFromMesh(output_layout.mesh())); |
320 | |
321 | mlir::Location loc = DT_LOC2(input->getLoc(), "DTensorAllReduceOp" ); |
322 | auto all_reduce = builder.create<mlir::TF::DTensorAllReduceOp>( |
323 | loc, input->getResultTypes()[0], input->getOpResult(0), |
324 | builder.create<mlir::TF::ConstOp>(DT_LOC2(loc, "group_assignment" ), |
325 | group_assignment), |
326 | builder.getStringAttr(std::string(reduce_op)), |
327 | builder.getStringAttr(device_type)); |
328 | SetSingleLayoutOnOp(all_reduce, output_layout); |
329 | input->getOpResult(0).replaceAllUsesExcept( |
330 | all_reduce.getResult(), |
331 | llvm::SmallPtrSet<mlir::Operation*, 1>{all_reduce}); |
332 | return all_reduce.getOperation(); |
333 | } |
334 | |
335 | namespace { |
336 | |
337 | // Returns a offset multiplier to calculate device id / mesh coordinate. |
338 | int GetMeshDimensionOffsetWithNeighbor(const Mesh& mesh, |
339 | const std::string& mesh_dim) { |
340 | const int index = mesh.GetMeshDimIndexWithName(mesh_dim); |
341 | const std::vector<int64_t> mesh_dim_sizes = mesh.dim_sizes(); |
342 | int offset = 1; |
343 | for (int i = index + 1; i < mesh_dim_sizes.size(); ++i) { |
344 | offset = offset * mesh_dim_sizes[i]; |
345 | } |
346 | return offset; |
347 | } |
348 | |
349 | // Returns a mesh coordinate of mesh index with `mesh_dim_name` given |
350 | // `device_id`. |
351 | StatusOr<int> GetMeshCoordinateIndex(const Mesh& mesh, |
352 | const std::string& mesh_dim_name, |
353 | int device_id) { |
354 | const int offset = GetMeshDimensionOffsetWithNeighbor(mesh, mesh_dim_name); |
355 | TF_ASSIGN_OR_RETURN(int64_t mesh_dim_size, mesh.dim_size(mesh_dim_name)); |
356 | |
357 | return (device_id / offset) % mesh_dim_size; |
358 | } |
359 | |
360 | // Returns a 2D tensor array of size [N, 2] that specifies source target pair |
361 | // to be used for halo exchange. |
362 | StatusOr<mlir::Value> CreateConstSrcTargetPair(const Mesh& mesh, |
363 | const std::string& mesh_dim_name, |
364 | bool shift_left, |
365 | mlir::Location location, |
366 | mlir::OpBuilder& builder) { |
367 | const int mesh_dim_index = mesh.GetMeshDimIndexWithName(mesh_dim_name); |
368 | const std::vector<MeshDimension> mesh_dimensions = mesh.dims(); |
369 | |
370 | llvm::SmallVector<int, 4> src_target_pair_flat; |
371 | src_target_pair_flat.reserve(mesh.local_device_ids().size() * 2); |
372 | for (const int local_device_id : mesh.local_device_ids()) { |
373 | // Calculate the mesh coordinate of the current local device id. |
374 | llvm::SmallVector<int, 4> mesh_coordinate_for_device_id; |
375 | |
376 | for (const MeshDimension& mesh_dim : mesh_dimensions) { |
377 | TF_ASSIGN_OR_RETURN( |
378 | const int coordinate, |
379 | GetMeshCoordinateIndex(mesh, mesh_dim.name, local_device_id)); |
380 | |
381 | mesh_coordinate_for_device_id.push_back(coordinate); |
382 | } |
383 | |
384 | // If mesh coordinate is on the left/right edge, then we conduct halo |
385 | // exchange with a processor which executes input block which represent |
386 | // `wrapped around` block. |
387 | const int mesh_coordinate = mesh_coordinate_for_device_id[mesh_dim_index]; |
388 | TF_ASSIGN_OR_RETURN(const int dim_size, mesh.dim_size(mesh_dim_name)); |
389 | |
390 | // For tensor requiring halo exchange, we use collective permute. |
391 | const int src_device_id = local_device_id; |
392 | int target_device_id = 0; |
393 | for (const auto& data : llvm::enumerate(mesh_dimensions)) { |
394 | const MeshDimension& mesh_dim = data.value(); |
395 | const int index = data.index(); |
396 | |
397 | int target_mesh_coordinate = 1; |
398 | if (mesh_dim.name == mesh_dim_name) { |
399 | target_mesh_coordinate = |
400 | shift_left ? mesh_coordinate - 1 : mesh_coordinate + 1; |
401 | |
402 | // For processors executing input tensor on the left/right edges, target |
403 | // processor is the processor that executes wrapped around input block. |
404 | if (target_mesh_coordinate < 0 || target_mesh_coordinate >= dim_size) |
405 | target_mesh_coordinate = |
406 | (target_mesh_coordinate + dim_size) % dim_size; |
407 | |
408 | } else { |
409 | target_mesh_coordinate = mesh_coordinate_for_device_id[index]; |
410 | } |
411 | |
412 | target_device_id += |
413 | target_mesh_coordinate * |
414 | GetMeshDimensionOffsetWithNeighbor(mesh, mesh_dim.name); |
415 | } |
416 | src_target_pair_flat.push_back(src_device_id); |
417 | src_target_pair_flat.push_back(target_device_id); |
418 | } |
419 | |
420 | const int num_pairs = src_target_pair_flat.size() / 2; |
421 | auto shaped_type = mlir::RankedTensorType::get( |
422 | {num_pairs, 2}, mlir::IntegerType::get(builder.getContext(), 32)); |
423 | |
424 | auto src_target_attr = |
425 | mlir::DenseIntElementsAttr::get(shaped_type, src_target_pair_flat); |
426 | mlir::Value src_target_pair_tensor = |
427 | builder.create<mlir::TF::ConstOp>(location, src_target_attr); |
428 | return src_target_pair_tensor; |
429 | } |
430 | |
431 | } // namespace |
432 | |
433 | StatusOr<mlir::Value> EmitHaloExchange(mlir::OpBuilder& builder, int halo_size, |
434 | const std::string& mesh_dim, |
435 | const Layout& layout, |
436 | mlir::Value mesh_coordinates, |
437 | mlir::tf_device::ClusterOp cluster, |
438 | mlir::Location location, |
439 | mlir::Value tensor) { |
440 | const Mesh& mesh = layout.mesh(); |
441 | |
442 | // Check mesh dimension requirements for halo exchange. |
443 | if (!mesh.IsMeshDim(mesh_dim)) |
444 | return errors::InvalidArgument( |
445 | "Requested halo exchange on unknown mesh dim" ); |
446 | |
447 | // TODO(hongjunchoi): Add support fof halo exchange for GPU/CPU. |
448 | if (!mesh.is_tpu_mesh()) |
449 | return errors::InvalidArgument("Halo exchange is only supported on TPU." ); |
450 | |
451 | auto input_tensor_type = tensor.getType().dyn_cast<mlir::RankedTensorType>(); |
452 | if (!input_tensor_type || !input_tensor_type.hasStaticShape()) |
453 | return errors::InvalidArgument( |
454 | "Static shape of input tensor must be known for halo exchange." ); |
455 | |
456 | llvm::ArrayRef<int64_t> input_tensor_shape = input_tensor_type.getShape(); |
457 | const std::vector<std::string> sharding_specs = layout.sharding_spec_strs(); |
458 | const int split_dim_index = std::distance( |
459 | sharding_specs.begin(), llvm::find(sharding_specs, mesh_dim)); |
460 | |
461 | if (input_tensor_shape[split_dim_index] < halo_size) |
462 | return errors::InvalidArgument( |
463 | "For halo exhange, input shard tensor size of each processor must be " |
464 | "greater than halo size" ); |
465 | |
466 | TF_ASSIGN_OR_RETURN(const int mesh_dim_index, mesh.idx_for_dim(mesh_dim)); |
467 | |
468 | TF_ASSIGN_OR_RETURN(mlir::Value scalar_mesh_coordinate, |
469 | SelectScalarValueFromArray(builder, mesh_dim_index, |
470 | location, mesh_coordinates)); |
471 | |
472 | llvm::SmallVector<int64_t, 4> halo_exchange_tensor_shape; |
473 | for (const auto& size_and_index : llvm::enumerate(input_tensor_shape)) { |
474 | const int index = size_and_index.index(); |
475 | const int size = size_and_index.value(); |
476 | halo_exchange_tensor_shape.push_back(index == split_dim_index ? halo_size |
477 | : size); |
478 | } |
479 | |
480 | // Find the halo tensor value to pad on the `left` side. Note that halo |
481 | // exchange can happen on top/bottom/left/right sides of a spatially |
482 | // partitioned tensor. However, we use `left`/`right` as the |
483 | // direction is implicit based on mesh dimension. |
484 | // |
485 | // For example, if mesh dimension splits the input tensor by its height |
486 | // dimension, then `left` actually means tensor to pad on the top side. |
487 | mlir::Value is_on_left_edge = builder.create<mlir::TF::EqualOp>( |
488 | location, CreateIntScalarConst(0, builder, location, /*use_int64=*/false), |
489 | scalar_mesh_coordinate, builder.getBoolAttr(true)); |
490 | |
491 | TF_ASSIGN_OR_RETURN(const int mesh_dim_size, mesh.dim_size(mesh_dim)); |
492 | mlir::Value is_on_right_edge = builder.create<mlir::TF::EqualOp>( |
493 | location, |
494 | CreateIntScalarConst(mesh_dim_size - 1, builder, location, |
495 | /*use_int64=*/false), |
496 | scalar_mesh_coordinate, builder.getBoolAttr(true)); |
497 | |
498 | // Create zero ghost tensor to pad on left side. |
499 | mlir::RankedTensorType halo_tensor_type = mlir::RankedTensorType::get( |
500 | halo_exchange_tensor_shape, input_tensor_type.getElementType()); |
501 | auto halo_type = mlir::RankedTensorType::get( |
502 | halo_tensor_type.getShape(), input_tensor_type.getElementType()); |
503 | |
504 | mlir::Attribute const_attr; |
505 | if (halo_type.getElementType().isIntOrIndex()) { |
506 | const_attr = |
507 | mlir::DenseIntElementsAttr::get(halo_type, llvm::SmallVector<int>{0}); |
508 | } else { |
509 | const_attr = |
510 | mlir::DenseFPElementsAttr::get(halo_type, llvm::SmallVector<float>{0}); |
511 | } |
512 | |
513 | mlir::Value ghost_tensor_left = |
514 | builder.create<mlir::TF::ConstOp>(location, const_attr).getResult(); |
515 | |
516 | // Get the right side slice of the input tensor to pad on left side. |
517 | llvm::SmallVector<int64_t, 4> begin_left(layout.rank(), 0); |
518 | begin_left[split_dim_index] = input_tensor_shape[split_dim_index] - halo_size; |
519 | mlir::Value begin_tensor_left = |
520 | ops_util::GetR1Const(begin_left, builder, location); |
521 | |
522 | llvm::SmallVector<int64_t, 4> size(input_tensor_shape.begin(), |
523 | input_tensor_shape.end()); |
524 | size[split_dim_index] = halo_size; |
525 | |
526 | mlir::Value size_tensor_left = ops_util::GetR1Const(size, builder, location); |
527 | mlir::Value sliced_tensor_left = builder.create<mlir::TF::SliceOp>( |
528 | location, halo_type, tensor, begin_tensor_left, size_tensor_left); |
529 | |
530 | mlir::Value halo_tensor_left = builder.create<mlir::TF::SelectV2Op>( |
531 | location, is_on_right_edge, ghost_tensor_left, sliced_tensor_left); |
532 | |
533 | // Invoke collective permute to receive the tensor from neighboring processor. |
534 | // Halo slices from the left neighbor are received on each processor (they |
535 | // are shifted right). |
536 | TF_ASSIGN_OR_RETURN( |
537 | mlir::Value src_target_pair_left, |
538 | CreateConstSrcTargetPair(mesh, mesh_dim, /*shift_left=*/false, location, |
539 | builder)); |
540 | |
541 | mlir::Value left_concat_value = builder.create<mlir::TF::CollectivePermuteOp>( |
542 | location, sliced_tensor_left.getType(), halo_tensor_left, |
543 | src_target_pair_left); |
544 | |
545 | mlir::Value ghost_tensor_right = |
546 | builder.create<mlir::TF::ConstOp>(location, const_attr).getResult(); |
547 | |
548 | // Else, values to pad is tensor from different processor. We use collective |
549 | // permute to access tensor slice from another device. |
550 | // Get the left side slice of the input tensor. |
551 | llvm::SmallVector<int64_t, 4> begin_right(layout.rank(), 0); |
552 | mlir::Value begin_tensor_right = |
553 | ops_util::GetR1Const(begin_right, builder, location); |
554 | mlir::Value size_tensor_right = ops_util::GetR1Const(size, builder, location); |
555 | mlir::Value sliced_tensor_right = builder.create<mlir::TF::SliceOp>( |
556 | location, halo_type, tensor, begin_tensor_right, size_tensor_right); |
557 | |
558 | // Find the halo tensor value to pad on the `right` side. |
559 | // If input block is on the right edge, we use zero ghost tensor instead. |
560 | mlir::Value halo_tensor_right = builder.create<mlir::TF::SelectV2Op>( |
561 | location, is_on_left_edge, ghost_tensor_right, sliced_tensor_right); |
562 | |
563 | // Invoke collective permute to receive the tensor from neighboring processor. |
564 | // Halo slices from the right neighbor are received on each processor (they |
565 | // are shifted left). |
566 | TF_ASSIGN_OR_RETURN( |
567 | mlir::Value src_target_pair_right, |
568 | CreateConstSrcTargetPair(mesh, mesh_dim, /*shift_left=*/true, location, |
569 | builder)); |
570 | mlir::Value right_concat_value = |
571 | builder.create<mlir::TF::CollectivePermuteOp>( |
572 | location, sliced_tensor_right.getType(), halo_tensor_right, |
573 | src_target_pair_right); |
574 | |
575 | // Final halo exchanged value is concatenated value of left_concat_value, |
576 | // tensor, and right_concat_value in the mesh_dimension. |
577 | llvm::SmallVector<int64_t, 4> final_shape(input_tensor_shape.begin(), |
578 | input_tensor_shape.end()); |
579 | final_shape[split_dim_index] = final_shape[split_dim_index] + 2 * halo_size; |
580 | |
581 | auto final_type = mlir::RankedTensorType::get( |
582 | final_shape, input_tensor_type.getElementType()); |
583 | mlir::Value concat_axis = |
584 | CreateIntScalarConst(split_dim_index, builder, location); |
585 | mlir::Value final_value = builder.create<mlir::TF::ConcatV2Op>( |
586 | location, final_type, |
587 | llvm::SmallVector<mlir::Value, 4>{left_concat_value, tensor, |
588 | right_concat_value}, |
589 | concat_axis); |
590 | |
591 | return final_value; |
592 | } |
593 | |
594 | } // namespace dtensor |
595 | } // namespace tensorflow |
596 | |