1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRADIENTS_H_
17#define TENSORFLOW_CORE_COMMON_RUNTIME_GRADIENTS_H_
18
19#include "tensorflow/core/graph/graph.h"
20#include "tensorflow/core/lib/core/status.h"
21#include "tensorflow/core/lib/gtl/array_slice.h"
22
23namespace tensorflow {
24
25// Represents the output of 'node' at 'index'.
26struct NodeOut {
27 Node* node;
28 int index;
29
30 // Returns the string name that represents the output of this node.
31 string name() const;
32 // Returns the data type of the output of this node.
33 DataType dtype() const;
34};
35
36// NOTE: This API is a work in progress and will likely be changing frequently.
37//
38// Given initial gradient-node outputs 'y_grad_node_outputs' (which compute the
39// symbolic partial derivatives of some loss function 'L' w.r.t the node outputs
40// 'y_node_outputs'), adds gradient nodes to 'graph' that compute the symbolic
41// partial derivatives of 'L' w.r.t the node outputs 'x_node_outputs'.
42//
43// REQUIRES: Each node in 'x_node_outputs' to be unique, and so to have a single
44// output (this restriction will be removed in a subsequent change).
45
46// TODO(andydavis) Add symbolic gradient support for general graphs (the current
47// implementation only supports gradients for functions). In particular,
48// the nodes in 'x_nodes' are currently restricted to have one output.
49
50Status AddSymbolicGradients(gtl::ArraySlice<NodeOut> y_node_outputs,
51 gtl::ArraySlice<NodeOut> x_node_outputs,
52 gtl::ArraySlice<NodeOut> y_grad_node_outputs,
53 std::vector<NodeOut>* x_grad_node_outputs,
54 Graph* graph);
55
56} // namespace tensorflow
57
58#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRADIENTS_H_
59