1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/data_flow_ops.cc.
17
18#include <limits.h>
19
20#include <vector>
21
22#include "tensorflow/core/common_runtime/device.h"
23#include "tensorflow/core/framework/device_base.h"
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/register_types.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/framework/tensor_shape.h"
28#include "tensorflow/core/framework/types.h"
29#include "tensorflow/core/lib/core/errors.h"
30#include "tensorflow/core/lib/gtl/map_util.h"
31#include "tensorflow/core/platform/errors.h"
32#include "tensorflow/core/platform/logging.h"
33#include "tensorflow/core/platform/macros.h"
34#include "tensorflow/core/platform/mutex.h"
35#include "tensorflow/core/platform/thread_annotations.h"
36#include "tensorflow/core/platform/types.h"
37
38namespace tensorflow {
39
40class GetSessionHandleOp : public OpKernel {
41 public:
42 explicit GetSessionHandleOp(OpKernelConstruction* context)
43 : OpKernel(context) {}
44
45 void Compute(OpKernelContext* ctx) override {
46 const Tensor& val = ctx->input(0);
47 auto session_state = ctx->session_state();
48 OP_REQUIRES(ctx, session_state != nullptr,
49 errors::FailedPrecondition(
50 "GetSessionHandle called on null session state"));
51 int64_t id = session_state->GetNewId();
52 TensorStore::TensorAndKey tk{val, id, requested_device()};
53 OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(name(), tk));
54
55 Tensor* handle = nullptr;
56 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
57 if (ctx->expected_output_dtype(0) == DT_RESOURCE) {
58 ResourceHandle resource_handle = MakeResourceHandle<Tensor>(
59 ctx, SessionState::kTensorHandleResourceTypeName,
60 tk.GetHandle(name()));
61 resource_handle.set_maybe_type_name(
62 SessionState::kTensorHandleResourceTypeName);
63 handle->scalar<ResourceHandle>()() = resource_handle;
64 } else {
65 // Legacy behavior in V1.
66 handle->flat<tstring>().setConstant(tk.GetHandle(name()));
67 }
68 }
69
70 TF_DISALLOW_COPY_AND_ASSIGN(GetSessionHandleOp);
71};
72
73REGISTER_KERNEL_BUILDER(Name("GetSessionHandle").Device(DEVICE_CPU),
74 GetSessionHandleOp);
75REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2").Device(DEVICE_CPU),
76 GetSessionHandleOp);
77
78#define REGISTER_DEFAULT_KERNEL(type) \
79 REGISTER_KERNEL_BUILDER(Name("GetSessionHandle") \
80 .Device(DEVICE_DEFAULT) \
81 .HostMemory("handle") \
82 .TypeConstraint<type>("T"), \
83 GetSessionHandleOp) \
84 REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2") \
85 .Device(DEVICE_DEFAULT) \
86 .HostMemory("handle") \
87 .TypeConstraint<type>("T"), \
88 GetSessionHandleOp)
89
90TF_CALL_NUMBER_TYPES(REGISTER_DEFAULT_KERNEL);
91REGISTER_DEFAULT_KERNEL(bool);
92#undef REGISTER_DEFAULT_KERNEL
93
94class GetSessionTensorOp : public OpKernel {
95 public:
96 explicit GetSessionTensorOp(OpKernelConstruction* context)
97 : OpKernel(context) {}
98
99 void Compute(OpKernelContext* ctx) override {
100 const Tensor& handle = ctx->input(0);
101 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(handle.shape()),
102 errors::InvalidArgument("handle must be scalar"));
103 const string& name = handle.scalar<tstring>()();
104 Tensor val;
105 auto session_state = ctx->session_state();
106 OP_REQUIRES(ctx, session_state != nullptr,
107 errors::FailedPrecondition(
108 "GetSessionTensor called on null session state"));
109 OP_REQUIRES_OK(ctx, session_state->GetTensor(name, &val));
110 ctx->set_output(0, val);
111 }
112
113 TF_DISALLOW_COPY_AND_ASSIGN(GetSessionTensorOp);
114};
115
116REGISTER_KERNEL_BUILDER(Name("GetSessionTensor").Device(DEVICE_CPU),
117 GetSessionTensorOp);
118
119#define REGISTER_DEFAULT_KERNEL(type) \
120 REGISTER_KERNEL_BUILDER(Name("GetSessionTensor") \
121 .Device(DEVICE_DEFAULT) \
122 .HostMemory("handle") \
123 .TypeConstraint<type>("dtype"), \
124 GetSessionTensorOp)
125
126TF_CALL_NUMBER_TYPES(REGISTER_DEFAULT_KERNEL);
127REGISTER_DEFAULT_KERNEL(bool);
128#undef REGISTER_DEFAULT_KERNEL
129
130class DeleteSessionTensorOp : public OpKernel {
131 public:
132 explicit DeleteSessionTensorOp(OpKernelConstruction* context)
133 : OpKernel(context) {}
134
135 void Compute(OpKernelContext* ctx) override {
136 const Tensor& handle = ctx->input(0);
137 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(handle.shape()),
138 errors::InvalidArgument("`handle` must be scalar"));
139 const string& name = handle.scalar<tstring>()();
140 auto session_state = ctx->session_state();
141 OP_REQUIRES(ctx, session_state != nullptr,
142 errors::FailedPrecondition(
143 "DeleteSessionTensor called on null session state"));
144 OP_REQUIRES_OK(ctx, session_state->DeleteTensor(name));
145 }
146
147 TF_DISALLOW_COPY_AND_ASSIGN(DeleteSessionTensorOp);
148};
149
150REGISTER_KERNEL_BUILDER(Name("DeleteSessionTensor").Device(DEVICE_CPU),
151 DeleteSessionTensorOp);
152REGISTER_KERNEL_BUILDER(
153 Name("DeleteSessionTensor").Device(DEVICE_DEFAULT).HostMemory("handle"),
154 DeleteSessionTensorOp);
155
156} // namespace tensorflow
157