1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
---|---|
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/tpu/tpu_embedding_output_layout_utils.h" |
17 | |
18 | #include <vector> |
19 | |
20 | namespace tensorflow { |
21 | namespace tpu { |
22 | |
23 | Status ComputeOutputTensorShapes( |
24 | const tensorflow::tpu::TPUEmbeddingConfiguration& config, |
25 | std::vector<TensorShapeProto>* shapes) { |
26 | const int64_t core_count_per_replica = |
27 | config.spmd_sharding().enabled() |
28 | ? config.spmd_sharding().num_cores_per_replica() |
29 | : 1; |
30 | if (config.feature_descriptor_size() > 0) { |
31 | for (const TPUEmbeddingConfiguration::FeatureDescriptor& feature : |
32 | config.feature_descriptor()) { |
33 | TensorShapeProto shape; |
34 | for (int32 input_shape : feature.input_shape()) { |
35 | auto* dim = shape.add_dim(); |
36 | dim->set_size(input_shape); |
37 | } |
38 | shape.add_dim()->set_size( |
39 | config.table_descriptor(feature.table_id()).dimension()); |
40 | shape.mutable_dim(0)->set_size(core_count_per_replica * |
41 | shape.dim(0).size()); |
42 | shapes->push_back(shape); |
43 | } |
44 | } else { |
45 | const int batch_size = config.batch_size_per_tensor_core(); |
46 | for (const TPUEmbeddingConfiguration::TableDescriptor& table : |
47 | config.table_descriptor()) { |
48 | TensorShapeProto shape; |
49 | auto* dim0 = shape.add_dim(); |
50 | dim0->set_size(core_count_per_replica * batch_size * |
51 | table.num_features()); |
52 | auto* dim1 = shape.add_dim(); |
53 | dim1->set_size(table.dimension()); |
54 | shapes->push_back(shape); |
55 | } |
56 | } |
57 | return OkStatus(); |
58 | } |
59 | |
60 | } // namespace tpu |
61 | } // namespace tensorflow |
62 |