1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/mlir/dtensor_send_recv.h"
17
18#include <string>
19
20#include "llvm/ADT/SmallVector.h"
21#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
22#include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
23#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
24#include "tensorflow/dtensor/mlir/device_utils.h"
25#include "tensorflow/dtensor/mlir/layout_parsing.h"
26#include "tensorflow/dtensor/mlir/value_utils.h"
27
28namespace tensorflow {
29namespace dtensor {
30namespace {
31
32// Returns compilation key placeholder. This placeholder will be replaced with
33// output of TPUCompile op during TPURewrite pass. Program key (output of
34// TPUCompile op) is used to differentiate TPU computation from which to receive
35// data.
36mlir::Value GetOrCreateCompilationKey(mlir::Operation* op) {
37 mlir::Value key;
38 auto cluster = op->getParentOfType<mlir::tf_device::ClusterOp>();
39 assert(cluster);
40 cluster.walk(
41 [&](mlir::TF::_TPUCompileMlirPlaceholderProgramKeyOp compilation_key) {
42 key = compilation_key.program();
43 });
44 if (key) return key;
45
46 mlir::OpBuilder builder(&cluster.GetBody().front());
47 auto result_type =
48 mlir::RankedTensorType::get({3}, builder.getType<mlir::TF::StringType>());
49 auto new_compilation_key =
50 builder.create<mlir::TF::_TPUCompileMlirPlaceholderProgramKeyOp>(
51 cluster.getLoc(), /*program=*/result_type,
52 llvm::ArrayRef<mlir::Value>{});
53 return new_compilation_key.program();
54}
55
56} // namespace
57
58StatusOr<mlir::Value> GetDeviceOrdinal(const Mesh& mesh,
59 const mlir::Location& loc,
60 mlir::func::FuncOp function,
61 mlir::OpBuilder* builder,
62 bool return_int64_type) {
63 // Create as many entries as the number of devices in the entire mesh.
64 llvm::SmallVector<int32, 4> device_id_to_ordinal(mesh.num_devices(), 0);
65 // Only fill in entries with indices equal to local device IDs. For TPUs,
66 // there are usually 8 local devices.
67 for (int i = 0; i < mesh.local_device_ids().size(); ++i) {
68 device_id_to_ordinal[mesh.local_device_ids()[i]] = i;
69 }
70 // Slice out the device ordinal using the device ID as index.
71 TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(function));
72 mlir::TF::SliceOp device_ordinal = builder->create<mlir::TF::SliceOp>(
73 loc,
74 /*output=*/EffectivelyScalarR1Type(builder->getIntegerType(32)),
75 /*input=*/IntConst(*builder, loc, device_id_to_ordinal),
76 /*begin=*/
77 mlir::TF::collection_ops_util::ReshapeScalarToSizeType(*builder,
78 device_id, loc),
79 /*size=*/IntConst(*builder, loc, {1}));
80 mlir::Value device_ordinal_scalar =
81 ReshapeSizeTypeToScalar(*builder, loc, device_ordinal);
82 if (return_int64_type) {
83 device_ordinal_scalar = builder->create<mlir::TF::CastOp>(
84 loc, mlir::RankedTensorType::get({}, builder->getI64Type()),
85 device_ordinal_scalar);
86 }
87 return device_ordinal_scalar;
88}
89
90// Lowers DTensorSend Op to either one of XlaSendFromHost op or XlaSendToHost,
91// depending on the src mesh cluster.
92StatusOr<mlir::Operation*> LowerDTensorSendToXlaOp(
93 const Layout& send_input_layout, mlir::Value send_input,
94 mlir::TF::DTensorSend dtensor_send, bool send_from_device_zero) {
95 const bool send_from_cpu = !send_input_layout.mesh().is_tpu_mesh();
96 mlir::OpBuilder builder(dtensor_send);
97
98 mlir::Location loc = dtensor_send.getLoc();
99 mlir::Operation* lowered_send_op;
100 if (send_from_cpu) {
101 llvm::SmallVector<mlir::Value, 4> value_to_send{send_input};
102 mlir::OpBuilder::InsertPoint insertion_point = builder.saveInsertionPoint();
103 mlir::Value program_key = GetOrCreateCompilationKey(dtensor_send);
104 builder.restoreInsertionPoint(insertion_point);
105
106 mlir::Value device_ordinal;
107 if (send_from_device_zero) {
108 // For CopyToMesh, we currently only support sending from host device 0
109 // to target TPUs.
110 device_ordinal = CreateIntScalarConst(0, builder, loc);
111 } else {
112 // For special topologies, always send from CPU device i to TPU device i.
113 auto send_cluster =
114 dtensor_send->getParentOfType<mlir::tf_device::ClusterOp>();
115 if (!send_cluster) {
116 return errors::InvalidArgument("DTensorSend is not inside a ClusterOp");
117 }
118 auto send_func = send_cluster->getParentOfType<mlir::func::FuncOp>();
119 if (!send_func) {
120 return errors::InvalidArgument("DTensorSend is not inside a FuncOp");
121 }
122 TF_ASSIGN_OR_RETURN(
123 device_ordinal,
124 GetDeviceOrdinal(send_input_layout.mesh(), loc, send_func, &builder));
125 }
126 // Create XlaSendFromHostV2 op
127 lowered_send_op = builder.create<mlir::TF::_XlaSendFromHostV2Op>(
128 loc, value_to_send, program_key, device_ordinal, dtensor_send.key());
129 } else {
130 // Note that for ops running in XLA/TPU, device ordinal input is not needed.
131 lowered_send_op = builder.create<mlir::TF::XlaSendToHostOp>(
132 loc, send_input, dtensor_send.key());
133 }
134
135 dtensor_send.erase();
136 return lowered_send_op;
137}
138
139// Lowers DTensorRecv op to either one of XlaRecvAtHost or XlaRecvFromHost,
140// depending on src mesh cluster configuration.
141StatusOr<mlir::Operation*> LowerDTensorRecvToXlaOp(
142 mlir::TF::DTensorRecv dtensor_recv) {
143 return LowerDTensorRecvToXlaOp(dtensor_recv, dtensor_recv.getType());
144}
145
146// Lowers DTensorRecv op to either one of XlaRecvAtHost or XlaRecvFromHost,
147// depending on src mesh cluster configuration. `output_type` can be set to the
148// specific local tensor type needed, if different from the Recv op output type.
149StatusOr<mlir::Operation*> LowerDTensorRecvToXlaOp(
150 mlir::TF::DTensorRecv dtensor_recv, mlir::Type output_type) {
151 const bool recv_at_cpu = dtensor_recv.layout().mesh().is_cpu_mesh();
152 mlir::Operation* recv_xla_op = nullptr;
153 mlir::OpBuilder builder(dtensor_recv);
154
155 if (recv_at_cpu) {
156 // Create XlaRecvAtHostV2 op.
157 llvm::SmallVector<mlir::Type, 4> output_types{output_type};
158 auto recv_cluster =
159 dtensor_recv->getParentOfType<mlir::tf_device::ClusterOp>();
160
161 TF_ASSIGN_OR_RETURN(absl::optional<Mesh> mesh,
162 ExtractDeviceMeshFromOp(recv_cluster));
163 if (!mesh.has_value())
164 return errors::InvalidArgument(
165 "failed to get device ordinal as mesh for operation is not "
166 "specified.");
167
168 mlir::OpBuilder builder(&recv_cluster.GetBody().front());
169 TF_ASSIGN_OR_RETURN(
170 mlir::Value device_ordinal,
171 GetDeviceOrdinal(*mesh, recv_cluster.getLoc(),
172 recv_cluster->getParentOfType<mlir::func::FuncOp>(),
173 &builder));
174
175 auto program_key = GetOrCreateCompilationKey(dtensor_recv);
176 builder.setInsertionPoint(dtensor_recv);
177 recv_xla_op = builder.create<mlir::TF::_XlaRecvAtHostV2Op>(
178 dtensor_recv.getLoc(), output_types,
179 /*dynamic_key=*/program_key, device_ordinal, dtensor_recv.keyAttr());
180 } else {
181 // Create XlaRecvFromHost op.
182 recv_xla_op = builder.create<mlir::TF::XlaRecvFromHostOp>(
183 dtensor_recv.getLoc(), output_type,
184 ConvertTypeToTensorShapeAttr(dtensor_recv.getType()),
185 dtensor_recv.keyAttr());
186 }
187
188 assert(recv_xla_op);
189
190 // TODO(hongjunchoi): After receiving tensor, convert tensor to requested
191 // layout with EmitRelayout.
192 return recv_xla_op;
193}
194
195// Lowers a DTensorSend Op from a CPU to a TF Send op.
196StatusOr<mlir::Operation*> LowerDTensorSendFromCPUToTFOp(
197 const Layout& send_input_layout, mlir::Value send_input,
198 mlir::TF::DTensorSend dtensor_send) {
199 mlir::OpBuilder builder(dtensor_send);
200 builder.setInsertionPointAfter(send_input.getDefiningOp());
201
202 llvm::SmallVector<mlir::Value, 4> value_to_send{send_input};
203
204 // Create multiple send from host. There should be #number of local
205 // devices(in target mesh) number of sends.
206 absl::Span<const std::string> sending_devices =
207 send_input_layout.mesh().local_devices();
208
209 Layout target_layout = dtensor_send.target_layout();
210 absl::Span<const std::string> receiving_devices =
211 target_layout.mesh().local_devices();
212
213 std::string tensor_name = dtensor_send.key().str();
214
215 mlir::Operation* lowered_send_op;
216 for (size_t i = 0; i < receiving_devices.size(); ++i)
217 lowered_send_op = builder.create<mlir::TF::_HostSendOp>(
218 send_input.getLoc(), dtensor_send.input(), tensor_name,
219 sending_devices[0],
220 /*send_device_incarnation=*/0, receiving_devices[i]);
221
222 dtensor_send.erase();
223 return lowered_send_op;
224}
225
226// Lowers DTensorRecv op to TF Recv Op.
227StatusOr<mlir::Operation*> LowerDTensorRecvFromCPUToTFOp(
228 const Mesh& send_mesh, mlir::TF::DTensorRecv dtensor_recv) {
229 const Layout& recv_layout = dtensor_recv.layout();
230
231 auto recv_cluster =
232 dtensor_recv->getParentOfType<mlir::tf_device::ClusterOp>();
233
234 mlir::OpBuilder builder(&recv_cluster.GetBody().front());
235 llvm::SmallVector<mlir::Type, 4> output_types{dtensor_recv.getType()};
236 builder.setInsertionPoint(dtensor_recv);
237 std::string tensor_name = dtensor_recv.key().str();
238 absl::Span<const std::string> sending_devices = send_mesh.local_devices();
239 absl::Span<const std::string> receiving_devices =
240 recv_layout.mesh().local_devices();
241
242 mlir::Operation* lowered_recv_op;
243 mlir::Location loc = dtensor_recv.getLoc();
244 for (size_t i = 0; i < receiving_devices.size(); ++i)
245 lowered_recv_op = builder.create<mlir::TF::_HostRecvOp>(
246 loc, dtensor_recv.getType(), tensor_name, sending_devices[0],
247 /*send_device_incarnation=*/0, receiving_devices[i]);
248
249 // Replace dtensor_recv with newly created recv op and remove DTensorRecv op.
250 assert(lowered_recv_op);
251 dtensor_recv.replaceAllUsesWith(lowered_recv_op);
252 dtensor_recv.erase();
253 return lowered_recv_op;
254}
255
256} // namespace dtensor
257} // namespace tensorflow
258