1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18
19namespace tensorflow {
20
21using shape_inference::InferenceContext;
22using shape_inference::ShapeHandle;
23
24REGISTER_OP("NcclAllReduce")
25 .Input("input: T")
26 .Output("data: T")
27 .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
28 .Attr("T: {half, float, float64, int32, int64}")
29 .Attr("num_devices: int")
30 .Attr("shared_name: string")
31 .SetIsStateful()
32 .SetShapeFn(shape_inference::UnchangedShape);
33
34// Note: This op has no kernel implementation, but is replaced by
35// _NcclReduceSend and _NcclReduceRecv during graph optimization stage.
36REGISTER_OP("NcclReduce")
37 .Input("input: num_devices * T")
38 .Output("data: T")
39 .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
40 .Attr("T: {half, float, float64, int32, int64}")
41 .Attr("num_devices: int")
42 .SetIsStateful()
43 .SetShapeFn(shape_inference::UnchangedShape);
44
45REGISTER_OP("_NcclReduceSend")
46 .Input("input: T")
47 .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
48 .Attr("T: {half, float, float64, int32, int64}")
49 .Attr("num_devices: int")
50 .Attr("shared_name: string")
51 .SetIsStateful()
52 .SetShapeFn(shape_inference::NoOutputs)
53 .Doc(R"doc(
54Replacement node for NcclReduce.
55
56Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`.
57The graph should be constructed so that 'num_devices-1' devices run
58`_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value
59`c`. Failure to do so will cause the graph execution to fail to complete.
60
61input: The input to the reduction.
62reduction: the reduction operation to perform.
63num_devices: The number of devices participating in this reduction.
64shared_name: Identifier that is shared between ops of the same reduce.
65 )doc");
66
67REGISTER_OP("_NcclReduceRecv")
68 .Input("input: T")
69 .Output("data: T")
70 .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
71 .Attr("T: {half, float, float64, int32, int64}")
72 .Attr("num_devices: int")
73 .Attr("shared_name: string")
74 .SetIsStateful()
75 .SetShapeFn(shape_inference::UnchangedShape)
76 .Doc(R"doc(
77Replacement node for NcclReduce.
78
79Reduces 'input' from this op and the NcclReduceSend ops registered in the same
80`shared_name`.
81The graph should be constructed so that 'num_devices-1' devices run
82`_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value
83`c`. Failure to do so will cause the graph execution to fail to complete.
84
85input: The input to the reduction.
86data: The reduced data received from this op and the NcclReduceSend op.
87reduction: the reduction operation to perform.
88num_devices: The number of devices participating in this reduction.
89shared_name: Identifier that is shared between ops of the same reduce.
90 )doc");
91
92// Note: This op has no kernel implementation, but is replaced by
93// _NcclBroadcastSend and _NcclBroadcastRecv during graph optimization stage.
94REGISTER_OP("NcclBroadcast")
95 .Input("input: T")
96 .Output("output: T")
97 .Attr("T: {half, float, float64, int32, int64}")
98 .Attr("shape: shape")
99 .SetIsStateful()
100 .SetShapeFn(shape_inference::UnchangedShape);
101
102REGISTER_OP("_NcclBroadcastSend")
103 .Input("input: T")
104 .Attr("T: {half, float, float64, int32, int64}")
105 .Attr("num_devices: int")
106 .Attr("shared_name: string")
107 .SetIsStateful()
108 .SetShapeFn(shape_inference::NoOutputs)
109 .Doc(R"doc(
110Replacement node for NcclBroadcast.
111
112Sends `input` to the _NcclBroadcastRecv ops registered in the same
113`shared_name`.
114The graph should be constructed so that one device runs `_NcclBroadcastSend` and
115`num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`.
116Failure to do so will cause the graph execution to fail to complete.
117
118input: The input to the broadcast.
119num_devices: The number of devices participating in this reduction.
120shared_name: Identifier that is shared between ops of the same broadcast.
121 )doc");
122
123REGISTER_OP("_NcclBroadcastRecv")
124 .Input("shape: int32")
125 .Output("output: T")
126 .Attr("T: {half, float, float64, int32, int64}")
127 .Attr("num_devices: int")
128 .Attr("shared_name: string")
129 .SetIsStateful()
130 .SetShapeFn([](InferenceContext* c) {
131 ShapeHandle out;
132 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
133 c->set_output(0, out);
134 return OkStatus();
135 })
136 .Doc(R"doc(
137Replacement node for NcclBroadcast.
138
139Sends data of shape `shape` from the _NcclBroadcastSend op registered in the
140same `shared_name`.
141The graph should be constructed so that one device runs `_NcclBroadcastSend` and
142`num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`.
143Failure to do so will cause the graph execution to fail to complete.
144
145shape: The shape of the output.
146output: The broadcast data received from the NcclBroadcastSend op.
147num_devices: The number of devices participating in this reduction.
148shared_name: Identifier that is shared between ops of the same broadcast.
149 )doc");
150
151} // namespace tensorflow
152