1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_ |
17 | |
18 | #include <string> |
19 | |
20 | #include "tensorflow/core/common_runtime/device.h" |
21 | #include "tensorflow/core/common_runtime/device_mgr.h" |
22 | #include "tensorflow/core/framework/collective.h" |
23 | #include "tensorflow/core/framework/device_attributes.pb.h" |
24 | #include "tensorflow/core/framework/tensor.h" |
25 | #include "tensorflow/core/lib/core/status.h" |
26 | |
27 | namespace tensorflow { |
28 | namespace collective_util { |
29 | |
30 | Status InitializeDeviceAndLocality(const DeviceMgr* dev_mgr, |
31 | const string& device_name, Device** device, |
32 | DeviceLocality* device_locality); |
33 | string SubdivPermDebugString(const CollectiveParams& col_params); |
34 | |
35 | // Used for executing a sub-operation, e.g. a merge_op instance, with |
36 | // an OpKernelContext based on the one passed into this Op. |
37 | class SubContext { |
38 | public: |
39 | OpKernelContext::Params sub_params_; |
40 | gtl::InlinedVector<TensorValue, 4> sub_inputs_; |
41 | gtl::InlinedVector<AllocatorAttributes, 4> sub_input_attr_; |
42 | gtl::InlinedVector<DeviceContext*, 4> sub_input_dc_; |
43 | // Used only for Binary and Unary Ops for which we require |
44 | // the calculation to be in-place on the first input. |
45 | int forward_from_ = 0; |
46 | std::unique_ptr<OpKernelContext> sub_ctx_; |
47 | SubContext(OpKernelContext* ctx, OpKernelContext::Params* params, |
48 | OpKernel* op, Tensor* output, Tensor* input); |
49 | ~SubContext() = default; |
50 | }; |
51 | |
52 | Status ComputeBinOp(OpKernelContext* op_ctx, OpKernelContext::Params* params, |
53 | Device* device, OpKernel* op, Tensor* output, |
54 | Tensor* input); |
55 | |
56 | } // namespace collective_util |
57 | } // namespace tensorflow |
58 | |
59 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_ |
60 | |