1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <string>
17#include <vector>
18
19#include "tensorflow/core/framework/attr_value.pb.h"
20#include "tensorflow/core/framework/common_shape_fns.h"
21#include "tensorflow/core/framework/op.h"
22#include "tensorflow/core/framework/shape_inference.h"
23#include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h"
24#include "tensorflow/core/tpu/tpu_embedding_output_layout_utils.h"
25
26namespace tensorflow {
27
28// TPUs use a specialized mechanism for performing embedding lookups,
29// necessitating differences in TF Graphs that use embeddings on TPUs relative
30// to CPUs. Embedding lookups on TPU systems are achieved by including the
31// following in the TF Graph.
32//
33// 0. Construct a TPUEmbeddingConfiguration, specifying the embedding tables
34// in the model, the size of the TPU system to be used, and the optimizer to
35// be used for each table. Some of this information is redundant with other
36// pieces of the TF Graph.
37// 1. Pass this TPUEmbeddingConfiguration to tpu.initialize_system() as the
38// tpu_embedding_config parameter.
39// 2. Use the LoadTPUEmbedding Ops to initialize the embedding tables in TPU
40// memories, sharded across the memories attached to each Host.
41// 3. Use EnqueueTPUEmbeddingSparseBatch to provide the TPU with embedding
42// indices and aggregation weights.
43// 4. RecvTPUEmbeddingActivations returns a list of Tensors, containing the
44// activations from each table specified in the configuration.
45// 5. TPUEmbeddingActivations, when used with appropriate Python libraries,
46// enables the automatic differentiation of models that use embeddings.
47// 6. SendTPUEmbeddingGradients takes a list of Tensors (of the same shapes
48// as those returned by TPUEmbeddingReceiveActivations) containing gradients
49// to use in updating the embedding tables.
50// 7. Before saving a checkpoint, use the RetrieveTPUEmbedding Ops to update
51// the Graph's embedding table Variables from the updated tables in the
52// TPU memories.
53//
54// TPU Embeddings use dedicated ops to enforce Host/TPU consistency in the
55// state of embedding table variables. Before beginning training or inference,
56// the model must Load the optimizer parameters into the TPU memories. Before
57// saving a checkpoint, the model must Retrieve the parameters back into the
58// host CPU memory.
59
60REGISTER_OP("RecvTPUEmbeddingActivations")
61 .Output("outputs: num_outputs * float32")
62 .Attr("num_outputs: int >= 1")
63 .Attr("config: string")
64 .SetIsStateful()
65 .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
66 std::string config_string;
67 TF_RETURN_IF_ERROR(c->GetAttr("config", &config_string));
68 tpu::TPUEmbeddingConfiguration config;
69 if (!config.ParseFromString(config_string)) {
70 return errors::InvalidArgument("Malformed tpu_embedding_config.");
71 }
72 std::vector<TensorShapeProto> output_shapes;
73 TF_RETURN_IF_ERROR(ComputeOutputTensorShapes(config, &output_shapes));
74 if (c->num_outputs() != output_shapes.size()) {
75 return errors::InvalidArgument("num outputs != size of output shapes");
76 }
77 for (int i = 0; i < c->num_outputs(); ++i) {
78 shape_inference::ShapeHandle output_shape;
79 TF_RETURN_IF_ERROR(
80 c->MakeShapeFromShapeProto(output_shapes[i], &output_shape));
81 c->set_output(i, output_shape);
82 }
83 return OkStatus();
84 });
85
86REGISTER_OP("TPUEmbeddingActivations")
87 .Input("embedding_variable: float32")
88 .Input("sliced_activations: float32")
89 .Output("output: float32")
90 .Attr("table_id: int >= 0")
91 .Attr("lookup_id: int >= 0")
92 .SetShapeFn([](shape_inference::InferenceContext *c) {
93 c->set_output(0, c->input(1));
94 return OkStatus();
95 });
96
97REGISTER_OP("SendTPUEmbeddingGradients")
98 .Input("inputs: N * float32")
99 .Input("learning_rates: NN * float32")
100 .Attr("N: int >= 1")
101 .Attr("NN: int >= 0 = 0")
102 .Attr("config: string")
103 .SetIsStateful()
104 .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
105 int nn;
106 TF_RETURN_IF_ERROR(c->GetAttr("NN", &nn));
107 std::vector<shape_inference::ShapeHandle> learning_rates;
108 TF_RETURN_IF_ERROR(c->input("learning_rates", &learning_rates));
109 for (int i = 0; i < nn; ++i) {
110 // Verify that each learning_rates element is scalar
111 shape_inference::ShapeHandle learning_rates_shape;
112 TF_RETURN_IF_ERROR(
113 c->WithRank(learning_rates[i], 0, &learning_rates_shape));
114 }
115
116 return OkStatus();
117 });
118
119REGISTER_OP("EnqueueTPUEmbeddingIntegerBatch")
120 .Input("batch: N * int32")
121 .Input("mode_override: string")
122 .Attr("N: int >= 1")
123 .Attr("device_ordinal: int = -1")
124 .SetIsStateful()
125 .SetShapeFn(shape_inference::UnknownShape);
126
127REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
128 .Input("sample_indices: N * T1")
129 .Input("embedding_indices: N * T2")
130 .Input("aggregation_weights: N * T3")
131 .Input("mode_override: string")
132 .Attr("T1: {int32,int64} = DT_INT32")
133 .Attr("T2: {int32,int64} = DT_INT32")
134 .Attr("T3: {float32,float64} = DT_FLOAT")
135 .Attr("N: int >= 1")
136 .Attr("device_ordinal: int = -1")
137 .Attr("combiners: list(string) = []")
138 .SetIsStateful()
139 .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
140 std::vector<string> combiners;
141 TF_RETURN_IF_ERROR(c->GetAttr("combiners", &combiners));
142 int n;
143 TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
144 if (!combiners.empty() && combiners.size() != n) {
145 return errors::InvalidArgument("Invalid length of combiners. Have ",
146 combiners.size(), " but expected 0 or ",
147 n);
148 }
149
150 return OkStatus();
151 });
152
153REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
154 .Input("sample_indices: N * T1")
155 .Input("embedding_indices: N * T2")
156 .Input("aggregation_weights: N * T3")
157 .Input("mode_override: string")
158 .Attr("T1: {int32,int64} = DT_INT32")
159 .Attr("T2: {int32,int64} = DT_INT32")
160 .Attr("T3: {float32,float64} = DT_FLOAT")
161 .Attr("N: int >= 1")
162 .Attr("device_ordinal: int = -1")
163 .Attr("combiners: list(string) = []")
164 .Attr("table_ids: list(int)")
165 .Attr("max_sequence_lengths: list(int) = []")
166 .Attr("num_features: list(int) = []")
167 .SetIsStateful()
168 .SetShapeFn(shape_inference::UnknownShape);
169
170REGISTER_OP("EnqueueTPUEmbeddingRaggedTensorBatch")
171 .Input("sample_splits: N * T1")
172 .Input("embedding_indices: N * T2")
173 .Input("aggregation_weights: N * T3")
174 .Input("mode_override: string")
175 .Attr("T1: {int32,int64} = DT_INT32")
176 .Attr("T2: {int32,int64} = DT_INT32")
177 .Attr("T3: {float32,float64} = DT_FLOAT")
178 .Attr("N: int >= 1")
179 .Attr("device_ordinal: int = -1")
180 .Attr("combiners: list(string) = []")
181 .Attr("table_ids: list(int)")
182 .Attr("max_sequence_lengths: list(int) = []")
183 .Attr("num_features: list(int) = []")
184 .SetIsStateful()
185 .SetShapeFn(shape_inference::UnknownShape);
186
187REGISTER_OP("EnqueueTPUEmbeddingArbitraryTensorBatch")
188 .Input("sample_indices_or_row_splits: N * T1")
189 .Input("embedding_indices: N * T2")
190 .Input("aggregation_weights: N * T3")
191 .Input("mode_override: string")
192 .Attr("T1: {int32,int64} = DT_INT32")
193 .Attr("T2: {int32,int64} = DT_INT32")
194 .Attr("T3: {float32,float64} = DT_FLOAT")
195 .Attr("N: int >= 1")
196 .Attr("device_ordinal: int = -1")
197 .Attr("combiners: list(string) = []")
198 .SetIsStateful()
199 .SetShapeFn(shape_inference::UnknownShape);
200
201REGISTER_OP("DynamicEnqueueTPUEmbeddingArbitraryTensorBatch")
202 .Input("sample_indices_or_row_splits: N * T1")
203 .Input("embedding_indices: N * T2")
204 .Input("aggregation_weights: N * T3")
205 .Input("mode_override: string")
206 .Input("device_ordinal: int32")
207 .Attr("T1: {int32,int64} = DT_INT32")
208 .Attr("T2: {int32,int64} = DT_INT32")
209 .Attr("T3: {float32,float64} = DT_FLOAT")
210 .Attr("N: int >= 1")
211 .Attr("combiners: list(string) = []")
212 .SetIsStateful()
213 .SetShapeFn(shape_inference::UnknownShape);
214
215} // namespace tensorflow
216