1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <string> |
17 | #include <vector> |
18 | |
19 | #include "tensorflow/core/framework/attr_value.pb.h" |
20 | #include "tensorflow/core/framework/common_shape_fns.h" |
21 | #include "tensorflow/core/framework/op.h" |
22 | #include "tensorflow/core/framework/shape_inference.h" |
23 | #include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h" |
24 | #include "tensorflow/core/tpu/tpu_embedding_output_layout_utils.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | // TPUs use a specialized mechanism for performing embedding lookups, |
29 | // necessitating differences in TF Graphs that use embeddings on TPUs relative |
30 | // to CPUs. Embedding lookups on TPU systems are achieved by including the |
31 | // following in the TF Graph. |
32 | // |
33 | // 0. Construct a TPUEmbeddingConfiguration, specifying the embedding tables |
34 | // in the model, the size of the TPU system to be used, and the optimizer to |
35 | // be used for each table. Some of this information is redundant with other |
36 | // pieces of the TF Graph. |
37 | // 1. Pass this TPUEmbeddingConfiguration to tpu.initialize_system() as the |
38 | // tpu_embedding_config parameter. |
39 | // 2. Use the LoadTPUEmbedding Ops to initialize the embedding tables in TPU |
40 | // memories, sharded across the memories attached to each Host. |
41 | // 3. Use EnqueueTPUEmbeddingSparseBatch to provide the TPU with embedding |
42 | // indices and aggregation weights. |
43 | // 4. RecvTPUEmbeddingActivations returns a list of Tensors, containing the |
44 | // activations from each table specified in the configuration. |
45 | // 5. TPUEmbeddingActivations, when used with appropriate Python libraries, |
46 | // enables the automatic differentiation of models that use embeddings. |
47 | // 6. SendTPUEmbeddingGradients takes a list of Tensors (of the same shapes |
48 | // as those returned by TPUEmbeddingReceiveActivations) containing gradients |
49 | // to use in updating the embedding tables. |
50 | // 7. Before saving a checkpoint, use the RetrieveTPUEmbedding Ops to update |
51 | // the Graph's embedding table Variables from the updated tables in the |
52 | // TPU memories. |
53 | // |
54 | // TPU Embeddings use dedicated ops to enforce Host/TPU consistency in the |
55 | // state of embedding table variables. Before beginning training or inference, |
56 | // the model must Load the optimizer parameters into the TPU memories. Before |
57 | // saving a checkpoint, the model must Retrieve the parameters back into the |
58 | // host CPU memory. |
59 | |
60 | REGISTER_OP("RecvTPUEmbeddingActivations" ) |
61 | .Output("outputs: num_outputs * float32" ) |
62 | .Attr("num_outputs: int >= 1" ) |
63 | .Attr("config: string" ) |
64 | .SetIsStateful() |
65 | .SetShapeFn([](shape_inference::InferenceContext* c) -> Status { |
66 | std::string config_string; |
67 | TF_RETURN_IF_ERROR(c->GetAttr("config" , &config_string)); |
68 | tpu::TPUEmbeddingConfiguration config; |
69 | if (!config.ParseFromString(config_string)) { |
70 | return errors::InvalidArgument("Malformed tpu_embedding_config." ); |
71 | } |
72 | std::vector<TensorShapeProto> output_shapes; |
73 | TF_RETURN_IF_ERROR(ComputeOutputTensorShapes(config, &output_shapes)); |
74 | if (c->num_outputs() != output_shapes.size()) { |
75 | return errors::InvalidArgument("num outputs != size of output shapes" ); |
76 | } |
77 | for (int i = 0; i < c->num_outputs(); ++i) { |
78 | shape_inference::ShapeHandle output_shape; |
79 | TF_RETURN_IF_ERROR( |
80 | c->MakeShapeFromShapeProto(output_shapes[i], &output_shape)); |
81 | c->set_output(i, output_shape); |
82 | } |
83 | return OkStatus(); |
84 | }); |
85 | |
86 | REGISTER_OP("TPUEmbeddingActivations" ) |
87 | .Input("embedding_variable: float32" ) |
88 | .Input("sliced_activations: float32" ) |
89 | .Output("output: float32" ) |
90 | .Attr("table_id: int >= 0" ) |
91 | .Attr("lookup_id: int >= 0" ) |
92 | .SetShapeFn([](shape_inference::InferenceContext *c) { |
93 | c->set_output(0, c->input(1)); |
94 | return OkStatus(); |
95 | }); |
96 | |
97 | REGISTER_OP("SendTPUEmbeddingGradients" ) |
98 | .Input("inputs: N * float32" ) |
99 | .Input("learning_rates: NN * float32" ) |
100 | .Attr("N: int >= 1" ) |
101 | .Attr("NN: int >= 0 = 0" ) |
102 | .Attr("config: string" ) |
103 | .SetIsStateful() |
104 | .SetShapeFn([](shape_inference::InferenceContext* c) -> Status { |
105 | int nn; |
106 | TF_RETURN_IF_ERROR(c->GetAttr("NN" , &nn)); |
107 | std::vector<shape_inference::ShapeHandle> learning_rates; |
108 | TF_RETURN_IF_ERROR(c->input("learning_rates" , &learning_rates)); |
109 | for (int i = 0; i < nn; ++i) { |
110 | // Verify that each learning_rates element is scalar |
111 | shape_inference::ShapeHandle learning_rates_shape; |
112 | TF_RETURN_IF_ERROR( |
113 | c->WithRank(learning_rates[i], 0, &learning_rates_shape)); |
114 | } |
115 | |
116 | return OkStatus(); |
117 | }); |
118 | |
119 | REGISTER_OP("EnqueueTPUEmbeddingIntegerBatch" ) |
120 | .Input("batch: N * int32" ) |
121 | .Input("mode_override: string" ) |
122 | .Attr("N: int >= 1" ) |
123 | .Attr("device_ordinal: int = -1" ) |
124 | .SetIsStateful() |
125 | .SetShapeFn(shape_inference::UnknownShape); |
126 | |
127 | REGISTER_OP("EnqueueTPUEmbeddingSparseBatch" ) |
128 | .Input("sample_indices: N * T1" ) |
129 | .Input("embedding_indices: N * T2" ) |
130 | .Input("aggregation_weights: N * T3" ) |
131 | .Input("mode_override: string" ) |
132 | .Attr("T1: {int32,int64} = DT_INT32" ) |
133 | .Attr("T2: {int32,int64} = DT_INT32" ) |
134 | .Attr("T3: {float32,float64} = DT_FLOAT" ) |
135 | .Attr("N: int >= 1" ) |
136 | .Attr("device_ordinal: int = -1" ) |
137 | .Attr("combiners: list(string) = []" ) |
138 | .SetIsStateful() |
139 | .SetShapeFn([](shape_inference::InferenceContext* c) -> Status { |
140 | std::vector<string> combiners; |
141 | TF_RETURN_IF_ERROR(c->GetAttr("combiners" , &combiners)); |
142 | int n; |
143 | TF_RETURN_IF_ERROR(c->GetAttr("N" , &n)); |
144 | if (!combiners.empty() && combiners.size() != n) { |
145 | return errors::InvalidArgument("Invalid length of combiners. Have " , |
146 | combiners.size(), " but expected 0 or " , |
147 | n); |
148 | } |
149 | |
150 | return OkStatus(); |
151 | }); |
152 | |
153 | REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch" ) |
154 | .Input("sample_indices: N * T1" ) |
155 | .Input("embedding_indices: N * T2" ) |
156 | .Input("aggregation_weights: N * T3" ) |
157 | .Input("mode_override: string" ) |
158 | .Attr("T1: {int32,int64} = DT_INT32" ) |
159 | .Attr("T2: {int32,int64} = DT_INT32" ) |
160 | .Attr("T3: {float32,float64} = DT_FLOAT" ) |
161 | .Attr("N: int >= 1" ) |
162 | .Attr("device_ordinal: int = -1" ) |
163 | .Attr("combiners: list(string) = []" ) |
164 | .Attr("table_ids: list(int)" ) |
165 | .Attr("max_sequence_lengths: list(int) = []" ) |
166 | .Attr("num_features: list(int) = []" ) |
167 | .SetIsStateful() |
168 | .SetShapeFn(shape_inference::UnknownShape); |
169 | |
170 | REGISTER_OP("EnqueueTPUEmbeddingRaggedTensorBatch" ) |
171 | .Input("sample_splits: N * T1" ) |
172 | .Input("embedding_indices: N * T2" ) |
173 | .Input("aggregation_weights: N * T3" ) |
174 | .Input("mode_override: string" ) |
175 | .Attr("T1: {int32,int64} = DT_INT32" ) |
176 | .Attr("T2: {int32,int64} = DT_INT32" ) |
177 | .Attr("T3: {float32,float64} = DT_FLOAT" ) |
178 | .Attr("N: int >= 1" ) |
179 | .Attr("device_ordinal: int = -1" ) |
180 | .Attr("combiners: list(string) = []" ) |
181 | .Attr("table_ids: list(int)" ) |
182 | .Attr("max_sequence_lengths: list(int) = []" ) |
183 | .Attr("num_features: list(int) = []" ) |
184 | .SetIsStateful() |
185 | .SetShapeFn(shape_inference::UnknownShape); |
186 | |
187 | REGISTER_OP("EnqueueTPUEmbeddingArbitraryTensorBatch" ) |
188 | .Input("sample_indices_or_row_splits: N * T1" ) |
189 | .Input("embedding_indices: N * T2" ) |
190 | .Input("aggregation_weights: N * T3" ) |
191 | .Input("mode_override: string" ) |
192 | .Attr("T1: {int32,int64} = DT_INT32" ) |
193 | .Attr("T2: {int32,int64} = DT_INT32" ) |
194 | .Attr("T3: {float32,float64} = DT_FLOAT" ) |
195 | .Attr("N: int >= 1" ) |
196 | .Attr("device_ordinal: int = -1" ) |
197 | .Attr("combiners: list(string) = []" ) |
198 | .SetIsStateful() |
199 | .SetShapeFn(shape_inference::UnknownShape); |
200 | |
201 | REGISTER_OP("DynamicEnqueueTPUEmbeddingArbitraryTensorBatch" ) |
202 | .Input("sample_indices_or_row_splits: N * T1" ) |
203 | .Input("embedding_indices: N * T2" ) |
204 | .Input("aggregation_weights: N * T3" ) |
205 | .Input("mode_override: string" ) |
206 | .Input("device_ordinal: int32" ) |
207 | .Attr("T1: {int32,int64} = DT_INT32" ) |
208 | .Attr("T2: {int32,int64} = DT_INT32" ) |
209 | .Attr("T3: {float32,float64} = DT_FLOAT" ) |
210 | .Attr("N: int >= 1" ) |
211 | .Attr("combiners: list(string) = []" ) |
212 | .SetIsStateful() |
213 | .SetShapeFn(shape_inference::UnknownShape); |
214 | |
215 | } // namespace tensorflow |
216 | |