1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/dtensor/mlir/dtensor_send_recv.h" |
17 | |
18 | #include <string> |
19 | |
20 | #include "llvm/ADT/SmallVector.h" |
21 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
22 | #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" |
23 | #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" |
24 | #include "tensorflow/dtensor/mlir/device_utils.h" |
25 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
26 | #include "tensorflow/dtensor/mlir/value_utils.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace dtensor { |
30 | namespace { |
31 | |
32 | // Returns compilation key placeholder. This placeholder will be replaced with |
33 | // output of TPUCompile op during TPURewrite pass. Program key (output of |
34 | // TPUCompile op) is used to differentiate TPU computation from which to receive |
35 | // data. |
36 | mlir::Value GetOrCreateCompilationKey(mlir::Operation* op) { |
37 | mlir::Value key; |
38 | auto cluster = op->getParentOfType<mlir::tf_device::ClusterOp>(); |
39 | assert(cluster); |
40 | cluster.walk( |
41 | [&](mlir::TF::_TPUCompileMlirPlaceholderProgramKeyOp compilation_key) { |
42 | key = compilation_key.program(); |
43 | }); |
44 | if (key) return key; |
45 | |
46 | mlir::OpBuilder builder(&cluster.GetBody().front()); |
47 | auto result_type = |
48 | mlir::RankedTensorType::get({3}, builder.getType<mlir::TF::StringType>()); |
49 | auto new_compilation_key = |
50 | builder.create<mlir::TF::_TPUCompileMlirPlaceholderProgramKeyOp>( |
51 | cluster.getLoc(), /*program=*/result_type, |
52 | llvm::ArrayRef<mlir::Value>{}); |
53 | return new_compilation_key.program(); |
54 | } |
55 | |
56 | } // namespace |
57 | |
58 | StatusOr<mlir::Value> GetDeviceOrdinal(const Mesh& mesh, |
59 | const mlir::Location& loc, |
60 | mlir::func::FuncOp function, |
61 | mlir::OpBuilder* builder, |
62 | bool return_int64_type) { |
63 | // Create as many entries as the number of devices in the entire mesh. |
64 | llvm::SmallVector<int32, 4> device_id_to_ordinal(mesh.num_devices(), 0); |
65 | // Only fill in entries with indices equal to local device IDs. For TPUs, |
66 | // there are usually 8 local devices. |
67 | for (int i = 0; i < mesh.local_device_ids().size(); ++i) { |
68 | device_id_to_ordinal[mesh.local_device_ids()[i]] = i; |
69 | } |
70 | // Slice out the device ordinal using the device ID as index. |
71 | TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(function)); |
72 | mlir::TF::SliceOp device_ordinal = builder->create<mlir::TF::SliceOp>( |
73 | loc, |
74 | /*output=*/EffectivelyScalarR1Type(builder->getIntegerType(32)), |
75 | /*input=*/IntConst(*builder, loc, device_id_to_ordinal), |
76 | /*begin=*/ |
77 | mlir::TF::collection_ops_util::ReshapeScalarToSizeType(*builder, |
78 | device_id, loc), |
79 | /*size=*/IntConst(*builder, loc, {1})); |
80 | mlir::Value device_ordinal_scalar = |
81 | ReshapeSizeTypeToScalar(*builder, loc, device_ordinal); |
82 | if (return_int64_type) { |
83 | device_ordinal_scalar = builder->create<mlir::TF::CastOp>( |
84 | loc, mlir::RankedTensorType::get({}, builder->getI64Type()), |
85 | device_ordinal_scalar); |
86 | } |
87 | return device_ordinal_scalar; |
88 | } |
89 | |
90 | // Lowers DTensorSend Op to either one of XlaSendFromHost op or XlaSendToHost, |
91 | // depending on the src mesh cluster. |
92 | StatusOr<mlir::Operation*> LowerDTensorSendToXlaOp( |
93 | const Layout& send_input_layout, mlir::Value send_input, |
94 | mlir::TF::DTensorSend dtensor_send, bool send_from_device_zero) { |
95 | const bool send_from_cpu = !send_input_layout.mesh().is_tpu_mesh(); |
96 | mlir::OpBuilder builder(dtensor_send); |
97 | |
98 | mlir::Location loc = dtensor_send.getLoc(); |
99 | mlir::Operation* lowered_send_op; |
100 | if (send_from_cpu) { |
101 | llvm::SmallVector<mlir::Value, 4> value_to_send{send_input}; |
102 | mlir::OpBuilder::InsertPoint insertion_point = builder.saveInsertionPoint(); |
103 | mlir::Value program_key = GetOrCreateCompilationKey(dtensor_send); |
104 | builder.restoreInsertionPoint(insertion_point); |
105 | |
106 | mlir::Value device_ordinal; |
107 | if (send_from_device_zero) { |
108 | // For CopyToMesh, we currently only support sending from host device 0 |
109 | // to target TPUs. |
110 | device_ordinal = CreateIntScalarConst(0, builder, loc); |
111 | } else { |
112 | // For special topologies, always send from CPU device i to TPU device i. |
113 | auto send_cluster = |
114 | dtensor_send->getParentOfType<mlir::tf_device::ClusterOp>(); |
115 | if (!send_cluster) { |
116 | return errors::InvalidArgument("DTensorSend is not inside a ClusterOp" ); |
117 | } |
118 | auto send_func = send_cluster->getParentOfType<mlir::func::FuncOp>(); |
119 | if (!send_func) { |
120 | return errors::InvalidArgument("DTensorSend is not inside a FuncOp" ); |
121 | } |
122 | TF_ASSIGN_OR_RETURN( |
123 | device_ordinal, |
124 | GetDeviceOrdinal(send_input_layout.mesh(), loc, send_func, &builder)); |
125 | } |
126 | // Create XlaSendFromHostV2 op |
127 | lowered_send_op = builder.create<mlir::TF::_XlaSendFromHostV2Op>( |
128 | loc, value_to_send, program_key, device_ordinal, dtensor_send.key()); |
129 | } else { |
130 | // Note that for ops running in XLA/TPU, device ordinal input is not needed. |
131 | lowered_send_op = builder.create<mlir::TF::XlaSendToHostOp>( |
132 | loc, send_input, dtensor_send.key()); |
133 | } |
134 | |
135 | dtensor_send.erase(); |
136 | return lowered_send_op; |
137 | } |
138 | |
139 | // Lowers DTensorRecv op to either one of XlaRecvAtHost or XlaRecvFromHost, |
140 | // depending on src mesh cluster configuration. |
141 | StatusOr<mlir::Operation*> LowerDTensorRecvToXlaOp( |
142 | mlir::TF::DTensorRecv dtensor_recv) { |
143 | return LowerDTensorRecvToXlaOp(dtensor_recv, dtensor_recv.getType()); |
144 | } |
145 | |
146 | // Lowers DTensorRecv op to either one of XlaRecvAtHost or XlaRecvFromHost, |
147 | // depending on src mesh cluster configuration. `output_type` can be set to the |
148 | // specific local tensor type needed, if different from the Recv op output type. |
149 | StatusOr<mlir::Operation*> LowerDTensorRecvToXlaOp( |
150 | mlir::TF::DTensorRecv dtensor_recv, mlir::Type output_type) { |
151 | const bool recv_at_cpu = dtensor_recv.layout().mesh().is_cpu_mesh(); |
152 | mlir::Operation* recv_xla_op = nullptr; |
153 | mlir::OpBuilder builder(dtensor_recv); |
154 | |
155 | if (recv_at_cpu) { |
156 | // Create XlaRecvAtHostV2 op. |
157 | llvm::SmallVector<mlir::Type, 4> output_types{output_type}; |
158 | auto recv_cluster = |
159 | dtensor_recv->getParentOfType<mlir::tf_device::ClusterOp>(); |
160 | |
161 | TF_ASSIGN_OR_RETURN(absl::optional<Mesh> mesh, |
162 | ExtractDeviceMeshFromOp(recv_cluster)); |
163 | if (!mesh.has_value()) |
164 | return errors::InvalidArgument( |
165 | "failed to get device ordinal as mesh for operation is not " |
166 | "specified." ); |
167 | |
168 | mlir::OpBuilder builder(&recv_cluster.GetBody().front()); |
169 | TF_ASSIGN_OR_RETURN( |
170 | mlir::Value device_ordinal, |
171 | GetDeviceOrdinal(*mesh, recv_cluster.getLoc(), |
172 | recv_cluster->getParentOfType<mlir::func::FuncOp>(), |
173 | &builder)); |
174 | |
175 | auto program_key = GetOrCreateCompilationKey(dtensor_recv); |
176 | builder.setInsertionPoint(dtensor_recv); |
177 | recv_xla_op = builder.create<mlir::TF::_XlaRecvAtHostV2Op>( |
178 | dtensor_recv.getLoc(), output_types, |
179 | /*dynamic_key=*/program_key, device_ordinal, dtensor_recv.keyAttr()); |
180 | } else { |
181 | // Create XlaRecvFromHost op. |
182 | recv_xla_op = builder.create<mlir::TF::XlaRecvFromHostOp>( |
183 | dtensor_recv.getLoc(), output_type, |
184 | ConvertTypeToTensorShapeAttr(dtensor_recv.getType()), |
185 | dtensor_recv.keyAttr()); |
186 | } |
187 | |
188 | assert(recv_xla_op); |
189 | |
190 | // TODO(hongjunchoi): After receiving tensor, convert tensor to requested |
191 | // layout with EmitRelayout. |
192 | return recv_xla_op; |
193 | } |
194 | |
195 | // Lowers a DTensorSend Op from a CPU to a TF Send op. |
196 | StatusOr<mlir::Operation*> LowerDTensorSendFromCPUToTFOp( |
197 | const Layout& send_input_layout, mlir::Value send_input, |
198 | mlir::TF::DTensorSend dtensor_send) { |
199 | mlir::OpBuilder builder(dtensor_send); |
200 | builder.setInsertionPointAfter(send_input.getDefiningOp()); |
201 | |
202 | llvm::SmallVector<mlir::Value, 4> value_to_send{send_input}; |
203 | |
204 | // Create multiple send from host. There should be #number of local |
205 | // devices(in target mesh) number of sends. |
206 | absl::Span<const std::string> sending_devices = |
207 | send_input_layout.mesh().local_devices(); |
208 | |
209 | Layout target_layout = dtensor_send.target_layout(); |
210 | absl::Span<const std::string> receiving_devices = |
211 | target_layout.mesh().local_devices(); |
212 | |
213 | std::string tensor_name = dtensor_send.key().str(); |
214 | |
215 | mlir::Operation* lowered_send_op; |
216 | for (size_t i = 0; i < receiving_devices.size(); ++i) |
217 | lowered_send_op = builder.create<mlir::TF::_HostSendOp>( |
218 | send_input.getLoc(), dtensor_send.input(), tensor_name, |
219 | sending_devices[0], |
220 | /*send_device_incarnation=*/0, receiving_devices[i]); |
221 | |
222 | dtensor_send.erase(); |
223 | return lowered_send_op; |
224 | } |
225 | |
226 | // Lowers DTensorRecv op to TF Recv Op. |
227 | StatusOr<mlir::Operation*> LowerDTensorRecvFromCPUToTFOp( |
228 | const Mesh& send_mesh, mlir::TF::DTensorRecv dtensor_recv) { |
229 | const Layout& recv_layout = dtensor_recv.layout(); |
230 | |
231 | auto recv_cluster = |
232 | dtensor_recv->getParentOfType<mlir::tf_device::ClusterOp>(); |
233 | |
234 | mlir::OpBuilder builder(&recv_cluster.GetBody().front()); |
235 | llvm::SmallVector<mlir::Type, 4> output_types{dtensor_recv.getType()}; |
236 | builder.setInsertionPoint(dtensor_recv); |
237 | std::string tensor_name = dtensor_recv.key().str(); |
238 | absl::Span<const std::string> sending_devices = send_mesh.local_devices(); |
239 | absl::Span<const std::string> receiving_devices = |
240 | recv_layout.mesh().local_devices(); |
241 | |
242 | mlir::Operation* lowered_recv_op; |
243 | mlir::Location loc = dtensor_recv.getLoc(); |
244 | for (size_t i = 0; i < receiving_devices.size(); ++i) |
245 | lowered_recv_op = builder.create<mlir::TF::_HostRecvOp>( |
246 | loc, dtensor_recv.getType(), tensor_name, sending_devices[0], |
247 | /*send_device_incarnation=*/0, receiving_devices[i]); |
248 | |
249 | // Replace dtensor_recv with newly created recv op and remove DTensorRecv op. |
250 | assert(lowered_recv_op); |
251 | dtensor_recv.replaceAllUsesWith(lowered_recv_op); |
252 | dtensor_recv.erase(); |
253 | return lowered_recv_op; |
254 | } |
255 | |
256 | } // namespace dtensor |
257 | } // namespace tensorflow |
258 | |