1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/mlir/collectives.h"
17
18#include <cstdint>
19#include <string>
20
21#include "absl/container/flat_hash_set.h"
22#include "absl/strings/string_view.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/Support/FormatVariadic.h"
25#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
26#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27#include "mlir/IR/Value.h" // from @llvm-project
28#include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
29#include "tensorflow/core/platform/errors.h"
30#include "tensorflow/dtensor/cc/dstatus.h"
31#include "tensorflow/dtensor/cc/tensor_layout.h"
32#include "tensorflow/dtensor/mlir/collectives_common.h"
33#include "tensorflow/dtensor/mlir/dtensor_location.h"
34#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
35#include "tensorflow/dtensor/mlir/layout_parsing.h"
36#include "tensorflow/dtensor/mlir/shape_utils.h"
37#include "tensorflow/dtensor/mlir/sparse_expander_common.h"
38#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
39#include "tensorflow/dtensor/mlir/value_utils.h"
40
41namespace tensorflow {
42namespace dtensor {
43
44namespace {
45
46namespace ops_util = ::mlir::TF::collection_ops_util;
47
48} // namespace
49
50StatusOr<mlir::Value> EmitAllGather(
51 mlir::OpBuilder& builder, mlir::Value input,
52 const dtensor::Layout& src_layout, const dtensor::Layout& tgt_layout,
53 llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops) {
54 if (src_layout.IsEquivalent(tgt_layout)) return input;
55
56 if (src_layout.rank() != tgt_layout.rank()) {
57 return errors::InvalidArgument(
58 "Expected source and target layout to have the same rank, got ",
59 src_layout.rank(), " vs ", tgt_layout.rank());
60 }
61
62 // Check that the tgt_layout is less sharded then src_layout.
63 for (int i = 0; i < src_layout.rank(); ++i) {
64 if (src_layout.sharding_spec(i) != tgt_layout.sharding_spec(i) &&
65 Layout::IsShardedDimension(tgt_layout.sharding_spec(i))) {
66 return errors::InvalidArgument("source layout (", src_layout.ToString(),
67 ") for all gather is not less sharded "
68 "than the target layout (",
69 tgt_layout.ToString());
70 }
71 }
72
73 // For convenience, operate on explicit input shapes. This isn't necessary,
74 // as we could instead generate operations on top of the dynamic shape.
75 const mlir::TensorType input_type =
76 input.getType().dyn_cast<mlir::TensorType>();
77 if (!input_type) {
78 return errors::Internal(
79 llvm::formatv(
80 "Cannot cast input_type : {0} to TensorType. Shape must be "
81 " statically known before emitting AllGather. This should not "
82 "happen as we already cast it when getting its shape.",
83 input.getType())
84 .str());
85 }
86
87 TF_ASSIGN_OR_RETURN(mlir::TensorType global_type,
88 GlobalTypeFromLocalType(src_layout, input_type));
89 TF_ASSIGN_OR_RETURN(mlir::TensorType output_type,
90 LocalTypeFromGlobalType(tgt_layout, global_type));
91
92 mlir::Location loc = DT_LOC2(input.getLoc(), "DTensorAllGatherOp");
93 mlir::TF::DTensorAllGatherOp all_gather =
94 builder.create<mlir::TF::DTensorAllGatherOp>(
95 loc, output_type, input,
96 mlir::dtensor::LayoutAttr::get(builder.getContext(), src_layout),
97 mlir::dtensor::LayoutAttr::get(builder.getContext(), tgt_layout));
98 SetSingleLayoutOnOp(all_gather, tgt_layout);
99
100 if (newly_created_ops != nullptr) newly_created_ops->insert(all_gather);
101
102 return all_gather.output();
103}
104
105StatusOr<const mlir::Value> EmitAllScatter(
106 mlir::OpBuilder& builder, const mlir::Value& original_value,
107 const Layout& original_layout, const Layout& desired_layout,
108 llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops) {
109 if (original_layout.IsEquivalent(desired_layout)) return original_value;
110
111 // Have an early return if desired layout is not more sharded then the
112 // original_layout.
113 assert(original_layout.rank() == desired_layout.rank());
114 for (int i = 0; i < original_layout.rank(); ++i) {
115 if (original_layout.sharding_spec(i) != desired_layout.sharding_spec(i) &&
116 Layout::IsShardedDimension(original_layout.sharding_spec(i))) {
117 return errors::InvalidArgument(
118 "EmitAllScatter was passed a desired_layout ",
119 desired_layout.ToString(),
120 " which was not more sharded than the original_layout ",
121 original_layout.ToString());
122 }
123 }
124
125 const mlir::TensorType input_type =
126 original_value.getType().dyn_cast<mlir::TensorType>();
127 if (!input_type)
128 return errors::InvalidArgument(
129 "input to EmitAllScatter does not have a TensorType");
130
131 TF_ASSIGN_OR_RETURN(const mlir::TensorType global_type,
132 GlobalTypeFromLocalType(original_layout, input_type));
133 TF_ASSIGN_OR_RETURN(const mlir::TensorType output_type,
134 LocalTypeFromGlobalType(desired_layout, global_type));
135
136 mlir::Location loc = DT_LOC2(original_value.getLoc(), "DTensorAllScatterOp");
137 mlir::TF::DTensorAllScatterOp all_scatter =
138 builder.create<mlir::TF::DTensorAllScatterOp>(
139 loc, output_type, original_value,
140 mlir::dtensor::LayoutAttr::get(builder.getContext(), original_layout),
141 mlir::dtensor::LayoutAttr::get(builder.getContext(), desired_layout));
142 SetSingleLayoutOnOp(all_scatter, desired_layout);
143
144 if (newly_created_ops != nullptr) newly_created_ops->insert(all_scatter);
145
146 return all_scatter.output();
147}
148
149StatusOr<mlir::Value> EmitDenseToSparseToDense(
150 mlir::OpBuilder& builder, mlir::Value input,
151 llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops) {
152 // First create a Dense To Sparse Op. Since there is no DenseToSparseOp,
153 // we do it manually by creating the indices, values, and shapes tensor
154 // through various ops.
155 //
156 // indices tensor = tf.where(tf.not_equal(input, tf.zeros_like(tensor)))
157 // values tensor = tf.gather_nd(input, indices)
158 // shape tensor = tf.shape(input)
159 mlir::TF::ZerosLikeOp zeros_like =
160 builder.create<mlir::TF::ZerosLikeOp>(input.getLoc(), input);
161 mlir::TF::NotEqualOp not_equal = builder.create<mlir::TF::NotEqualOp>(
162 zeros_like.getLoc(), input, zeros_like, builder.getBoolAttr(false));
163
164 mlir::TF::WhereOp indices = builder.create<mlir::TF::WhereOp>(
165 not_equal.getLoc(),
166 mlir::RankedTensorType::get(GetShapeOfValue(not_equal).value(),
167 builder.getI64Type()),
168 not_equal);
169
170 mlir::TF::GatherNdOp values = builder.create<mlir::TF::GatherNdOp>(
171 input.getLoc(), input.getType(), input, indices);
172 auto shape = builder.create<mlir::TF::ShapeOp>(input.getLoc(), input,
173 builder.getBoolAttr(false));
174
175 // Emit a SparseToDenseOp and replace the SparseTensor with the result of
176 // this new op.
177 auto zero_scalar = CreateZeroScalarConst(
178 builder, input.getLoc(),
179 input.getType().cast<mlir::TensorType>().getElementType());
180 if (!zero_scalar.has_value())
181 return errors::Internal("Failure in creating a zero scalar const");
182
183 auto dense = builder.create<mlir::TF::SparseToDenseOp>(
184 input.getLoc(), input.getType(),
185 mlir::ValueRange({indices, shape, values, zero_scalar.value()}));
186
187 if (newly_created_ops != nullptr) {
188 for (auto new_op : {dense.getOperation(), shape.getOperation(),
189 values.getOperation(), indices.getOperation(),
190 not_equal.getOperation(), zeros_like.getOperation()}) {
191 newly_created_ops->insert(new_op);
192 }
193 }
194
195 return dense.getResult();
196}
197
198StatusOr<mlir::Value> EmitRelayout(
199 mlir::Value input, const dtensor::Layout& src_layout,
200 const dtensor::Layout& tgt_layout,
201 llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops) {
202 // EmitRelayout is performed by doing a split, an AllGather and another split.
203 // The first split oppertunistically splits input tensor dimension i on mesh
204 // mesh axis x if:
205 // 1. tgt_layout contains x at position i
206 // 2. src_layout is unsharded at position i.
207 // 3. src_layout does not contain mesh axis x.
208 // This produces intermediate layout 1.
209 // Next an all concat is performed on any axis in the intermediate layout 1
210 // that does not agree with the sharding on the output axis.
211 // This produces intermediate layout 2.
212 // A split is performed from intermediate layout 2 to the tgt layout.
213
214 if (src_layout.IsEquivalent(tgt_layout)) return input;
215
216 // Save whether the input is from a SparseToDenseOp. If it is, then we will
217 // emit a DenseToSparse and a SparseToDense op.
218 bool is_sparse = IsSparseValue(input);
219 if (!input.getType().isa<mlir::RankedTensorType>())
220 return errors::Internal(
221 "attempting to relayout a tensor that does not "
222 "have a rank");
223
224 if (src_layout.mesh() != tgt_layout.mesh()) {
225 return errors::Internal("Attempted to relayout to a different mesh.");
226 }
227 if (src_layout.rank() != tgt_layout.rank()) {
228 return errors::Internal(
229 "Attempted to relayout to a different global shape.");
230 }
231
232 absl::flat_hash_set<std::string> src_sharding_dims;
233 for (int i = 0; i < src_layout.rank(); ++i)
234 src_sharding_dims.emplace(src_layout.sharding_spec(i));
235
236 std::vector<ShardingSpec> intermediate_specs_1(src_layout.rank());
237 for (int i = 0; i < src_layout.rank(); ++i) {
238 if (Layout::IsShardedSpec(tgt_layout.dim(i)) &&
239 !Layout::IsShardedSpec(src_layout.dim(i)) &&
240 !src_sharding_dims.contains(tgt_layout.sharding_spec(i)))
241 intermediate_specs_1[i] = tgt_layout.dim(i);
242 else
243 intermediate_specs_1[i] = src_layout.dim(i);
244 }
245 TF_ASSIGN_OR_RETURN(
246 Layout intermediate_layout_1,
247 Layout::GetLayout(intermediate_specs_1, src_layout.mesh()));
248
249 mlir::OpBuilder builder(input.getContext());
250 TF_RETURN_IF_ERROR(SetBuilderInsertionAfterValue(input, builder));
251
252 llvm::SmallPtrSet<mlir::Operation*, 4> local_newly_created_ops;
253 TF_ASSIGN_OR_RETURN(mlir::Value split_result,
254 EmitAllScatter(builder, input, src_layout,
255 intermediate_layout_1, newly_created_ops));
256
257 std::vector<ShardingSpec> intermediate_specs_2(src_layout.rank());
258 for (int i = 0; i < src_layout.rank(); ++i) {
259 if (Layout::IsShardedSpec(intermediate_specs_1[i]) &&
260 intermediate_specs_1[i].sharding_spec() != tgt_layout.sharding_spec(i))
261 intermediate_specs_2[i].set_sharding_spec(Layout::kUnshardedDim);
262 else
263 intermediate_specs_2[i] = intermediate_specs_1[i];
264 }
265 TF_ASSIGN_OR_RETURN(
266 Layout intermediate_layout_2,
267 Layout::GetLayout(intermediate_specs_2, src_layout.mesh()));
268
269 TF_ASSIGN_OR_RETURN(
270 mlir::Value concat_result,
271 EmitAllGather(builder, split_result, intermediate_layout_1,
272 intermediate_layout_2, newly_created_ops));
273
274 auto all_scatter =
275 EmitAllScatter(builder, concat_result, intermediate_layout_2, tgt_layout,
276 newly_created_ops);
277
278 if (!is_sparse) return all_scatter;
279 if (!all_scatter.ok()) return all_scatter;
280 return EmitDenseToSparseToDense(builder, all_scatter.value(),
281 newly_created_ops);
282}
283
284StatusOr<mlir::Operation*> EmitAllReduce(
285 mlir::OpBuilder& builder, const dtensor::Layout& output_layout,
286 const absl::flat_hash_set<std::string>& reduced_dims,
287 mlir::Operation* input, absl::string_view reduce_op) {
288 TF_ASSIGN_OR_RETURN(auto partitions, GetAllReducePartitionsFromReducedDims(
289 output_layout, reduced_dims));
290 const int32 num_partitions = partitions.size();
291
292 // If every device lives in its own partition, we don't need to emit a
293 // collective.
294 if (num_partitions == output_layout.num_devices()) {
295 return InferSPMDExpandedLocalShape(input);
296 }
297
298 // Construct a flattened list of reduce partitions. This will be converted
299 // into a 2-D const tensor for the DTensorAllReduce op.
300 std::vector<int32> partitions_flat;
301 for (auto& p : partitions) {
302 if (p.second.size() != partitions.begin()->second.size()) {
303 return errors::InvalidArgument(
304 "AllReduce partitions had different sizes -- this is not supported "
305 "in MLIR.");
306 }
307 partitions_flat.insert(partitions_flat.end(), p.second.begin(),
308 p.second.end());
309 }
310
311 int32 partition_size = partitions.begin()->second.size();
312 auto shaped_type = mlir::RankedTensorType::get(
313 {num_partitions, partition_size},
314 mlir::IntegerType::get(builder.getContext(), 32));
315 auto group_assignment =
316 mlir::DenseIntElementsAttr::get(shaped_type, partitions_flat);
317
318 TF_ASSIGN_OR_RETURN(std::string device_type,
319 DeviceTypeFromMesh(output_layout.mesh()));
320
321 mlir::Location loc = DT_LOC2(input->getLoc(), "DTensorAllReduceOp");
322 auto all_reduce = builder.create<mlir::TF::DTensorAllReduceOp>(
323 loc, input->getResultTypes()[0], input->getOpResult(0),
324 builder.create<mlir::TF::ConstOp>(DT_LOC2(loc, "group_assignment"),
325 group_assignment),
326 builder.getStringAttr(std::string(reduce_op)),
327 builder.getStringAttr(device_type));
328 SetSingleLayoutOnOp(all_reduce, output_layout);
329 input->getOpResult(0).replaceAllUsesExcept(
330 all_reduce.getResult(),
331 llvm::SmallPtrSet<mlir::Operation*, 1>{all_reduce});
332 return all_reduce.getOperation();
333}
334
335namespace {
336
337// Returns a offset multiplier to calculate device id / mesh coordinate.
338int GetMeshDimensionOffsetWithNeighbor(const Mesh& mesh,
339 const std::string& mesh_dim) {
340 const int index = mesh.GetMeshDimIndexWithName(mesh_dim);
341 const std::vector<int64_t> mesh_dim_sizes = mesh.dim_sizes();
342 int offset = 1;
343 for (int i = index + 1; i < mesh_dim_sizes.size(); ++i) {
344 offset = offset * mesh_dim_sizes[i];
345 }
346 return offset;
347}
348
349// Returns a mesh coordinate of mesh index with `mesh_dim_name` given
350// `device_id`.
351StatusOr<int> GetMeshCoordinateIndex(const Mesh& mesh,
352 const std::string& mesh_dim_name,
353 int device_id) {
354 const int offset = GetMeshDimensionOffsetWithNeighbor(mesh, mesh_dim_name);
355 TF_ASSIGN_OR_RETURN(int64_t mesh_dim_size, mesh.dim_size(mesh_dim_name));
356
357 return (device_id / offset) % mesh_dim_size;
358}
359
360// Returns a 2D tensor array of size [N, 2] that specifies source target pair
361// to be used for halo exchange.
362StatusOr<mlir::Value> CreateConstSrcTargetPair(const Mesh& mesh,
363 const std::string& mesh_dim_name,
364 bool shift_left,
365 mlir::Location location,
366 mlir::OpBuilder& builder) {
367 const int mesh_dim_index = mesh.GetMeshDimIndexWithName(mesh_dim_name);
368 const std::vector<MeshDimension> mesh_dimensions = mesh.dims();
369
370 llvm::SmallVector<int, 4> src_target_pair_flat;
371 src_target_pair_flat.reserve(mesh.local_device_ids().size() * 2);
372 for (const int local_device_id : mesh.local_device_ids()) {
373 // Calculate the mesh coordinate of the current local device id.
374 llvm::SmallVector<int, 4> mesh_coordinate_for_device_id;
375
376 for (const MeshDimension& mesh_dim : mesh_dimensions) {
377 TF_ASSIGN_OR_RETURN(
378 const int coordinate,
379 GetMeshCoordinateIndex(mesh, mesh_dim.name, local_device_id));
380
381 mesh_coordinate_for_device_id.push_back(coordinate);
382 }
383
384 // If mesh coordinate is on the left/right edge, then we conduct halo
385 // exchange with a processor which executes input block which represent
386 // `wrapped around` block.
387 const int mesh_coordinate = mesh_coordinate_for_device_id[mesh_dim_index];
388 TF_ASSIGN_OR_RETURN(const int dim_size, mesh.dim_size(mesh_dim_name));
389
390 // For tensor requiring halo exchange, we use collective permute.
391 const int src_device_id = local_device_id;
392 int target_device_id = 0;
393 for (const auto& data : llvm::enumerate(mesh_dimensions)) {
394 const MeshDimension& mesh_dim = data.value();
395 const int index = data.index();
396
397 int target_mesh_coordinate = 1;
398 if (mesh_dim.name == mesh_dim_name) {
399 target_mesh_coordinate =
400 shift_left ? mesh_coordinate - 1 : mesh_coordinate + 1;
401
402 // For processors executing input tensor on the left/right edges, target
403 // processor is the processor that executes wrapped around input block.
404 if (target_mesh_coordinate < 0 || target_mesh_coordinate >= dim_size)
405 target_mesh_coordinate =
406 (target_mesh_coordinate + dim_size) % dim_size;
407
408 } else {
409 target_mesh_coordinate = mesh_coordinate_for_device_id[index];
410 }
411
412 target_device_id +=
413 target_mesh_coordinate *
414 GetMeshDimensionOffsetWithNeighbor(mesh, mesh_dim.name);
415 }
416 src_target_pair_flat.push_back(src_device_id);
417 src_target_pair_flat.push_back(target_device_id);
418 }
419
420 const int num_pairs = src_target_pair_flat.size() / 2;
421 auto shaped_type = mlir::RankedTensorType::get(
422 {num_pairs, 2}, mlir::IntegerType::get(builder.getContext(), 32));
423
424 auto src_target_attr =
425 mlir::DenseIntElementsAttr::get(shaped_type, src_target_pair_flat);
426 mlir::Value src_target_pair_tensor =
427 builder.create<mlir::TF::ConstOp>(location, src_target_attr);
428 return src_target_pair_tensor;
429}
430
431} // namespace
432
433StatusOr<mlir::Value> EmitHaloExchange(mlir::OpBuilder& builder, int halo_size,
434 const std::string& mesh_dim,
435 const Layout& layout,
436 mlir::Value mesh_coordinates,
437 mlir::tf_device::ClusterOp cluster,
438 mlir::Location location,
439 mlir::Value tensor) {
440 const Mesh& mesh = layout.mesh();
441
442 // Check mesh dimension requirements for halo exchange.
443 if (!mesh.IsMeshDim(mesh_dim))
444 return errors::InvalidArgument(
445 "Requested halo exchange on unknown mesh dim");
446
447 // TODO(hongjunchoi): Add support fof halo exchange for GPU/CPU.
448 if (!mesh.is_tpu_mesh())
449 return errors::InvalidArgument("Halo exchange is only supported on TPU.");
450
451 auto input_tensor_type = tensor.getType().dyn_cast<mlir::RankedTensorType>();
452 if (!input_tensor_type || !input_tensor_type.hasStaticShape())
453 return errors::InvalidArgument(
454 "Static shape of input tensor must be known for halo exchange.");
455
456 llvm::ArrayRef<int64_t> input_tensor_shape = input_tensor_type.getShape();
457 const std::vector<std::string> sharding_specs = layout.sharding_spec_strs();
458 const int split_dim_index = std::distance(
459 sharding_specs.begin(), llvm::find(sharding_specs, mesh_dim));
460
461 if (input_tensor_shape[split_dim_index] < halo_size)
462 return errors::InvalidArgument(
463 "For halo exhange, input shard tensor size of each processor must be "
464 "greater than halo size");
465
466 TF_ASSIGN_OR_RETURN(const int mesh_dim_index, mesh.idx_for_dim(mesh_dim));
467
468 TF_ASSIGN_OR_RETURN(mlir::Value scalar_mesh_coordinate,
469 SelectScalarValueFromArray(builder, mesh_dim_index,
470 location, mesh_coordinates));
471
472 llvm::SmallVector<int64_t, 4> halo_exchange_tensor_shape;
473 for (const auto& size_and_index : llvm::enumerate(input_tensor_shape)) {
474 const int index = size_and_index.index();
475 const int size = size_and_index.value();
476 halo_exchange_tensor_shape.push_back(index == split_dim_index ? halo_size
477 : size);
478 }
479
480 // Find the halo tensor value to pad on the `left` side. Note that halo
481 // exchange can happen on top/bottom/left/right sides of a spatially
482 // partitioned tensor. However, we use `left`/`right` as the
483 // direction is implicit based on mesh dimension.
484 //
485 // For example, if mesh dimension splits the input tensor by its height
486 // dimension, then `left` actually means tensor to pad on the top side.
487 mlir::Value is_on_left_edge = builder.create<mlir::TF::EqualOp>(
488 location, CreateIntScalarConst(0, builder, location, /*use_int64=*/false),
489 scalar_mesh_coordinate, builder.getBoolAttr(true));
490
491 TF_ASSIGN_OR_RETURN(const int mesh_dim_size, mesh.dim_size(mesh_dim));
492 mlir::Value is_on_right_edge = builder.create<mlir::TF::EqualOp>(
493 location,
494 CreateIntScalarConst(mesh_dim_size - 1, builder, location,
495 /*use_int64=*/false),
496 scalar_mesh_coordinate, builder.getBoolAttr(true));
497
498 // Create zero ghost tensor to pad on left side.
499 mlir::RankedTensorType halo_tensor_type = mlir::RankedTensorType::get(
500 halo_exchange_tensor_shape, input_tensor_type.getElementType());
501 auto halo_type = mlir::RankedTensorType::get(
502 halo_tensor_type.getShape(), input_tensor_type.getElementType());
503
504 mlir::Attribute const_attr;
505 if (halo_type.getElementType().isIntOrIndex()) {
506 const_attr =
507 mlir::DenseIntElementsAttr::get(halo_type, llvm::SmallVector<int>{0});
508 } else {
509 const_attr =
510 mlir::DenseFPElementsAttr::get(halo_type, llvm::SmallVector<float>{0});
511 }
512
513 mlir::Value ghost_tensor_left =
514 builder.create<mlir::TF::ConstOp>(location, const_attr).getResult();
515
516 // Get the right side slice of the input tensor to pad on left side.
517 llvm::SmallVector<int64_t, 4> begin_left(layout.rank(), 0);
518 begin_left[split_dim_index] = input_tensor_shape[split_dim_index] - halo_size;
519 mlir::Value begin_tensor_left =
520 ops_util::GetR1Const(begin_left, builder, location);
521
522 llvm::SmallVector<int64_t, 4> size(input_tensor_shape.begin(),
523 input_tensor_shape.end());
524 size[split_dim_index] = halo_size;
525
526 mlir::Value size_tensor_left = ops_util::GetR1Const(size, builder, location);
527 mlir::Value sliced_tensor_left = builder.create<mlir::TF::SliceOp>(
528 location, halo_type, tensor, begin_tensor_left, size_tensor_left);
529
530 mlir::Value halo_tensor_left = builder.create<mlir::TF::SelectV2Op>(
531 location, is_on_right_edge, ghost_tensor_left, sliced_tensor_left);
532
533 // Invoke collective permute to receive the tensor from neighboring processor.
534 // Halo slices from the left neighbor are received on each processor (they
535 // are shifted right).
536 TF_ASSIGN_OR_RETURN(
537 mlir::Value src_target_pair_left,
538 CreateConstSrcTargetPair(mesh, mesh_dim, /*shift_left=*/false, location,
539 builder));
540
541 mlir::Value left_concat_value = builder.create<mlir::TF::CollectivePermuteOp>(
542 location, sliced_tensor_left.getType(), halo_tensor_left,
543 src_target_pair_left);
544
545 mlir::Value ghost_tensor_right =
546 builder.create<mlir::TF::ConstOp>(location, const_attr).getResult();
547
548 // Else, values to pad is tensor from different processor. We use collective
549 // permute to access tensor slice from another device.
550 // Get the left side slice of the input tensor.
551 llvm::SmallVector<int64_t, 4> begin_right(layout.rank(), 0);
552 mlir::Value begin_tensor_right =
553 ops_util::GetR1Const(begin_right, builder, location);
554 mlir::Value size_tensor_right = ops_util::GetR1Const(size, builder, location);
555 mlir::Value sliced_tensor_right = builder.create<mlir::TF::SliceOp>(
556 location, halo_type, tensor, begin_tensor_right, size_tensor_right);
557
558 // Find the halo tensor value to pad on the `right` side.
559 // If input block is on the right edge, we use zero ghost tensor instead.
560 mlir::Value halo_tensor_right = builder.create<mlir::TF::SelectV2Op>(
561 location, is_on_left_edge, ghost_tensor_right, sliced_tensor_right);
562
563 // Invoke collective permute to receive the tensor from neighboring processor.
564 // Halo slices from the right neighbor are received on each processor (they
565 // are shifted left).
566 TF_ASSIGN_OR_RETURN(
567 mlir::Value src_target_pair_right,
568 CreateConstSrcTargetPair(mesh, mesh_dim, /*shift_left=*/true, location,
569 builder));
570 mlir::Value right_concat_value =
571 builder.create<mlir::TF::CollectivePermuteOp>(
572 location, sliced_tensor_right.getType(), halo_tensor_right,
573 src_target_pair_right);
574
575 // Final halo exchanged value is concatenated value of left_concat_value,
576 // tensor, and right_concat_value in the mesh_dimension.
577 llvm::SmallVector<int64_t, 4> final_shape(input_tensor_shape.begin(),
578 input_tensor_shape.end());
579 final_shape[split_dim_index] = final_shape[split_dim_index] + 2 * halo_size;
580
581 auto final_type = mlir::RankedTensorType::get(
582 final_shape, input_tensor_type.getElementType());
583 mlir::Value concat_axis =
584 CreateIntScalarConst(split_dim_index, builder, location);
585 mlir::Value final_value = builder.create<mlir::TF::ConcatV2Op>(
586 location, final_type,
587 llvm::SmallVector<mlir::Value, 4>{left_concat_value, tensor,
588 right_concat_value},
589 concat_axis);
590
591 return final_value;
592}
593
594} // namespace dtensor
595} // namespace tensorflow
596