1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_IR_UTILITY_H_
17#define TENSORFLOW_CORE_IR_UTILITY_H_
18
19#include "llvm/ADT/STLExtras.h"
20#include "mlir/IR/Block.h" // from @llvm-project
21#include "mlir/IR/OperationSupport.h" // from @llvm-project
22#include "mlir/IR/Value.h" // from @llvm-project
23#include "mlir/Support/LLVM.h" // from @llvm-project
24#include "tensorflow/core/ir/dialect.h"
25
26namespace mlir {
27namespace tfg {
28
29// Region-based loop ops store control tokens all after the data values, unlike
30// functions which store them as pairs. This is required by
31// RegionBranchOpInterface's API which requires MutableOperandRange, i.e. the
32// data operands need to be stored contiguously.
33
34// TODO(jeffniu): These functions aren't just for "loop regions" any more, but
35// any region-based ops (if/case have explicit capture forms).
36
37// Given a region belonging to a region-based loop operation (e.g. a while
38// loop), return the subrange of block arguments that are data values.
39Block::BlockArgListType GetLoopRegionDataArgs(Region &region);
40// Given a region belonging to a region-based loop operation (e.g. a while
41// loop), return the subrange of block arguments that are control tokens.
42Block::BlockArgListType GetLoopRegionControlTokens(Region &region);
43// Given a data value block argument of a region belonging to a region-based
44// loop operation (e.g. a while loop), return the block argument that
45// corresponds to the control token.
46BlockArgument GetLoopRegionControlOf(BlockArgument data);
47// Given a control token block argument of a region belonging to a region-based
48// loop operation (e.g. a while loop), return the block argument that
49// corresponds to the data value.
50BlockArgument GetLoopRegionDataOf(BlockArgument ctl);
51
52// Given a TFG value, lookup the associated control token. For op results, the
53// token will be the last result of the op. For block arguments, the token will
54// be the subsequent argument. A data value always has an associated control
55// token.
56Value LookupControlDependency(Value data);
57
58// Given a TFG control token, lookup the associated data value. Block arguments
59// will always have an associated data value: the previous argument. For ops,
60// if the only result is a control token, return None. Otherwise, returns the
61// first result.
62Optional<Value> LookupDataValue(Value ctl);
63
64// Given a range of values, operands, or results, that contains data and control
65// values, where all control tokens come after the data values, split the range
66// between the two.
67template <typename RangeT>
68std::pair<RangeT, RangeT> SplitDataAndControlValues(RangeT values,
69 ControlType ctl_type) {
70 unsigned num_ctl = 0;
71 for (Value value : llvm::reverse(values)) {
72 if (value.getType() == ctl_type)
73 ++num_ctl;
74 else
75 break;
76 }
77 unsigned split_idx = llvm::size(values) - num_ctl;
78 return std::make_pair(values.slice(0, split_idx),
79 values.slice(split_idx, num_ctl));
80}
81
82} // namespace tfg
83} // namespace mlir
84
85#endif // TENSORFLOW_CORE_IR_UTILITY_H_
86