1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#define EIGEN_USE_THREADS
16
17#include "tensorflow/core/kernels/lookup_table_init_op.h"
18
19#include <algorithm>
20#include <memory>
21#include <string>
22#include <vector>
23
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/register_types.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/framework/tensor_shape.h"
28#include "tensorflow/core/framework/types.h"
29#include "tensorflow/core/graph/graph_def_builder.h"
30#include "tensorflow/core/kernels/lookup_util.h"
31#include "tensorflow/core/lib/core/errors.h"
32#include "tensorflow/core/lib/core/status.h"
33#include "tensorflow/core/lib/io/inputbuffer.h"
34#include "tensorflow/core/lib/strings/numbers.h"
35#include "tensorflow/core/lib/strings/str_util.h"
36#include "tensorflow/core/platform/macros.h"
37
38namespace tensorflow {
39
40using InitializerSerializer =
41 lookup::InitializableLookupTable::InitializerSerializer;
42
43// Kernel to initialize a look table given a key and value tensors.
44// After this operation, the table becomes read-only.
45class InitializeTableOp : public OpKernel {
46 public:
47 explicit InitializeTableOp(OpKernelConstruction* context)
48 : OpKernel(context) {}
49
50 void Compute(OpKernelContext* ctx) override {
51 mutex_lock l(mu_);
52 lookup::InitializableLookupTable* table;
53 OP_REQUIRES_OK(ctx,
54 GetInitializableLookupTable("table_handle", ctx, &table));
55 core::ScopedUnref unref_me(table);
56
57 DataType expected_input_0 =
58 (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
59 DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
60 table->value_dtype()};
61 DataTypeVector expected_outputs = {};
62 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
63
64 const Tensor& keys = ctx->input(1);
65 OP_REQUIRES(
66 ctx, TensorShapeUtils::IsVector(keys.shape()),
67 errors::InvalidArgument("Keys must be a vector, but received shape",
68 keys.shape().DebugString()));
69
70 const Tensor& values = ctx->input(2);
71 OP_REQUIRES(
72 ctx, TensorShapeUtils::IsVector(values.shape()),
73 errors::InvalidArgument("Values must be a vector, but received shape",
74 values.shape().DebugString()));
75
76 OP_REQUIRES(ctx, keys.NumElements() == values.NumElements(),
77 errors::InvalidArgument(
78 "Keys and values must have the same size ",
79 keys.NumElements(), " vs ", values.NumElements()));
80
81 int memory_used_before = 0;
82 if (ctx->track_allocations()) {
83 memory_used_before = table->MemoryUsed();
84 }
85 OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values));
86 if (ctx->track_allocations()) {
87 ctx->record_persistent_memory_allocation(table->MemoryUsed() -
88 memory_used_before);
89 }
90 }
91
92 private:
93 mutex mu_;
94};
95
96REGISTER_KERNEL_BUILDER(Name("InitializeTable").Device(DEVICE_CPU),
97 InitializeTableOp);
98REGISTER_KERNEL_BUILDER(Name("InitializeTableV2").Device(DEVICE_CPU),
99 InitializeTableOp);
100
101// Kernel to initialize a lookup table from a text file.
102//
103// After this operation, the table becomes read-only.
104class InitializeTableFromTextFileOp : public OpKernel {
105 public:
106 explicit InitializeTableFromTextFileOp(OpKernelConstruction* ctx)
107 : OpKernel(ctx) {
108 OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_size", &vocab_size_));
109 OP_REQUIRES_OK(ctx, ctx->GetAttr("key_index", &key_index_));
110 OP_REQUIRES_OK(ctx, ctx->GetAttr("value_index", &value_index_));
111 if (ctx->HasAttr("offset")) {
112 OP_REQUIRES_OK(ctx, ctx->GetAttr("offset", &offset_));
113 }
114 string delimiter;
115 OP_REQUIRES_OK(ctx, ctx->GetAttr("delimiter", &delimiter));
116 OP_REQUIRES(ctx, delimiter.size() == 1,
117 errors::InvalidArgument("delimiter should be only 1 char"));
118 delimiter_ = delimiter[0];
119 }
120
121 void Compute(OpKernelContext* ctx) override {
122 mutex_lock l(mu_);
123 lookup::InitializableLookupTable* table;
124 OP_REQUIRES_OK(ctx,
125 GetInitializableLookupTable("table_handle", ctx, &table));
126 core::ScopedUnref unref_me(table);
127
128 DataType expected_input_0 =
129 (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
130 DataTypeVector expected_inputs = {expected_input_0, DT_STRING};
131 DataTypeVector expected_outputs = {};
132 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
133
134 const Tensor& vocab_filename_tensor = ctx->input(1);
135 OP_REQUIRES(
136 ctx, TensorShapeUtils::IsScalar(vocab_filename_tensor.shape()),
137 errors::InvalidArgument("filename should be a single string, but got ",
138 vocab_filename_tensor.shape().DebugString()));
139
140 const string& vocab_filename = vocab_filename_tensor.scalar<tstring>()();
141 OP_REQUIRES(ctx, !vocab_filename.empty(),
142 errors::InvalidArgument("filename cannot be empty."));
143
144 int64_t memory_used_before = 0;
145 if (ctx->track_allocations()) {
146 memory_used_before = table->MemoryUsed();
147 }
148 OP_REQUIRES_OK(
149 ctx, lookup::InitializeTableFromTextFile(
150 vocab_filename, vocab_size_, delimiter_, key_index_,
151 value_index_, offset_, ctx->env(),
152 MakeInitializerSerializer(vocab_filename_tensor), table));
153 if (ctx->track_allocations()) {
154 ctx->record_persistent_memory_allocation(table->MemoryUsed() -
155 memory_used_before);
156 }
157 }
158
159 private:
160 std::unique_ptr<InitializerSerializer> MakeInitializerSerializer(
161 Tensor vocab_filename) {
162 return std::make_unique<InitializerSerializer>(
163 [vocab_filename, vocab_size = vocab_size_, delimiter = delimiter_,
164 key_index = key_index_, value_index = value_index_,
165 offset = offset_](GraphDefBuilder* builder, Node* table, Node** out) {
166 Node* vocab_filename_node = ops::SourceOp(
167 "Const", builder->opts()
168 .WithAttr("dtype", vocab_filename.dtype())
169 .WithAttr("value", vocab_filename));
170 std::string delimiter_string(1, delimiter);
171 Node* import_table = ops::BinaryOp(
172 "InitializeTableFromTextFileV2", table, vocab_filename_node,
173 builder->opts()
174 .WithAttr("vocab_size", vocab_size)
175 .WithAttr("key_index", key_index)
176 .WithAttr("value_index", value_index)
177 .WithAttr("offset", offset)
178 .WithAttr("delimiter", delimiter_string));
179 *out = ops::UnaryOp("Identity", table,
180 builder->opts().WithControlInput(import_table));
181 return OkStatus();
182 });
183 }
184
185 mutex mu_;
186 int64_t vocab_size_;
187 char delimiter_;
188 int64_t key_index_;
189 int64_t value_index_;
190 int64_t offset_ = 0;
191
192 TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromTextFileOp);
193};
194
195REGISTER_KERNEL_BUILDER(Name("InitializeTableFromTextFile").Device(DEVICE_CPU),
196 InitializeTableFromTextFileOp);
197REGISTER_KERNEL_BUILDER(
198 Name("InitializeTableFromTextFileV2").Device(DEVICE_CPU),
199 InitializeTableFromTextFileOp);
200} // namespace tensorflow
201