1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <string>
17
18#include "re2/re2.h"
19#include "tensorflow/core/framework/op_kernel.h"
20#include "tensorflow/core/framework/tensor.h"
21#include "tensorflow/core/lib/core/errors.h"
22#include "tensorflow/core/lib/core/status.h"
23#include "tensorflow/core/platform/mutex.h"
24#include "tensorflow/core/platform/thread_annotations.h"
25#include "tensorflow/core/util/ptr_util.h"
26
27namespace tensorflow {
28namespace {
29
30// Execute the specified regex using the given context.
31// Context requirements:
32// - "input" string Tensor at input_index=0
33// - "output" string Tensor at output_index=0
34Status InternalCompute(const RE2& regex, const string& rewrite,
35 const bool replace_global, OpKernelContext* ctx) {
36 const Tensor* input_tensor;
37 TF_RETURN_IF_ERROR(ctx->input("input", &input_tensor));
38 Tensor* output_tensor;
39 std::unique_ptr<Tensor> maybe_forwarded =
40 ctx->forward_input(0 /*input_index*/, 0 /*output_index*/,
41 tensorflow::DT_STRING, input_tensor->shape(),
42 ctx->input_memory_type(0), ctx->input_alloc_attr(0));
43 if (maybe_forwarded) {
44 output_tensor = maybe_forwarded.get();
45 TF_RETURN_IF_ERROR(ctx->set_output("output", *output_tensor));
46 } else {
47 TF_RETURN_IF_ERROR(
48 ctx->allocate_output("output", input_tensor->shape(), &output_tensor));
49 output_tensor->flat<tstring>() = input_tensor->flat<tstring>();
50 }
51 auto output_flat = output_tensor->flat<tstring>();
52 for (size_t i = 0; i < output_flat.size(); ++i) {
53 // TODO(dero): Mitigate copy; Global and GlobalReplace below currently only
54 // accept std::string.
55 string buf = output_flat(i);
56 if (replace_global) {
57 RE2::GlobalReplace(&buf, regex, rewrite);
58 } else {
59 RE2::Replace(&buf, regex, rewrite);
60 }
61 output_flat(i) = std::move(buf);
62 }
63 return OkStatus();
64}
65} // namespace
66
67class RegexReplaceOp : public OpKernel {
68 public:
69 explicit RegexReplaceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
70 OP_REQUIRES_OK(ctx, ctx->GetAttr("replace_global", &replace_global_));
71 }
72
73 ~RegexReplaceOp() override {}
74
75 void Compute(OpKernelContext* ctx) override {
76 const Tensor* pattern_tensor;
77 OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor));
78 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()),
79 errors::InvalidArgument("Pattern must be scalar, but received ",
80 pattern_tensor->shape().DebugString()));
81 const string& pattern = pattern_tensor->scalar<tstring>()();
82 std::shared_ptr<RE2> regex = CachedRE2(pattern);
83 OP_REQUIRES(ctx, regex->ok(),
84 errors::InvalidArgument("Invalid pattern: ", pattern,
85 ", error: ", regex->error()));
86
87 const Tensor* rewrite_tensor;
88 OP_REQUIRES_OK(ctx, ctx->input("rewrite", &rewrite_tensor));
89 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rewrite_tensor->shape()),
90 errors::InvalidArgument("Rewrite must be scalar, but received ",
91 rewrite_tensor->shape().DebugString()));
92 const string& rewrite = rewrite_tensor->scalar<tstring>()();
93 OP_REQUIRES_OK(ctx, InternalCompute(*regex, rewrite, replace_global_, ctx));
94 }
95
96 private:
97 std::shared_ptr<RE2> CachedRE2(const string& pattern) {
98 {
99 tf_shared_lock l(mu_);
100 if (regex_ != nullptr && regex_->pattern() == pattern) {
101 return regex_;
102 }
103 }
104 // Construct the new RE2 object before acquiring the lock.
105 auto regex = std::make_shared<RE2>(pattern);
106 {
107 mutex_lock l(mu_);
108 // Swap instead of assigning so that we destruct the old
109 // RE2 object (when necessary) after releasing the lock.
110 regex_.swap(regex);
111 return regex_;
112 }
113 }
114
115 bool replace_global_;
116 mutex mu_;
117 std::shared_ptr<RE2> regex_ TF_GUARDED_BY(mu_);
118
119 TF_DISALLOW_COPY_AND_ASSIGN(RegexReplaceOp);
120};
121
122REGISTER_KERNEL_BUILDER(Name("RegexReplace").Device(DEVICE_CPU),
123 RegexReplaceOp);
124
125class StaticRegexReplaceOp : public OpKernel {
126 public:
127 explicit StaticRegexReplaceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
128 string pattern;
129 OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern));
130 re_ = MakeUnique<RE2>(pattern);
131 OP_REQUIRES(ctx, re_->ok(),
132 errors::InvalidArgument("Invalid pattern: ", pattern,
133 ", error: ", re_->error()));
134 OP_REQUIRES_OK(ctx, ctx->GetAttr("rewrite", &rewrite_str_));
135 OP_REQUIRES_OK(ctx, ctx->GetAttr("replace_global", &replace_global_));
136 }
137
138 void Compute(OpKernelContext* ctx) override {
139 OP_REQUIRES_OK(ctx,
140 InternalCompute(*re_, rewrite_str_, replace_global_, ctx));
141 }
142
143 private:
144 std::unique_ptr<RE2> re_;
145 string rewrite_str_;
146 bool replace_global_;
147};
148
149REGISTER_KERNEL_BUILDER(Name("StaticRegexReplace").Device(DEVICE_CPU),
150 StaticRegexReplaceOp);
151
152} // namespace tensorflow
153