1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17#include "tensorflow/core/kernels/variable_ops.h"
18
19#include "tensorflow/core/framework/control_flow.h"
20#include "tensorflow/core/framework/op_kernel.h"
21#include "tensorflow/core/framework/register_types.h"
22#include "tensorflow/core/lib/core/errors.h"
23#include "tensorflow/core/platform/strcat.h"
24#include "tensorflow/core/platform/types.h"
25
26namespace tensorflow {
27
28namespace {
29
30// Makes a unique name for a temporary variable inside a while loop body,
31// because loop can be executed in multiple iterations in parallel.
32string TemporaryVariableName(const string& var_name,
33 const FrameAndIter& control_frame) {
34 if (control_frame.frame_id != kIllegalFrameId &&
35 control_frame.iter_id != kIllegalIterId) {
36 return strings::StrCat(var_name, "/frame:", control_frame.frame_id,
37 "/iter:", control_frame.iter_id);
38 }
39 return var_name;
40}
41
42} // namespace
43
44// Resource stored by variables in the resource manager
45// (legacy, ref-style version).
46class LegacyVar : public ResourceBase {
47 public:
48 explicit LegacyVar(DataType dtype) : tensor_(dtype) {}
49 // Not copyable or movable.
50 LegacyVar(const LegacyVar&) = delete;
51 LegacyVar& operator=(const LegacyVar&) = delete;
52
53 mutex* mu() { return &mu_; }
54 Tensor* tensor() { return &tensor_; }
55
56 string DebugString() const override {
57 return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
58 tensor_.shape().DebugString());
59 }
60
61 private:
62 mutex mu_;
63 Tensor tensor_;
64
65 ~LegacyVar() override {}
66};
67
68VariableOp::VariableOp(OpKernelConstruction* context) : OpKernel(context) {
69 OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
70 dtype_ = RemoveRefType(context->output_type(0));
71 OP_REQUIRES_OK(context, cinfo_.Init(context->resource_manager(), def(),
72 true /* use name() */));
73}
74
75void VariableOp::Compute(OpKernelContext* ctx) {
76 auto creator = [this](LegacyVar** var) {
77 *var = new LegacyVar(dtype_);
78 (*var)->tensor()->set_shape(shape_);
79 return OkStatus();
80 };
81 LegacyVar* var;
82 OP_REQUIRES_OK(ctx, cinfo_.resource_manager()->LookupOrCreate<LegacyVar>(
83 cinfo_.container(), cinfo_.name(), &var, creator));
84 // Output a reference to our tensor, so it may be updated.
85 //
86 // As long as the resource manager hasn't been cleared the ref we return
87 // here is valid because it owns a ref on var.
88 ctx->set_output_ref(0, var->mu(), var->tensor());
89 if (ctx->track_allocations() && var->tensor()->IsInitialized()) {
90 ctx->record_persistent_memory_allocation(var->tensor()->AllocatedBytes());
91 }
92 var->Unref();
93}
94
95class TemporaryVariableOp : public OpKernel {
96 public:
97 explicit TemporaryVariableOp(OpKernelConstruction* context)
98 : OpKernel(context) {
99 OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
100 OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
101 OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
102 // Variable name defaults to op name if not specified explicitly.
103 if (var_name_.empty()) var_name_ = name();
104 }
105
106 void Compute(OpKernelContext* context) override {
107 Status s;
108 ResourceMgr* rm = context->resource_manager();
109 OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
110 auto unique_name = TemporaryVariableName(var_name_, context->frame_iter());
111 auto* tmp_var = new TmpVar;
112 OP_REQUIRES(context, tmp_var,
113 errors::ResourceExhausted("Could not allocate TmpVar."));
114 tmp_var->name = unique_name;
115 s = context->allocate_temp(dtype_, shape_, &tmp_var->val);
116 if (!s.ok()) tmp_var->Unref();
117 OP_REQUIRES_OK(context, s);
118 OP_REQUIRES_OK(context,
119 context->step_container()->Create(rm, unique_name, tmp_var));
120 context->set_output_ref(0, &tmp_var->mu, &tmp_var->val);
121 if (context->track_allocations()) {
122 context->record_persistent_memory_allocation(
123 tmp_var->val.AllocatedBytes());
124 }
125 }
126
127 private:
128 // Refcounted temporary variable resource.
129 friend class DestroyTemporaryVariableOp;
130 struct TmpVar : public ResourceBase {
131 mutex mu;
132 Tensor val;
133 string name;
134 string DebugString() const override { return name; }
135 ~TmpVar() override { VLOG(3) << "TmpVar " << name << " deleted"; }
136 };
137
138 TensorShape shape_;
139 DataType dtype_;
140 string var_name_;
141};
142
143class DestroyTemporaryVariableOp : public OpKernel {
144 public:
145 explicit DestroyTemporaryVariableOp(OpKernelConstruction* context)
146 : OpKernel(context) {
147 OP_REQUIRES(context, IsRefType(context->input_type(0)),
148 errors::InvalidArgument("lhs input needs to be a ref type"));
149 OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
150 OP_REQUIRES(context, !var_name_.empty(),
151 errors::InvalidArgument("Missing var_name attribute"));
152 }
153
154 void Compute(OpKernelContext* context) override {
155 // NOTE(pbar): All other mutators of the Tensor Ref *must* have completed
156 // their execution before this DestroyTemporaryVariable op executes.
157 // This is typically achieved using control dependencies.
158 CHECK(IsRefType(context->input_dtype(0)));
159 Tensor tmpvar = context->mutable_input(0, false);
160 context->set_output(0, tmpvar);
161 ResourceMgr* rm = context->resource_manager();
162 OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
163 auto unique_name = TemporaryVariableName(var_name_, context->frame_iter());
164 OP_REQUIRES_OK(
165 context, context->step_container()->Delete<TemporaryVariableOp::TmpVar>(
166 rm, unique_name));
167 if (context->track_allocations()) {
168 context->record_persistent_memory_allocation(
169 -static_cast<int64_t>(tmpvar.AllocatedBytes()));
170 }
171 }
172
173 private:
174 string var_name_;
175};
176
177class IsVariableInitializedOp : public OpKernel {
178 public:
179 explicit IsVariableInitializedOp(OpKernelConstruction* context)
180 : OpKernel(context) {}
181
182 void Compute(OpKernelContext* context) override {
183 // Get a mutable input tensor of the Ref input.
184 const Tensor& input_tensor = context->mutable_input(0, false);
185 Tensor* output = nullptr;
186 OP_REQUIRES_OK(context,
187 context->allocate_output(0, TensorShape({}), &output));
188 auto output_tensor = output->tensor<bool, 0>();
189 bool result = input_tensor.IsInitialized();
190 output_tensor() = result;
191 }
192};
193
194REGISTER_KERNEL_BUILDER(Name("Variable").Device(DEVICE_CPU), VariableOp);
195REGISTER_KERNEL_BUILDER(Name("VariableV2").Device(DEVICE_CPU), VariableOp);
196REGISTER_KERNEL_BUILDER(Name("TemporaryVariable").Device(DEVICE_CPU),
197 TemporaryVariableOp);
198REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable").Device(DEVICE_CPU),
199 DestroyTemporaryVariableOp);
200REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized").Device(DEVICE_CPU),
201 IsVariableInitializedOp);
202
203#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
204// Only register 'Variable' on GPU for the subset of types also supported by
205// 'Assign' (see dense_update_ops.cc.)
206#define REGISTER_GPU_KERNELS(type) \
207 REGISTER_KERNEL_BUILDER( \
208 Name("Variable").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \
209 VariableOp); \
210 REGISTER_KERNEL_BUILDER( \
211 Name("VariableV2").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \
212 VariableOp); \
213 REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \
214 .Device(DEVICE_GPU) \
215 .TypeConstraint<type>("dtype"), \
216 TemporaryVariableOp); \
217 REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
218 .Device(DEVICE_GPU) \
219 .TypeConstraint<type>("T"), \
220 DestroyTemporaryVariableOp); \
221 REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
222 .Device(DEVICE_GPU) \
223 .TypeConstraint<type>("dtype") \
224 .HostMemory("is_initialized"), \
225 IsVariableInitializedOp);
226
227TF_CALL_int64(REGISTER_GPU_KERNELS);
228TF_CALL_uint32(REGISTER_GPU_KERNELS);
229TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
230#undef REGISTER_GPU_KERNELS
231#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
232
233#define REGISTER_DEFAULT_KERNELS(type) \
234 REGISTER_KERNEL_BUILDER( \
235 Name("Variable").Device(DEVICE_DEFAULT).TypeConstraint<type>("dtype"), \
236 VariableOp); \
237 REGISTER_KERNEL_BUILDER( \
238 Name("VariableV2").Device(DEVICE_DEFAULT).TypeConstraint<type>("dtype"), \
239 VariableOp); \
240 REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \
241 .Device(DEVICE_DEFAULT) \
242 .TypeConstraint<type>("dtype"), \
243 TemporaryVariableOp); \
244 REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
245 .Device(DEVICE_DEFAULT) \
246 .TypeConstraint<type>("T"), \
247 DestroyTemporaryVariableOp); \
248 REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
249 .Device(DEVICE_DEFAULT) \
250 .TypeConstraint<type>("dtype") \
251 .HostMemory("is_initialized"), \
252 IsVariableInitializedOp);
253
254TF_CALL_int64(REGISTER_DEFAULT_KERNELS);
255TF_CALL_uint32(REGISTER_DEFAULT_KERNELS);
256TF_CALL_GPU_ALL_TYPES(REGISTER_DEFAULT_KERNELS);
257#undef REGISTER_DEFAULT_KERNELS
258
259} // namespace tensorflow
260