1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <string> |
17 | |
18 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
19 | #include "mlir/IR/Attributes.h" // from @llvm-project |
20 | #include "mlir/IR/Operation.h" // from @llvm-project |
21 | #include "mlir/IR/SymbolTable.h" // from @llvm-project |
22 | #include "mlir/Pass/Pass.h" // from @llvm-project |
23 | #include "tensorflow/dtensor/cc/constants.h" |
24 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
25 | |
26 | namespace tensorflow { |
27 | namespace dtensor { |
28 | |
29 | namespace { |
30 | #define GEN_PASS_DEF_DTENSORFUNCTIONRENAMING |
31 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
32 | |
33 | struct DTensorFunctionRenaming |
34 | : public impl::DTensorFunctionRenamingBase<DTensorFunctionRenaming> { |
35 | void runOnOperation() override { |
36 | mlir::ModuleOp module = getOperation(); |
37 | |
38 | const std::string append = |
39 | module->getAttrOfType<mlir::StringAttr>(dtensor::kCacheKey) |
40 | .getValue() |
41 | .str(); |
42 | |
43 | // If the cache key isn't set, simply return without renameing functions. |
44 | if (append.empty()) return; |
45 | |
46 | mlir::SymbolTableCollection symbol_table; |
47 | mlir::SymbolUserMap symbolUsers(symbol_table, module); |
48 | |
49 | for (mlir::func::FuncOp func_op : |
50 | llvm::make_early_inc_range(module.getOps<mlir::func::FuncOp>())) { |
51 | // Only rename private functions, functions which are public (i.e. the |
52 | // main function of the module), must have stable names since they are |
53 | // public and may be used by other modules/pieces of code. |
54 | if (func_op.getVisibility() != mlir::SymbolTable::Visibility::Private) |
55 | continue; |
56 | std::string new_name = absl::StrCat( |
57 | mlir::SymbolTable::getSymbolName(func_op).getValue().str(), append); |
58 | symbolUsers.replaceAllUsesWith( |
59 | func_op, mlir::StringAttr::get(&getContext(), new_name)); |
60 | mlir::SymbolTable::setSymbolName(func_op, new_name); |
61 | } |
62 | }; |
63 | }; |
64 | |
65 | } // namespace |
66 | |
67 | std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> |
68 | CreateFunctionRenamingPass() { |
69 | return std::make_unique<DTensorFunctionRenaming>(); |
70 | } |
71 | |
72 | } // namespace dtensor |
73 | } // namespace tensorflow |
74 | |