1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/parsing_ops.cc. |
17 | #include <vector> |
18 | #include "tensorflow/core/framework/op_kernel.h" |
19 | #include "tensorflow/core/framework/tensor.h" |
20 | #include "tensorflow/core/framework/tensor_shape.h" |
21 | #include "tensorflow/core/framework/types.h" |
22 | #include "tensorflow/core/lib/core/errors.h" |
23 | #include "tensorflow/core/lib/strings/numbers.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | class DecodeCSVOp : public OpKernel { |
28 | public: |
29 | explicit DecodeCSVOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
30 | string delim; |
31 | |
32 | OP_REQUIRES_OK(ctx, ctx->GetAttr("OUT_TYPE" , &out_type_)); |
33 | OP_REQUIRES(ctx, out_type_.size() < std::numeric_limits<int>::max(), |
34 | errors::InvalidArgument("Out type too large" )); |
35 | OP_REQUIRES_OK(ctx, ctx->GetAttr("field_delim" , &delim)); |
36 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_quote_delim" , &use_quote_delim_)); |
37 | OP_REQUIRES_OK(ctx, ctx->GetAttr("select_cols" , &select_cols_)); |
38 | OP_REQUIRES( |
39 | ctx, out_type_.size() == select_cols_.size() || select_cols_.empty(), |
40 | errors::InvalidArgument("select_cols should match output size" )); |
41 | select_all_cols_ = select_cols_.empty(); |
42 | for (int i = 1; i < select_cols_.size(); i++) { |
43 | OP_REQUIRES(ctx, select_cols_[i - 1] < select_cols_[i], |
44 | errors::InvalidArgument( |
45 | "select_cols should be strictly increasing indices" )); |
46 | } |
47 | OP_REQUIRES( |
48 | ctx, select_cols_.empty() || select_cols_.front() >= 0, |
49 | errors::InvalidArgument("select_cols should be non-negative indices" )); |
50 | OP_REQUIRES(ctx, delim.size() == 1, |
51 | errors::InvalidArgument("field_delim should be only 1 char" )); |
52 | delim_ = delim[0]; |
53 | OP_REQUIRES_OK(ctx, ctx->GetAttr("na_value" , &na_value_)); |
54 | } |
55 | |
56 | void Compute(OpKernelContext* ctx) override { |
57 | const Tensor* records; |
58 | OpInputList record_defaults; |
59 | |
60 | OP_REQUIRES_OK(ctx, ctx->input("records" , &records)); |
61 | OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults" , &record_defaults)); |
62 | |
63 | for (int i = 0; i < record_defaults.size(); ++i) { |
64 | OP_REQUIRES(ctx, record_defaults[i].dims() <= 1, |
65 | errors::InvalidArgument( |
66 | "Each record default should be at most rank 1" )); |
67 | OP_REQUIRES(ctx, record_defaults[i].NumElements() < 2, |
68 | errors::InvalidArgument( |
69 | "There should only be 1 default per field but field " , i, |
70 | " has " , record_defaults[i].NumElements())); |
71 | } |
72 | |
73 | auto records_t = records->flat<tstring>(); |
74 | int64_t records_size = records_t.size(); |
75 | |
76 | OpOutputList output; |
77 | OP_REQUIRES_OK(ctx, ctx->output_list("output" , &output)); |
78 | |
79 | for (int i = 0; i < static_cast<int>(out_type_.size()); ++i) { |
80 | Tensor* out = nullptr; |
81 | OP_REQUIRES_OK(ctx, output.allocate(i, records->shape(), &out)); |
82 | } |
83 | |
84 | for (int64_t i = 0; i < records_size; ++i) { |
85 | const StringPiece record(records_t(i)); |
86 | std::vector<string> fields; |
87 | ExtractFields(ctx, record, &fields); |
88 | OP_REQUIRES(ctx, fields.size() == out_type_.size(), |
89 | errors::InvalidArgument("Expect " , out_type_.size(), |
90 | " fields but have " , fields.size(), |
91 | " in record " , i)); |
92 | |
93 | // Check each field in the record |
94 | for (int f = 0; f < static_cast<int>(out_type_.size()); ++f) { |
95 | const DataType& dtype = out_type_[f]; |
96 | switch (dtype) { |
97 | case DT_INT32: { |
98 | // If this field is empty or NA value, check if default is given: |
99 | // If yes, use default value; Otherwise report error. |
100 | if (fields[f].empty() || fields[f] == na_value_) { |
101 | OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1, |
102 | errors::InvalidArgument( |
103 | "Field " , f, |
104 | " is required but missing in record " , i, "!" )); |
105 | |
106 | output[f]->flat<int32>()(i) = record_defaults[f].flat<int32>()(0); |
107 | } else { |
108 | int32_t value; |
109 | OP_REQUIRES(ctx, strings::safe_strto32(fields[f], &value), |
110 | errors::InvalidArgument( |
111 | "Field " , f, " in record " , i, |
112 | " is not a valid int32: " , fields[f])); |
113 | output[f]->flat<int32>()(i) = value; |
114 | } |
115 | break; |
116 | } |
117 | case DT_INT64: { |
118 | // If this field is empty or NA value, check if default is given: |
119 | // If yes, use default value; Otherwise report error. |
120 | if (fields[f].empty() || fields[f] == na_value_) { |
121 | OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1, |
122 | errors::InvalidArgument( |
123 | "Field " , f, |
124 | " is required but missing in record " , i, "!" )); |
125 | |
126 | output[f]->flat<int64_t>()(i) = |
127 | record_defaults[f].flat<int64_t>()(0); |
128 | } else { |
129 | int64_t value; |
130 | OP_REQUIRES(ctx, strings::safe_strto64(fields[f], &value), |
131 | errors::InvalidArgument( |
132 | "Field " , f, " in record " , i, |
133 | " is not a valid int64: " , fields[f])); |
134 | output[f]->flat<int64_t>()(i) = value; |
135 | } |
136 | break; |
137 | } |
138 | case DT_FLOAT: { |
139 | // If this field is empty or NA value, check if default is given: |
140 | // If yes, use default value; Otherwise report error. |
141 | if (fields[f].empty() || fields[f] == na_value_) { |
142 | OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1, |
143 | errors::InvalidArgument( |
144 | "Field " , f, |
145 | " is required but missing in record " , i, "!" )); |
146 | output[f]->flat<float>()(i) = record_defaults[f].flat<float>()(0); |
147 | } else { |
148 | float value; |
149 | OP_REQUIRES(ctx, strings::safe_strtof(fields[f], &value), |
150 | errors::InvalidArgument( |
151 | "Field " , f, " in record " , i, |
152 | " is not a valid float: " , fields[f])); |
153 | output[f]->flat<float>()(i) = value; |
154 | } |
155 | break; |
156 | } |
157 | case DT_DOUBLE: { |
158 | // If this field is empty or NA value, check if default is given: |
159 | // If yes, use default value; Otherwise report error. |
160 | if (fields[f].empty() || fields[f] == na_value_) { |
161 | OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1, |
162 | errors::InvalidArgument( |
163 | "Field " , f, |
164 | " is required but missing in record " , i, "!" )); |
165 | output[f]->flat<double>()(i) = |
166 | record_defaults[f].flat<double>()(0); |
167 | } else { |
168 | double value; |
169 | OP_REQUIRES(ctx, strings::safe_strtod(fields[f], &value), |
170 | errors::InvalidArgument( |
171 | "Field " , f, " in record " , i, |
172 | " is not a valid double: " , fields[f])); |
173 | output[f]->flat<double>()(i) = value; |
174 | } |
175 | break; |
176 | } |
177 | case DT_STRING: { |
178 | // If this field is empty or NA value, check if default is given: |
179 | // If yes, use default value; Otherwise report error. |
180 | if (fields[f].empty() || fields[f] == na_value_) { |
181 | OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1, |
182 | errors::InvalidArgument( |
183 | "Field " , f, |
184 | " is required but missing in record " , i, "!" )); |
185 | output[f]->flat<tstring>()(i) = |
186 | record_defaults[f].flat<tstring>()(0); |
187 | } else { |
188 | output[f]->flat<tstring>()(i) = std::move(fields[f]); |
189 | } |
190 | break; |
191 | } |
192 | default: |
193 | OP_REQUIRES(ctx, false, |
194 | errors::InvalidArgument("csv: data type " , dtype, |
195 | " not supported in field " , f)); |
196 | } |
197 | } |
198 | } |
199 | } |
200 | |
201 | private: |
202 | std::vector<DataType> out_type_; |
203 | std::vector<int64_t> select_cols_; |
204 | char delim_; |
205 | bool use_quote_delim_; |
206 | bool select_all_cols_; |
207 | string na_value_; |
208 | |
209 | void (OpKernelContext* ctx, StringPiece input, |
210 | std::vector<string>* result) { |
211 | int64_t current_idx = 0; |
212 | int64_t num_fields_parsed = 0; |
213 | int64_t selector_idx = 0; // Keep track of index into select_cols |
214 | |
215 | if (!input.empty()) { |
216 | while (static_cast<size_t>(current_idx) < input.size()) { |
217 | if (input[current_idx] == '\n' || input[current_idx] == '\r') { |
218 | current_idx++; |
219 | continue; |
220 | } |
221 | |
222 | bool quoted = false; |
223 | bool include = |
224 | (select_all_cols_ || select_cols_[selector_idx] == |
225 | static_cast<size_t>(num_fields_parsed)); |
226 | |
227 | if (use_quote_delim_ && input[current_idx] == '"') { |
228 | quoted = true; |
229 | current_idx++; |
230 | } |
231 | |
232 | // This is the body of the field; |
233 | string field; |
234 | if (!quoted) { |
235 | while (static_cast<size_t>(current_idx) < input.size() && |
236 | input[current_idx] != delim_) { |
237 | OP_REQUIRES(ctx, |
238 | (!use_quote_delim_ || input[current_idx] != '"') && |
239 | input[current_idx] != '\n' && |
240 | input[current_idx] != '\r', |
241 | errors::InvalidArgument( |
242 | "Unquoted fields cannot have quotes/CRLFs inside" )); |
243 | if (include) field += input[current_idx]; |
244 | current_idx++; |
245 | } |
246 | |
247 | // Go to next field or the end |
248 | current_idx++; |
249 | } else if (use_quote_delim_) { |
250 | // Quoted field needs to be ended with '"' and delim or end |
251 | while ( |
252 | (static_cast<size_t>(current_idx) < input.size() - 1) && |
253 | (input[current_idx] != '"' || input[current_idx + 1] != delim_)) { |
254 | if (input[current_idx] != '"') { |
255 | if (include) field += input[current_idx]; |
256 | current_idx++; |
257 | } else { |
258 | OP_REQUIRES( |
259 | ctx, input[current_idx + 1] == '"', |
260 | errors::InvalidArgument("Quote inside a string has to be " |
261 | "escaped by another quote" )); |
262 | if (include) field += '"'; |
263 | current_idx += 2; |
264 | } |
265 | } |
266 | |
267 | OP_REQUIRES( |
268 | ctx, |
269 | (static_cast<size_t>(current_idx) < input.size() && |
270 | input[current_idx] == '"' && |
271 | (static_cast<size_t>(current_idx) == input.size() - 1 || |
272 | input[current_idx + 1] == delim_)), |
273 | errors::InvalidArgument("Quoted field has to end with quote " |
274 | "followed by delim or end" )); |
275 | |
276 | current_idx += 2; |
277 | } |
278 | |
279 | num_fields_parsed++; |
280 | if (include) { |
281 | result->push_back(field); |
282 | selector_idx++; |
283 | if (selector_idx == select_cols_.size()) return; |
284 | } |
285 | } |
286 | |
287 | bool include = |
288 | (select_all_cols_ || select_cols_[selector_idx] == |
289 | static_cast<size_t>(num_fields_parsed)); |
290 | // Check if the last field is missing |
291 | if (include && input[input.size() - 1] == delim_) |
292 | result->push_back(string()); |
293 | } |
294 | } |
295 | }; |
296 | |
297 | REGISTER_KERNEL_BUILDER(Name("DecodeCSV" ).Device(DEVICE_CPU), DecodeCSVOp); |
298 | |
299 | } // namespace tensorflow |
300 | |