1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
17#define TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
18
19#include "tensorflow/core/kernels/conditional_accumulator_base.h"
20
21namespace tensorflow {
22
23/*
24 * TypedConditionalAccumulatorBase is a templated companion of
25 * ConditionalAccumulatorBase which allows for subclasses to use different
26 * types for the input gradients. (See ConditionalAccumulator and
27 * SparseConditionalAccumulator.)
28 *
29 * TypedConditionalAccumulatorBase defines virtual methods and implements
30 * methods which depend on the gradient type. These are mainly methods that are
31 * used for adding a new gradient to the accumulator.
32 */
33template <typename GradientTensorType>
34class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase {
35 public:
36 TypedConditionalAccumulatorBase(const DataType& dtype,
37 const PartialTensorShape& shape,
38 const string& name,
39 const string& reduction_type)
40 : ConditionalAccumulatorBase(dtype, shape, name, reduction_type) {}
41
42 /**
43 * Attempts to add a gradient to the accumulator. An ApplyGrad attempt is
44 * successful (i.e., has its gradient applied) if its local_step >=
45 * current_global_step_ at the time the attempt is processed. Otherwise, if
46 * local_step < current_global_step_, the stale gradient is silently dropped.
47 *
48 * local_step: Time-step at which the gradient was computed.
49 * grad: Gradient tensor to be added to the accumulator.
50 * ctx: Context in which the op is executed.
51 */
52 void TryApplyGrad(int64_t local_step, OpKernelContext* ctx) override {
53 {
54 mutex_lock l(mu_);
55 if (local_step >= current_global_step_) {
56 GradientTensorType* grad = nullptr;
57 bool is_valid = GetAndValidateTensorInputForApplyGrad(ctx, &grad);
58 if (is_valid) {
59 if (counter_ > 0) {
60 AddToAccumGradFunction(ctx, grad);
61 } else {
62 AllocateAndAssignToAccumGradFunction(ctx, grad);
63 }
64 counter_++;
65 }
66 CleanUpGradTensor(grad);
67 }
68 }
69 FlushUnlocked();
70 }
71
72 protected:
73 // Virtual methods to be implemented by sub-classes for different datatypes.
74 // Implements arithmetic operations specific to datatype.
75 virtual void AllocateAndAssignToAccumGradFunction(
76 OpKernelContext* ctx, GradientTensorType* grad) = 0;
77
78 virtual void AddToAccumGradFunction(OpKernelContext* ctx,
79 GradientTensorType* grad) = 0;
80
81 // Method for extracting and validating input provided in an OpKernelContext.
82 // Returns true if input was successfully retrieved and is valid.
83 // Gradient is returned via the GradientTensorType** tensor.
84 virtual bool GetAndValidateTensorInputForApplyGrad(
85 OpKernelContext* ctx, GradientTensorType** tensor)
86 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
87
88 // Method for cleaning up any memory allocated in
89 // GetAndValidateTensorInputForApplyGrad
90 virtual void CleanUpGradTensor(GradientTensorType* tensor) = 0;
91};
92
93} // namespace tensorflow
94
95#endif // TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
96