1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | |
19 | namespace tensorflow { |
20 | |
21 | using shape_inference::InferenceContext; |
22 | using shape_inference::ShapeHandle; |
23 | |
24 | REGISTER_OP("NcclAllReduce" ) |
25 | .Input("input: T" ) |
26 | .Output("data: T" ) |
27 | .Attr("reduction: {'min', 'max', 'prod', 'sum'}" ) |
28 | .Attr("T: {half, float, float64, int32, int64}" ) |
29 | .Attr("num_devices: int" ) |
30 | .Attr("shared_name: string" ) |
31 | .SetIsStateful() |
32 | .SetShapeFn(shape_inference::UnchangedShape); |
33 | |
34 | // Note: This op has no kernel implementation, but is replaced by |
35 | // _NcclReduceSend and _NcclReduceRecv during graph optimization stage. |
36 | REGISTER_OP("NcclReduce" ) |
37 | .Input("input: num_devices * T" ) |
38 | .Output("data: T" ) |
39 | .Attr("reduction: {'min', 'max', 'prod', 'sum'}" ) |
40 | .Attr("T: {half, float, float64, int32, int64}" ) |
41 | .Attr("num_devices: int" ) |
42 | .SetIsStateful() |
43 | .SetShapeFn(shape_inference::UnchangedShape); |
44 | |
45 | REGISTER_OP("_NcclReduceSend" ) |
46 | .Input("input: T" ) |
47 | .Attr("reduction: {'min', 'max', 'prod', 'sum'}" ) |
48 | .Attr("T: {half, float, float64, int32, int64}" ) |
49 | .Attr("num_devices: int" ) |
50 | .Attr("shared_name: string" ) |
51 | .SetIsStateful() |
52 | .SetShapeFn(shape_inference::NoOutputs) |
53 | .Doc(R"doc( |
54 | Replacement node for NcclReduce. |
55 | |
56 | Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`. |
57 | The graph should be constructed so that 'num_devices-1' devices run |
58 | `_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value |
59 | `c`. Failure to do so will cause the graph execution to fail to complete. |
60 | |
61 | input: The input to the reduction. |
62 | reduction: the reduction operation to perform. |
63 | num_devices: The number of devices participating in this reduction. |
64 | shared_name: Identifier that is shared between ops of the same reduce. |
65 | )doc" ); |
66 | |
67 | REGISTER_OP("_NcclReduceRecv" ) |
68 | .Input("input: T" ) |
69 | .Output("data: T" ) |
70 | .Attr("reduction: {'min', 'max', 'prod', 'sum'}" ) |
71 | .Attr("T: {half, float, float64, int32, int64}" ) |
72 | .Attr("num_devices: int" ) |
73 | .Attr("shared_name: string" ) |
74 | .SetIsStateful() |
75 | .SetShapeFn(shape_inference::UnchangedShape) |
76 | .Doc(R"doc( |
77 | Replacement node for NcclReduce. |
78 | |
79 | Reduces 'input' from this op and the NcclReduceSend ops registered in the same |
80 | `shared_name`. |
81 | The graph should be constructed so that 'num_devices-1' devices run |
82 | `_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value |
83 | `c`. Failure to do so will cause the graph execution to fail to complete. |
84 | |
85 | input: The input to the reduction. |
86 | data: The reduced data received from this op and the NcclReduceSend op. |
87 | reduction: the reduction operation to perform. |
88 | num_devices: The number of devices participating in this reduction. |
89 | shared_name: Identifier that is shared between ops of the same reduce. |
90 | )doc" ); |
91 | |
92 | // Note: This op has no kernel implementation, but is replaced by |
93 | // _NcclBroadcastSend and _NcclBroadcastRecv during graph optimization stage. |
94 | REGISTER_OP("NcclBroadcast" ) |
95 | .Input("input: T" ) |
96 | .Output("output: T" ) |
97 | .Attr("T: {half, float, float64, int32, int64}" ) |
98 | .Attr("shape: shape" ) |
99 | .SetIsStateful() |
100 | .SetShapeFn(shape_inference::UnchangedShape); |
101 | |
102 | REGISTER_OP("_NcclBroadcastSend" ) |
103 | .Input("input: T" ) |
104 | .Attr("T: {half, float, float64, int32, int64}" ) |
105 | .Attr("num_devices: int" ) |
106 | .Attr("shared_name: string" ) |
107 | .SetIsStateful() |
108 | .SetShapeFn(shape_inference::NoOutputs) |
109 | .Doc(R"doc( |
110 | Replacement node for NcclBroadcast. |
111 | |
112 | Sends `input` to the _NcclBroadcastRecv ops registered in the same |
113 | `shared_name`. |
114 | The graph should be constructed so that one device runs `_NcclBroadcastSend` and |
115 | `num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`. |
116 | Failure to do so will cause the graph execution to fail to complete. |
117 | |
118 | input: The input to the broadcast. |
119 | num_devices: The number of devices participating in this reduction. |
120 | shared_name: Identifier that is shared between ops of the same broadcast. |
121 | )doc" ); |
122 | |
123 | REGISTER_OP("_NcclBroadcastRecv" ) |
124 | .Input("shape: int32" ) |
125 | .Output("output: T" ) |
126 | .Attr("T: {half, float, float64, int32, int64}" ) |
127 | .Attr("num_devices: int" ) |
128 | .Attr("shared_name: string" ) |
129 | .SetIsStateful() |
130 | .SetShapeFn([](InferenceContext* c) { |
131 | ShapeHandle out; |
132 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); |
133 | c->set_output(0, out); |
134 | return OkStatus(); |
135 | }) |
136 | .Doc(R"doc( |
137 | Replacement node for NcclBroadcast. |
138 | |
139 | Sends data of shape `shape` from the _NcclBroadcastSend op registered in the |
140 | same `shared_name`. |
141 | The graph should be constructed so that one device runs `_NcclBroadcastSend` and |
142 | `num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`. |
143 | Failure to do so will cause the graph execution to fail to complete. |
144 | |
145 | shape: The shape of the output. |
146 | output: The broadcast data received from the NcclBroadcastSend op. |
147 | num_devices: The number of devices participating in this reduction. |
148 | shared_name: Identifier that is shared between ops of the same broadcast. |
149 | )doc" ); |
150 | |
151 | } // namespace tensorflow |
152 | |