1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <string>
17
18#include "llvm/Support/Casting.h"
19#include "llvm/Support/FormatVariadic.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
21#include "mlir/IR/Attributes.h" // from @llvm-project
22#include "mlir/IR/Builders.h" // from @llvm-project
23#include "mlir/IR/BuiltinOps.h" // from @llvm-project
24#include "mlir/IR/Operation.h" // from @llvm-project
25#include "mlir/IR/Visitors.h" // from @llvm-project
26#include "mlir/Pass/Pass.h" // from @llvm-project
27#include "mlir/Pass/PassManager.h" // from @llvm-project
28#include "mlir/Support/LogicalResult.h" // from @llvm-project
29#include "mlir/Transforms/Passes.h" // from @llvm-project
30#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
33#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
34#include "tensorflow/dtensor/cc/constants.h"
35#include "tensorflow/dtensor/cc/tensor_layout.h"
36#include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
37#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
38#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
39#include "tensorflow/dtensor/mlir/layout_parsing.h"
40
41namespace tensorflow {
42namespace dtensor {
43
44namespace {
45#define GEN_PASS_DEF_DTENSOROPTODEVICECLUSTER
46#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
47
48// Extracts mesh config from the Op.
49// We currently hard extract mesh information from all the args and assume they
50// are the same. This should not be the case when we have multiple functions.
51mlir::LogicalResult WrapDeviceCluster(mlir::OpBuilder *builder,
52 mlir::Operation *op) {
53 // Create new tf_device.cluster op wrapping a single operation.
54 builder->setInsertionPoint(op);
55 auto cluster = builder->create<mlir::tf_device::ClusterOp>(
56 op->getLoc(), op->getResultTypes());
57 if (auto layout_op = llvm::dyn_cast<mlir::TF::DTensorLayout>(op)) {
58 cluster->setAttr(kMeshAttr, builder->getStringAttr(
59 layout_op.layout().mesh().ToString()));
60 } else if (auto copy_to_mesh = llvm::dyn_cast<mlir::TF::CopyToMeshOp>(op)) {
61 const std::string layout_string = copy_to_mesh.layout().str();
62 auto layout_or = Layout::FromString(layout_string);
63 if (!layout_or.ok())
64 return op->emitOpError(
65 llvm::formatv("Found tf.CopyToMesh Op with unparsable layout : {0}",
66 layout_string));
67
68 cluster->setAttr(kMeshAttr,
69 builder->getStringAttr(layout_or->mesh().ToString()));
70 } else {
71 // If mesh configuration can be inferred from the op directly, use the mesh
72 // information from op attribute directly. If op is not annotated with mesh
73 // information, then mesh will be inferred in following
74 // DTensorMeshPropagation pass and will be inferred from consumers or
75 // operands.
76 auto status_or_mesh = ExtractDeviceMeshFromOp(op);
77
78 if (!status_or_mesh.ok())
79 return op->emitOpError(
80 llvm::formatv("failed to wrap to device cluster. {0}",
81 status_or_mesh.status().error_message()));
82
83 const auto mesh_config = status_or_mesh.value();
84 if (mesh_config)
85 cluster->setAttr(kMeshAttr,
86 builder->getStringAttr(mesh_config->ToString()));
87 }
88
89 op->replaceAllUsesWith(cluster);
90
91 cluster.getBody().push_back(new mlir::Block);
92
93 builder->setInsertionPointToEnd(&cluster.GetBody());
94 builder->create<mlir::tf_device::ReturnOp>(op->getLoc(), op->getResults());
95
96 // Move `op` inside newly created `ClusterOp`.
97 op->moveBefore(cluster.GetBody().getTerminator());
98
99 return mlir::success();
100}
101
102// MLIR pass that wraps tf_device.cluster op to every TF op.
103struct DTensorOpToDeviceClusterPass
104 : public impl::DTensorOpToDeviceClusterBase<DTensorOpToDeviceClusterPass> {
105 void getDependentDialects(mlir::DialectRegistry &registry) const override {
106 registry.insert<mlir::dtensor::DTensorDialect>();
107 registry.insert<mlir::tf_device::TensorFlowDeviceDialect>();
108 }
109
110 void runOnOperation() override {
111 mlir::MLIRContext &context = getContext();
112 mlir::OpBuilder op_builder(&context);
113 mlir::Dialect *tf =
114 getContext().getLoadedDialect<mlir::TF::TensorFlowDialect>();
115
116 auto walk_result = getOperation().walk([&](mlir::Operation *operation) {
117 const auto op_dialect = operation->getDialect();
118 // Only TF dialects are supported for layout propagation.
119 if (op_dialect != tf) return mlir::WalkResult::advance();
120
121 // For control flow operations, tf.yield ops exists and should not be
122 // wrapped to tf_device.cluster as the op does not need to be transformed
123 // in SPMD expansion and tf.If/tf.While op require all ops to terminate
124 // with tf.Yield op. Wrapping yield op in tf_device.cluster invalidates
125 // this invariant.
126 if (llvm::isa<mlir::TF::YieldOp>(operation))
127 return mlir::WalkResult::advance();
128
129 if (mlir::failed(WrapDeviceCluster(&op_builder, operation)))
130 return mlir::WalkResult::interrupt();
131 return mlir::WalkResult::advance();
132 });
133
134 if (walk_result.wasInterrupted()) signalPassFailure();
135 }
136};
137
138} // namespace
139
140std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
141CreateDTensorOpToDeviceClusterPass() {
142 return std::make_unique<DTensorOpToDeviceClusterPass>();
143}
144
145} // namespace dtensor
146} // namespace tensorflow
147