1// This file is MACHINE GENERATED! Do not edit.
2
3#ifndef TENSORFLOW_CC_OPS_DATA_FLOW_OPS_INTERNAL_H_
4#define TENSORFLOW_CC_OPS_DATA_FLOW_OPS_INTERNAL_H_
5
6// This file is MACHINE GENERATED! Do not edit.
7
8#include "tensorflow/cc/framework/ops.h"
9#include "tensorflow/cc/framework/scope.h"
10#include "tensorflow/core/framework/tensor.h"
11#include "tensorflow/core/framework/tensor_shape.h"
12#include "tensorflow/core/framework/types.h"
13#include "tensorflow/core/lib/gtl/array_slice.h"
14
15namespace tensorflow {
16namespace ops {
17namespace internal {
18// NOTE: This namespace has internal TensorFlow details that
19// are not part of TensorFlow's public API.
20
21/// @defgroup data_flow_ops_internal Data Flow Ops Internal
22/// @{
23
24/// Applies a gradient to a given accumulator.
25///
26/// Does not add if local_step is lesser than the accumulator's global_step.
27///
28/// Args:
29/// * scope: A Scope object
30/// * handle: The handle to a accumulator.
31/// * local_step: The local_step value at which the gradient was computed.
32/// * gradient: A tensor of the gradient to be accumulated.
33///
34/// Returns:
35/// * the created `Operation`
36class ResourceAccumulatorApplyGradient {
37 public:
38 ResourceAccumulatorApplyGradient(const ::tensorflow::Scope& scope,
39 ::tensorflow::Input handle,
40 ::tensorflow::Input local_step,
41 ::tensorflow::Input gradient);
42 operator ::tensorflow::Operation() const { return operation; }
43
44 Operation operation;
45};
46
47/// Returns the number of gradients aggregated in the given accumulators.
48///
49/// Args:
50/// * scope: A Scope object
51/// * handle: The handle to an accumulator.
52///
53/// Returns:
54/// * `Output`: The number of gradients aggregated in the given accumulator.
55class ResourceAccumulatorNumAccumulated {
56 public:
57 ResourceAccumulatorNumAccumulated(const ::tensorflow::Scope& scope,
58 ::tensorflow::Input handle);
59 operator ::tensorflow::Output() const { return num_accumulated; }
60 operator ::tensorflow::Input() const { return num_accumulated; }
61 ::tensorflow::Node* node() const { return num_accumulated.node(); }
62
63 Operation operation;
64 ::tensorflow::Output num_accumulated;
65};
66
67/// Updates the accumulator with a new value for global_step.
68///
69/// Logs warning if the accumulator's value is already higher than
70/// new_global_step.
71///
72/// Args:
73/// * scope: A Scope object
74/// * handle: The handle to an accumulator.
75/// * new_global_step: The new global_step value to set.
76///
77/// Returns:
78/// * the created `Operation`
79class ResourceAccumulatorSetGlobalStep {
80 public:
81 ResourceAccumulatorSetGlobalStep(const ::tensorflow::Scope& scope,
82 ::tensorflow::Input handle,
83 ::tensorflow::Input new_global_step);
84 operator ::tensorflow::Operation() const { return operation; }
85
86 Operation operation;
87};
88
89/// Extracts the average gradient in the given ConditionalAccumulator.
90///
91/// The op blocks until sufficient (i.e., more than num_required)
92/// gradients have been accumulated. If the accumulator has already
93/// aggregated more than num_required gradients, it returns the average of
94/// the accumulated gradients. Also automatically increments the recorded
95/// global_step in the accumulator by 1, and resets the aggregate to 0.
96///
97/// Args:
98/// * scope: A Scope object
99/// * handle: The handle to an accumulator.
100/// * num_required: Number of gradients required before we return an aggregate.
101/// * dtype: The data type of accumulated gradients. Needs to correspond to the type
102/// of the accumulator.
103///
104/// Returns:
105/// * `Output`: The average of the accumulated gradients.
106class ResourceAccumulatorTakeGradient {
107 public:
108 ResourceAccumulatorTakeGradient(const ::tensorflow::Scope& scope,
109 ::tensorflow::Input handle, ::tensorflow::Input
110 num_required, DataType dtype);
111 operator ::tensorflow::Output() const { return average; }
112 operator ::tensorflow::Input() const { return average; }
113 ::tensorflow::Node* node() const { return average.node(); }
114
115 Operation operation;
116 ::tensorflow::Output average;
117};
118
119/// A conditional accumulator for aggregating gradients.
120///
121/// The accumulator accepts gradients marked with local_step greater or
122/// equal to the most recent global_step known to the accumulator. The
123/// average can be extracted from the accumulator, provided sufficient
124/// gradients have been accumulated. Extracting the average automatically
125/// resets the aggregate to 0, and increments the global_step recorded by
126/// the accumulator.
127/// This is a resource version of ConditionalAccumulator that will work in TF2.0
128/// with tf.cond version 2.
129///
130/// Args:
131/// * scope: A Scope object
132/// * dtype: The type of the value being accumulated.
133/// * shape: The shape of the values, can be [], in which case shape is unknown.
134///
135/// Optional attributes (see `Attrs`):
136/// * container: If non-empty, this accumulator is placed in the given container.
137/// Otherwise, a default container is used.
138/// * shared_name: If non-empty, this accumulator will be shared under the
139/// given name across multiple sessions.
140///
141/// Returns:
142/// * `Output`: The handle to the accumulator.
143class ResourceConditionalAccumulator {
144 public:
145 /// Optional attribute setters for ResourceConditionalAccumulator
146 struct Attrs {
147 /// If non-empty, this accumulator is placed in the given container.
148 /// Otherwise, a default container is used.
149 ///
150 /// Defaults to ""
151 TF_MUST_USE_RESULT Attrs Container(StringPiece x) {
152 Attrs ret = *this;
153 ret.container_ = x;
154 return ret;
155 }
156
157 /// If non-empty, this accumulator will be shared under the
158 /// given name across multiple sessions.
159 ///
160 /// Defaults to ""
161 TF_MUST_USE_RESULT Attrs SharedName(StringPiece x) {
162 Attrs ret = *this;
163 ret.shared_name_ = x;
164 return ret;
165 }
166
167 /// Defaults to "MEAN"
168 TF_MUST_USE_RESULT Attrs ReductionType(StringPiece x) {
169 Attrs ret = *this;
170 ret.reduction_type_ = x;
171 return ret;
172 }
173
174 StringPiece container_ = "";
175 StringPiece shared_name_ = "";
176 StringPiece reduction_type_ = "MEAN";
177 };
178 ResourceConditionalAccumulator(const ::tensorflow::Scope& scope, DataType
179 dtype, PartialTensorShape shape);
180 ResourceConditionalAccumulator(const ::tensorflow::Scope& scope, DataType
181 dtype, PartialTensorShape shape, const
182 ResourceConditionalAccumulator::Attrs& attrs);
183 operator ::tensorflow::Output() const { return handle; }
184 operator ::tensorflow::Input() const { return handle; }
185 ::tensorflow::Node* node() const { return handle.node(); }
186
187 static Attrs Container(StringPiece x) {
188 return Attrs().Container(x);
189 }
190 static Attrs SharedName(StringPiece x) {
191 return Attrs().SharedName(x);
192 }
193 static Attrs ReductionType(StringPiece x) {
194 return Attrs().ReductionType(x);
195 }
196
197 Operation operation;
198 ::tensorflow::Output handle;
199};
200
201} // namespace internal
202} // namespace ops
203} // namespace tensorflow
204
205#endif // TENSORFLOW_CC_OPS_DATA_FLOW_OPS_INTERNAL_H_
206