1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | #include "tensorflow/core/kernels/variable_ops.h" |
18 | |
19 | #include "tensorflow/core/framework/control_flow.h" |
20 | #include "tensorflow/core/framework/op_kernel.h" |
21 | #include "tensorflow/core/framework/register_types.h" |
22 | #include "tensorflow/core/lib/core/errors.h" |
23 | #include "tensorflow/core/platform/strcat.h" |
24 | #include "tensorflow/core/platform/types.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | namespace { |
29 | |
30 | // Makes a unique name for a temporary variable inside a while loop body, |
31 | // because loop can be executed in multiple iterations in parallel. |
32 | string TemporaryVariableName(const string& var_name, |
33 | const FrameAndIter& control_frame) { |
34 | if (control_frame.frame_id != kIllegalFrameId && |
35 | control_frame.iter_id != kIllegalIterId) { |
36 | return strings::StrCat(var_name, "/frame:" , control_frame.frame_id, |
37 | "/iter:" , control_frame.iter_id); |
38 | } |
39 | return var_name; |
40 | } |
41 | |
42 | } // namespace |
43 | |
44 | // Resource stored by variables in the resource manager |
45 | // (legacy, ref-style version). |
46 | class LegacyVar : public ResourceBase { |
47 | public: |
48 | explicit LegacyVar(DataType dtype) : tensor_(dtype) {} |
49 | // Not copyable or movable. |
50 | LegacyVar(const LegacyVar&) = delete; |
51 | LegacyVar& operator=(const LegacyVar&) = delete; |
52 | |
53 | mutex* mu() { return &mu_; } |
54 | Tensor* tensor() { return &tensor_; } |
55 | |
56 | string DebugString() const override { |
57 | return strings::StrCat(DataTypeString(tensor_.dtype()), "/" , |
58 | tensor_.shape().DebugString()); |
59 | } |
60 | |
61 | private: |
62 | mutex mu_; |
63 | Tensor tensor_; |
64 | |
65 | ~LegacyVar() override {} |
66 | }; |
67 | |
68 | VariableOp::VariableOp(OpKernelConstruction* context) : OpKernel(context) { |
69 | OP_REQUIRES_OK(context, context->GetAttr("shape" , &shape_)); |
70 | dtype_ = RemoveRefType(context->output_type(0)); |
71 | OP_REQUIRES_OK(context, cinfo_.Init(context->resource_manager(), def(), |
72 | true /* use name() */)); |
73 | } |
74 | |
75 | void VariableOp::Compute(OpKernelContext* ctx) { |
76 | auto creator = [this](LegacyVar** var) { |
77 | *var = new LegacyVar(dtype_); |
78 | (*var)->tensor()->set_shape(shape_); |
79 | return OkStatus(); |
80 | }; |
81 | LegacyVar* var; |
82 | OP_REQUIRES_OK(ctx, cinfo_.resource_manager()->LookupOrCreate<LegacyVar>( |
83 | cinfo_.container(), cinfo_.name(), &var, creator)); |
84 | // Output a reference to our tensor, so it may be updated. |
85 | // |
86 | // As long as the resource manager hasn't been cleared the ref we return |
87 | // here is valid because it owns a ref on var. |
88 | ctx->set_output_ref(0, var->mu(), var->tensor()); |
89 | if (ctx->track_allocations() && var->tensor()->IsInitialized()) { |
90 | ctx->record_persistent_memory_allocation(var->tensor()->AllocatedBytes()); |
91 | } |
92 | var->Unref(); |
93 | } |
94 | |
95 | class TemporaryVariableOp : public OpKernel { |
96 | public: |
97 | explicit TemporaryVariableOp(OpKernelConstruction* context) |
98 | : OpKernel(context) { |
99 | OP_REQUIRES_OK(context, context->GetAttr("shape" , &shape_)); |
100 | OP_REQUIRES_OK(context, context->GetAttr("dtype" , &dtype_)); |
101 | OP_REQUIRES_OK(context, context->GetAttr("var_name" , &var_name_)); |
102 | // Variable name defaults to op name if not specified explicitly. |
103 | if (var_name_.empty()) var_name_ = name(); |
104 | } |
105 | |
106 | void Compute(OpKernelContext* context) override { |
107 | Status s; |
108 | ResourceMgr* rm = context->resource_manager(); |
109 | OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager." )); |
110 | auto unique_name = TemporaryVariableName(var_name_, context->frame_iter()); |
111 | auto* tmp_var = new TmpVar; |
112 | OP_REQUIRES(context, tmp_var, |
113 | errors::ResourceExhausted("Could not allocate TmpVar." )); |
114 | tmp_var->name = unique_name; |
115 | s = context->allocate_temp(dtype_, shape_, &tmp_var->val); |
116 | if (!s.ok()) tmp_var->Unref(); |
117 | OP_REQUIRES_OK(context, s); |
118 | OP_REQUIRES_OK(context, |
119 | context->step_container()->Create(rm, unique_name, tmp_var)); |
120 | context->set_output_ref(0, &tmp_var->mu, &tmp_var->val); |
121 | if (context->track_allocations()) { |
122 | context->record_persistent_memory_allocation( |
123 | tmp_var->val.AllocatedBytes()); |
124 | } |
125 | } |
126 | |
127 | private: |
128 | // Refcounted temporary variable resource. |
129 | friend class DestroyTemporaryVariableOp; |
130 | struct TmpVar : public ResourceBase { |
131 | mutex mu; |
132 | Tensor val; |
133 | string name; |
134 | string DebugString() const override { return name; } |
135 | ~TmpVar() override { VLOG(3) << "TmpVar " << name << " deleted" ; } |
136 | }; |
137 | |
138 | TensorShape shape_; |
139 | DataType dtype_; |
140 | string var_name_; |
141 | }; |
142 | |
143 | class DestroyTemporaryVariableOp : public OpKernel { |
144 | public: |
145 | explicit DestroyTemporaryVariableOp(OpKernelConstruction* context) |
146 | : OpKernel(context) { |
147 | OP_REQUIRES(context, IsRefType(context->input_type(0)), |
148 | errors::InvalidArgument("lhs input needs to be a ref type" )); |
149 | OP_REQUIRES_OK(context, context->GetAttr("var_name" , &var_name_)); |
150 | OP_REQUIRES(context, !var_name_.empty(), |
151 | errors::InvalidArgument("Missing var_name attribute" )); |
152 | } |
153 | |
154 | void Compute(OpKernelContext* context) override { |
155 | // NOTE(pbar): All other mutators of the Tensor Ref *must* have completed |
156 | // their execution before this DestroyTemporaryVariable op executes. |
157 | // This is typically achieved using control dependencies. |
158 | CHECK(IsRefType(context->input_dtype(0))); |
159 | Tensor tmpvar = context->mutable_input(0, false); |
160 | context->set_output(0, tmpvar); |
161 | ResourceMgr* rm = context->resource_manager(); |
162 | OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager." )); |
163 | auto unique_name = TemporaryVariableName(var_name_, context->frame_iter()); |
164 | OP_REQUIRES_OK( |
165 | context, context->step_container()->Delete<TemporaryVariableOp::TmpVar>( |
166 | rm, unique_name)); |
167 | if (context->track_allocations()) { |
168 | context->record_persistent_memory_allocation( |
169 | -static_cast<int64_t>(tmpvar.AllocatedBytes())); |
170 | } |
171 | } |
172 | |
173 | private: |
174 | string var_name_; |
175 | }; |
176 | |
177 | class IsVariableInitializedOp : public OpKernel { |
178 | public: |
179 | explicit IsVariableInitializedOp(OpKernelConstruction* context) |
180 | : OpKernel(context) {} |
181 | |
182 | void Compute(OpKernelContext* context) override { |
183 | // Get a mutable input tensor of the Ref input. |
184 | const Tensor& input_tensor = context->mutable_input(0, false); |
185 | Tensor* output = nullptr; |
186 | OP_REQUIRES_OK(context, |
187 | context->allocate_output(0, TensorShape({}), &output)); |
188 | auto output_tensor = output->tensor<bool, 0>(); |
189 | bool result = input_tensor.IsInitialized(); |
190 | output_tensor() = result; |
191 | } |
192 | }; |
193 | |
194 | REGISTER_KERNEL_BUILDER(Name("Variable" ).Device(DEVICE_CPU), VariableOp); |
195 | REGISTER_KERNEL_BUILDER(Name("VariableV2" ).Device(DEVICE_CPU), VariableOp); |
196 | REGISTER_KERNEL_BUILDER(Name("TemporaryVariable" ).Device(DEVICE_CPU), |
197 | TemporaryVariableOp); |
198 | REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable" ).Device(DEVICE_CPU), |
199 | DestroyTemporaryVariableOp); |
200 | REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized" ).Device(DEVICE_CPU), |
201 | IsVariableInitializedOp); |
202 | |
203 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
204 | // Only register 'Variable' on GPU for the subset of types also supported by |
205 | // 'Assign' (see dense_update_ops.cc.) |
206 | #define REGISTER_GPU_KERNELS(type) \ |
207 | REGISTER_KERNEL_BUILDER( \ |
208 | Name("Variable").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \ |
209 | VariableOp); \ |
210 | REGISTER_KERNEL_BUILDER( \ |
211 | Name("VariableV2").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \ |
212 | VariableOp); \ |
213 | REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \ |
214 | .Device(DEVICE_GPU) \ |
215 | .TypeConstraint<type>("dtype"), \ |
216 | TemporaryVariableOp); \ |
217 | REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \ |
218 | .Device(DEVICE_GPU) \ |
219 | .TypeConstraint<type>("T"), \ |
220 | DestroyTemporaryVariableOp); \ |
221 | REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \ |
222 | .Device(DEVICE_GPU) \ |
223 | .TypeConstraint<type>("dtype") \ |
224 | .HostMemory("is_initialized"), \ |
225 | IsVariableInitializedOp); |
226 | |
227 | TF_CALL_int64(REGISTER_GPU_KERNELS); |
228 | TF_CALL_uint32(REGISTER_GPU_KERNELS); |
229 | TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); |
230 | #undef REGISTER_GPU_KERNELS |
231 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
232 | |
233 | #define REGISTER_DEFAULT_KERNELS(type) \ |
234 | REGISTER_KERNEL_BUILDER( \ |
235 | Name("Variable").Device(DEVICE_DEFAULT).TypeConstraint<type>("dtype"), \ |
236 | VariableOp); \ |
237 | REGISTER_KERNEL_BUILDER( \ |
238 | Name("VariableV2").Device(DEVICE_DEFAULT).TypeConstraint<type>("dtype"), \ |
239 | VariableOp); \ |
240 | REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \ |
241 | .Device(DEVICE_DEFAULT) \ |
242 | .TypeConstraint<type>("dtype"), \ |
243 | TemporaryVariableOp); \ |
244 | REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \ |
245 | .Device(DEVICE_DEFAULT) \ |
246 | .TypeConstraint<type>("T"), \ |
247 | DestroyTemporaryVariableOp); \ |
248 | REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \ |
249 | .Device(DEVICE_DEFAULT) \ |
250 | .TypeConstraint<type>("dtype") \ |
251 | .HostMemory("is_initialized"), \ |
252 | IsVariableInitializedOp); |
253 | |
254 | TF_CALL_int64(REGISTER_DEFAULT_KERNELS); |
255 | TF_CALL_uint32(REGISTER_DEFAULT_KERNELS); |
256 | TF_CALL_GPU_ALL_TYPES(REGISTER_DEFAULT_KERNELS); |
257 | #undef REGISTER_DEFAULT_KERNELS |
258 | |
259 | } // namespace tensorflow |
260 | |