1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/shape_inference.h" |
19 | #include "tensorflow/core/lib/core/status.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | using shape_inference::InferenceContext; |
24 | using shape_inference::ShapeHandle; |
25 | |
26 | // Configuring a distributed TPU system is achieved by running |
27 | // the following Ops: |
28 | // |
29 | // 1 Run _DisconnectHostFromDistributedTPUSystem on the TPU_SYSTEM of each |
30 | // host. This is needed in case the system had previously been configured. It |
31 | // returns, for each host, the number of TPU chips on the host. |
32 | // |
33 | // 2 Run _ConfigureDistributedTPU on TPU_SYSTEM of worker 0. Takes as input the |
34 | // number of chips on each host. Validates that all hosts have the same number |
35 | // of chips, and that the chips are consistent with the topology set by |
36 | // flags. Has a single output which is a proto describing the requested system |
37 | // configuration, which is sent to all hosts. Note that for multi-client setups |
38 | // the input to _ConfigureDistributedTPU refers only to hosts controlled by the |
39 | // local process/client; the topology set by flags determines the total number |
40 | // of hosts across all clients, and this is reflected in the return value. |
41 | // |
42 | // 3 Run _InitializeHostForDistributedTPU on the TPU_SYSTEM of each host, taking |
43 | // as input the output from ConfigureDistributedTPU. Has a single Tensor output |
44 | // which is a vector of int32 indicating, for each TPU on the host, what its |
45 | // global TPU system id is. |
46 | // |
47 | // 4 Run _WaitForDistributedTPU on TPU_SYSTEM, taking as input the |
48 | // outputs from all the _InitializeHostForDistributedTPU |
49 | // Ops. _These partial specs are combined in the Op with the outputs from |
50 | // the host initialization Ops to construct a mapping from full TPU device |
51 | // specs to global TPU ids. Has a single Tensor output which is a |
52 | // matrix of int32 indicating, for each host (outer dimension) and for |
53 | // each TPU on the host (inner dimension) what that TPU's global id |
54 | // is. _WaitForDistributedTPU also waits for the TPU distributed |
55 | // system to initialize fully, which may take several minutes for a |
56 | // large system. |
57 | // |
58 | // 5 Run _SetGlobalTPUArray on the TPU_SYSTEM of each host, taking as input the |
59 | // output from _WaitForDistributedTPU. This Op tells each host the global Id of |
60 | // every TPU on every host. |
61 | // |
62 | // Most user code works by placing the ConfigureDistributedTPU Op on the desired |
63 | // TPU_SYSTEM device, and a graph rewrite replaces it by the subgraph described |
64 | // above. |
65 | // |
66 | // |
67 | // A distributed TPU system can be cleanly shut down by running the following |
68 | // Ops: |
69 | // |
70 | // 1 Run _DisconnectHostFromDistributedTPUSystem on the TPU_SYSTEM of each host. |
71 | // |
72 | // 2 Run _ShutdownDistributedTPU on the TPU_SYSTEM where |
73 | // _ConfigureDistributedTPU was run. The Op will return an error if no system is |
74 | // configured. |
75 | // |
76 | // |
77 | // Most user code works by placing the ShutdownDistributedTPU Op on the desired |
78 | // TPU_SYSTEM device, and a graph rewrite replaces it by the subgraph described |
79 | // above. |
80 | |
81 | REGISTER_OP("_ConfigureDistributedTPU" ) |
82 | .Input("inputs: N * int32" ) |
83 | .Output("output: string" ) |
84 | .Attr("N: int >= 1" ) |
85 | .Attr("enable_whole_mesh_compilations: bool = false" ) |
86 | .SetIsStateful() |
87 | .SetShapeFn([](InferenceContext* c) { |
88 | ShapeHandle input; |
89 | // Validate that all the inputs are scalars. |
90 | for (int i = 0; i < c->num_inputs(); ++i) { |
91 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &input)); |
92 | } |
93 | c->set_output(0, c->Scalar()); |
94 | return OkStatus(); |
95 | }) |
96 | .Doc(R"doc( |
97 | An op that sets up the centralized structures for a distributed TPU |
98 | system. |
99 | |
100 | inputs: A scalar tensor for each host indicating how many TPU chips |
101 | there are on the host. |
102 | output: A tensor containing a TPUHostConfiguration proto serialized to |
103 | a string, containing the information necessary to initialize the chips |
104 | in a host. |
105 | enable_whole_mesh_compilations: Usually the master TPU worker is the only |
106 | worker compile ops are sent, and the master worker is the only one which |
107 | can execute them. Other TPU clients distribute TPU compilation across all |
108 | the hosts of the mesh, and setting this flag to True enables such mesh |
109 | initialization mode. |
110 | )doc" ); |
111 | |
112 | REGISTER_OP("_WaitForDistributedTPU" ) |
113 | .Input("inputs: N * int32" ) |
114 | .Output("topology: string" ) |
115 | .Attr("startup_timeout_sec: int = 20" ) |
116 | .Attr("N: int" ) |
117 | .SetIsStateful() |
118 | .SetShapeFn([](InferenceContext* c) { |
119 | ShapeHandle input; |
120 | // Validate that all the inputs have the same vector shape. |
121 | for (int i = 0; i < c->num_inputs(); ++i) { |
122 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &input)); |
123 | } |
124 | c->set_output(0, c->Scalar()); |
125 | return OkStatus(); |
126 | }) |
127 | .Doc(R"doc( |
128 | An op that blocks execution until a distributed TPU system has |
129 | started up. This Op must be run on the same TPU_SYSTEM device as |
130 | _ConfigureDistributedTPU, and takes an inputs the outputs from the |
131 | _InitializeHostForDistributedTPU Ops. |
132 | |
133 | inputs: For each initialized host, a vector giving the global TPU id |
134 | of each TPU on the host. |
135 | topology: A serialized tensorflow.tpu.TopologyProto that describes the TPU |
136 | topology. |
137 | startup_timeout_sec: The number of seconds to wait for the TPU system |
138 | to stabilize. |
139 | )doc" ); |
140 | |
141 | REGISTER_OP("_SetGlobalTPUArray" ) |
142 | .Input("topology: string" ) |
143 | .SetIsStateful() |
144 | .SetShapeFn([](InferenceContext* c) { |
145 | ShapeHandle input; |
146 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &input)); |
147 | return OkStatus(); |
148 | }) |
149 | .Doc(R"doc( |
150 | An op that informs a host of the global ids of all the of TPUs in the |
151 | system. |
152 | |
153 | topology: A serialized tensorflow.tpu.TopologyProto that describes the TPU |
154 | topology. |
155 | )doc" ); |
156 | |
157 | REGISTER_OP("_ShutdownDistributedTPU" ) |
158 | .SetIsStateful() |
159 | .SetShapeFn(shape_inference::UnknownShape) |
160 | .Doc(R"doc( |
161 | An op that shuts down a running distributed TPU system. The Op returns |
162 | an error if no system is running. This Op must be run on the same |
163 | TPU_SYSTEM device as the corresponding _ConfigureDistributedTPU was run |
164 | to start the system, and must be run only after |
165 | _DisconnectHostFromDistributedTPUSystem has completed on every host in |
166 | the system. |
167 | )doc" ); |
168 | |
169 | REGISTER_OP("_InitializeHostForDistributedTPU" ) |
170 | .Input("input: string" ) |
171 | .Output("tpu_ids: int32" ) |
172 | .Attr("enable_whole_mesh_compilations: bool = false" ) |
173 | // Available values: 0 (unset), 1 (enabled) or 2 (disabled). |
174 | // This attribute is ignored in non-TFRT TPU runtime. |
175 | .Attr("tpu_cancellation_closes_chips: int = 0" ) |
176 | .SetIsStateful() |
177 | .SetShapeFn([](InferenceContext* c) { |
178 | ShapeHandle input; |
179 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &input)); |
180 | c->set_output(0, c->Vector(c->UnknownDim())); |
181 | return OkStatus(); |
182 | }) |
183 | .Doc(R"doc( |
184 | An op that connects each chip on the host to a centralized UberDriver to allow |
185 | them to operate as a distributed system with chips in other hosts. |
186 | |
187 | input: A string containing the address of the UberDriver to connect to. |
188 | tpu_ids: A vector containing the global TPU id of each TPU on the host. |
189 | enable_whole_mesh_compilations: Usually the master TPU worker is the only |
190 | worker compile ops are sent, and the master worker is the only one which |
191 | can execute them. Other TPU clients distribute TPU compilation across all |
192 | the hosts of the mesh, and setting this flag to True enables such mesh |
193 | initialization mode. |
194 | )doc" ); |
195 | |
196 | REGISTER_OP("_DisconnectHostFromDistributedTPUSystem" ) |
197 | .Output("number_of_tpu_chips: int32" ) |
198 | .SetIsStateful() |
199 | .SetShapeFn(shape_inference::UnknownShape) |
200 | .Doc(R"doc( |
201 | An op that disconnects the TPUs on a host from a running distributed |
202 | TPU system. |
203 | |
204 | number_of_tpu_chips: A scalar tensor containing the number of TPU |
205 | chips on the host. |
206 | )doc" ); |
207 | |
208 | REGISTER_OP("ConfigureDistributedTPU" ) |
209 | .Output("topology: string" ) |
210 | .Attr("embedding_config: string = ''" ) |
211 | .Attr("tpu_embedding_config: string = ''" ) |
212 | .Attr("is_global_init: bool = false" ) |
213 | .Attr("enable_whole_mesh_compilations: bool = false" ) |
214 | .Attr("compilation_failure_closes_chips: bool = true" ) |
215 | // Available values: 0 (unset), 1 (enabled) or 2 (disabled). |
216 | // This attribute is ignored in non-TFRT TPU runtime. |
217 | .Attr("tpu_cancellation_closes_chips: int = 0" ) |
218 | .SetIsStateful() |
219 | .SetShapeFn(shape_inference::UnknownShape); |
220 | |
221 | REGISTER_OP("ShutdownDistributedTPU" ) |
222 | .SetIsStateful() |
223 | .SetShapeFn(shape_inference::UnknownShape); |
224 | |
225 | REGISTER_OP("ConfigureTPUEmbedding" ) |
226 | .Attr("config: string" ) |
227 | .SetIsStateful() |
228 | .SetShapeFn(shape_inference::UnknownShape); |
229 | |
230 | REGISTER_OP("IsTPUEmbeddingInitialized" ) |
231 | .Output("is_tpu_embedding_initialized: bool" ) |
232 | .Attr("config: string = ''" ) |
233 | .SetDoNotOptimize() |
234 | .SetShapeFn(shape_inference::ScalarShape); |
235 | } // end namespace tensorflow |
236 | |