1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <vector>
17
18#include "tensorflow/core/framework/common_shape_fns.h"
19#include "tensorflow/core/framework/op.h"
20#include "tensorflow/core/framework/resource_mgr.h"
21#include "tensorflow/core/framework/shape_inference.h"
22#include "tensorflow/core/framework/tensor_shape.h"
23#include "tensorflow/core/lib/core/errors.h"
24
25namespace tensorflow {
26
27using shape_inference::DimensionHandle;
28using shape_inference::InferenceContext;
29using shape_inference::ShapeHandle;
30
31REGISTER_RESOURCE_HANDLE_OP(BoostedTreesEnsembleResource);
32
33REGISTER_OP("IsBoostedTreesEnsembleInitialized")
34 .Input("tree_ensemble_handle: resource")
35 .Output("is_initialized: bool")
36 .SetShapeFn([](shape_inference::InferenceContext* c) {
37 shape_inference::ShapeHandle unused_input;
38 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
39 c->set_output(0, c->Scalar());
40 return OkStatus();
41 });
42
43REGISTER_OP("BoostedTreesCalculateBestGainsPerFeature")
44 .Input("node_id_range: int32")
45 .Input("stats_summary_list: num_features * float32")
46 .Input("l1: float")
47 .Input("l2: float")
48 .Input("tree_complexity: float")
49 .Input("min_node_weight: float")
50 .Attr("max_splits: int >= 1")
51 .Attr("num_features: int >= 1") // not passed but populated automatically.
52 .Output("node_ids_list: num_features * int32")
53 .Output("gains_list: num_features * float32")
54 .Output("thresholds_list: num_features * int32")
55 .Output("left_node_contribs_list: num_features * float32")
56 .Output("right_node_contribs_list: num_features * float32")
57 .SetShapeFn([](shape_inference::InferenceContext* c) {
58 // Confirms the rank of the inputs and sets the shape of the outputs.
59 int max_splits;
60 int num_features;
61 TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits));
62 TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
63 shape_inference::ShapeHandle node_id_range_shape;
64 shape_inference::ShapeHandle unused_shape;
65 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape));
66 TF_RETURN_IF_ERROR(
67 c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape));
68 // Checks that all stats summary entries are of the same shape.
69 shape_inference::ShapeHandle summary_shape_base;
70 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &summary_shape_base));
71 TF_RETURN_IF_ERROR(c->Merge(summary_shape_base,
72 c->MakeShape({max_splits, -1, 2}),
73 &unused_shape));
74 for (int i = 1; i < num_features; ++i) {
75 shape_inference::ShapeHandle summary_shape;
76 TF_RETURN_IF_ERROR(c->WithRank(c->input(1 + i), 3, &summary_shape));
77 TF_RETURN_IF_ERROR(
78 c->Merge(summary_shape_base, summary_shape, &unused_shape));
79 }
80 TF_RETURN_IF_ERROR(
81 c->WithRank(c->input(num_features + 1), 0, &unused_shape));
82 TF_RETURN_IF_ERROR(
83 c->WithRank(c->input(num_features + 2), 0, &unused_shape));
84 TF_RETURN_IF_ERROR(
85 c->WithRank(c->input(num_features + 3), 0, &unused_shape));
86 // Sets the output lists.
87 std::vector<shape_inference::ShapeHandle> output_shapes_vec(
88 num_features, c->MakeShape({-1}));
89 TF_RETURN_IF_ERROR(c->set_output("node_ids_list", output_shapes_vec));
90 TF_RETURN_IF_ERROR(c->set_output("gains_list", output_shapes_vec));
91 TF_RETURN_IF_ERROR(c->set_output("thresholds_list", output_shapes_vec));
92 std::vector<shape_inference::ShapeHandle> output_shapes_contribs(
93 num_features, c->MakeShape({-1, 1}));
94 TF_RETURN_IF_ERROR(
95 c->set_output("left_node_contribs_list", output_shapes_contribs));
96 TF_RETURN_IF_ERROR(
97 c->set_output("right_node_contribs_list", output_shapes_contribs));
98 return OkStatus();
99 });
100
101REGISTER_OP("BoostedTreesCalculateBestFeatureSplit")
102 .Input("node_id_range: int32")
103 .Input("stats_summary: float32")
104 .Input("l1: float")
105 .Input("l2: float")
106 .Input("tree_complexity: float")
107 .Input("min_node_weight: float")
108 .Attr("logits_dimension: int >= 1")
109 .Attr("split_type: {'inequality', 'equality'} = 'inequality'")
110 .Output("node_ids: int32")
111 .Output("gains: float32")
112 .Output("feature_dimensions: int32")
113 .Output("thresholds: int32")
114 .Output("left_node_contribs: float32")
115 .Output("right_node_contribs: float32")
116 .Output("split_with_default_directions: string")
117 .SetShapeFn([](shape_inference::InferenceContext* c) {
118 shape_inference::ShapeHandle node_id_range_shape;
119 shape_inference::ShapeHandle unused_shape;
120 // node id range is rank 1 with 2 values.
121 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape));
122 TF_RETURN_IF_ERROR(
123 c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape));
124 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &unused_shape));
125 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
126 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
127 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
128 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape));
129 ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()});
130 c->set_output(0, rank_1_output_shape);
131 c->set_output(1, rank_1_output_shape);
132 c->set_output(2, rank_1_output_shape);
133 c->set_output(3, rank_1_output_shape);
134 c->set_output(6, rank_1_output_shape);
135 int logits_dimension;
136 TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
137 ShapeHandle contribs_output_shape =
138 c->MakeShape({c->UnknownDim(), logits_dimension});
139 c->set_output(4, contribs_output_shape);
140 c->set_output(5, contribs_output_shape);
141 return OkStatus();
142 });
143
144REGISTER_OP("BoostedTreesCalculateBestFeatureSplitV2")
145 .Input("node_id_range: int32")
146 .Input("stats_summaries_list: num_features * float32")
147 .Input("split_types: string")
148 .Input("candidate_feature_ids: int32")
149 .Input("l1: float")
150 .Input("l2: float")
151 .Input("tree_complexity: float")
152 .Input("min_node_weight: float")
153 .Attr("num_features: int >= 1") // not passed but populated automatically.
154 .Attr("logits_dimension: int >= 1")
155 .Output("node_ids: int32")
156 .Output("gains: float32")
157 .Output("feature_ids: int32")
158 .Output("feature_dimensions: int32")
159 .Output("thresholds: int32")
160 .Output("left_node_contribs: float32")
161 .Output("right_node_contribs: float32")
162 .Output("split_with_default_directions: string")
163 .SetShapeFn([](shape_inference::InferenceContext* c) {
164 // Attributes.
165 int num_features;
166 TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
167 int logits_dimension;
168 TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
169 // Inputs.
170 shape_inference::ShapeHandle unused_shape;
171 // node id range is rank 1 with 2 values.
172 shape_inference::ShapeHandle node_id_range_shape;
173 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape));
174 TF_RETURN_IF_ERROR(
175 c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape));
176 // Stats summary validation.
177 shape_inference::ShapeHandle summary_shape_base;
178 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &summary_shape_base));
179 // All stats summary entries are of the same shape.
180 for (int i = 1; i < num_features; ++i) {
181 shape_inference::ShapeHandle summary_shape;
182 TF_RETURN_IF_ERROR(c->WithRank(c->input(1 + i), 4, &summary_shape));
183 TF_RETURN_IF_ERROR(
184 c->Merge(summary_shape_base, summary_shape, &unused_shape));
185 }
186 // Validate rank 1 split_types.
187 TF_RETURN_IF_ERROR(
188 c->WithRank(c->input(1 + num_features), 1, &unused_shape));
189 // Validate rank 1 feature_ids.
190 TF_RETURN_IF_ERROR(
191 c->WithRank(c->input(2 + num_features), 1, &unused_shape));
192 // Validate rank 0: l1, l2, tree_complexity, min_node_weight.
193 for (int i = 0; i < 4; ++i) {
194 TF_RETURN_IF_ERROR(
195 c->WithRank(c->input(3 + num_features + i), 0, &unused_shape));
196 }
197 // Output shapes.
198 ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()});
199 c->set_output(0, rank_1_output_shape);
200 c->set_output(1, rank_1_output_shape);
201 c->set_output(2, rank_1_output_shape);
202 c->set_output(3, rank_1_output_shape);
203 c->set_output(4, rank_1_output_shape);
204 ShapeHandle contribs_output_shape =
205 c->MakeShape({c->UnknownDim(), logits_dimension});
206 c->set_output(5, contribs_output_shape);
207 c->set_output(6, contribs_output_shape);
208 c->set_output(7, rank_1_output_shape);
209 return OkStatus();
210 });
211
212REGISTER_OP("BoostedTreesSparseCalculateBestFeatureSplit")
213 .Input("node_id_range: int32")
214 .Input("stats_summary_indices: int32")
215 .Input("stats_summary_values: float")
216 .Input("stats_summary_shape: int32")
217 .Input("l1: float")
218 .Input("l2: float")
219 .Input("tree_complexity: float")
220 .Input("min_node_weight: float")
221 .Attr("logits_dimension: int >= 1")
222 .Attr("split_type: {'inequality'} = 'inequality'")
223 .Output("node_ids: int32")
224 .Output("gains: float32")
225 .Output("feature_dimensions: int32")
226 .Output("thresholds: int32")
227 .Output("left_node_contribs: float32")
228 .Output("right_node_contribs: float32")
229 .Output("split_with_default_directions: string")
230 .SetShapeFn([](shape_inference::InferenceContext* c) {
231 shape_inference::ShapeHandle node_id_range_shape;
232 shape_inference::ShapeHandle unused_shape;
233 // node id range is rank 1 with 2 values.
234 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape));
235 TF_RETURN_IF_ERROR(
236 c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape));
237 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
238 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused_shape));
239 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused_shape));
240 shape_inference::ShapeHandle summary_shape;
241 TF_RETURN_IF_ERROR(
242 c->Merge(summary_shape, c->MakeShape({4}), &unused_shape));
243 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
244 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape));
245 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused_shape));
246 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused_shape));
247 ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()});
248 c->set_output(0, rank_1_output_shape);
249 c->set_output(1, rank_1_output_shape);
250 c->set_output(2, rank_1_output_shape);
251 c->set_output(3, rank_1_output_shape);
252 c->set_output(6, rank_1_output_shape);
253 int logits_dimension;
254 TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
255 ShapeHandle contribs_output_shape =
256 c->MakeShape({c->UnknownDim(), logits_dimension});
257 c->set_output(4, contribs_output_shape);
258 c->set_output(5, contribs_output_shape);
259 return OkStatus();
260 });
261
262REGISTER_OP("BoostedTreesCreateEnsemble")
263 .Input("tree_ensemble_handle: resource")
264 .Input("stamp_token: int64")
265 .Input("tree_ensemble_serialized: string")
266 .SetShapeFn([](shape_inference::InferenceContext* c) {
267 shape_inference::ShapeHandle unused_input;
268 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
269 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
270 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
271 return OkStatus();
272 });
273
274REGISTER_OP("BoostedTreesDeserializeEnsemble")
275 .Input("tree_ensemble_handle: resource")
276 .Input("stamp_token: int64")
277 .Input("tree_ensemble_serialized: string")
278 .SetShapeFn([](shape_inference::InferenceContext* c) {
279 shape_inference::ShapeHandle unused_input;
280 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
281 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
282 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
283 return OkStatus();
284 });
285
286REGISTER_OP("BoostedTreesGetEnsembleStates")
287 .Input("tree_ensemble_handle: resource")
288 .Output("stamp_token: int64")
289 .Output("num_trees: int32")
290 .Output("num_finalized_trees: int32")
291 .Output("num_attempted_layers: int32")
292 .Output("last_layer_nodes_range: int32")
293 .SetShapeFn([](shape_inference::InferenceContext* c) {
294 shape_inference::ShapeHandle unused_input;
295 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
296 c->set_output(0, c->Scalar());
297 c->set_output(1, c->Scalar());
298 c->set_output(2, c->Scalar());
299 c->set_output(3, c->Scalar());
300 c->set_output(4, c->Vector(2));
301 return OkStatus();
302 });
303
304REGISTER_OP("BoostedTreesMakeStatsSummary")
305 .Input("node_ids: int32")
306 .Input("gradients: float")
307 .Input("hessians: float")
308 .Input("bucketized_features_list: num_features * int32")
309 .Attr("max_splits: int >= 1")
310 .Attr("num_buckets: int >= 1")
311 .Attr("num_features: int >= 1")
312 .Output("stats_summary: float")
313 .SetShapeFn([](shape_inference::InferenceContext* c) {
314 // Sets the shape of the output as a Rank 4 Tensor.
315 int max_splits;
316 int num_buckets;
317 int num_features;
318 TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits));
319 TF_RETURN_IF_ERROR(c->GetAttr("num_buckets", &num_buckets));
320 TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
321 shape_inference::ShapeHandle node_ids_shape;
322 shape_inference::ShapeHandle gradients_shape;
323 shape_inference::ShapeHandle hessians_shape;
324 shape_inference::ShapeHandle bucketized_feature_shape;
325 shape_inference::ShapeHandle unused_shape;
326 shape_inference::DimensionHandle unused_dim;
327 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_ids_shape));
328 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape));
329 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape));
330 TF_RETURN_IF_ERROR(c->Merge(c->Dim(node_ids_shape, 0),
331 c->Dim(gradients_shape, 0), &unused_dim));
332 TF_RETURN_IF_ERROR(
333 c->Merge(gradients_shape, hessians_shape, &unused_shape));
334 for (int f = 0; f < num_features; ++f) {
335 TF_RETURN_IF_ERROR(
336 c->WithRank(c->input(3 + f), 1, &bucketized_feature_shape));
337 TF_RETURN_IF_ERROR(c->Merge(c->Dim(node_ids_shape, 0),
338 c->Dim(bucketized_feature_shape, 0),
339 &unused_dim));
340 }
341 c->set_output(0,
342 c->MakeShape({num_features, max_splits, num_buckets, 2}));
343 return OkStatus();
344 });
345
346// V2 of BoostedTreesMakeStatsSummary. Supports multi-dim dense Tensor and
347// multi class.
348REGISTER_OP("BoostedTreesAggregateStats")
349 .Input("node_ids: int32")
350 .Input("gradients: float")
351 .Input("hessians: float")
352 .Input("feature: int32")
353 .Attr("max_splits: int >= 1")
354 .Attr("num_buckets: int >= 1")
355 .Output("stats_summary: float")
356 .SetShapeFn([](shape_inference::InferenceContext* c) {
357 // Sets the shape of the output as a Rank 4 Tensor.
358 int max_splits;
359 int num_buckets;
360 TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits));
361 TF_RETURN_IF_ERROR(c->GetAttr("num_buckets", &num_buckets));
362
363 shape_inference::ShapeHandle node_ids_shape;
364 shape_inference::ShapeHandle gradients_shape;
365 shape_inference::ShapeHandle hessians_shape;
366 shape_inference::ShapeHandle feature_shape;
367
368 shape_inference::DimensionHandle batch_size = c->Dim(c->input(0), 0);
369
370 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_ids_shape));
371 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape));
372 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape));
373 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &feature_shape));
374
375 // Verify all three inputs have same first dimension, i.e., batch_size.
376 TF_RETURN_IF_ERROR(c->Merge(c->Dim(gradients_shape, 0),
377 c->Dim(node_ids_shape, 0), &batch_size));
378 TF_RETURN_IF_ERROR(c->Merge(c->Dim(hessians_shape, 0),
379 c->Dim(node_ids_shape, 0), &batch_size));
380 TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
381 c->Dim(node_ids_shape, 0), &batch_size));
382
383 DimensionHandle logits_dim = c->Dim(c->input(1), 1);
384 DimensionHandle hessian_dim = c->Dim(c->input(2), 1);
385 DimensionHandle feature_dim = c->Dim(c->input(3), 1);
386 DimensionHandle stats_dim;
387 TF_RETURN_IF_ERROR(c->Add(logits_dim, hessian_dim, &stats_dim));
388 c->set_output(0, c->MakeShape({max_splits, feature_dim,
389 num_buckets + 1, // +1 for missing bucket.
390 stats_dim}));
391 return OkStatus();
392 });
393
394// Sparse Version of BoostedTreesAggregatesStats.
395REGISTER_OP("BoostedTreesSparseAggregateStats")
396 .Input("node_ids: int32")
397 .Input("gradients: float")
398 .Input("hessians: float")
399 .Input("feature_indices: int32")
400 .Input("feature_values: int32")
401 .Input("feature_shape: int32")
402 .Attr("max_splits: int >= 1")
403 .Attr("num_buckets: int >= 1")
404 .Output("stats_summary_indices: int32")
405 .Output("stats_summary_values: float")
406 .Output("stats_summary_shape: int32")
407 .SetShapeFn([](shape_inference::InferenceContext* c) {
408 int max_splits;
409 int num_buckets;
410 TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits));
411 TF_RETURN_IF_ERROR(c->GetAttr("num_buckets", &num_buckets));
412
413 shape_inference::ShapeHandle node_ids_shape;
414 shape_inference::ShapeHandle gradients_shape;
415 shape_inference::ShapeHandle hessians_shape;
416 shape_inference::ShapeHandle feature_indices_shape;
417 shape_inference::ShapeHandle feature_values_shape;
418 shape_inference::ShapeHandle feature_shape;
419
420 shape_inference::DimensionHandle batch_size = c->Dim(c->input(0), 0);
421 shape_inference::DimensionHandle num_entries;
422
423 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_ids_shape));
424 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape));
425 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape));
426 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &feature_indices_shape));
427 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &feature_values_shape));
428 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &feature_shape));
429
430 // Verify all inputs have same first dimension, i.e., batch_size.
431 TF_RETURN_IF_ERROR(c->Merge(c->Dim(gradients_shape, 0),
432 c->Dim(node_ids_shape, 0), &batch_size));
433 TF_RETURN_IF_ERROR(c->Merge(c->Dim(hessians_shape, 0),
434 c->Dim(node_ids_shape, 0), &batch_size));
435 TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_indices_shape, 0),
436 c->Dim(feature_values_shape, 0),
437 &num_entries));
438
439 DimensionHandle unused;
440 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(feature_shape, 0), 2, &unused));
441
442 DimensionHandle logits_dim = c->Dim(c->input(1), 1);
443 DimensionHandle hessian_dim = c->Dim(c->input(2), 1);
444 DimensionHandle stats_dim;
445 TF_RETURN_IF_ERROR(c->Add(logits_dim, hessian_dim, &stats_dim));
446
447 c->set_output(0, c->MakeShape({c->UnknownDim(), 4}));
448 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
449 c->set_output(2, c->MakeShape({4}));
450 return OkStatus();
451 });
452
453// TODO(nponomareva): when/if creating the new op for unbucketized data, rename
454// bucketized_features to features.
455REGISTER_OP("BoostedTreesPredict")
456 .Input("tree_ensemble_handle: resource")
457 .Input("bucketized_features: num_bucketized_features * int32")
458 .Attr("num_bucketized_features: int >= 1") // Inferred.
459 .Attr("logits_dimension: int")
460 .Output("logits: float")
461 .SetShapeFn([](shape_inference::InferenceContext* c) {
462 shape_inference::ShapeHandle feature_shape;
463 int num_bucketized_features;
464 TF_RETURN_IF_ERROR(
465 c->GetAttr("num_bucketized_features", &num_bucketized_features));
466 shape_inference::DimensionHandle batch_size = c->Dim(c->input(1), 0);
467 for (int i = 0; i < num_bucketized_features; ++i) {
468 TF_RETURN_IF_ERROR(
469 c->WithRankAtMost(c->input(i + 1), 2, &feature_shape));
470 // Check that all bucketized features have the same batch size.
471 TF_RETURN_IF_ERROR(c->Merge(c->Dim(c->input(1), 0),
472 c->Dim(c->input(i + 1), 0), &batch_size));
473 }
474
475 int logits_dimension;
476 TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
477 auto logits_shape =
478 c->MakeShape({c->Dim(feature_shape, 0), logits_dimension});
479 // Logits.
480 c->set_output(0, logits_shape);
481 return OkStatus();
482 });
483
484REGISTER_OP("BoostedTreesExampleDebugOutputs")
485 .Input("tree_ensemble_handle: resource")
486 .Input("bucketized_features: num_bucketized_features * int32")
487 .Attr("num_bucketized_features: int >= 1") // Inferred.
488 .Attr("logits_dimension: int")
489 .Output("examples_debug_outputs_serialized: string")
490 .SetShapeFn([](shape_inference::InferenceContext* c) {
491 shape_inference::ShapeHandle feature_shape;
492 int num_bucketized_features;
493 TF_RETURN_IF_ERROR(
494 c->GetAttr("num_bucketized_features", &num_bucketized_features));
495 shape_inference::DimensionHandle batch_dim = c->Dim(c->input(1), 0);
496 for (int i = 0; i < num_bucketized_features; ++i) {
497 TF_RETURN_IF_ERROR(
498 c->WithRankAtMost(c->input(i + 1), 2, &feature_shape));
499 // Check that all bucketized features have the same batch size.
500 TF_RETURN_IF_ERROR(c->Merge(c->Dim(c->input(1), 0),
501 c->Dim(c->input(i + 1), 0), &batch_dim));
502 }
503
504 // Multi-class will be supported by modifying the proto.
505 auto batch_size = c->MakeShape({c->Dim(feature_shape, 0)});
506 c->set_output(0, batch_size);
507 return OkStatus();
508 });
509
510REGISTER_OP("BoostedTreesSerializeEnsemble")
511 .Input("tree_ensemble_handle: resource")
512 .Output("stamp_token: int64")
513 .Output("tree_ensemble_serialized: string")
514 .SetShapeFn([](shape_inference::InferenceContext* c) {
515 shape_inference::ShapeHandle unused_input;
516 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
517 c->set_output(0, c->Scalar());
518 c->set_output(1, c->Scalar());
519 return OkStatus();
520 });
521
522REGISTER_OP("BoostedTreesTrainingPredict")
523 .Input("tree_ensemble_handle: resource")
524 .Input("cached_tree_ids: int32")
525 .Input("cached_node_ids: int32")
526 .Input("bucketized_features: num_bucketized_features * int32")
527 .Attr("num_bucketized_features: int >= 1")
528 .Attr("logits_dimension: int")
529 .Output("partial_logits: float")
530 .Output("tree_ids: int32")
531 .Output("node_ids: int32")
532 .SetShapeFn([](shape_inference::InferenceContext* c) {
533 shape_inference::ShapeHandle feature_shape;
534 int num_bucketized_features;
535 TF_RETURN_IF_ERROR(
536 c->GetAttr("num_bucketized_features", &num_bucketized_features));
537
538 shape_inference::ShapeHandle unused_input;
539 shape_inference::DimensionHandle batch_size = c->Dim(c->input(3), 0);
540 for (int i = 0; i < num_bucketized_features; ++i) {
541 TF_RETURN_IF_ERROR(
542 c->WithRankAtMost(c->input(i + 3), 2, &feature_shape));
543 TF_RETURN_IF_ERROR(
544 c->Merge(c->input(i + 3), feature_shape, &unused_input));
545 }
546 shape_inference::ShapeHandle tree_ids_shape;
547 shape_inference::ShapeHandle node_ids_shape;
548 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &tree_ids_shape));
549 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &node_ids_shape));
550 TF_RETURN_IF_ERROR(c->Merge(c->Dim(tree_ids_shape, 0),
551 c->Dim(node_ids_shape, 0), &batch_size));
552
553 int logits_dimension;
554 TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
555 auto logits_shape =
556 c->MakeShape({c->Dim(feature_shape, 0), logits_dimension});
557 // Partial logits.
558 c->set_output(0, logits_shape);
559 // Tree ids.
560 c->set_output(1, c->MakeShape({c->Dim(feature_shape, 0)}));
561 // Node ids.
562 c->set_output(2, c->MakeShape({c->Dim(feature_shape, 0)}));
563 return OkStatus();
564 });
565
566REGISTER_OP("BoostedTreesUpdateEnsemble")
567 .Input("tree_ensemble_handle: resource")
568 .Input("feature_ids: int32")
569 .Input("node_ids: num_features * int32")
570 .Input("gains: num_features * float")
571 .Input("thresholds: num_features * int32")
572 .Input("left_node_contribs: num_features * float")
573 .Input("right_node_contribs: num_features * float")
574 .Input("max_depth: int32")
575 .Input("learning_rate: float")
576 .Attr("pruning_mode: int >=0")
577 .Attr("num_features: int >= 0") // Inferred.
578 .SetShapeFn([](shape_inference::InferenceContext* c) {
579 shape_inference::ShapeHandle shape_handle;
580 int num_features;
581 TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
582
583 // Feature_ids, should be one for each feature.
584 shape_inference::ShapeHandle feature_ids_shape;
585 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &feature_ids_shape));
586 TF_RETURN_IF_ERROR(
587 c->Merge(c->input(1), c->Vector(num_features), &shape_handle));
588
589 for (int i = 0; i < num_features; ++i) {
590 // Node ids.
591 TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2), 1, &shape_handle));
592 auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)});
593 auto shape_rank_2 = c->MakeShape({c->Dim(shape_handle, 0), 1});
594
595 // Gains.
596 TF_RETURN_IF_ERROR(
597 c->WithRank(c->input(i + num_features + 2), 1, &shape_handle));
598 // TODO(nponomareva): replace this with input("name",vector of shapes).
599 TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features + 2),
600 shape_rank_1, &shape_handle));
601 // Thresholds.
602 TF_RETURN_IF_ERROR(
603 c->WithRank(c->input(i + num_features * 2 + 2), 1, &shape_handle));
604 TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 2 + 2),
605 shape_rank_1, &shape_handle));
606 // Left and right node contribs.
607 TF_RETURN_IF_ERROR(
608 c->WithRank(c->input(i + num_features * 3 + 2), 2, &shape_handle));
609 TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 3 + 2),
610 shape_rank_2, &shape_handle));
611 TF_RETURN_IF_ERROR(
612 c->WithRank(c->input(i + num_features * 4 + 2), 2, &shape_handle));
613 TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 4 + 2),
614 shape_rank_2, &shape_handle));
615 }
616 return OkStatus();
617 });
618
619REGISTER_OP("BoostedTreesUpdateEnsembleV2")
620 .Input("tree_ensemble_handle: resource")
621 .Input("feature_ids: num_groups * int32")
622 .Input("dimension_ids: num_features * int32")
623 .Input("node_ids: num_features * int32")
624 .Input("gains: num_features * float")
625 .Input("thresholds: num_features * int32")
626 .Input("left_node_contribs: num_features * float")
627 .Input("right_node_contribs: num_features * float")
628 .Input("split_types: num_features * string")
629 .Input("max_depth: int32")
630 .Input("learning_rate: float")
631 .Input("pruning_mode: int32")
632 .Attr("num_features: int >= 0") // Inferred.
633 .Attr("logits_dimension: int = 1")
634 .Attr("num_groups: int = 1") // Inferred; number of groups to process.
635 .SetShapeFn([](shape_inference::InferenceContext* c) {
636 int num_features;
637 int logits_dimension;
638 int num_groups;
639 TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
640 TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
641 TF_RETURN_IF_ERROR(c->GetAttr("num_groups", &num_groups));
642 // num_features was kept for backwards compatibility reasons. It now
643 // represents number of groups.
644 DCHECK_EQ(num_features, num_groups);
645 shape_inference::ShapeHandle shape_handle;
646 for (int i = 0; i < num_groups; ++i) {
647 int offset = i + 1;
648
649 // Feature ids
650 TF_RETURN_IF_ERROR(c->WithRank(c->input(offset), 1, &shape_handle));
651 // TODO(nponomareva): replace this with input("name",vector of shapes).
652 auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)});
653 TF_RETURN_IF_ERROR(
654 c->Merge(c->input(offset), shape_rank_1, &shape_handle));
655
656 // Dimension ids.
657 TF_RETURN_IF_ERROR(
658 c->WithRank(c->input(offset + num_features), 1, &shape_handle));
659 TF_RETURN_IF_ERROR(
660 c->Merge(c->input(offset), shape_rank_1, &shape_handle));
661
662 // Node ids.
663 TF_RETURN_IF_ERROR(
664 c->WithRank(c->input(offset + num_features * 2), 1, &shape_handle));
665 TF_RETURN_IF_ERROR(
666 c->Merge(c->input(offset), shape_rank_1, &shape_handle));
667
668 // Gains.
669 TF_RETURN_IF_ERROR(
670 c->WithRank(c->input(offset + num_features * 3), 1, &shape_handle));
671 TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 3),
672 shape_rank_1, &shape_handle));
673
674 // Thresholds.
675 TF_RETURN_IF_ERROR(
676 c->WithRank(c->input(offset + num_features * 4), 1, &shape_handle));
677 TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 4),
678 shape_rank_1, &shape_handle));
679
680 // Left and right node contribs.
681 auto shape_rank_2 =
682 c->MakeShape({c->Dim(shape_handle, 0), logits_dimension});
683 TF_RETURN_IF_ERROR(
684 c->WithRank(c->input(offset + num_features * 5), 2, &shape_handle));
685 TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 5),
686 shape_rank_2, &shape_handle));
687 TF_RETURN_IF_ERROR(
688 c->WithRank(c->input(offset + num_features * 6), 2, &shape_handle));
689 TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 6),
690 shape_rank_2, &shape_handle));
691
692 // Split types.
693 TF_RETURN_IF_ERROR(
694 c->WithRank(c->input(offset + num_features * 7), 1, &shape_handle));
695 TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 7),
696 shape_rank_1, &shape_handle));
697 }
698 return OkStatus();
699 });
700
701REGISTER_OP("BoostedTreesCenterBias")
702 .Input("tree_ensemble_handle: resource")
703 .Input("mean_gradients: float")
704 .Input("mean_hessians: float")
705 // Regularization-related.
706 .Input("l1: float")
707 .Input("l2: float")
708 .Output("continue_centering: bool")
709 .SetShapeFn([](shape_inference::InferenceContext* c) {
710 shape_inference::ShapeHandle gradients_shape;
711 shape_inference::ShapeHandle hessians_shape;
712 shape_inference::ShapeHandle unused_shape;
713 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape));
714 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape));
715 TF_RETURN_IF_ERROR(
716 c->Merge(gradients_shape, hessians_shape, &unused_shape));
717 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
718 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
719
720 c->set_output(0, c->Scalar());
721 return OkStatus();
722 });
723
724REGISTER_RESOURCE_HANDLE_OP(BoostedTreesQuantileStreamResource);
725
726REGISTER_OP("IsBoostedTreesQuantileStreamResourceInitialized")
727 .Input("quantile_stream_resource_handle: resource")
728 .Output("is_initialized: bool")
729 .SetShapeFn([](shape_inference::InferenceContext* c) {
730 shape_inference::ShapeHandle unused_input;
731 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
732 c->set_output(0, c->Scalar());
733 return OkStatus();
734 });
735
736REGISTER_OP("BoostedTreesCreateQuantileStreamResource")
737 .Attr("max_elements: int = 1099511627776") // 1 << 40
738 .Input("quantile_stream_resource_handle: resource")
739 .Input("epsilon: float")
740 .Input("num_streams: int64")
741 .SetShapeFn([](shape_inference::InferenceContext* c) {
742 shape_inference::ShapeHandle unused_input;
743 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
744 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
745 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
746 return OkStatus();
747 });
748
749REGISTER_OP("BoostedTreesMakeQuantileSummaries")
750 .Attr("num_features: int >= 0")
751 .Input("float_values: num_features * float")
752 .Input("example_weights: float")
753 .Input("epsilon: float")
754 .Output("summaries: num_features * float")
755 .SetShapeFn([](InferenceContext* c) {
756 int num_features;
757 TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
758 ShapeHandle example_weights_shape;
759 TF_RETURN_IF_ERROR(
760 c->WithRank(c->input(num_features), 1, &example_weights_shape));
761 for (int i = 0; i < num_features; ++i) {
762 ShapeHandle feature_shape;
763 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &feature_shape));
764 // the columns are value, weight, min_rank, max_rank.
765 c->set_output(i, c->MakeShape({c->UnknownDim(), 4}));
766 }
767 // epsilon must be a scalar.
768 ShapeHandle unused_input;
769 TF_RETURN_IF_ERROR(
770 c->WithRank(c->input(num_features + 1), 0, &unused_input));
771 return OkStatus();
772 });
773
774REGISTER_OP("BoostedTreesFlushQuantileSummaries")
775 .Attr("num_features: int >= 0")
776 .Input("quantile_stream_resource_handle: resource")
777 .Output("summaries: num_features * float")
778 .SetShapeFn([](InferenceContext* c) {
779 int num_features;
780 TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
781 for (int i = 0; i < num_features; ++i) {
782 // the columns are value, weight, min_rank, max_rank.
783 c->set_output(i, c->MakeShape({c->UnknownDim(), 4}));
784 }
785 return OkStatus();
786 });
787
788REGISTER_OP("BoostedTreesQuantileStreamResourceAddSummaries")
789 .Attr("num_features: int >= 0")
790 .Input("quantile_stream_resource_handle: resource")
791 .Input("summaries: num_features * float")
792 .SetShapeFn([](InferenceContext* c) {
793 int num_features;
794 TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
795 // resource handle must be a scalar.
796 shape_inference::ShapeHandle unused_input;
797 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
798 // each summary must be rank 2.
799 for (int i = 1; i < num_features + 1; i++) {
800 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &unused_input));
801 }
802 return OkStatus();
803 });
804
805REGISTER_OP("BoostedTreesQuantileStreamResourceDeserialize")
806 .Attr("num_streams: int")
807 .Input("quantile_stream_resource_handle: resource")
808 .Input("bucket_boundaries: num_streams * float")
809 .SetShapeFn([](shape_inference::InferenceContext* c) {
810 shape_inference::ShapeHandle unused_input;
811 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
812 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
813 return OkStatus();
814 });
815
816REGISTER_OP("BoostedTreesQuantileStreamResourceFlush")
817 .Attr("generate_quantiles: bool = False")
818 .Input("quantile_stream_resource_handle: resource")
819 .Input("num_buckets: int64")
820 .SetShapeFn([](InferenceContext* c) {
821 // All the inputs are scalars.
822 shape_inference::ShapeHandle unused_input;
823 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
824 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
825 return OkStatus();
826 });
827
828REGISTER_OP("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
829 .Attr("num_features: int >= 0")
830 .Input("quantile_stream_resource_handle: resource")
831 .Output("bucket_boundaries: num_features * float")
832 .SetShapeFn([](InferenceContext* c) {
833 int num_features;
834 TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
835 shape_inference::ShapeHandle unused_input;
836 // resource handle must be a scalar.
837 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
838 for (int i = 0; i < num_features; i++) {
839 c->set_output(i, c->Vector(c->UnknownDim()));
840 }
841 return OkStatus();
842 });
843
844REGISTER_OP("BoostedTreesBucketize")
845 .Attr("num_features: int >= 0")
846 .Input("float_values: num_features * float")
847 .Input("bucket_boundaries: num_features * float")
848 .Output("buckets: num_features * int32")
849 .SetShapeFn([](InferenceContext* c) {
850 int num_features;
851 TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
852 ShapeHandle feature_shape;
853 DimensionHandle unused_dim;
854 for (int i = 0; i < num_features; i++) {
855 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &feature_shape));
856 TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
857 c->Dim(c->input(0), 0), &unused_dim));
858 }
859 // Bucketized result should have same dimension as input.
860 for (int i = 0; i < num_features; i++) {
861 c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0)}));
862 }
863 return OkStatus();
864 });
865
866} // namespace tensorflow
867