1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <string>
17
18#include "re2/re2.h"
19#include "tensorflow/core/framework/op_kernel.h"
20#include "tensorflow/core/framework/tensor.h"
21#include "tensorflow/core/lib/core/errors.h"
22#include "tensorflow/core/lib/core/status.h"
23#include "tensorflow/core/platform/mutex.h"
24#include "tensorflow/core/platform/thread_annotations.h"
25#include "tensorflow/core/util/ptr_util.h"
26
27namespace tensorflow {
28
29class RegexFullMatchOp : public OpKernel {
30 public:
31 explicit RegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
32
33 ~RegexFullMatchOp() override {}
34
35 void Compute(OpKernelContext* ctx) override {
36 const Tensor* input_tensor;
37 OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
38 const auto& input_flat = input_tensor->flat<tstring>();
39
40 const Tensor* pattern_tensor;
41 OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor));
42 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()),
43 errors::InvalidArgument("Pattern must be scalar, but received ",
44 pattern_tensor->shape().DebugString()));
45 const string pattern = pattern_tensor->flat<tstring>()(0);
46 std::shared_ptr<RE2> regex = CachedRE2(pattern);
47 OP_REQUIRES(ctx, regex->ok(),
48 errors::InvalidArgument("Invalid pattern: ", pattern,
49 ", error: ", regex->error()));
50
51 Tensor* output_tensor = nullptr;
52 OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
53 &output_tensor));
54 auto output_flat = output_tensor->flat<bool>();
55 for (size_t i = 0; i < input_flat.size(); ++i) {
56 output_flat(i) = RE2::FullMatch(input_flat(i), *regex);
57 }
58 }
59
60 private:
61 std::shared_ptr<RE2> CachedRE2(const string& pattern) {
62 {
63 tf_shared_lock l(mu_);
64 if (regex_ != nullptr && regex_->pattern() == pattern) {
65 return regex_;
66 }
67 }
68 // Construct the new RE2 object before acquiring the lock.
69 auto regex = std::make_shared<RE2>(pattern);
70 {
71 mutex_lock l(mu_);
72 // Swap instead of assigning so that we destruct the old
73 // RE2 object (when necessary) after releasing the lock.
74 regex_.swap(regex);
75 return regex_;
76 }
77 }
78
79 mutex mu_;
80 std::shared_ptr<RE2> regex_ TF_GUARDED_BY(mu_);
81
82 TF_DISALLOW_COPY_AND_ASSIGN(RegexFullMatchOp);
83};
84
85REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU),
86 RegexFullMatchOp);
87
88class StaticRegexFullMatchOp : public OpKernel {
89 public:
90 explicit StaticRegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
91 string pattern;
92 OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern));
93 re_ = MakeUnique<RE2>(pattern);
94 OP_REQUIRES(ctx, re_->ok(),
95 errors::InvalidArgument("Invalid pattern: ", pattern,
96 ", error: ", re_->error()));
97 }
98
99 void Compute(OpKernelContext* ctx) override {
100 const Tensor* input_tensor;
101 OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
102 const auto& input_flat = input_tensor->flat<tstring>();
103
104 Tensor* output_tensor = nullptr;
105 OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
106 &output_tensor));
107 auto output_flat = output_tensor->flat<bool>();
108 for (size_t i = 0; i < input_flat.size(); ++i) {
109 output_flat(i) = RE2::FullMatch(input_flat(i), *re_);
110 }
111 }
112
113 private:
114 std::unique_ptr<RE2> re_;
115};
116
117REGISTER_KERNEL_BUILDER(Name("StaticRegexFullMatch").Device(DEVICE_CPU),
118 StaticRegexFullMatchOp);
119
120} // namespace tensorflow
121