1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/shape_inference.h"
19#include "tensorflow/core/lib/core/status.h"
20
21namespace tensorflow {
22
23using shape_inference::InferenceContext;
24using shape_inference::ShapeHandle;
25
26// Configuring a distributed TPU system is achieved by running
27// the following Ops:
28//
29// 1 Run _DisconnectHostFromDistributedTPUSystem on the TPU_SYSTEM of each
30// host. This is needed in case the system had previously been configured. It
31// returns, for each host, the number of TPU chips on the host.
32//
33// 2 Run _ConfigureDistributedTPU on TPU_SYSTEM of worker 0. Takes as input the
34// number of chips on each host. Validates that all hosts have the same number
35// of chips, and that the chips are consistent with the topology set by
36// flags. Has a single output which is a proto describing the requested system
37// configuration, which is sent to all hosts. Note that for multi-client setups
38// the input to _ConfigureDistributedTPU refers only to hosts controlled by the
39// local process/client; the topology set by flags determines the total number
40// of hosts across all clients, and this is reflected in the return value.
41//
42// 3 Run _InitializeHostForDistributedTPU on the TPU_SYSTEM of each host, taking
43// as input the output from ConfigureDistributedTPU. Has a single Tensor output
44// which is a vector of int32 indicating, for each TPU on the host, what its
45// global TPU system id is.
46//
47// 4 Run _WaitForDistributedTPU on TPU_SYSTEM, taking as input the
48// outputs from all the _InitializeHostForDistributedTPU
49// Ops. _These partial specs are combined in the Op with the outputs from
50// the host initialization Ops to construct a mapping from full TPU device
51// specs to global TPU ids. Has a single Tensor output which is a
52// matrix of int32 indicating, for each host (outer dimension) and for
53// each TPU on the host (inner dimension) what that TPU's global id
54// is. _WaitForDistributedTPU also waits for the TPU distributed
55// system to initialize fully, which may take several minutes for a
56// large system.
57//
58// 5 Run _SetGlobalTPUArray on the TPU_SYSTEM of each host, taking as input the
59// output from _WaitForDistributedTPU. This Op tells each host the global Id of
60// every TPU on every host.
61//
62// Most user code works by placing the ConfigureDistributedTPU Op on the desired
63// TPU_SYSTEM device, and a graph rewrite replaces it by the subgraph described
64// above.
65//
66//
67// A distributed TPU system can be cleanly shut down by running the following
68// Ops:
69//
70// 1 Run _DisconnectHostFromDistributedTPUSystem on the TPU_SYSTEM of each host.
71//
72// 2 Run _ShutdownDistributedTPU on the TPU_SYSTEM where
73// _ConfigureDistributedTPU was run. The Op will return an error if no system is
74// configured.
75//
76//
77// Most user code works by placing the ShutdownDistributedTPU Op on the desired
78// TPU_SYSTEM device, and a graph rewrite replaces it by the subgraph described
79// above.
80
81REGISTER_OP("_ConfigureDistributedTPU")
82 .Input("inputs: N * int32")
83 .Output("output: string")
84 .Attr("N: int >= 1")
85 .Attr("enable_whole_mesh_compilations: bool = false")
86 .SetIsStateful()
87 .SetShapeFn([](InferenceContext* c) {
88 ShapeHandle input;
89 // Validate that all the inputs are scalars.
90 for (int i = 0; i < c->num_inputs(); ++i) {
91 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &input));
92 }
93 c->set_output(0, c->Scalar());
94 return OkStatus();
95 })
96 .Doc(R"doc(
97An op that sets up the centralized structures for a distributed TPU
98system.
99
100inputs: A scalar tensor for each host indicating how many TPU chips
101there are on the host.
102output: A tensor containing a TPUHostConfiguration proto serialized to
103a string, containing the information necessary to initialize the chips
104in a host.
105enable_whole_mesh_compilations: Usually the master TPU worker is the only
106worker compile ops are sent, and the master worker is the only one which
107can execute them. Other TPU clients distribute TPU compilation across all
108the hosts of the mesh, and setting this flag to True enables such mesh
109initialization mode.
110)doc");
111
112REGISTER_OP("_WaitForDistributedTPU")
113 .Input("inputs: N * int32")
114 .Output("topology: string")
115 .Attr("startup_timeout_sec: int = 20")
116 .Attr("N: int")
117 .SetIsStateful()
118 .SetShapeFn([](InferenceContext* c) {
119 ShapeHandle input;
120 // Validate that all the inputs have the same vector shape.
121 for (int i = 0; i < c->num_inputs(); ++i) {
122 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &input));
123 }
124 c->set_output(0, c->Scalar());
125 return OkStatus();
126 })
127 .Doc(R"doc(
128An op that blocks execution until a distributed TPU system has
129started up. This Op must be run on the same TPU_SYSTEM device as
130_ConfigureDistributedTPU, and takes an inputs the outputs from the
131_InitializeHostForDistributedTPU Ops.
132
133inputs: For each initialized host, a vector giving the global TPU id
134of each TPU on the host.
135topology: A serialized tensorflow.tpu.TopologyProto that describes the TPU
136topology.
137startup_timeout_sec: The number of seconds to wait for the TPU system
138to stabilize.
139)doc");
140
141REGISTER_OP("_SetGlobalTPUArray")
142 .Input("topology: string")
143 .SetIsStateful()
144 .SetShapeFn([](InferenceContext* c) {
145 ShapeHandle input;
146 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &input));
147 return OkStatus();
148 })
149 .Doc(R"doc(
150An op that informs a host of the global ids of all the of TPUs in the
151system.
152
153topology: A serialized tensorflow.tpu.TopologyProto that describes the TPU
154topology.
155)doc");
156
157REGISTER_OP("_ShutdownDistributedTPU")
158 .SetIsStateful()
159 .SetShapeFn(shape_inference::UnknownShape)
160 .Doc(R"doc(
161An op that shuts down a running distributed TPU system. The Op returns
162an error if no system is running. This Op must be run on the same
163TPU_SYSTEM device as the corresponding _ConfigureDistributedTPU was run
164to start the system, and must be run only after
165_DisconnectHostFromDistributedTPUSystem has completed on every host in
166the system.
167)doc");
168
169REGISTER_OP("_InitializeHostForDistributedTPU")
170 .Input("input: string")
171 .Output("tpu_ids: int32")
172 .Attr("enable_whole_mesh_compilations: bool = false")
173 // Available values: 0 (unset), 1 (enabled) or 2 (disabled).
174 // This attribute is ignored in non-TFRT TPU runtime.
175 .Attr("tpu_cancellation_closes_chips: int = 0")
176 .SetIsStateful()
177 .SetShapeFn([](InferenceContext* c) {
178 ShapeHandle input;
179 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &input));
180 c->set_output(0, c->Vector(c->UnknownDim()));
181 return OkStatus();
182 })
183 .Doc(R"doc(
184An op that connects each chip on the host to a centralized UberDriver to allow
185them to operate as a distributed system with chips in other hosts.
186
187input: A string containing the address of the UberDriver to connect to.
188tpu_ids: A vector containing the global TPU id of each TPU on the host.
189enable_whole_mesh_compilations: Usually the master TPU worker is the only
190worker compile ops are sent, and the master worker is the only one which
191can execute them. Other TPU clients distribute TPU compilation across all
192the hosts of the mesh, and setting this flag to True enables such mesh
193initialization mode.
194)doc");
195
196REGISTER_OP("_DisconnectHostFromDistributedTPUSystem")
197 .Output("number_of_tpu_chips: int32")
198 .SetIsStateful()
199 .SetShapeFn(shape_inference::UnknownShape)
200 .Doc(R"doc(
201An op that disconnects the TPUs on a host from a running distributed
202TPU system.
203
204number_of_tpu_chips: A scalar tensor containing the number of TPU
205chips on the host.
206)doc");
207
208REGISTER_OP("ConfigureDistributedTPU")
209 .Output("topology: string")
210 .Attr("embedding_config: string = ''")
211 .Attr("tpu_embedding_config: string = ''")
212 .Attr("is_global_init: bool = false")
213 .Attr("enable_whole_mesh_compilations: bool = false")
214 .Attr("compilation_failure_closes_chips: bool = true")
215 // Available values: 0 (unset), 1 (enabled) or 2 (disabled).
216 // This attribute is ignored in non-TFRT TPU runtime.
217 .Attr("tpu_cancellation_closes_chips: int = 0")
218 .SetIsStateful()
219 .SetShapeFn(shape_inference::UnknownShape);
220
221REGISTER_OP("ShutdownDistributedTPU")
222 .SetIsStateful()
223 .SetShapeFn(shape_inference::UnknownShape);
224
225REGISTER_OP("ConfigureTPUEmbedding")
226 .Attr("config: string")
227 .SetIsStateful()
228 .SetShapeFn(shape_inference::UnknownShape);
229
230REGISTER_OP("IsTPUEmbeddingInitialized")
231 .Output("is_tpu_embedding_initialized: bool")
232 .Attr("config: string = ''")
233 .SetDoNotOptimize()
234 .SetShapeFn(shape_inference::ScalarShape);
235} // end namespace tensorflow
236