1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/mlir/device_utils.h"
17
18#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
19#include "mlir/IR/Value.h" // from @llvm-project
20#include "tensorflow/core/platform/errors.h"
21
22namespace tensorflow {
23namespace dtensor {
24
25// Returns an MLIR value representing the current device ID.
26StatusOr<mlir::Value> DeviceId(mlir::Operation* op) {
27 mlir::func::FuncOp function = llvm::dyn_cast<mlir::func::FuncOp>(op);
28 if (!function) {
29 // Device ID is the 0th argument of the enclosing function.
30 function = op->getParentOfType<mlir::func::FuncOp>();
31 if (!function)
32 return errors::InvalidArgument(
33 "operation must be enclosed inside a function.");
34 }
35
36 if (function.getNumArguments() == 0)
37 return errors::InvalidArgument(
38 "enclosing function must contain device id as argument");
39
40 auto device_id = function.getArgument(0);
41 auto device_id_type = device_id.getType().dyn_cast<mlir::RankedTensorType>();
42 if (!device_id_type ||
43 !device_id_type.getElementType().isa<mlir::IntegerType>())
44 return errors::InvalidArgument(
45 "0-th argument of the enclosing function should be integer device id.");
46
47 return device_id;
48}
49
50StatusOr<mlir::Value> DeviceId(mlir::Value val) {
51 if (auto block_arg = val.dyn_cast<mlir::BlockArgument>()) {
52 auto device_id = block_arg.getOwner()->getArgument(0);
53 auto device_id_type =
54 device_id.getType().dyn_cast<mlir::RankedTensorType>();
55 if (!device_id_type ||
56 !device_id_type.getElementType().isa<mlir::IntegerType>())
57 return errors::InvalidArgument(
58 "0-th argument of the enclosing block should be integer device id.");
59 return device_id;
60 }
61 return DeviceId(val.getDefiningOp());
62}
63
64} // namespace dtensor
65} // namespace tensorflow
66