1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <vector> |
17 | |
18 | #include "tensorflow/core/framework/common_shape_fns.h" |
19 | #include "tensorflow/core/framework/op.h" |
20 | #include "tensorflow/core/framework/resource_mgr.h" |
21 | #include "tensorflow/core/framework/shape_inference.h" |
22 | #include "tensorflow/core/framework/tensor_shape.h" |
23 | #include "tensorflow/core/lib/core/errors.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | using shape_inference::DimensionHandle; |
28 | using shape_inference::InferenceContext; |
29 | using shape_inference::ShapeHandle; |
30 | |
31 | REGISTER_RESOURCE_HANDLE_OP(BoostedTreesEnsembleResource); |
32 | |
33 | REGISTER_OP("IsBoostedTreesEnsembleInitialized" ) |
34 | .Input("tree_ensemble_handle: resource" ) |
35 | .Output("is_initialized: bool" ) |
36 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
37 | shape_inference::ShapeHandle unused_input; |
38 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); |
39 | c->set_output(0, c->Scalar()); |
40 | return OkStatus(); |
41 | }); |
42 | |
43 | REGISTER_OP("BoostedTreesCalculateBestGainsPerFeature" ) |
44 | .Input("node_id_range: int32" ) |
45 | .Input("stats_summary_list: num_features * float32" ) |
46 | .Input("l1: float" ) |
47 | .Input("l2: float" ) |
48 | .Input("tree_complexity: float" ) |
49 | .Input("min_node_weight: float" ) |
50 | .Attr("max_splits: int >= 1" ) |
51 | .Attr("num_features: int >= 1" ) // not passed but populated automatically. |
52 | .Output("node_ids_list: num_features * int32" ) |
53 | .Output("gains_list: num_features * float32" ) |
54 | .Output("thresholds_list: num_features * int32" ) |
55 | .Output("left_node_contribs_list: num_features * float32" ) |
56 | .Output("right_node_contribs_list: num_features * float32" ) |
57 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
58 | // Confirms the rank of the inputs and sets the shape of the outputs. |
59 | int max_splits; |
60 | int num_features; |
61 | TF_RETURN_IF_ERROR(c->GetAttr("max_splits" , &max_splits)); |
62 | TF_RETURN_IF_ERROR(c->GetAttr("num_features" , &num_features)); |
63 | shape_inference::ShapeHandle node_id_range_shape; |
64 | shape_inference::ShapeHandle unused_shape; |
65 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape)); |
66 | TF_RETURN_IF_ERROR( |
67 | c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape)); |
68 | // Checks that all stats summary entries are of the same shape. |
69 | shape_inference::ShapeHandle summary_shape_base; |
70 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &summary_shape_base)); |
71 | TF_RETURN_IF_ERROR(c->Merge(summary_shape_base, |
72 | c->MakeShape({max_splits, -1, 2}), |
73 | &unused_shape)); |
74 | for (int i = 1; i < num_features; ++i) { |
75 | shape_inference::ShapeHandle summary_shape; |
76 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1 + i), 3, &summary_shape)); |
77 | TF_RETURN_IF_ERROR( |
78 | c->Merge(summary_shape_base, summary_shape, &unused_shape)); |
79 | } |
80 | TF_RETURN_IF_ERROR( |
81 | c->WithRank(c->input(num_features + 1), 0, &unused_shape)); |
82 | TF_RETURN_IF_ERROR( |
83 | c->WithRank(c->input(num_features + 2), 0, &unused_shape)); |
84 | TF_RETURN_IF_ERROR( |
85 | c->WithRank(c->input(num_features + 3), 0, &unused_shape)); |
86 | // Sets the output lists. |
87 | std::vector<shape_inference::ShapeHandle> output_shapes_vec( |
88 | num_features, c->MakeShape({-1})); |
89 | TF_RETURN_IF_ERROR(c->set_output("node_ids_list" , output_shapes_vec)); |
90 | TF_RETURN_IF_ERROR(c->set_output("gains_list" , output_shapes_vec)); |
91 | TF_RETURN_IF_ERROR(c->set_output("thresholds_list" , output_shapes_vec)); |
92 | std::vector<shape_inference::ShapeHandle> output_shapes_contribs( |
93 | num_features, c->MakeShape({-1, 1})); |
94 | TF_RETURN_IF_ERROR( |
95 | c->set_output("left_node_contribs_list" , output_shapes_contribs)); |
96 | TF_RETURN_IF_ERROR( |
97 | c->set_output("right_node_contribs_list" , output_shapes_contribs)); |
98 | return OkStatus(); |
99 | }); |
100 | |
101 | REGISTER_OP("BoostedTreesCalculateBestFeatureSplit" ) |
102 | .Input("node_id_range: int32" ) |
103 | .Input("stats_summary: float32" ) |
104 | .Input("l1: float" ) |
105 | .Input("l2: float" ) |
106 | .Input("tree_complexity: float" ) |
107 | .Input("min_node_weight: float" ) |
108 | .Attr("logits_dimension: int >= 1" ) |
109 | .Attr("split_type: {'inequality', 'equality'} = 'inequality'" ) |
110 | .Output("node_ids: int32" ) |
111 | .Output("gains: float32" ) |
112 | .Output("feature_dimensions: int32" ) |
113 | .Output("thresholds: int32" ) |
114 | .Output("left_node_contribs: float32" ) |
115 | .Output("right_node_contribs: float32" ) |
116 | .Output("split_with_default_directions: string" ) |
117 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
118 | shape_inference::ShapeHandle node_id_range_shape; |
119 | shape_inference::ShapeHandle unused_shape; |
120 | // node id range is rank 1 with 2 values. |
121 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape)); |
122 | TF_RETURN_IF_ERROR( |
123 | c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape)); |
124 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &unused_shape)); |
125 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape)); |
126 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape)); |
127 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape)); |
128 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape)); |
129 | ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()}); |
130 | c->set_output(0, rank_1_output_shape); |
131 | c->set_output(1, rank_1_output_shape); |
132 | c->set_output(2, rank_1_output_shape); |
133 | c->set_output(3, rank_1_output_shape); |
134 | c->set_output(6, rank_1_output_shape); |
135 | int logits_dimension; |
136 | TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension" , &logits_dimension)); |
137 | ShapeHandle contribs_output_shape = |
138 | c->MakeShape({c->UnknownDim(), logits_dimension}); |
139 | c->set_output(4, contribs_output_shape); |
140 | c->set_output(5, contribs_output_shape); |
141 | return OkStatus(); |
142 | }); |
143 | |
144 | REGISTER_OP("BoostedTreesCalculateBestFeatureSplitV2" ) |
145 | .Input("node_id_range: int32" ) |
146 | .Input("stats_summaries_list: num_features * float32" ) |
147 | .Input("split_types: string" ) |
148 | .Input("candidate_feature_ids: int32" ) |
149 | .Input("l1: float" ) |
150 | .Input("l2: float" ) |
151 | .Input("tree_complexity: float" ) |
152 | .Input("min_node_weight: float" ) |
153 | .Attr("num_features: int >= 1" ) // not passed but populated automatically. |
154 | .Attr("logits_dimension: int >= 1" ) |
155 | .Output("node_ids: int32" ) |
156 | .Output("gains: float32" ) |
157 | .Output("feature_ids: int32" ) |
158 | .Output("feature_dimensions: int32" ) |
159 | .Output("thresholds: int32" ) |
160 | .Output("left_node_contribs: float32" ) |
161 | .Output("right_node_contribs: float32" ) |
162 | .Output("split_with_default_directions: string" ) |
163 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
164 | // Attributes. |
165 | int num_features; |
166 | TF_RETURN_IF_ERROR(c->GetAttr("num_features" , &num_features)); |
167 | int logits_dimension; |
168 | TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension" , &logits_dimension)); |
169 | // Inputs. |
170 | shape_inference::ShapeHandle unused_shape; |
171 | // node id range is rank 1 with 2 values. |
172 | shape_inference::ShapeHandle node_id_range_shape; |
173 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape)); |
174 | TF_RETURN_IF_ERROR( |
175 | c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape)); |
176 | // Stats summary validation. |
177 | shape_inference::ShapeHandle summary_shape_base; |
178 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &summary_shape_base)); |
179 | // All stats summary entries are of the same shape. |
180 | for (int i = 1; i < num_features; ++i) { |
181 | shape_inference::ShapeHandle summary_shape; |
182 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1 + i), 4, &summary_shape)); |
183 | TF_RETURN_IF_ERROR( |
184 | c->Merge(summary_shape_base, summary_shape, &unused_shape)); |
185 | } |
186 | // Validate rank 1 split_types. |
187 | TF_RETURN_IF_ERROR( |
188 | c->WithRank(c->input(1 + num_features), 1, &unused_shape)); |
189 | // Validate rank 1 feature_ids. |
190 | TF_RETURN_IF_ERROR( |
191 | c->WithRank(c->input(2 + num_features), 1, &unused_shape)); |
192 | // Validate rank 0: l1, l2, tree_complexity, min_node_weight. |
193 | for (int i = 0; i < 4; ++i) { |
194 | TF_RETURN_IF_ERROR( |
195 | c->WithRank(c->input(3 + num_features + i), 0, &unused_shape)); |
196 | } |
197 | // Output shapes. |
198 | ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()}); |
199 | c->set_output(0, rank_1_output_shape); |
200 | c->set_output(1, rank_1_output_shape); |
201 | c->set_output(2, rank_1_output_shape); |
202 | c->set_output(3, rank_1_output_shape); |
203 | c->set_output(4, rank_1_output_shape); |
204 | ShapeHandle contribs_output_shape = |
205 | c->MakeShape({c->UnknownDim(), logits_dimension}); |
206 | c->set_output(5, contribs_output_shape); |
207 | c->set_output(6, contribs_output_shape); |
208 | c->set_output(7, rank_1_output_shape); |
209 | return OkStatus(); |
210 | }); |
211 | |
212 | REGISTER_OP("BoostedTreesSparseCalculateBestFeatureSplit" ) |
213 | .Input("node_id_range: int32" ) |
214 | .Input("stats_summary_indices: int32" ) |
215 | .Input("stats_summary_values: float" ) |
216 | .Input("stats_summary_shape: int32" ) |
217 | .Input("l1: float" ) |
218 | .Input("l2: float" ) |
219 | .Input("tree_complexity: float" ) |
220 | .Input("min_node_weight: float" ) |
221 | .Attr("logits_dimension: int >= 1" ) |
222 | .Attr("split_type: {'inequality'} = 'inequality'" ) |
223 | .Output("node_ids: int32" ) |
224 | .Output("gains: float32" ) |
225 | .Output("feature_dimensions: int32" ) |
226 | .Output("thresholds: int32" ) |
227 | .Output("left_node_contribs: float32" ) |
228 | .Output("right_node_contribs: float32" ) |
229 | .Output("split_with_default_directions: string" ) |
230 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
231 | shape_inference::ShapeHandle node_id_range_shape; |
232 | shape_inference::ShapeHandle unused_shape; |
233 | // node id range is rank 1 with 2 values. |
234 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape)); |
235 | TF_RETURN_IF_ERROR( |
236 | c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape)); |
237 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape)); |
238 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused_shape)); |
239 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused_shape)); |
240 | shape_inference::ShapeHandle summary_shape; |
241 | TF_RETURN_IF_ERROR( |
242 | c->Merge(summary_shape, c->MakeShape({4}), &unused_shape)); |
243 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape)); |
244 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape)); |
245 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused_shape)); |
246 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused_shape)); |
247 | ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()}); |
248 | c->set_output(0, rank_1_output_shape); |
249 | c->set_output(1, rank_1_output_shape); |
250 | c->set_output(2, rank_1_output_shape); |
251 | c->set_output(3, rank_1_output_shape); |
252 | c->set_output(6, rank_1_output_shape); |
253 | int logits_dimension; |
254 | TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension" , &logits_dimension)); |
255 | ShapeHandle contribs_output_shape = |
256 | c->MakeShape({c->UnknownDim(), logits_dimension}); |
257 | c->set_output(4, contribs_output_shape); |
258 | c->set_output(5, contribs_output_shape); |
259 | return OkStatus(); |
260 | }); |
261 | |
262 | REGISTER_OP("BoostedTreesCreateEnsemble" ) |
263 | .Input("tree_ensemble_handle: resource" ) |
264 | .Input("stamp_token: int64" ) |
265 | .Input("tree_ensemble_serialized: string" ) |
266 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
267 | shape_inference::ShapeHandle unused_input; |
268 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); |
269 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); |
270 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input)); |
271 | return OkStatus(); |
272 | }); |
273 | |
274 | REGISTER_OP("BoostedTreesDeserializeEnsemble" ) |
275 | .Input("tree_ensemble_handle: resource" ) |
276 | .Input("stamp_token: int64" ) |
277 | .Input("tree_ensemble_serialized: string" ) |
278 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
279 | shape_inference::ShapeHandle unused_input; |
280 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); |
281 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); |
282 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input)); |
283 | return OkStatus(); |
284 | }); |
285 | |
286 | REGISTER_OP("BoostedTreesGetEnsembleStates" ) |
287 | .Input("tree_ensemble_handle: resource" ) |
288 | .Output("stamp_token: int64" ) |
289 | .Output("num_trees: int32" ) |
290 | .Output("num_finalized_trees: int32" ) |
291 | .Output("num_attempted_layers: int32" ) |
292 | .Output("last_layer_nodes_range: int32" ) |
293 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
294 | shape_inference::ShapeHandle unused_input; |
295 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); |
296 | c->set_output(0, c->Scalar()); |
297 | c->set_output(1, c->Scalar()); |
298 | c->set_output(2, c->Scalar()); |
299 | c->set_output(3, c->Scalar()); |
300 | c->set_output(4, c->Vector(2)); |
301 | return OkStatus(); |
302 | }); |
303 | |
304 | REGISTER_OP("BoostedTreesMakeStatsSummary" ) |
305 | .Input("node_ids: int32" ) |
306 | .Input("gradients: float" ) |
307 | .Input("hessians: float" ) |
308 | .Input("bucketized_features_list: num_features * int32" ) |
309 | .Attr("max_splits: int >= 1" ) |
310 | .Attr("num_buckets: int >= 1" ) |
311 | .Attr("num_features: int >= 1" ) |
312 | .Output("stats_summary: float" ) |
313 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
314 | // Sets the shape of the output as a Rank 4 Tensor. |
315 | int max_splits; |
316 | int num_buckets; |
317 | int num_features; |
318 | TF_RETURN_IF_ERROR(c->GetAttr("max_splits" , &max_splits)); |
319 | TF_RETURN_IF_ERROR(c->GetAttr("num_buckets" , &num_buckets)); |
320 | TF_RETURN_IF_ERROR(c->GetAttr("num_features" , &num_features)); |
321 | shape_inference::ShapeHandle node_ids_shape; |
322 | shape_inference::ShapeHandle gradients_shape; |
323 | shape_inference::ShapeHandle hessians_shape; |
324 | shape_inference::ShapeHandle bucketized_feature_shape; |
325 | shape_inference::ShapeHandle unused_shape; |
326 | shape_inference::DimensionHandle unused_dim; |
327 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_ids_shape)); |
328 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape)); |
329 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape)); |
330 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(node_ids_shape, 0), |
331 | c->Dim(gradients_shape, 0), &unused_dim)); |
332 | TF_RETURN_IF_ERROR( |
333 | c->Merge(gradients_shape, hessians_shape, &unused_shape)); |
334 | for (int f = 0; f < num_features; ++f) { |
335 | TF_RETURN_IF_ERROR( |
336 | c->WithRank(c->input(3 + f), 1, &bucketized_feature_shape)); |
337 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(node_ids_shape, 0), |
338 | c->Dim(bucketized_feature_shape, 0), |
339 | &unused_dim)); |
340 | } |
341 | c->set_output(0, |
342 | c->MakeShape({num_features, max_splits, num_buckets, 2})); |
343 | return OkStatus(); |
344 | }); |
345 | |
346 | // V2 of BoostedTreesMakeStatsSummary. Supports multi-dim dense Tensor and |
347 | // multi class. |
348 | REGISTER_OP("BoostedTreesAggregateStats" ) |
349 | .Input("node_ids: int32" ) |
350 | .Input("gradients: float" ) |
351 | .Input("hessians: float" ) |
352 | .Input("feature: int32" ) |
353 | .Attr("max_splits: int >= 1" ) |
354 | .Attr("num_buckets: int >= 1" ) |
355 | .Output("stats_summary: float" ) |
356 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
357 | // Sets the shape of the output as a Rank 4 Tensor. |
358 | int max_splits; |
359 | int num_buckets; |
360 | TF_RETURN_IF_ERROR(c->GetAttr("max_splits" , &max_splits)); |
361 | TF_RETURN_IF_ERROR(c->GetAttr("num_buckets" , &num_buckets)); |
362 | |
363 | shape_inference::ShapeHandle node_ids_shape; |
364 | shape_inference::ShapeHandle gradients_shape; |
365 | shape_inference::ShapeHandle hessians_shape; |
366 | shape_inference::ShapeHandle feature_shape; |
367 | |
368 | shape_inference::DimensionHandle batch_size = c->Dim(c->input(0), 0); |
369 | |
370 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_ids_shape)); |
371 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape)); |
372 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape)); |
373 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &feature_shape)); |
374 | |
375 | // Verify all three inputs have same first dimension, i.e., batch_size. |
376 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(gradients_shape, 0), |
377 | c->Dim(node_ids_shape, 0), &batch_size)); |
378 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(hessians_shape, 0), |
379 | c->Dim(node_ids_shape, 0), &batch_size)); |
380 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0), |
381 | c->Dim(node_ids_shape, 0), &batch_size)); |
382 | |
383 | DimensionHandle logits_dim = c->Dim(c->input(1), 1); |
384 | DimensionHandle hessian_dim = c->Dim(c->input(2), 1); |
385 | DimensionHandle feature_dim = c->Dim(c->input(3), 1); |
386 | DimensionHandle stats_dim; |
387 | TF_RETURN_IF_ERROR(c->Add(logits_dim, hessian_dim, &stats_dim)); |
388 | c->set_output(0, c->MakeShape({max_splits, feature_dim, |
389 | num_buckets + 1, // +1 for missing bucket. |
390 | stats_dim})); |
391 | return OkStatus(); |
392 | }); |
393 | |
394 | // Sparse Version of BoostedTreesAggregatesStats. |
395 | REGISTER_OP("BoostedTreesSparseAggregateStats" ) |
396 | .Input("node_ids: int32" ) |
397 | .Input("gradients: float" ) |
398 | .Input("hessians: float" ) |
399 | .Input("feature_indices: int32" ) |
400 | .Input("feature_values: int32" ) |
401 | .Input("feature_shape: int32" ) |
402 | .Attr("max_splits: int >= 1" ) |
403 | .Attr("num_buckets: int >= 1" ) |
404 | .Output("stats_summary_indices: int32" ) |
405 | .Output("stats_summary_values: float" ) |
406 | .Output("stats_summary_shape: int32" ) |
407 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
408 | int max_splits; |
409 | int num_buckets; |
410 | TF_RETURN_IF_ERROR(c->GetAttr("max_splits" , &max_splits)); |
411 | TF_RETURN_IF_ERROR(c->GetAttr("num_buckets" , &num_buckets)); |
412 | |
413 | shape_inference::ShapeHandle node_ids_shape; |
414 | shape_inference::ShapeHandle gradients_shape; |
415 | shape_inference::ShapeHandle hessians_shape; |
416 | shape_inference::ShapeHandle feature_indices_shape; |
417 | shape_inference::ShapeHandle feature_values_shape; |
418 | shape_inference::ShapeHandle feature_shape; |
419 | |
420 | shape_inference::DimensionHandle batch_size = c->Dim(c->input(0), 0); |
421 | shape_inference::DimensionHandle num_entries; |
422 | |
423 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_ids_shape)); |
424 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape)); |
425 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape)); |
426 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &feature_indices_shape)); |
427 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &feature_values_shape)); |
428 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &feature_shape)); |
429 | |
430 | // Verify all inputs have same first dimension, i.e., batch_size. |
431 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(gradients_shape, 0), |
432 | c->Dim(node_ids_shape, 0), &batch_size)); |
433 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(hessians_shape, 0), |
434 | c->Dim(node_ids_shape, 0), &batch_size)); |
435 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_indices_shape, 0), |
436 | c->Dim(feature_values_shape, 0), |
437 | &num_entries)); |
438 | |
439 | DimensionHandle unused; |
440 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(feature_shape, 0), 2, &unused)); |
441 | |
442 | DimensionHandle logits_dim = c->Dim(c->input(1), 1); |
443 | DimensionHandle hessian_dim = c->Dim(c->input(2), 1); |
444 | DimensionHandle stats_dim; |
445 | TF_RETURN_IF_ERROR(c->Add(logits_dim, hessian_dim, &stats_dim)); |
446 | |
447 | c->set_output(0, c->MakeShape({c->UnknownDim(), 4})); |
448 | c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); |
449 | c->set_output(2, c->MakeShape({4})); |
450 | return OkStatus(); |
451 | }); |
452 | |
453 | // TODO(nponomareva): when/if creating the new op for unbucketized data, rename |
454 | // bucketized_features to features. |
455 | REGISTER_OP("BoostedTreesPredict" ) |
456 | .Input("tree_ensemble_handle: resource" ) |
457 | .Input("bucketized_features: num_bucketized_features * int32" ) |
458 | .Attr("num_bucketized_features: int >= 1" ) // Inferred. |
459 | .Attr("logits_dimension: int" ) |
460 | .Output("logits: float" ) |
461 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
462 | shape_inference::ShapeHandle feature_shape; |
463 | int num_bucketized_features; |
464 | TF_RETURN_IF_ERROR( |
465 | c->GetAttr("num_bucketized_features" , &num_bucketized_features)); |
466 | shape_inference::DimensionHandle batch_size = c->Dim(c->input(1), 0); |
467 | for (int i = 0; i < num_bucketized_features; ++i) { |
468 | TF_RETURN_IF_ERROR( |
469 | c->WithRankAtMost(c->input(i + 1), 2, &feature_shape)); |
470 | // Check that all bucketized features have the same batch size. |
471 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(c->input(1), 0), |
472 | c->Dim(c->input(i + 1), 0), &batch_size)); |
473 | } |
474 | |
475 | int logits_dimension; |
476 | TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension" , &logits_dimension)); |
477 | auto logits_shape = |
478 | c->MakeShape({c->Dim(feature_shape, 0), logits_dimension}); |
479 | // Logits. |
480 | c->set_output(0, logits_shape); |
481 | return OkStatus(); |
482 | }); |
483 | |
484 | REGISTER_OP("BoostedTreesExampleDebugOutputs" ) |
485 | .Input("tree_ensemble_handle: resource" ) |
486 | .Input("bucketized_features: num_bucketized_features * int32" ) |
487 | .Attr("num_bucketized_features: int >= 1" ) // Inferred. |
488 | .Attr("logits_dimension: int" ) |
489 | .Output("examples_debug_outputs_serialized: string" ) |
490 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
491 | shape_inference::ShapeHandle feature_shape; |
492 | int num_bucketized_features; |
493 | TF_RETURN_IF_ERROR( |
494 | c->GetAttr("num_bucketized_features" , &num_bucketized_features)); |
495 | shape_inference::DimensionHandle batch_dim = c->Dim(c->input(1), 0); |
496 | for (int i = 0; i < num_bucketized_features; ++i) { |
497 | TF_RETURN_IF_ERROR( |
498 | c->WithRankAtMost(c->input(i + 1), 2, &feature_shape)); |
499 | // Check that all bucketized features have the same batch size. |
500 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(c->input(1), 0), |
501 | c->Dim(c->input(i + 1), 0), &batch_dim)); |
502 | } |
503 | |
504 | // Multi-class will be supported by modifying the proto. |
505 | auto batch_size = c->MakeShape({c->Dim(feature_shape, 0)}); |
506 | c->set_output(0, batch_size); |
507 | return OkStatus(); |
508 | }); |
509 | |
510 | REGISTER_OP("BoostedTreesSerializeEnsemble" ) |
511 | .Input("tree_ensemble_handle: resource" ) |
512 | .Output("stamp_token: int64" ) |
513 | .Output("tree_ensemble_serialized: string" ) |
514 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
515 | shape_inference::ShapeHandle unused_input; |
516 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); |
517 | c->set_output(0, c->Scalar()); |
518 | c->set_output(1, c->Scalar()); |
519 | return OkStatus(); |
520 | }); |
521 | |
522 | REGISTER_OP("BoostedTreesTrainingPredict" ) |
523 | .Input("tree_ensemble_handle: resource" ) |
524 | .Input("cached_tree_ids: int32" ) |
525 | .Input("cached_node_ids: int32" ) |
526 | .Input("bucketized_features: num_bucketized_features * int32" ) |
527 | .Attr("num_bucketized_features: int >= 1" ) |
528 | .Attr("logits_dimension: int" ) |
529 | .Output("partial_logits: float" ) |
530 | .Output("tree_ids: int32" ) |
531 | .Output("node_ids: int32" ) |
532 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
533 | shape_inference::ShapeHandle feature_shape; |
534 | int num_bucketized_features; |
535 | TF_RETURN_IF_ERROR( |
536 | c->GetAttr("num_bucketized_features" , &num_bucketized_features)); |
537 | |
538 | shape_inference::ShapeHandle unused_input; |
539 | shape_inference::DimensionHandle batch_size = c->Dim(c->input(3), 0); |
540 | for (int i = 0; i < num_bucketized_features; ++i) { |
541 | TF_RETURN_IF_ERROR( |
542 | c->WithRankAtMost(c->input(i + 3), 2, &feature_shape)); |
543 | TF_RETURN_IF_ERROR( |
544 | c->Merge(c->input(i + 3), feature_shape, &unused_input)); |
545 | } |
546 | shape_inference::ShapeHandle tree_ids_shape; |
547 | shape_inference::ShapeHandle node_ids_shape; |
548 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &tree_ids_shape)); |
549 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &node_ids_shape)); |
550 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(tree_ids_shape, 0), |
551 | c->Dim(node_ids_shape, 0), &batch_size)); |
552 | |
553 | int logits_dimension; |
554 | TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension" , &logits_dimension)); |
555 | auto logits_shape = |
556 | c->MakeShape({c->Dim(feature_shape, 0), logits_dimension}); |
557 | // Partial logits. |
558 | c->set_output(0, logits_shape); |
559 | // Tree ids. |
560 | c->set_output(1, c->MakeShape({c->Dim(feature_shape, 0)})); |
561 | // Node ids. |
562 | c->set_output(2, c->MakeShape({c->Dim(feature_shape, 0)})); |
563 | return OkStatus(); |
564 | }); |
565 | |
566 | REGISTER_OP("BoostedTreesUpdateEnsemble" ) |
567 | .Input("tree_ensemble_handle: resource" ) |
568 | .Input("feature_ids: int32" ) |
569 | .Input("node_ids: num_features * int32" ) |
570 | .Input("gains: num_features * float" ) |
571 | .Input("thresholds: num_features * int32" ) |
572 | .Input("left_node_contribs: num_features * float" ) |
573 | .Input("right_node_contribs: num_features * float" ) |
574 | .Input("max_depth: int32" ) |
575 | .Input("learning_rate: float" ) |
576 | .Attr("pruning_mode: int >=0" ) |
577 | .Attr("num_features: int >= 0" ) // Inferred. |
578 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
579 | shape_inference::ShapeHandle shape_handle; |
580 | int num_features; |
581 | TF_RETURN_IF_ERROR(c->GetAttr("num_features" , &num_features)); |
582 | |
583 | // Feature_ids, should be one for each feature. |
584 | shape_inference::ShapeHandle feature_ids_shape; |
585 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &feature_ids_shape)); |
586 | TF_RETURN_IF_ERROR( |
587 | c->Merge(c->input(1), c->Vector(num_features), &shape_handle)); |
588 | |
589 | for (int i = 0; i < num_features; ++i) { |
590 | // Node ids. |
591 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2), 1, &shape_handle)); |
592 | auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)}); |
593 | auto shape_rank_2 = c->MakeShape({c->Dim(shape_handle, 0), 1}); |
594 | |
595 | // Gains. |
596 | TF_RETURN_IF_ERROR( |
597 | c->WithRank(c->input(i + num_features + 2), 1, &shape_handle)); |
598 | // TODO(nponomareva): replace this with input("name",vector of shapes). |
599 | TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features + 2), |
600 | shape_rank_1, &shape_handle)); |
601 | // Thresholds. |
602 | TF_RETURN_IF_ERROR( |
603 | c->WithRank(c->input(i + num_features * 2 + 2), 1, &shape_handle)); |
604 | TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 2 + 2), |
605 | shape_rank_1, &shape_handle)); |
606 | // Left and right node contribs. |
607 | TF_RETURN_IF_ERROR( |
608 | c->WithRank(c->input(i + num_features * 3 + 2), 2, &shape_handle)); |
609 | TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 3 + 2), |
610 | shape_rank_2, &shape_handle)); |
611 | TF_RETURN_IF_ERROR( |
612 | c->WithRank(c->input(i + num_features * 4 + 2), 2, &shape_handle)); |
613 | TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 4 + 2), |
614 | shape_rank_2, &shape_handle)); |
615 | } |
616 | return OkStatus(); |
617 | }); |
618 | |
619 | REGISTER_OP("BoostedTreesUpdateEnsembleV2" ) |
620 | .Input("tree_ensemble_handle: resource" ) |
621 | .Input("feature_ids: num_groups * int32" ) |
622 | .Input("dimension_ids: num_features * int32" ) |
623 | .Input("node_ids: num_features * int32" ) |
624 | .Input("gains: num_features * float" ) |
625 | .Input("thresholds: num_features * int32" ) |
626 | .Input("left_node_contribs: num_features * float" ) |
627 | .Input("right_node_contribs: num_features * float" ) |
628 | .Input("split_types: num_features * string" ) |
629 | .Input("max_depth: int32" ) |
630 | .Input("learning_rate: float" ) |
631 | .Input("pruning_mode: int32" ) |
632 | .Attr("num_features: int >= 0" ) // Inferred. |
633 | .Attr("logits_dimension: int = 1" ) |
634 | .Attr("num_groups: int = 1" ) // Inferred; number of groups to process. |
635 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
636 | int num_features; |
637 | int logits_dimension; |
638 | int num_groups; |
639 | TF_RETURN_IF_ERROR(c->GetAttr("num_features" , &num_features)); |
640 | TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension" , &logits_dimension)); |
641 | TF_RETURN_IF_ERROR(c->GetAttr("num_groups" , &num_groups)); |
642 | // num_features was kept for backwards compatibility reasons. It now |
643 | // represents number of groups. |
644 | DCHECK_EQ(num_features, num_groups); |
645 | shape_inference::ShapeHandle shape_handle; |
646 | for (int i = 0; i < num_groups; ++i) { |
647 | int offset = i + 1; |
648 | |
649 | // Feature ids |
650 | TF_RETURN_IF_ERROR(c->WithRank(c->input(offset), 1, &shape_handle)); |
651 | // TODO(nponomareva): replace this with input("name",vector of shapes). |
652 | auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)}); |
653 | TF_RETURN_IF_ERROR( |
654 | c->Merge(c->input(offset), shape_rank_1, &shape_handle)); |
655 | |
656 | // Dimension ids. |
657 | TF_RETURN_IF_ERROR( |
658 | c->WithRank(c->input(offset + num_features), 1, &shape_handle)); |
659 | TF_RETURN_IF_ERROR( |
660 | c->Merge(c->input(offset), shape_rank_1, &shape_handle)); |
661 | |
662 | // Node ids. |
663 | TF_RETURN_IF_ERROR( |
664 | c->WithRank(c->input(offset + num_features * 2), 1, &shape_handle)); |
665 | TF_RETURN_IF_ERROR( |
666 | c->Merge(c->input(offset), shape_rank_1, &shape_handle)); |
667 | |
668 | // Gains. |
669 | TF_RETURN_IF_ERROR( |
670 | c->WithRank(c->input(offset + num_features * 3), 1, &shape_handle)); |
671 | TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 3), |
672 | shape_rank_1, &shape_handle)); |
673 | |
674 | // Thresholds. |
675 | TF_RETURN_IF_ERROR( |
676 | c->WithRank(c->input(offset + num_features * 4), 1, &shape_handle)); |
677 | TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 4), |
678 | shape_rank_1, &shape_handle)); |
679 | |
680 | // Left and right node contribs. |
681 | auto shape_rank_2 = |
682 | c->MakeShape({c->Dim(shape_handle, 0), logits_dimension}); |
683 | TF_RETURN_IF_ERROR( |
684 | c->WithRank(c->input(offset + num_features * 5), 2, &shape_handle)); |
685 | TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 5), |
686 | shape_rank_2, &shape_handle)); |
687 | TF_RETURN_IF_ERROR( |
688 | c->WithRank(c->input(offset + num_features * 6), 2, &shape_handle)); |
689 | TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 6), |
690 | shape_rank_2, &shape_handle)); |
691 | |
692 | // Split types. |
693 | TF_RETURN_IF_ERROR( |
694 | c->WithRank(c->input(offset + num_features * 7), 1, &shape_handle)); |
695 | TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 7), |
696 | shape_rank_1, &shape_handle)); |
697 | } |
698 | return OkStatus(); |
699 | }); |
700 | |
701 | REGISTER_OP("BoostedTreesCenterBias" ) |
702 | .Input("tree_ensemble_handle: resource" ) |
703 | .Input("mean_gradients: float" ) |
704 | .Input("mean_hessians: float" ) |
705 | // Regularization-related. |
706 | .Input("l1: float" ) |
707 | .Input("l2: float" ) |
708 | .Output("continue_centering: bool" ) |
709 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
710 | shape_inference::ShapeHandle gradients_shape; |
711 | shape_inference::ShapeHandle hessians_shape; |
712 | shape_inference::ShapeHandle unused_shape; |
713 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape)); |
714 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape)); |
715 | TF_RETURN_IF_ERROR( |
716 | c->Merge(gradients_shape, hessians_shape, &unused_shape)); |
717 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape)); |
718 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape)); |
719 | |
720 | c->set_output(0, c->Scalar()); |
721 | return OkStatus(); |
722 | }); |
723 | |
724 | REGISTER_RESOURCE_HANDLE_OP(BoostedTreesQuantileStreamResource); |
725 | |
726 | REGISTER_OP("IsBoostedTreesQuantileStreamResourceInitialized" ) |
727 | .Input("quantile_stream_resource_handle: resource" ) |
728 | .Output("is_initialized: bool" ) |
729 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
730 | shape_inference::ShapeHandle unused_input; |
731 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); |
732 | c->set_output(0, c->Scalar()); |
733 | return OkStatus(); |
734 | }); |
735 | |
736 | REGISTER_OP("BoostedTreesCreateQuantileStreamResource" ) |
737 | .Attr("max_elements: int = 1099511627776" ) // 1 << 40 |
738 | .Input("quantile_stream_resource_handle: resource" ) |
739 | .Input("epsilon: float" ) |
740 | .Input("num_streams: int64" ) |
741 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
742 | shape_inference::ShapeHandle unused_input; |
743 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); |
744 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); |
745 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input)); |
746 | return OkStatus(); |
747 | }); |
748 | |
749 | REGISTER_OP("BoostedTreesMakeQuantileSummaries" ) |
750 | .Attr("num_features: int >= 0" ) |
751 | .Input("float_values: num_features * float" ) |
752 | .Input("example_weights: float" ) |
753 | .Input("epsilon: float" ) |
754 | .Output("summaries: num_features * float" ) |
755 | .SetShapeFn([](InferenceContext* c) { |
756 | int num_features; |
757 | TF_RETURN_IF_ERROR(c->GetAttr("num_features" , &num_features)); |
758 | ShapeHandle example_weights_shape; |
759 | TF_RETURN_IF_ERROR( |
760 | c->WithRank(c->input(num_features), 1, &example_weights_shape)); |
761 | for (int i = 0; i < num_features; ++i) { |
762 | ShapeHandle feature_shape; |
763 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &feature_shape)); |
764 | // the columns are value, weight, min_rank, max_rank. |
765 | c->set_output(i, c->MakeShape({c->UnknownDim(), 4})); |
766 | } |
767 | // epsilon must be a scalar. |
768 | ShapeHandle unused_input; |
769 | TF_RETURN_IF_ERROR( |
770 | c->WithRank(c->input(num_features + 1), 0, &unused_input)); |
771 | return OkStatus(); |
772 | }); |
773 | |
774 | REGISTER_OP("BoostedTreesFlushQuantileSummaries" ) |
775 | .Attr("num_features: int >= 0" ) |
776 | .Input("quantile_stream_resource_handle: resource" ) |
777 | .Output("summaries: num_features * float" ) |
778 | .SetShapeFn([](InferenceContext* c) { |
779 | int num_features; |
780 | TF_RETURN_IF_ERROR(c->GetAttr("num_features" , &num_features)); |
781 | for (int i = 0; i < num_features; ++i) { |
782 | // the columns are value, weight, min_rank, max_rank. |
783 | c->set_output(i, c->MakeShape({c->UnknownDim(), 4})); |
784 | } |
785 | return OkStatus(); |
786 | }); |
787 | |
788 | REGISTER_OP("BoostedTreesQuantileStreamResourceAddSummaries" ) |
789 | .Attr("num_features: int >= 0" ) |
790 | .Input("quantile_stream_resource_handle: resource" ) |
791 | .Input("summaries: num_features * float" ) |
792 | .SetShapeFn([](InferenceContext* c) { |
793 | int num_features; |
794 | TF_RETURN_IF_ERROR(c->GetAttr("num_features" , &num_features)); |
795 | // resource handle must be a scalar. |
796 | shape_inference::ShapeHandle unused_input; |
797 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); |
798 | // each summary must be rank 2. |
799 | for (int i = 1; i < num_features + 1; i++) { |
800 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &unused_input)); |
801 | } |
802 | return OkStatus(); |
803 | }); |
804 | |
805 | REGISTER_OP("BoostedTreesQuantileStreamResourceDeserialize" ) |
806 | .Attr("num_streams: int" ) |
807 | .Input("quantile_stream_resource_handle: resource" ) |
808 | .Input("bucket_boundaries: num_streams * float" ) |
809 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
810 | shape_inference::ShapeHandle unused_input; |
811 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); |
812 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); |
813 | return OkStatus(); |
814 | }); |
815 | |
816 | REGISTER_OP("BoostedTreesQuantileStreamResourceFlush" ) |
817 | .Attr("generate_quantiles: bool = False" ) |
818 | .Input("quantile_stream_resource_handle: resource" ) |
819 | .Input("num_buckets: int64" ) |
820 | .SetShapeFn([](InferenceContext* c) { |
821 | // All the inputs are scalars. |
822 | shape_inference::ShapeHandle unused_input; |
823 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); |
824 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); |
825 | return OkStatus(); |
826 | }); |
827 | |
828 | REGISTER_OP("BoostedTreesQuantileStreamResourceGetBucketBoundaries" ) |
829 | .Attr("num_features: int >= 0" ) |
830 | .Input("quantile_stream_resource_handle: resource" ) |
831 | .Output("bucket_boundaries: num_features * float" ) |
832 | .SetShapeFn([](InferenceContext* c) { |
833 | int num_features; |
834 | TF_RETURN_IF_ERROR(c->GetAttr("num_features" , &num_features)); |
835 | shape_inference::ShapeHandle unused_input; |
836 | // resource handle must be a scalar. |
837 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); |
838 | for (int i = 0; i < num_features; i++) { |
839 | c->set_output(i, c->Vector(c->UnknownDim())); |
840 | } |
841 | return OkStatus(); |
842 | }); |
843 | |
844 | REGISTER_OP("BoostedTreesBucketize" ) |
845 | .Attr("num_features: int >= 0" ) |
846 | .Input("float_values: num_features * float" ) |
847 | .Input("bucket_boundaries: num_features * float" ) |
848 | .Output("buckets: num_features * int32" ) |
849 | .SetShapeFn([](InferenceContext* c) { |
850 | int num_features; |
851 | TF_RETURN_IF_ERROR(c->GetAttr("num_features" , &num_features)); |
852 | ShapeHandle feature_shape; |
853 | DimensionHandle unused_dim; |
854 | for (int i = 0; i < num_features; i++) { |
855 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &feature_shape)); |
856 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0), |
857 | c->Dim(c->input(0), 0), &unused_dim)); |
858 | } |
859 | // Bucketized result should have same dimension as input. |
860 | for (int i = 0; i < num_features; i++) { |
861 | c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0)})); |
862 | } |
863 | return OkStatus(); |
864 | }); |
865 | |
866 | } // namespace tensorflow |
867 | |