1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_
17#define TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_
18
19#include "tensorflow/core/framework/memory_types.h"
20#include "tensorflow/core/graph/graph.h"
21#include "tensorflow/core/lib/core/status.h"
22
23namespace tensorflow {
24
25// Returns an error iff *g running on a single device of 'device_type'
26// has memory type mismatch for any edge's source and destination.
27Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g);
28
29// Updates '*g' so that every edge's source and destination has
30// compatible memory types by inserting proper HostSend/Recv and
31// Send/HostRecv nodes. 'device_type' specifies the type of device on
32// which '*g' is going to run on and that device has the name
33// 'device_name'.
34//
35// Returns OK if '*g' is updated properly (ValidateMemoryTypes(g) must
36// be OK). Otherwise, returns an error and '*g' may be in an
37// invalidate state and the caller should discard it.
38Status EnsureMemoryTypes(const DeviceType& device_type,
39 const string& device_name, Graph* g);
40
41// Get the memory type for 'index'th output of node 'n' in graph 'g', when
42// running on 'device_type'.
43Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g,
44 const Node* n, int index, MemoryType* memory_type);
45
46} // end namespace tensorflow
47
48#endif // TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_
49