1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_DTENSOR_MLIR_DTENSOR_SEND_RECV_H_
17#define TENSORFLOW_DTENSOR_MLIR_DTENSOR_SEND_RECV_H_
18
19#include "llvm/Support/Casting.h"
20#include "mlir/IR/Builders.h" // from @llvm-project
21#include "mlir/IR/BuiltinOps.h" // from @llvm-project
22#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
23#include "mlir/IR/Location.h" // from @llvm-project
24#include "mlir/IR/Value.h" // from @llvm-project
25#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
26#include "tensorflow/core/platform/errors.h"
27#include "tensorflow/dtensor/cc/dstatus.h"
28#include "tensorflow/dtensor/cc/tensor_layout.h"
29#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
30
31namespace tensorflow {
32namespace dtensor {
33
34// Given DTensorSend or DTensorRecv op, returns the corresponding DTensorRecv
35// or DTensorSend op with the same key.
36template <typename DTensorOp>
37StatusOr<mlir::Operation*> GetCorrespondingDTensorSendRecvOp(
38 mlir::ModuleOp module, DTensorOp dtensor_op) {
39 mlir::Operation* corresponding_op = nullptr;
40 if (std::is_same<DTensorOp, mlir::TF::DTensorSend>::value) {
41 module.walk([&](mlir::Operation* op) {
42 if (auto xla_recv_tpu = llvm::dyn_cast<mlir::TF::XlaRecvFromHostOp>(op)) {
43 if (dtensor_op.key() == xla_recv_tpu.key()) {
44 corresponding_op = op;
45 return mlir::WalkResult::interrupt();
46 }
47 } else if (auto xla_recv_cpu =
48 llvm::dyn_cast<mlir::TF::_XlaRecvAtHostV2Op>(op)) {
49 if (dtensor_op.key() == xla_recv_cpu.key()) {
50 corresponding_op = op;
51 return mlir::WalkResult::interrupt();
52 }
53 } else if (auto dtensor_recv =
54 llvm::dyn_cast<mlir::TF::DTensorRecv>(op)) {
55 if (dtensor_op.key() == dtensor_recv.key()) {
56 corresponding_op = op;
57 return mlir::WalkResult::interrupt();
58 }
59 } else if (auto host_recv = llvm::dyn_cast<mlir::TF::_HostRecvOp>(op)) {
60 if (dtensor_op.key() == host_recv.tensor_name()) {
61 corresponding_op = op;
62 return mlir::WalkResult::interrupt();
63 }
64 }
65 return mlir::WalkResult::advance();
66 });
67 } else {
68 const bool is_recv = std::is_same<DTensorOp, mlir::TF::DTensorRecv>::value;
69 if (!is_recv) {
70 return errors::Internal(
71 "Error checking if is same for DTensorOp and DTensorRecv.");
72 }
73 module.walk([&](mlir::Operation* op) {
74 if (auto xla_send_tpu = llvm::dyn_cast<mlir::TF::XlaSendToHostOp>(op)) {
75 if (dtensor_op.key() == xla_send_tpu.key()) {
76 corresponding_op = op;
77 return mlir::WalkResult::interrupt();
78 }
79 } else if (auto xla_send_cpu =
80 llvm::dyn_cast<mlir::TF::_XlaSendFromHostV2Op>(op)) {
81 if (dtensor_op.key() == xla_send_cpu.key()) {
82 corresponding_op = op;
83 return mlir::WalkResult::interrupt();
84 }
85 } else if (auto dtensor_send =
86 llvm::dyn_cast<mlir::TF::DTensorSend>(op)) {
87 if (dtensor_op.key() == dtensor_send.key()) {
88 corresponding_op = op;
89 return mlir::WalkResult::interrupt();
90 }
91 } else if (auto host_send = llvm::dyn_cast<mlir::TF::_HostSendOp>(op)) {
92 if (dtensor_op.key() == host_send.tensor_name()) {
93 corresponding_op = op;
94 return mlir::WalkResult::interrupt();
95 }
96 }
97 return mlir::WalkResult::advance();
98 });
99 }
100
101 if (!corresponding_op)
102 return errors::InvalidArgument(
103 "DTensorSend/DTensorRecv op must have corresponding "
104 "DTensorRecv/DTensorSend op.");
105
106 return corresponding_op;
107}
108
109// Lowers DTensorRecv op to either one of XlaRecvAtHost or XlaRecvFromHost,
110// depending on src mesh cluster configuration.
111StatusOr<mlir::Operation*> LowerDTensorRecvToXlaOp(
112 mlir::TF::DTensorRecv dtensor_recv);
113
114// Lowers DTensorRecv op to either one of XlaRecvAtHost or XlaRecvFromHost,
115// depending on src mesh cluster configuration. `output_type` can be set to the
116// specific local tensor type needed, if different from the Recv op output type.
117StatusOr<mlir::Operation*> LowerDTensorRecvToXlaOp(
118 mlir::TF::DTensorRecv dtensor_recv, mlir::Type output_type);
119
120// Lowers DTensorSend Op to either one of XlaSendFromHost op or XlaSendToHost,
121// depending on the src mesh cluster. `send_from_device_zero` should be set if
122// control flow needs to be inserted to gather data onto and only sent from the
123// zero'th device.
124StatusOr<mlir::Operation*> LowerDTensorSendToXlaOp(
125 const Layout& send_input_layout, mlir::Value send_input,
126 mlir::TF::DTensorSend dtensor_send, bool send_from_device_zero);
127
128// Lowers DTensorSend Op to a TF HostSend op.
129StatusOr<mlir::Operation*> LowerDTensorSendFromCPUToTFOp(
130 const Layout& send_input_layout, mlir::Value send_input,
131 mlir::TF::DTensorSend dtensor_send);
132
133// Lowers DTensorSend Op to a TF HostRecv op.
134StatusOr<mlir::Operation*> LowerDTensorRecvFromCPUToTFOp(
135 const Mesh& send_mesh, mlir::TF::DTensorRecv dtensor_recv);
136
137} // namespace dtensor
138} // namespace tensorflow
139
140#endif // TENSORFLOW_DTENSOR_MLIR_DTENSOR_SEND_RECV_H_
141