1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/example/feature_util.h"
17
18#include <string>
19
20#include "absl/strings/string_view.h"
21
22namespace tensorflow {
23
24namespace internal {
25Feature& ExampleFeature(absl::string_view name, Example* example) {
26 return *GetFeature(name, example);
27}
28
29} // namespace internal
30
31template <>
32bool HasFeature<>(absl::string_view key, const Features& features) {
33 return features.feature().contains(internal::ProtoMapKey(key));
34}
35
36template <>
37bool HasFeature<protobuf_int64>(absl::string_view key,
38 const Features& features) {
39 auto it = features.feature().find(internal::ProtoMapKey(key));
40 return (it != features.feature().end()) &&
41 (it->second.kind_case() == Feature::KindCase::kInt64List);
42}
43
44template <>
45bool HasFeature<float>(absl::string_view key, const Features& features) {
46 auto it = features.feature().find(internal::ProtoMapKey(key));
47 return (it != features.feature().end()) &&
48 (it->second.kind_case() == Feature::KindCase::kFloatList);
49}
50
51template <>
52bool HasFeature<std::string>(absl::string_view key, const Features& features) {
53 auto it = features.feature().find(internal::ProtoMapKey(key));
54 return (it != features.feature().end()) &&
55 (it->second.kind_case() == Feature::KindCase::kBytesList);
56}
57
58template <>
59bool HasFeature<tstring>(absl::string_view key, const Features& features) {
60 auto it = features.feature().find(internal::ProtoMapKey(key));
61 return (it != features.feature().end()) &&
62 (it->second.kind_case() == Feature::KindCase::kBytesList);
63}
64
65bool HasFeatureList(absl::string_view key,
66 const SequenceExample& sequence_example) {
67 return sequence_example.feature_lists().feature_list().contains(
68 internal::ProtoMapKey(key));
69}
70
71template <>
72const protobuf::RepeatedField<protobuf_int64>& GetFeatureValues<protobuf_int64>(
73 const Feature& feature) {
74 return feature.int64_list().value();
75}
76
77template <>
78protobuf::RepeatedField<protobuf_int64>* GetFeatureValues<protobuf_int64>(
79 Feature* feature) {
80 return feature->mutable_int64_list()->mutable_value();
81}
82
83template <>
84const protobuf::RepeatedField<float>& GetFeatureValues<float>(
85 const Feature& feature) {
86 return feature.float_list().value();
87}
88
89template <>
90protobuf::RepeatedField<float>* GetFeatureValues<float>(Feature* feature) {
91 return feature->mutable_float_list()->mutable_value();
92}
93
94template <>
95const protobuf::RepeatedPtrField<std::string>& GetFeatureValues<tstring>(
96 const Feature& feature) {
97 return feature.bytes_list().value();
98}
99
100template <>
101const protobuf::RepeatedPtrField<std::string>& GetFeatureValues<std::string>(
102 const Feature& feature) {
103 return feature.bytes_list().value();
104}
105
106template <>
107protobuf::RepeatedPtrField<std::string>* GetFeatureValues<tstring>(
108 Feature* feature) {
109 return feature->mutable_bytes_list()->mutable_value();
110}
111
112template <>
113protobuf::RepeatedPtrField<std::string>* GetFeatureValues<std::string>(
114 Feature* feature) {
115 return feature->mutable_bytes_list()->mutable_value();
116}
117
118const protobuf::RepeatedPtrField<Feature>& GetFeatureList(
119 absl::string_view key, const SequenceExample& sequence_example) {
120 return sequence_example.feature_lists()
121 .feature_list()
122 .at(internal::ProtoMapKey(key))
123 .feature();
124}
125
126protobuf::RepeatedPtrField<Feature>* GetFeatureList(
127 absl::string_view feature_list_key, SequenceExample* sequence_example) {
128 return (*sequence_example->mutable_feature_lists()
129 ->mutable_feature_list())[internal::ProtoMapKey(
130 feature_list_key)]
131 .mutable_feature();
132}
133
134template <>
135void ClearFeatureValues<protobuf_int64>(Feature* feature) {
136 feature->mutable_int64_list()->Clear();
137}
138
139template <>
140void ClearFeatureValues<float>(Feature* feature) {
141 feature->mutable_float_list()->Clear();
142}
143
144template <>
145void ClearFeatureValues<std::string>(Feature* feature) {
146 feature->mutable_bytes_list()->Clear();
147}
148
149template <>
150void ClearFeatureValues<tstring>(Feature* feature) {
151 feature->mutable_bytes_list()->Clear();
152}
153
154template <>
155Features* GetFeatures<Features>(Features* proto) {
156 return proto;
157}
158
159template <>
160Features* GetFeatures<Example>(Example* proto) {
161 return proto->mutable_features();
162}
163
164template <>
165Features* GetFeatures<SequenceExample>(SequenceExample* proto) {
166 return proto->mutable_context();
167}
168
169template <>
170const Features& GetFeatures<Features>(const Features& proto) {
171 return proto;
172}
173
174template <>
175const Features& GetFeatures<Example>(const Example& proto) {
176 return proto.features();
177}
178
179template <>
180const Features& GetFeatures<SequenceExample>(const SequenceExample& proto) {
181 return proto.context();
182}
183
184template <>
185const protobuf::RepeatedField<protobuf_int64>& GetFeatureValues<protobuf_int64>(
186 const Feature& feature);
187
188template <>
189protobuf::RepeatedField<protobuf_int64>* GetFeatureValues<protobuf_int64>(
190 Feature* feature);
191
192template <>
193const protobuf::RepeatedField<float>& GetFeatureValues<float>(
194 const Feature& feature);
195
196template <>
197protobuf::RepeatedField<float>* GetFeatureValues<float>(Feature* feature);
198
199template <>
200const protobuf::RepeatedPtrField<std::string>& GetFeatureValues<std::string>(
201 const Feature& feature);
202
203template <>
204const protobuf::RepeatedPtrField<std::string>& GetFeatureValues<tstring>(
205 const Feature& feature);
206
207template <>
208protobuf::RepeatedPtrField<std::string>* GetFeatureValues<std::string>(
209 Feature* feature);
210
211template <>
212protobuf::RepeatedPtrField<std::string>* GetFeatureValues<tstring>(
213 Feature* feature);
214
215} // namespace tensorflow
216