1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/shape_inference.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | using shape_inference::InferenceContext; |
23 | using shape_inference::ShapeHandle; |
24 | |
25 | // -------------------------------------------------------------------------- |
26 | static Status ApplySdcaOptimizerShapeFn(InferenceContext* c) { |
27 | std::vector<ShapeHandle> sparse_handles; |
28 | if (c->input("sparse_weights" , &sparse_handles).ok()) { |
29 | TF_RETURN_IF_ERROR( |
30 | c->set_output("out_delta_sparse_weights" , sparse_handles)); |
31 | } |
32 | std::vector<ShapeHandle> dense_handles; |
33 | if (c->input("dense_weights" , &dense_handles).ok()) { |
34 | TF_RETURN_IF_ERROR(c->set_output("out_delta_dense_weights" , dense_handles)); |
35 | } |
36 | return c->set_output( |
37 | "out_example_state_data" , |
38 | {c->Matrix(InferenceContext::kUnknownDim, c->MakeDim(4))}); |
39 | } |
40 | |
41 | REGISTER_OP("SdcaOptimizer" ) |
42 | .Attr( |
43 | "loss_type: {'logistic_loss', 'squared_loss', 'hinge_loss'," |
44 | "'smooth_hinge_loss', 'poisson_loss'}" ) |
45 | .Attr("adaptative : bool=false" ) |
46 | .Attr("num_sparse_features: int >= 0" ) |
47 | .Attr("num_sparse_features_with_values: int >= 0" ) |
48 | .Attr("num_dense_features: int >= 0" ) |
49 | .Attr("l1: float" ) |
50 | .Attr("l2: float" ) |
51 | .Attr("num_loss_partitions: int >= 1" ) |
52 | .Attr("num_inner_iterations: int >= 1" ) |
53 | .Input("sparse_example_indices: num_sparse_features * int64" ) |
54 | .Input("sparse_feature_indices: num_sparse_features * int64" ) |
55 | .Input("sparse_feature_values: num_sparse_features_with_values * float" ) |
56 | .Input("dense_features: num_dense_features * float" ) |
57 | .Input("example_weights: float" ) |
58 | .Input("example_labels: float" ) |
59 | .Input("sparse_indices: num_sparse_features * int64" ) |
60 | .Input("sparse_weights: num_sparse_features * float" ) |
61 | .Input("dense_weights: num_dense_features * float" ) |
62 | .Input("example_state_data: float" ) |
63 | .Output("out_example_state_data: float" ) |
64 | .Output("out_delta_sparse_weights: num_sparse_features * float" ) |
65 | .Output("out_delta_dense_weights: num_dense_features * float" ) |
66 | .SetShapeFn(ApplySdcaOptimizerShapeFn); |
67 | |
68 | // The SdcaOptimizerV2 op fixes the "adaptative" typo in v1. |
69 | REGISTER_OP("SdcaOptimizerV2" ) |
70 | .Attr( |
71 | "loss_type: {'logistic_loss', 'squared_loss', 'hinge_loss'," |
72 | "'smooth_hinge_loss', 'poisson_loss'}" ) |
73 | .Attr("adaptive : bool=false" ) |
74 | .Attr("num_sparse_features: int >= 0" ) |
75 | .Attr("num_sparse_features_with_values: int >= 0" ) |
76 | .Attr("num_dense_features: int >= 0" ) |
77 | .Attr("l1: float" ) |
78 | .Attr("l2: float" ) |
79 | .Attr("num_loss_partitions: int >= 1" ) |
80 | .Attr("num_inner_iterations: int >= 1" ) |
81 | .Input("sparse_example_indices: num_sparse_features * int64" ) |
82 | .Input("sparse_feature_indices: num_sparse_features * int64" ) |
83 | .Input("sparse_feature_values: num_sparse_features_with_values * float" ) |
84 | .Input("dense_features: num_dense_features * float" ) |
85 | .Input("example_weights: float" ) |
86 | .Input("example_labels: float" ) |
87 | .Input("sparse_indices: num_sparse_features * int64" ) |
88 | .Input("sparse_weights: num_sparse_features * float" ) |
89 | .Input("dense_weights: num_dense_features * float" ) |
90 | .Input("example_state_data: float" ) |
91 | .Output("out_example_state_data: float" ) |
92 | .Output("out_delta_sparse_weights: num_sparse_features * float" ) |
93 | .Output("out_delta_dense_weights: num_dense_features * float" ) |
94 | .SetShapeFn(ApplySdcaOptimizerShapeFn); |
95 | |
96 | REGISTER_OP("SdcaShrinkL1" ) |
97 | .Attr("num_features: int >= 0" ) |
98 | .Attr("l1: float" ) |
99 | .Attr("l2: float" ) |
100 | .Input("weights: Ref(num_features * float)" ) |
101 | .SetShapeFn(shape_inference::UnknownShape); |
102 | |
103 | REGISTER_OP("SdcaFprint" ) |
104 | .Input("input: string" ) |
105 | .Output("output: int64" ) |
106 | .SetShapeFn([](InferenceContext* c) { |
107 | ShapeHandle handle; |
108 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
109 | ShapeHandle output_shape; |
110 | TF_RETURN_IF_ERROR(c->Concatenate(handle, c->Vector(2), &output_shape)); |
111 | c->set_output(0, output_shape); |
112 | return OkStatus(); |
113 | }); |
114 | |
115 | } // namespace tensorflow |
116 | |