1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRADIENTS_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_GRADIENTS_H_ |
18 | |
19 | #include "tensorflow/core/graph/graph.h" |
20 | #include "tensorflow/core/lib/core/status.h" |
21 | #include "tensorflow/core/lib/gtl/array_slice.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | // Represents the output of 'node' at 'index'. |
26 | struct NodeOut { |
27 | Node* node; |
28 | int index; |
29 | |
30 | // Returns the string name that represents the output of this node. |
31 | string name() const; |
32 | // Returns the data type of the output of this node. |
33 | DataType dtype() const; |
34 | }; |
35 | |
36 | // NOTE: This API is a work in progress and will likely be changing frequently. |
37 | // |
38 | // Given initial gradient-node outputs 'y_grad_node_outputs' (which compute the |
39 | // symbolic partial derivatives of some loss function 'L' w.r.t the node outputs |
40 | // 'y_node_outputs'), adds gradient nodes to 'graph' that compute the symbolic |
41 | // partial derivatives of 'L' w.r.t the node outputs 'x_node_outputs'. |
42 | // |
43 | // REQUIRES: Each node in 'x_node_outputs' to be unique, and so to have a single |
44 | // output (this restriction will be removed in a subsequent change). |
45 | |
46 | // TODO(andydavis) Add symbolic gradient support for general graphs (the current |
47 | // implementation only supports gradients for functions). In particular, |
48 | // the nodes in 'x_nodes' are currently restricted to have one output. |
49 | |
50 | Status AddSymbolicGradients(gtl::ArraySlice<NodeOut> y_node_outputs, |
51 | gtl::ArraySlice<NodeOut> x_node_outputs, |
52 | gtl::ArraySlice<NodeOut> y_grad_node_outputs, |
53 | std::vector<NodeOut>* x_grad_node_outputs, |
54 | Graph* graph); |
55 | |
56 | } // namespace tensorflow |
57 | |
58 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRADIENTS_H_ |
59 | |