1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_ |
18 | |
19 | #include "tensorflow/core/framework/memory_types.h" |
20 | #include "tensorflow/core/graph/graph.h" |
21 | #include "tensorflow/core/lib/core/status.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | // Returns an error iff *g running on a single device of 'device_type' |
26 | // has memory type mismatch for any edge's source and destination. |
27 | Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g); |
28 | |
29 | // Updates '*g' so that every edge's source and destination has |
30 | // compatible memory types by inserting proper HostSend/Recv and |
31 | // Send/HostRecv nodes. 'device_type' specifies the type of device on |
32 | // which '*g' is going to run on and that device has the name |
33 | // 'device_name'. |
34 | // |
35 | // Returns OK if '*g' is updated properly (ValidateMemoryTypes(g) must |
36 | // be OK). Otherwise, returns an error and '*g' may be in an |
37 | // invalidate state and the caller should discard it. |
38 | Status EnsureMemoryTypes(const DeviceType& device_type, |
39 | const string& device_name, Graph* g); |
40 | |
41 | // Get the memory type for 'index'th output of node 'n' in graph 'g', when |
42 | // running on 'device_type'. |
43 | Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g, |
44 | const Node* n, int index, MemoryType* memory_type); |
45 | |
46 | } // end namespace tensorflow |
47 | |
48 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_ |
49 | |