1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_ |
18 | |
19 | #include "tensorflow/core/kernels/conditional_accumulator_base.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | /* |
24 | * TypedConditionalAccumulatorBase is a templated companion of |
25 | * ConditionalAccumulatorBase which allows for subclasses to use different |
26 | * types for the input gradients. (See ConditionalAccumulator and |
27 | * SparseConditionalAccumulator.) |
28 | * |
29 | * TypedConditionalAccumulatorBase defines virtual methods and implements |
30 | * methods which depend on the gradient type. These are mainly methods that are |
31 | * used for adding a new gradient to the accumulator. |
32 | */ |
33 | template <typename GradientTensorType> |
34 | class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase { |
35 | public: |
36 | TypedConditionalAccumulatorBase(const DataType& dtype, |
37 | const PartialTensorShape& shape, |
38 | const string& name, |
39 | const string& reduction_type) |
40 | : ConditionalAccumulatorBase(dtype, shape, name, reduction_type) {} |
41 | |
42 | /** |
43 | * Attempts to add a gradient to the accumulator. An ApplyGrad attempt is |
44 | * successful (i.e., has its gradient applied) if its local_step >= |
45 | * current_global_step_ at the time the attempt is processed. Otherwise, if |
46 | * local_step < current_global_step_, the stale gradient is silently dropped. |
47 | * |
48 | * local_step: Time-step at which the gradient was computed. |
49 | * grad: Gradient tensor to be added to the accumulator. |
50 | * ctx: Context in which the op is executed. |
51 | */ |
52 | void TryApplyGrad(int64_t local_step, OpKernelContext* ctx) override { |
53 | { |
54 | mutex_lock l(mu_); |
55 | if (local_step >= current_global_step_) { |
56 | GradientTensorType* grad = nullptr; |
57 | bool is_valid = GetAndValidateTensorInputForApplyGrad(ctx, &grad); |
58 | if (is_valid) { |
59 | if (counter_ > 0) { |
60 | AddToAccumGradFunction(ctx, grad); |
61 | } else { |
62 | AllocateAndAssignToAccumGradFunction(ctx, grad); |
63 | } |
64 | counter_++; |
65 | } |
66 | CleanUpGradTensor(grad); |
67 | } |
68 | } |
69 | FlushUnlocked(); |
70 | } |
71 | |
72 | protected: |
73 | // Virtual methods to be implemented by sub-classes for different datatypes. |
74 | // Implements arithmetic operations specific to datatype. |
75 | virtual void AllocateAndAssignToAccumGradFunction( |
76 | OpKernelContext* ctx, GradientTensorType* grad) = 0; |
77 | |
78 | virtual void AddToAccumGradFunction(OpKernelContext* ctx, |
79 | GradientTensorType* grad) = 0; |
80 | |
81 | // Method for extracting and validating input provided in an OpKernelContext. |
82 | // Returns true if input was successfully retrieved and is valid. |
83 | // Gradient is returned via the GradientTensorType** tensor. |
84 | virtual bool GetAndValidateTensorInputForApplyGrad( |
85 | OpKernelContext* ctx, GradientTensorType** tensor) |
86 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; |
87 | |
88 | // Method for cleaning up any memory allocated in |
89 | // GetAndValidateTensorInputForApplyGrad |
90 | virtual void CleanUpGradTensor(GradientTensorType* tensor) = 0; |
91 | }; |
92 | |
93 | } // namespace tensorflow |
94 | |
95 | #endif // TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_ |
96 | |