1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/shape_inference.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | REGISTER_OP("CollectiveReduce" ) |
23 | .Input("input: T" ) |
24 | .Output("data: T" ) |
25 | .Attr("T: {bfloat16, float, float16, float64, int32, int64}" ) |
26 | .Attr("group_size: int" ) |
27 | .Attr("group_key: int" ) |
28 | .Attr("instance_key: int" ) |
29 | .Attr("merge_op: {'Min', 'Max', 'Mul', 'Add'}" ) |
30 | .Attr("final_op: {'Id', 'Div'}" ) |
31 | .Attr("subdiv_offsets: list(int)" ) |
32 | .Attr("wait_for: list(int) = []" ) |
33 | .Attr("communication_hint: string = 'auto'" ) |
34 | .Attr("timeout_seconds: float = 0" ) |
35 | .SetIsStateful() |
36 | .SetIsDistributedCommunication() |
37 | .SetShapeFn(shape_inference::UnchangedShape); |
38 | |
39 | REGISTER_OP("CollectiveGather" ) |
40 | .Input("input: T" ) |
41 | .Output("data: T" ) |
42 | .Attr("T: {float, float16, float64, int32, int64}" ) |
43 | .Attr("group_size: int" ) |
44 | .Attr("group_key: int" ) |
45 | .Attr("instance_key: int" ) |
46 | .Attr("shape: shape" ) |
47 | .Attr("communication_hint: string = 'auto'" ) |
48 | .Attr("timeout_seconds: float = 0" ) |
49 | .SetIsStateful() |
50 | .SetIsDistributedCommunication() |
51 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
52 | // Scalar input is not supported. |
53 | shape_inference::ShapeHandle unused; |
54 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused)); |
55 | |
56 | shape_inference::ShapeHandle in_subshape; |
57 | TF_RETURN_IF_ERROR(c->Subshape(c->input(0), 1, &in_subshape)); |
58 | |
59 | auto input_first_dim_value = c->Value(c->Dim(c->input(0), 0)); |
60 | |
61 | // This output should have the same shape as its input except the first |
62 | // dimension should be multiplied by group size. |
63 | shape_inference::ShapeHandle output_first_dim_as_shape; |
64 | if (input_first_dim_value == |
65 | shape_inference::InferenceContext::kUnknownDim) { |
66 | output_first_dim_as_shape = |
67 | c->Vector(shape_inference::InferenceContext::kUnknownDim); |
68 | } else { |
69 | int group_size; |
70 | TF_CHECK_OK(c->GetAttr("group_size" , &group_size)); |
71 | std::vector<shape_inference::DimensionHandle> output_first_dim; |
72 | output_first_dim.push_back( |
73 | c->MakeDim(group_size * input_first_dim_value)); |
74 | output_first_dim_as_shape = c->MakeShape(output_first_dim); |
75 | } |
76 | |
77 | shape_inference::ShapeHandle out; |
78 | TF_RETURN_IF_ERROR( |
79 | c->Concatenate(output_first_dim_as_shape, in_subshape, &out)); |
80 | c->set_output(0, out); |
81 | return OkStatus(); |
82 | }); |
83 | |
84 | REGISTER_OP("CollectiveBcastSend" ) |
85 | .Input("input: T" ) |
86 | .Output("data: T" ) |
87 | .Attr("T: {bool, float, float16, float64, int32, int64}" ) |
88 | .Attr("group_size: int" ) |
89 | .Attr("group_key: int" ) |
90 | .Attr("instance_key: int" ) |
91 | .Attr("shape: shape" ) |
92 | .Attr("communication_hint: string = 'auto'" ) |
93 | .Attr("timeout_seconds: float = 0" ) |
94 | .SetIsStateful() |
95 | .SetIsDistributedCommunication() |
96 | .SetShapeFn(shape_inference::ExplicitShape); |
97 | |
98 | REGISTER_OP("CollectiveBcastRecv" ) |
99 | .Output("data: T" ) |
100 | .Attr("T: {bool, float, float16, float64, int32, int64}" ) |
101 | .Attr("group_size: int" ) |
102 | .Attr("group_key: int" ) |
103 | .Attr("instance_key: int" ) |
104 | .Attr("shape: shape" ) |
105 | .Attr("communication_hint: string = 'auto'" ) |
106 | .Attr("timeout_seconds: float = 0" ) |
107 | .SetIsStateful() |
108 | .SetIsDistributedCommunication() |
109 | .SetShapeFn(shape_inference::ExplicitShape); |
110 | |
111 | REGISTER_OP("CollectiveAssignGroupV2" ) |
112 | .Input("group_assignment: int32" ) |
113 | .Input("device_index: int32" ) |
114 | .Input("base_key: int32" ) |
115 | .Output("group_size: int32" ) |
116 | .Output("group_key: int32" ) |
117 | // To avoid tensorflow::constant_folding. |
118 | .SetDoNotOptimize() // Also marked in auto_control_dep.py and |
119 | // function_optimizer.cc |
120 | .SetIsDistributedCommunication() |
121 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
122 | c->set_output(0, c->Scalar()); |
123 | c->set_output(1, c->Scalar()); |
124 | return OkStatus(); |
125 | }); |
126 | |
127 | REGISTER_OP("CollectiveReduceV2" ) |
128 | .Input("input: T" ) |
129 | .Output("data: T" ) |
130 | .Attr("T: {bfloat16, float, float16, float64, int32, int64}" ) |
131 | .Input("group_size: int32" ) |
132 | .Input("group_key: int32" ) |
133 | .Input("instance_key: int32" ) |
134 | .Input("ordering_token: Nordering_token * resource" ) |
135 | .Attr("merge_op: {'Min', 'Max', 'Mul', 'Add'}" ) |
136 | .Attr("final_op: {'Id', 'Div'}" ) |
137 | .Attr("communication_hint: string = 'auto'" ) |
138 | .Attr("timeout_seconds: float = 0" ) |
139 | .Attr("Nordering_token: int >= 0 = 0" ) |
140 | .Attr("max_subdivs_per_device: int = -1" ) |
141 | .SetIsStateful() |
142 | .SetIsDistributedCommunication() |
143 | .SetShapeFn(shape_inference::UnchangedShape); |
144 | |
145 | REGISTER_OP("CollectiveGatherV2" ) |
146 | .Input("input: T" ) |
147 | .Output("data: T" ) |
148 | .Attr("T: {float, float16, float64, int32, int64}" ) |
149 | .Input("group_size: int32" ) |
150 | .Input("group_key: int32" ) |
151 | .Input("instance_key: int32" ) |
152 | .Input("ordering_token: Nordering_token * resource" ) |
153 | .Attr("communication_hint: string = 'auto'" ) |
154 | .Attr("timeout_seconds: float = 0" ) |
155 | .Attr("Nordering_token: int >= 0 = 0" ) |
156 | .SetIsStateful() |
157 | .SetIsDistributedCommunication() |
158 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
159 | // Scalar input is not supported. |
160 | shape_inference::ShapeHandle unused; |
161 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused)); |
162 | // This output should have the same shape as its input except the first |
163 | // dimension is unknown, since the group size is unknown. |
164 | shape_inference::ShapeHandle out; |
165 | TF_RETURN_IF_ERROR( |
166 | c->ReplaceDim(c->input(0), /*dim_index*/ 0, c->UnknownDim(), &out)); |
167 | c->set_output(0, out); |
168 | return OkStatus(); |
169 | }); |
170 | |
171 | REGISTER_OP("CollectiveBcastSendV2" ) |
172 | .Input("input: T" ) |
173 | .Output("data: T" ) |
174 | .Attr("T: {bool, float, float16, float64, int32, int64}" ) |
175 | .Input("group_size: int32" ) |
176 | .Input("group_key: int32" ) |
177 | .Input("instance_key: int32" ) |
178 | .Attr("communication_hint: string = 'auto'" ) |
179 | .Attr("timeout_seconds: float = 0" ) |
180 | .SetIsStateful() |
181 | .SetIsDistributedCommunication() |
182 | .SetShapeFn(shape_inference::UnchangedShape); |
183 | |
184 | REGISTER_OP("CollectiveBcastRecvV2" ) |
185 | .Output("data: T" ) |
186 | .Attr("T: {bool, float, float16, float64, int32, int64}" ) |
187 | .Input("group_size: int32" ) |
188 | .Input("group_key: int32" ) |
189 | .Input("instance_key: int32" ) |
190 | .Input("shape: Tshape" ) |
191 | .Attr("Tshape: {int32, int64} = DT_INT32" ) |
192 | .Attr("communication_hint: string = 'auto'" ) |
193 | .Attr("timeout_seconds: float = 0" ) |
194 | .SetIsStateful() |
195 | .SetIsDistributedCommunication() |
196 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
197 | // The output shape is given by the `shape` input at index 3. |
198 | shape_inference::ShapeHandle out; |
199 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(/*input_idx=*/3, &out)); |
200 | c->set_output(/*idx=*/0, out); |
201 | return OkStatus(); |
202 | }); |
203 | |
204 | REGISTER_OP("CollectiveInitializeCommunicator" ) |
205 | .Input("group_key: int32" ) |
206 | .Input("rank: int32" ) |
207 | .Input("group_size: int32" ) |
208 | .Attr("communication_hint: string = 'auto'" ) |
209 | .Attr("timeout_seconds: float = 0" ) |
210 | .Output("communicator: resource" ) |
211 | .SetDoNotOptimize() // Also marked in auto_control_dep.py and |
212 | // function_optimizer.cc |
213 | .SetIsDistributedCommunication() |
214 | .SetShapeFn(shape_inference::ScalarShape); |
215 | |
216 | REGISTER_OP("CollectiveReduceV3" ) |
217 | .Input("input: T" ) |
218 | .Input("communicator: resource" ) |
219 | .Input("group_assignment: int32" ) |
220 | .Output("data: T" ) |
221 | .Attr("T: {bfloat16, float, float16, float64, int32, int64}" ) |
222 | .Attr("reduction: {'Min', 'Max', 'Mul', 'Add'}" ) |
223 | .Attr("timeout_seconds: float = 0" ) |
224 | .SetIsStateful() |
225 | .SetIsDistributedCommunication() |
226 | .SetShapeFn(shape_inference::UnchangedShape); |
227 | |
228 | REGISTER_OP("CollectiveAllToAllV3" ) |
229 | .Input("input: T" ) |
230 | .Input("communicator: resource" ) |
231 | .Input("group_assignment: int32" ) |
232 | .Output("data: T" ) |
233 | .Attr("T: {bfloat16, float, float16, float64, int32, int64}" ) |
234 | .Attr("timeout_seconds: float = 0" ) |
235 | .SetIsStateful() |
236 | .SetIsDistributedCommunication() |
237 | .SetShapeFn(shape_inference::UnchangedShape); |
238 | |
239 | } // namespace tensorflow |
240 | |