1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/parsing_ops.cc.
17#include <vector>
18#include "tensorflow/core/framework/op_kernel.h"
19#include "tensorflow/core/framework/tensor.h"
20#include "tensorflow/core/framework/tensor_shape.h"
21#include "tensorflow/core/framework/types.h"
22#include "tensorflow/core/lib/core/errors.h"
23#include "tensorflow/core/lib/strings/numbers.h"
24
25namespace tensorflow {
26
27class DecodeCSVOp : public OpKernel {
28 public:
29 explicit DecodeCSVOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
30 string delim;
31
32 OP_REQUIRES_OK(ctx, ctx->GetAttr("OUT_TYPE", &out_type_));
33 OP_REQUIRES(ctx, out_type_.size() < std::numeric_limits<int>::max(),
34 errors::InvalidArgument("Out type too large"));
35 OP_REQUIRES_OK(ctx, ctx->GetAttr("field_delim", &delim));
36 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_quote_delim", &use_quote_delim_));
37 OP_REQUIRES_OK(ctx, ctx->GetAttr("select_cols", &select_cols_));
38 OP_REQUIRES(
39 ctx, out_type_.size() == select_cols_.size() || select_cols_.empty(),
40 errors::InvalidArgument("select_cols should match output size"));
41 select_all_cols_ = select_cols_.empty();
42 for (int i = 1; i < select_cols_.size(); i++) {
43 OP_REQUIRES(ctx, select_cols_[i - 1] < select_cols_[i],
44 errors::InvalidArgument(
45 "select_cols should be strictly increasing indices"));
46 }
47 OP_REQUIRES(
48 ctx, select_cols_.empty() || select_cols_.front() >= 0,
49 errors::InvalidArgument("select_cols should be non-negative indices"));
50 OP_REQUIRES(ctx, delim.size() == 1,
51 errors::InvalidArgument("field_delim should be only 1 char"));
52 delim_ = delim[0];
53 OP_REQUIRES_OK(ctx, ctx->GetAttr("na_value", &na_value_));
54 }
55
56 void Compute(OpKernelContext* ctx) override {
57 const Tensor* records;
58 OpInputList record_defaults;
59
60 OP_REQUIRES_OK(ctx, ctx->input("records", &records));
61 OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults));
62
63 for (int i = 0; i < record_defaults.size(); ++i) {
64 OP_REQUIRES(ctx, record_defaults[i].dims() <= 1,
65 errors::InvalidArgument(
66 "Each record default should be at most rank 1"));
67 OP_REQUIRES(ctx, record_defaults[i].NumElements() < 2,
68 errors::InvalidArgument(
69 "There should only be 1 default per field but field ", i,
70 " has ", record_defaults[i].NumElements()));
71 }
72
73 auto records_t = records->flat<tstring>();
74 int64_t records_size = records_t.size();
75
76 OpOutputList output;
77 OP_REQUIRES_OK(ctx, ctx->output_list("output", &output));
78
79 for (int i = 0; i < static_cast<int>(out_type_.size()); ++i) {
80 Tensor* out = nullptr;
81 OP_REQUIRES_OK(ctx, output.allocate(i, records->shape(), &out));
82 }
83
84 for (int64_t i = 0; i < records_size; ++i) {
85 const StringPiece record(records_t(i));
86 std::vector<string> fields;
87 ExtractFields(ctx, record, &fields);
88 OP_REQUIRES(ctx, fields.size() == out_type_.size(),
89 errors::InvalidArgument("Expect ", out_type_.size(),
90 " fields but have ", fields.size(),
91 " in record ", i));
92
93 // Check each field in the record
94 for (int f = 0; f < static_cast<int>(out_type_.size()); ++f) {
95 const DataType& dtype = out_type_[f];
96 switch (dtype) {
97 case DT_INT32: {
98 // If this field is empty or NA value, check if default is given:
99 // If yes, use default value; Otherwise report error.
100 if (fields[f].empty() || fields[f] == na_value_) {
101 OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
102 errors::InvalidArgument(
103 "Field ", f,
104 " is required but missing in record ", i, "!"));
105
106 output[f]->flat<int32>()(i) = record_defaults[f].flat<int32>()(0);
107 } else {
108 int32_t value;
109 OP_REQUIRES(ctx, strings::safe_strto32(fields[f], &value),
110 errors::InvalidArgument(
111 "Field ", f, " in record ", i,
112 " is not a valid int32: ", fields[f]));
113 output[f]->flat<int32>()(i) = value;
114 }
115 break;
116 }
117 case DT_INT64: {
118 // If this field is empty or NA value, check if default is given:
119 // If yes, use default value; Otherwise report error.
120 if (fields[f].empty() || fields[f] == na_value_) {
121 OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
122 errors::InvalidArgument(
123 "Field ", f,
124 " is required but missing in record ", i, "!"));
125
126 output[f]->flat<int64_t>()(i) =
127 record_defaults[f].flat<int64_t>()(0);
128 } else {
129 int64_t value;
130 OP_REQUIRES(ctx, strings::safe_strto64(fields[f], &value),
131 errors::InvalidArgument(
132 "Field ", f, " in record ", i,
133 " is not a valid int64: ", fields[f]));
134 output[f]->flat<int64_t>()(i) = value;
135 }
136 break;
137 }
138 case DT_FLOAT: {
139 // If this field is empty or NA value, check if default is given:
140 // If yes, use default value; Otherwise report error.
141 if (fields[f].empty() || fields[f] == na_value_) {
142 OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
143 errors::InvalidArgument(
144 "Field ", f,
145 " is required but missing in record ", i, "!"));
146 output[f]->flat<float>()(i) = record_defaults[f].flat<float>()(0);
147 } else {
148 float value;
149 OP_REQUIRES(ctx, strings::safe_strtof(fields[f], &value),
150 errors::InvalidArgument(
151 "Field ", f, " in record ", i,
152 " is not a valid float: ", fields[f]));
153 output[f]->flat<float>()(i) = value;
154 }
155 break;
156 }
157 case DT_DOUBLE: {
158 // If this field is empty or NA value, check if default is given:
159 // If yes, use default value; Otherwise report error.
160 if (fields[f].empty() || fields[f] == na_value_) {
161 OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
162 errors::InvalidArgument(
163 "Field ", f,
164 " is required but missing in record ", i, "!"));
165 output[f]->flat<double>()(i) =
166 record_defaults[f].flat<double>()(0);
167 } else {
168 double value;
169 OP_REQUIRES(ctx, strings::safe_strtod(fields[f], &value),
170 errors::InvalidArgument(
171 "Field ", f, " in record ", i,
172 " is not a valid double: ", fields[f]));
173 output[f]->flat<double>()(i) = value;
174 }
175 break;
176 }
177 case DT_STRING: {
178 // If this field is empty or NA value, check if default is given:
179 // If yes, use default value; Otherwise report error.
180 if (fields[f].empty() || fields[f] == na_value_) {
181 OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
182 errors::InvalidArgument(
183 "Field ", f,
184 " is required but missing in record ", i, "!"));
185 output[f]->flat<tstring>()(i) =
186 record_defaults[f].flat<tstring>()(0);
187 } else {
188 output[f]->flat<tstring>()(i) = std::move(fields[f]);
189 }
190 break;
191 }
192 default:
193 OP_REQUIRES(ctx, false,
194 errors::InvalidArgument("csv: data type ", dtype,
195 " not supported in field ", f));
196 }
197 }
198 }
199 }
200
201 private:
202 std::vector<DataType> out_type_;
203 std::vector<int64_t> select_cols_;
204 char delim_;
205 bool use_quote_delim_;
206 bool select_all_cols_;
207 string na_value_;
208
209 void ExtractFields(OpKernelContext* ctx, StringPiece input,
210 std::vector<string>* result) {
211 int64_t current_idx = 0;
212 int64_t num_fields_parsed = 0;
213 int64_t selector_idx = 0; // Keep track of index into select_cols
214
215 if (!input.empty()) {
216 while (static_cast<size_t>(current_idx) < input.size()) {
217 if (input[current_idx] == '\n' || input[current_idx] == '\r') {
218 current_idx++;
219 continue;
220 }
221
222 bool quoted = false;
223 bool include =
224 (select_all_cols_ || select_cols_[selector_idx] ==
225 static_cast<size_t>(num_fields_parsed));
226
227 if (use_quote_delim_ && input[current_idx] == '"') {
228 quoted = true;
229 current_idx++;
230 }
231
232 // This is the body of the field;
233 string field;
234 if (!quoted) {
235 while (static_cast<size_t>(current_idx) < input.size() &&
236 input[current_idx] != delim_) {
237 OP_REQUIRES(ctx,
238 (!use_quote_delim_ || input[current_idx] != '"') &&
239 input[current_idx] != '\n' &&
240 input[current_idx] != '\r',
241 errors::InvalidArgument(
242 "Unquoted fields cannot have quotes/CRLFs inside"));
243 if (include) field += input[current_idx];
244 current_idx++;
245 }
246
247 // Go to next field or the end
248 current_idx++;
249 } else if (use_quote_delim_) {
250 // Quoted field needs to be ended with '"' and delim or end
251 while (
252 (static_cast<size_t>(current_idx) < input.size() - 1) &&
253 (input[current_idx] != '"' || input[current_idx + 1] != delim_)) {
254 if (input[current_idx] != '"') {
255 if (include) field += input[current_idx];
256 current_idx++;
257 } else {
258 OP_REQUIRES(
259 ctx, input[current_idx + 1] == '"',
260 errors::InvalidArgument("Quote inside a string has to be "
261 "escaped by another quote"));
262 if (include) field += '"';
263 current_idx += 2;
264 }
265 }
266
267 OP_REQUIRES(
268 ctx,
269 (static_cast<size_t>(current_idx) < input.size() &&
270 input[current_idx] == '"' &&
271 (static_cast<size_t>(current_idx) == input.size() - 1 ||
272 input[current_idx + 1] == delim_)),
273 errors::InvalidArgument("Quoted field has to end with quote "
274 "followed by delim or end"));
275
276 current_idx += 2;
277 }
278
279 num_fields_parsed++;
280 if (include) {
281 result->push_back(field);
282 selector_idx++;
283 if (selector_idx == select_cols_.size()) return;
284 }
285 }
286
287 bool include =
288 (select_all_cols_ || select_cols_[selector_idx] ==
289 static_cast<size_t>(num_fields_parsed));
290 // Check if the last field is missing
291 if (include && input[input.size() - 1] == delim_)
292 result->push_back(string());
293 }
294 }
295};
296
297REGISTER_KERNEL_BUILDER(Name("DecodeCSV").Device(DEVICE_CPU), DecodeCSVOp);
298
299} // namespace tensorflow
300