1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/shape_inference.h"
19
20namespace tensorflow {
21
22REGISTER_OP("CollectiveReduce")
23 .Input("input: T")
24 .Output("data: T")
25 .Attr("T: {bfloat16, float, float16, float64, int32, int64}")
26 .Attr("group_size: int")
27 .Attr("group_key: int")
28 .Attr("instance_key: int")
29 .Attr("merge_op: {'Min', 'Max', 'Mul', 'Add'}")
30 .Attr("final_op: {'Id', 'Div'}")
31 .Attr("subdiv_offsets: list(int)")
32 .Attr("wait_for: list(int) = []")
33 .Attr("communication_hint: string = 'auto'")
34 .Attr("timeout_seconds: float = 0")
35 .SetIsStateful()
36 .SetIsDistributedCommunication()
37 .SetShapeFn(shape_inference::UnchangedShape);
38
39REGISTER_OP("CollectiveGather")
40 .Input("input: T")
41 .Output("data: T")
42 .Attr("T: {float, float16, float64, int32, int64}")
43 .Attr("group_size: int")
44 .Attr("group_key: int")
45 .Attr("instance_key: int")
46 .Attr("shape: shape")
47 .Attr("communication_hint: string = 'auto'")
48 .Attr("timeout_seconds: float = 0")
49 .SetIsStateful()
50 .SetIsDistributedCommunication()
51 .SetShapeFn([](shape_inference::InferenceContext* c) {
52 // Scalar input is not supported.
53 shape_inference::ShapeHandle unused;
54 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused));
55
56 shape_inference::ShapeHandle in_subshape;
57 TF_RETURN_IF_ERROR(c->Subshape(c->input(0), 1, &in_subshape));
58
59 auto input_first_dim_value = c->Value(c->Dim(c->input(0), 0));
60
61 // This output should have the same shape as its input except the first
62 // dimension should be multiplied by group size.
63 shape_inference::ShapeHandle output_first_dim_as_shape;
64 if (input_first_dim_value ==
65 shape_inference::InferenceContext::kUnknownDim) {
66 output_first_dim_as_shape =
67 c->Vector(shape_inference::InferenceContext::kUnknownDim);
68 } else {
69 int group_size;
70 TF_CHECK_OK(c->GetAttr("group_size", &group_size));
71 std::vector<shape_inference::DimensionHandle> output_first_dim;
72 output_first_dim.push_back(
73 c->MakeDim(group_size * input_first_dim_value));
74 output_first_dim_as_shape = c->MakeShape(output_first_dim);
75 }
76
77 shape_inference::ShapeHandle out;
78 TF_RETURN_IF_ERROR(
79 c->Concatenate(output_first_dim_as_shape, in_subshape, &out));
80 c->set_output(0, out);
81 return OkStatus();
82 });
83
84REGISTER_OP("CollectiveBcastSend")
85 .Input("input: T")
86 .Output("data: T")
87 .Attr("T: {bool, float, float16, float64, int32, int64}")
88 .Attr("group_size: int")
89 .Attr("group_key: int")
90 .Attr("instance_key: int")
91 .Attr("shape: shape")
92 .Attr("communication_hint: string = 'auto'")
93 .Attr("timeout_seconds: float = 0")
94 .SetIsStateful()
95 .SetIsDistributedCommunication()
96 .SetShapeFn(shape_inference::ExplicitShape);
97
98REGISTER_OP("CollectiveBcastRecv")
99 .Output("data: T")
100 .Attr("T: {bool, float, float16, float64, int32, int64}")
101 .Attr("group_size: int")
102 .Attr("group_key: int")
103 .Attr("instance_key: int")
104 .Attr("shape: shape")
105 .Attr("communication_hint: string = 'auto'")
106 .Attr("timeout_seconds: float = 0")
107 .SetIsStateful()
108 .SetIsDistributedCommunication()
109 .SetShapeFn(shape_inference::ExplicitShape);
110
111REGISTER_OP("CollectiveAssignGroupV2")
112 .Input("group_assignment: int32")
113 .Input("device_index: int32")
114 .Input("base_key: int32")
115 .Output("group_size: int32")
116 .Output("group_key: int32")
117 // To avoid tensorflow::constant_folding.
118 .SetDoNotOptimize() // Also marked in auto_control_dep.py and
119 // function_optimizer.cc
120 .SetIsDistributedCommunication()
121 .SetShapeFn([](shape_inference::InferenceContext* c) {
122 c->set_output(0, c->Scalar());
123 c->set_output(1, c->Scalar());
124 return OkStatus();
125 });
126
127REGISTER_OP("CollectiveReduceV2")
128 .Input("input: T")
129 .Output("data: T")
130 .Attr("T: {bfloat16, float, float16, float64, int32, int64}")
131 .Input("group_size: int32")
132 .Input("group_key: int32")
133 .Input("instance_key: int32")
134 .Input("ordering_token: Nordering_token * resource")
135 .Attr("merge_op: {'Min', 'Max', 'Mul', 'Add'}")
136 .Attr("final_op: {'Id', 'Div'}")
137 .Attr("communication_hint: string = 'auto'")
138 .Attr("timeout_seconds: float = 0")
139 .Attr("Nordering_token: int >= 0 = 0")
140 .Attr("max_subdivs_per_device: int = -1")
141 .SetIsStateful()
142 .SetIsDistributedCommunication()
143 .SetShapeFn(shape_inference::UnchangedShape);
144
145REGISTER_OP("CollectiveGatherV2")
146 .Input("input: T")
147 .Output("data: T")
148 .Attr("T: {float, float16, float64, int32, int64}")
149 .Input("group_size: int32")
150 .Input("group_key: int32")
151 .Input("instance_key: int32")
152 .Input("ordering_token: Nordering_token * resource")
153 .Attr("communication_hint: string = 'auto'")
154 .Attr("timeout_seconds: float = 0")
155 .Attr("Nordering_token: int >= 0 = 0")
156 .SetIsStateful()
157 .SetIsDistributedCommunication()
158 .SetShapeFn([](shape_inference::InferenceContext* c) {
159 // Scalar input is not supported.
160 shape_inference::ShapeHandle unused;
161 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused));
162 // This output should have the same shape as its input except the first
163 // dimension is unknown, since the group size is unknown.
164 shape_inference::ShapeHandle out;
165 TF_RETURN_IF_ERROR(
166 c->ReplaceDim(c->input(0), /*dim_index*/ 0, c->UnknownDim(), &out));
167 c->set_output(0, out);
168 return OkStatus();
169 });
170
171REGISTER_OP("CollectiveBcastSendV2")
172 .Input("input: T")
173 .Output("data: T")
174 .Attr("T: {bool, float, float16, float64, int32, int64}")
175 .Input("group_size: int32")
176 .Input("group_key: int32")
177 .Input("instance_key: int32")
178 .Attr("communication_hint: string = 'auto'")
179 .Attr("timeout_seconds: float = 0")
180 .SetIsStateful()
181 .SetIsDistributedCommunication()
182 .SetShapeFn(shape_inference::UnchangedShape);
183
184REGISTER_OP("CollectiveBcastRecvV2")
185 .Output("data: T")
186 .Attr("T: {bool, float, float16, float64, int32, int64}")
187 .Input("group_size: int32")
188 .Input("group_key: int32")
189 .Input("instance_key: int32")
190 .Input("shape: Tshape")
191 .Attr("Tshape: {int32, int64} = DT_INT32")
192 .Attr("communication_hint: string = 'auto'")
193 .Attr("timeout_seconds: float = 0")
194 .SetIsStateful()
195 .SetIsDistributedCommunication()
196 .SetShapeFn([](shape_inference::InferenceContext* c) {
197 // The output shape is given by the `shape` input at index 3.
198 shape_inference::ShapeHandle out;
199 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(/*input_idx=*/3, &out));
200 c->set_output(/*idx=*/0, out);
201 return OkStatus();
202 });
203
204REGISTER_OP("CollectiveInitializeCommunicator")
205 .Input("group_key: int32")
206 .Input("rank: int32")
207 .Input("group_size: int32")
208 .Attr("communication_hint: string = 'auto'")
209 .Attr("timeout_seconds: float = 0")
210 .Output("communicator: resource")
211 .SetDoNotOptimize() // Also marked in auto_control_dep.py and
212 // function_optimizer.cc
213 .SetIsDistributedCommunication()
214 .SetShapeFn(shape_inference::ScalarShape);
215
216REGISTER_OP("CollectiveReduceV3")
217 .Input("input: T")
218 .Input("communicator: resource")
219 .Input("group_assignment: int32")
220 .Output("data: T")
221 .Attr("T: {bfloat16, float, float16, float64, int32, int64}")
222 .Attr("reduction: {'Min', 'Max', 'Mul', 'Add'}")
223 .Attr("timeout_seconds: float = 0")
224 .SetIsStateful()
225 .SetIsDistributedCommunication()
226 .SetShapeFn(shape_inference::UnchangedShape);
227
228REGISTER_OP("CollectiveAllToAllV3")
229 .Input("input: T")
230 .Input("communicator: resource")
231 .Input("group_assignment: int32")
232 .Output("data: T")
233 .Attr("T: {bfloat16, float, float16, float64, int32, int64}")
234 .Attr("timeout_seconds: float = 0")
235 .SetIsStateful()
236 .SetIsDistributedCommunication()
237 .SetShapeFn(shape_inference::UnchangedShape);
238
239} // namespace tensorflow
240