1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_DTENSOR_MLIR_DTENSOR_SEND_RECV_H_ |
17 | #define TENSORFLOW_DTENSOR_MLIR_DTENSOR_SEND_RECV_H_ |
18 | |
19 | #include "llvm/Support/Casting.h" |
20 | #include "mlir/IR/Builders.h" // from @llvm-project |
21 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
22 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
23 | #include "mlir/IR/Location.h" // from @llvm-project |
24 | #include "mlir/IR/Value.h" // from @llvm-project |
25 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
26 | #include "tensorflow/core/platform/errors.h" |
27 | #include "tensorflow/dtensor/cc/dstatus.h" |
28 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
29 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
30 | |
31 | namespace tensorflow { |
32 | namespace dtensor { |
33 | |
34 | // Given DTensorSend or DTensorRecv op, returns the corresponding DTensorRecv |
35 | // or DTensorSend op with the same key. |
36 | template <typename DTensorOp> |
37 | StatusOr<mlir::Operation*> GetCorrespondingDTensorSendRecvOp( |
38 | mlir::ModuleOp module, DTensorOp dtensor_op) { |
39 | mlir::Operation* corresponding_op = nullptr; |
40 | if (std::is_same<DTensorOp, mlir::TF::DTensorSend>::value) { |
41 | module.walk([&](mlir::Operation* op) { |
42 | if (auto xla_recv_tpu = llvm::dyn_cast<mlir::TF::XlaRecvFromHostOp>(op)) { |
43 | if (dtensor_op.key() == xla_recv_tpu.key()) { |
44 | corresponding_op = op; |
45 | return mlir::WalkResult::interrupt(); |
46 | } |
47 | } else if (auto xla_recv_cpu = |
48 | llvm::dyn_cast<mlir::TF::_XlaRecvAtHostV2Op>(op)) { |
49 | if (dtensor_op.key() == xla_recv_cpu.key()) { |
50 | corresponding_op = op; |
51 | return mlir::WalkResult::interrupt(); |
52 | } |
53 | } else if (auto dtensor_recv = |
54 | llvm::dyn_cast<mlir::TF::DTensorRecv>(op)) { |
55 | if (dtensor_op.key() == dtensor_recv.key()) { |
56 | corresponding_op = op; |
57 | return mlir::WalkResult::interrupt(); |
58 | } |
59 | } else if (auto host_recv = llvm::dyn_cast<mlir::TF::_HostRecvOp>(op)) { |
60 | if (dtensor_op.key() == host_recv.tensor_name()) { |
61 | corresponding_op = op; |
62 | return mlir::WalkResult::interrupt(); |
63 | } |
64 | } |
65 | return mlir::WalkResult::advance(); |
66 | }); |
67 | } else { |
68 | const bool is_recv = std::is_same<DTensorOp, mlir::TF::DTensorRecv>::value; |
69 | if (!is_recv) { |
70 | return errors::Internal( |
71 | "Error checking if is same for DTensorOp and DTensorRecv." ); |
72 | } |
73 | module.walk([&](mlir::Operation* op) { |
74 | if (auto xla_send_tpu = llvm::dyn_cast<mlir::TF::XlaSendToHostOp>(op)) { |
75 | if (dtensor_op.key() == xla_send_tpu.key()) { |
76 | corresponding_op = op; |
77 | return mlir::WalkResult::interrupt(); |
78 | } |
79 | } else if (auto xla_send_cpu = |
80 | llvm::dyn_cast<mlir::TF::_XlaSendFromHostV2Op>(op)) { |
81 | if (dtensor_op.key() == xla_send_cpu.key()) { |
82 | corresponding_op = op; |
83 | return mlir::WalkResult::interrupt(); |
84 | } |
85 | } else if (auto dtensor_send = |
86 | llvm::dyn_cast<mlir::TF::DTensorSend>(op)) { |
87 | if (dtensor_op.key() == dtensor_send.key()) { |
88 | corresponding_op = op; |
89 | return mlir::WalkResult::interrupt(); |
90 | } |
91 | } else if (auto host_send = llvm::dyn_cast<mlir::TF::_HostSendOp>(op)) { |
92 | if (dtensor_op.key() == host_send.tensor_name()) { |
93 | corresponding_op = op; |
94 | return mlir::WalkResult::interrupt(); |
95 | } |
96 | } |
97 | return mlir::WalkResult::advance(); |
98 | }); |
99 | } |
100 | |
101 | if (!corresponding_op) |
102 | return errors::InvalidArgument( |
103 | "DTensorSend/DTensorRecv op must have corresponding " |
104 | "DTensorRecv/DTensorSend op." ); |
105 | |
106 | return corresponding_op; |
107 | } |
108 | |
109 | // Lowers DTensorRecv op to either one of XlaRecvAtHost or XlaRecvFromHost, |
110 | // depending on src mesh cluster configuration. |
111 | StatusOr<mlir::Operation*> LowerDTensorRecvToXlaOp( |
112 | mlir::TF::DTensorRecv dtensor_recv); |
113 | |
114 | // Lowers DTensorRecv op to either one of XlaRecvAtHost or XlaRecvFromHost, |
115 | // depending on src mesh cluster configuration. `output_type` can be set to the |
116 | // specific local tensor type needed, if different from the Recv op output type. |
117 | StatusOr<mlir::Operation*> LowerDTensorRecvToXlaOp( |
118 | mlir::TF::DTensorRecv dtensor_recv, mlir::Type output_type); |
119 | |
120 | // Lowers DTensorSend Op to either one of XlaSendFromHost op or XlaSendToHost, |
121 | // depending on the src mesh cluster. `send_from_device_zero` should be set if |
122 | // control flow needs to be inserted to gather data onto and only sent from the |
123 | // zero'th device. |
124 | StatusOr<mlir::Operation*> LowerDTensorSendToXlaOp( |
125 | const Layout& send_input_layout, mlir::Value send_input, |
126 | mlir::TF::DTensorSend dtensor_send, bool send_from_device_zero); |
127 | |
128 | // Lowers DTensorSend Op to a TF HostSend op. |
129 | StatusOr<mlir::Operation*> LowerDTensorSendFromCPUToTFOp( |
130 | const Layout& send_input_layout, mlir::Value send_input, |
131 | mlir::TF::DTensorSend dtensor_send); |
132 | |
133 | // Lowers DTensorSend Op to a TF HostRecv op. |
134 | StatusOr<mlir::Operation*> LowerDTensorRecvFromCPUToTFOp( |
135 | const Mesh& send_mesh, mlir::TF::DTensorRecv dtensor_recv); |
136 | |
137 | } // namespace dtensor |
138 | } // namespace tensorflow |
139 | |
140 | #endif // TENSORFLOW_DTENSOR_MLIR_DTENSOR_SEND_RECV_H_ |
141 | |