1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_GRADIENT_EXCLUSIONS_H_ |
16 | #define TENSORFLOW_PYTHON_EAGER_PYWRAP_GRADIENT_EXCLUSIONS_H_ |
17 | |
18 | #include "absl/types/optional.h" |
19 | #include "tensorflow/core/lib/gtl/flatmap.h" |
20 | #include "tensorflow/core/lib/gtl/flatset.h" |
21 | |
22 | // Lookup whether the Op with the given op_name has unused input indices. |
23 | // Returns absl::nullopt if all inputs are used, set of unused indices |
24 | // otherwise. Empty set indicates that all indices are unused. The latter is |
25 | // necessary because sometimes it may not be possible to enumerate all indices |
26 | // just using OpDef e.g. when there are `list(T)` or `N * T` type inputs. |
27 | absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices( |
28 | const tensorflow::string& op_name); |
29 | |
30 | // Lookup whether the Op with the given op_name has unused output indices. |
31 | // Returns absl::nullopt if all outputs are used, set of unused indices |
32 | // otherwise. Empty set indicates that all indices are unused. The latter is |
33 | // necessary because sometimes it may not be possible to enumerate all indices |
34 | // just using OpDef e.g. when there are `list(T)` or `N * T` type outputs. |
35 | absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices( |
36 | const tensorflow::string& op_name); |
37 | |
38 | #endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_GRADIENT_EXCLUSIONS_H_ |
39 | |