1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/ir/ops.h" |
17 | |
18 | #include <algorithm> |
19 | #include <cstdint> |
20 | #include <list> |
21 | #include <memory> |
22 | #include <string> |
23 | #include <utility> |
24 | |
25 | #include "llvm/ADT/ArrayRef.h" |
26 | #include "llvm/ADT/STLExtras.h" |
27 | #include "llvm/ADT/TypeSwitch.h" |
28 | #include "llvm/Support/ErrorHandling.h" |
29 | #include "llvm/Support/SMLoc.h" |
30 | #include "llvm/Support/raw_ostream.h" |
31 | #include "mlir/IR/Builders.h" // from @llvm-project |
32 | #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
33 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
34 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
35 | #include "mlir/IR/Diagnostics.h" // from @llvm-project |
36 | #include "mlir/IR/Dialect.h" // from @llvm-project |
37 | #include "mlir/IR/DialectImplementation.h" // from @llvm-project |
38 | #include "mlir/IR/FunctionImplementation.h" // from @llvm-project |
39 | #include "mlir/IR/FunctionInterfaces.h" // from @llvm-project |
40 | #include "mlir/IR/MLIRContext.h" // from @llvm-project |
41 | #include "mlir/IR/OpImplementation.h" // from @llvm-project |
42 | #include "mlir/IR/OperationSupport.h" // from @llvm-project |
43 | #include "mlir/IR/SymbolTable.h" // from @llvm-project |
44 | #include "mlir/IR/TypeRange.h" // from @llvm-project |
45 | #include "mlir/IR/TypeUtilities.h" // from @llvm-project |
46 | #include "mlir/IR/Value.h" // from @llvm-project |
47 | #include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project |
48 | #include "mlir/Support/LLVM.h" // from @llvm-project |
49 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
50 | #include "tensorflow/core/ir/dialect.h" |
51 | #include "tensorflow/core/ir/interfaces.h" |
52 | #include "tensorflow/core/ir/types/dialect.h" |
53 | #include "tensorflow/core/ir/utility.h" |
54 | |
55 | // Generated definitions. |
56 | #include "tensorflow/core/ir/dialect.cc.inc" |
57 | |
58 | namespace mlir { |
59 | namespace tfg { |
60 | |
61 | //===----------------------------------------------------------------------===// |
62 | // TFGraph dialect. |
63 | //===----------------------------------------------------------------------===// |
64 | |
65 | // Name operation results with the operation name, except control outputs which |
66 | // are named "ctl". MLIR will automatically use a numerical suffix to unique. |
67 | static void GenericGetAsmResultNames(Operation *op, |
68 | OpAsmSetValueNameFn set_name_fn) { |
69 | // We only name the results when there are results to name, an op like `print` |
70 | // which does not have results will just use the `ctl` name for the control |
71 | // output. |
72 | if (op->getNumResults() > 1 && !op->getResult(0).getType().isa<ControlType>()) |
73 | set_name_fn(op->getResult(0), op->getName().stripDialect()); |
74 | for (Value result : op->getResults()) { |
75 | if (result.getType().isa<ControlType>()) { |
76 | set_name_fn(op->getResult(op->getNumResults() - 1), "ctl" ); |
77 | break; |
78 | } |
79 | } |
80 | } |
81 | |
82 | // TFGraph support for interacting with the AsmPrinter. |
83 | // Gives prettier names to SSA values. |
84 | struct TFGraphOpAsmInterface |
85 | : public OpAsmOpInterface::FallbackModel<TFGraphOpAsmInterface> { |
86 | static bool classof(Operation *op) { return true; } |
87 | |
88 | void getAsmResultNames(Operation *op, OpAsmSetValueNameFn set_name_fn) const { |
89 | GenericGetAsmResultNames(op, set_name_fn); |
90 | } |
91 | void getAsmBlockArgumentNames(Operation *op, Region ®ion, |
92 | OpAsmSetValueNameFn setNameFn) const {} |
93 | void getAsmBlockNames(Operation *op, |
94 | mlir::OpAsmSetBlockNameFn setNameFn) const {} |
95 | }; |
96 | |
97 | // Dialect construction: there is one instance per context and it registers its |
98 | // operations, types, and interfaces here. |
99 | void TFGraphDialect::initialize() { |
100 | getContext()->getOrLoadDialect<tf_type::TFTypeDialect>(); |
101 | addOperations< |
102 | #define GET_OP_LIST |
103 | #include "tensorflow/core/ir/ops.cc.inc" |
104 | >(); |
105 | addAttributes< |
106 | #define GET_ATTRDEF_LIST |
107 | #include "tensorflow/core/ir/attributes.cc.inc" |
108 | >(); |
109 | |
110 | // Support unknown operations because not all TensorFlow operations are |
111 | // registered. |
112 | allowUnknownOperations(); |
113 | |
114 | // Create the fallback OpAsmOpInterface instance. |
115 | fallbackOpAsmInterface_ = new TFGraphOpAsmInterface; |
116 | |
117 | // Register the memory effects interface adaptor. |
118 | addInterfaces<StatefulMemoryEffectInterface>(); |
119 | |
120 | // Initialized the cached operation names. |
121 | #define GET_OP_NAME_DEFS |
122 | #include "tensorflow/core/ir/tf_op_names.inc" |
123 | |
124 | // Caching some often used context-owned informations for fast-access. |
125 | name_key_ = StringAttr::get(getContext(), getNameAttrKey()); |
126 | device_key_ = StringAttr::get(getContext(), getDeviceAttrKey()); |
127 | assigned_device_key_ = |
128 | StringAttr::get(getContext(), getAssignedDeviceAttrKey()); |
129 | fulltype_key_ = StringAttr::get(getContext(), getFullTypeAttrKey()); |
130 | lifted_graph_func_name_ = |
131 | StringAttr::get(getContext(), getLiftedGraphFuncNameKey()); |
132 | tfg_name_key_ = StringAttr::get(getContext(), getTfgNameAttrKey()); |
133 | tfg_description_key_ = |
134 | StringAttr::get(getContext(), getTfgDescriptionAttrKey()); |
135 | tfg_is_ref_key_ = StringAttr::get(getContext(), getTfgIsRefAttrKey()); |
136 | tfg_handle_data_key_ = |
137 | StringAttr::get(getContext(), getTfgHandleDataAttrKey()); |
138 | tfg_full_type_key_ = StringAttr::get(getContext(), getTfgFullTypeAttrKey()); |
139 | |
140 | control_ty_ = ControlType::get(getContext()); |
141 | } |
142 | |
143 | // Provides a hook for op interface. |
144 | void *TFGraphDialect::getRegisteredInterfaceForOp(TypeID interface, |
145 | OperationName opName) { |
146 | if (interface == TypeID::get<OpAsmOpInterface>()) { |
147 | return fallbackOpAsmInterface_; |
148 | } |
149 | |
150 | // Intrinsic operations explicitly implement intefaces. |
151 | if (opName.hasTrait<OpTrait::IntrinsicOperation>()) { |
152 | return nullptr; |
153 | } |
154 | |
155 | if (interface == TypeID::get<TensorFlowRegistryInterface>()) { |
156 | if (auto *instance = |
157 | getRegisteredInterface<TensorFlowRegistryInterfaceBase>()) { |
158 | // Important: cast to (Concept *) to shift the pointer off the vtable. |
159 | return static_cast<TensorFlowRegistryInterfaceBase::Concept *>( |
160 | const_cast<TensorFlowRegistryInterfaceBase *>(instance)); |
161 | } |
162 | } else if (interface == TypeID::get<MemoryEffectOpInterface>()) { |
163 | auto *instance = getRegisteredInterface<StatefulMemoryEffectInterface>(); |
164 | assert(instance && "expected the memory interface to be registered" ); |
165 | return static_cast<StatefulMemoryEffectInterface::Concept *>( |
166 | const_cast<StatefulMemoryEffectInterface *>(instance)); |
167 | } |
168 | |
169 | return nullptr; |
170 | } |
171 | |
172 | TFGraphDialect::~TFGraphDialect() { delete fallbackOpAsmInterface_; } |
173 | |
174 | // The name of certain optional attributes. |
175 | static std::array<StringRef, 3> keyword_attrs{ |
176 | "_mlir_device" , "_mlir_assigned_device" , "_mlir_name" }; |
177 | |
178 | static void PrintKeywordAttributes(Operation *op, OpAsmPrinter &printer, |
179 | ArrayRef<StringRef> elided_attrs = {}) { |
180 | // Handles the optional "device" and "name" attribute. |
181 | for (StringRef keyword : keyword_attrs) { |
182 | if (StringAttr value_attr = op->getAttrOfType<StringAttr>(keyword)) { |
183 | assert(!value_attr.getValue().empty()); |
184 | printer << " " << keyword.drop_front(/*len(_mlir_)*/ 6) << "(\"" |
185 | << value_attr.getValue() << "\")" ; |
186 | } |
187 | } |
188 | |
189 | // Print attributes (other than name and device). |
190 | SmallVector<StringRef> attrs_to_elide = llvm::to_vector(elided_attrs); |
191 | llvm::append_range(attrs_to_elide, keyword_attrs); |
192 | printer.printOptionalAttrDict(op->getAttrs(), attrs_to_elide); |
193 | } |
194 | |
195 | // Print an operation that belongs to this dialect, if unregistered. |
196 | // The general syntax is: |
197 | // tfg.OpName(%input1, %input2, %input3) [%control_dep1, %control_dep2] |
198 | // name("<node_name>") device("<device>") { attribute-dict } : |
199 | // (input types) -> (result_types) |
200 | void TFGraphDialect::printCustomTfOp(Operation *op, |
201 | OpAsmPrinter &printer) const { |
202 | ControlType control_ty = getControlType(); |
203 | |
204 | // Check that all control dependencies are after the regular values, |
205 | // otherwise print the generic form. We don't expect this to happen but |
206 | // we're defensive in the printer since this may happen in "hard-to-debug" |
207 | // issues. |
208 | { |
209 | bool has_control_dep = false; |
210 | for (Value operand : op->getOperands()) { |
211 | if (operand.getType() == control_ty) { |
212 | has_control_dep = true; |
213 | continue; |
214 | } |
215 | if (has_control_dep) { |
216 | printer.printGenericOp(op); |
217 | return; |
218 | } |
219 | } |
220 | has_control_dep = false; |
221 | for (Value result : op->getResults()) { |
222 | if (result.getType() == control_ty) { |
223 | has_control_dep = true; |
224 | continue; |
225 | } |
226 | if (has_control_dep) { |
227 | printer.printGenericOp(op); |
228 | return; |
229 | } |
230 | } |
231 | } |
232 | |
233 | // Print the inputs (other than the control dependencies), if any. |
234 | TFOp tfg_op(op); |
235 | OperandRange data = tfg_op.getNonControlOperands(); |
236 | if (!data.empty()) printer << '(' << data << ')'; |
237 | // Print the control dependencies (if any). |
238 | OperandRange ctls = tfg_op.getControlOperands(); |
239 | if (!ctls.empty()) printer << " [" << ctls << ']'; |
240 | |
241 | // Print the keyword attributes and optional attribute dictionary. |
242 | PrintKeywordAttributes(op, printer); |
243 | |
244 | // Print the type, but omit control dependencies. |
245 | // If there is a single control return, just print the list of input types, |
246 | // otherwise print the complete type in a "function-style" way: (operands) |
247 | // -> (results). |
248 | ResultRange results = tfg_op.getNonControlResults(); |
249 | if (results.empty()) { |
250 | if (!data.empty()) printer << " : " << data.getTypes(); |
251 | } else { |
252 | printer << " : (" << data.getTypes() << ") -> (" << results.getTypes() |
253 | << ")" ; |
254 | } |
255 | } |
256 | |
257 | // Print a custom TFG op. |
258 | static void PrintCustomTfOp(Operation *op, OpAsmPrinter &printer) { |
259 | cast<TFGraphDialect>(op->getDialect())->printCustomTfOp(op, printer); |
260 | } |
261 | |
262 | llvm::unique_function<void(Operation *, OpAsmPrinter &)> |
263 | TFGraphDialect::getOperationPrinter(Operation *op) const { |
264 | return [this](Operation *op, OpAsmPrinter &printer) { |
265 | this->printCustomTfOp(op, printer); |
266 | }; |
267 | } |
268 | |
269 | // Try to parse optional keyword attributes and prefix them with `_mlir_`, of |
270 | // `device`, `assigned_device`, and `name`. |
271 | static ParseResult ParseKeywordAttributes(OpAsmParser &parser, |
272 | OperationState &result) { |
273 | for (const char *keyword : {"device" , "assigned_device" , "name" }) { |
274 | if (succeeded(parser.parseOptionalKeyword(keyword))) { |
275 | StringAttr value; |
276 | if (parser.parseLParen() || |
277 | parser.parseAttribute<StringAttr>( |
278 | value, NoneType::get(parser.getContext())) || |
279 | parser.parseRParen()) |
280 | return failure(); |
281 | result.addAttribute((Twine("_mlir_" ) + keyword).str(), value); |
282 | } |
283 | } |
284 | return parser.parseOptionalAttrDict(result.attributes); |
285 | } |
286 | |
287 | // Parse an operation that belongs to this dialect, if unregistered. |
288 | // The general syntax is: |
289 | // tfg.OpName(%input1, %input2, %input3) [%control_dep1, %control_dep2] |
290 | // name("<node_name>") device("<device>") { attribute-dict } : |
291 | // (input types) -> (result_types) |
292 | static ParseResult ParseCustomTfOp(OpAsmParser &parser, |
293 | OperationState &result) { |
294 | MLIRContext *context = parser.getBuilder().getContext(); |
295 | // Parse optional argument list |
296 | SmallVector<OpAsmParser::UnresolvedOperand, 4> op_infos; |
297 | if (parser.parseOperandList(op_infos, AsmParser::Delimiter::OptionalParen)) |
298 | return failure(); |
299 | unsigned numNonControlOperands = op_infos.size(); |
300 | // Optional control list, in between brackets. |
301 | if (parser.parseOperandList(op_infos, AsmParser::Delimiter::OptionalSquare)) |
302 | return failure(); |
303 | |
304 | // Parse the optional keyword attributes and optional attribute dictionary. |
305 | if (ParseKeywordAttributes(parser, result)) return failure(); |
306 | |
307 | // Parse the functional type. |
308 | SmallVector<Type> arg_types; |
309 | arg_types.reserve(op_infos.size()); |
310 | llvm::SMLoc loc = parser.getCurrentLocation(); |
311 | Type control_type = ControlType::get(context); |
312 | if (failed(parser.parseOptionalColonTypeList(arg_types))) return failure(); |
313 | if (arg_types.size() == 1 && arg_types.front().isa<FunctionType>()) { |
314 | auto funcType = arg_types.front().cast<FunctionType>(); |
315 | if (funcType.getNumInputs() != numNonControlOperands) |
316 | return parser.emitError(loc) |
317 | << "got " << numNonControlOperands |
318 | << " non-control operands, but the type defines " |
319 | << funcType.getNumInputs() << " input types" ; |
320 | arg_types.clear(); |
321 | arg_types.append(funcType.getInputs().begin(), funcType.getInputs().end()); |
322 | result.types.append(funcType.getResults().begin(), |
323 | funcType.getResults().end()); |
324 | } |
325 | |
326 | // The control input are elided from the type list, add them here. |
327 | arg_types.resize(op_infos.size(), control_type); |
328 | if (!arg_types.empty()) |
329 | if (parser.resolveOperands(op_infos, arg_types, loc, result.operands)) |
330 | return failure(); |
331 | if (result.name.getStringRef() != "tfg.return" ) |
332 | result.types.push_back(control_type); |
333 | return success(); |
334 | } |
335 | |
336 | Optional<Dialect::ParseOpHook> TFGraphDialect::getParseOperationHook( |
337 | StringRef opName) const { |
338 | return ParseOpHook(ParseCustomTfOp); |
339 | } |
340 | |
341 | static bool VerifyGenericTFGOperation(Operation &op) { |
342 | TFGraphDialect *dialect = dyn_cast<TFGraphDialect>(op.getDialect()); |
343 | if (!dialect) return true; |
344 | ControlType control_ty = dialect->getControlType(); |
345 | |
346 | // verifies that control operands (or results) are always after regular |
347 | // inputs (or results). |
348 | auto check_ctl_at_end = [&](TypeRange types, StringRef input_or_output) { |
349 | int has_control_dep = -1; |
350 | for (auto &indexed_operand : llvm::enumerate(types)) { |
351 | if (indexed_operand.value() == control_ty) { |
352 | has_control_dep = indexed_operand.index(); |
353 | continue; |
354 | } |
355 | if (has_control_dep != -1) { |
356 | op.emitOpError() << "found non-control " << input_or_output |
357 | << " in position #" << indexed_operand.index() |
358 | << " after control " << input_or_output |
359 | << " in position #" << has_control_dep; |
360 | return false; |
361 | } |
362 | } |
363 | return true; |
364 | }; |
365 | if (!check_ctl_at_end(op.getOperandTypes(), "input" )) return false; |
366 | if (!check_ctl_at_end(op.getResultTypes(), "result" )) return false; |
367 | |
368 | // Certain attributes are supposed to be inserted with non-empty value. |
369 | for (StringRef keyword : keyword_attrs) { |
370 | if (StringAttr value_attr = op.getAttrOfType<StringAttr>(keyword)) { |
371 | if (value_attr.getValue().empty()) { |
372 | op.emitOpError() << keyword |
373 | << " has empty value. Only insert this attribute when " |
374 | "it has a value" ; |
375 | } |
376 | } |
377 | } |
378 | |
379 | return true; |
380 | } |
381 | |
382 | //===----------------------------------------------------------------------===// |
383 | // Graph Operation |
384 | //===----------------------------------------------------------------------===// |
385 | |
386 | LogicalResult GraphOp::verify() { |
387 | GraphOp op = *this; |
388 | // Check all ops in the body. |
389 | if (!all_of(*op.getBody(), VerifyGenericTFGOperation)) return failure(); |
390 | |
391 | return success(); |
392 | } |
393 | //===----------------------------------------------------------------------===// |
394 | // Func Operation |
395 | //===----------------------------------------------------------------------===// |
396 | |
397 | bool GraphFuncOp::isMarkedForCompilation() { |
398 | auto is_enabled = [this](StringRef attr_name) -> bool { |
399 | Attribute attr = (*this)->getAttr(attr_name); |
400 | if (!attr) return false; |
401 | if (auto bool_attr = attr.dyn_cast<BoolAttr>()) return bool_attr.getValue(); |
402 | if (auto str_attr = attr.dyn_cast<StringAttr>()) |
403 | return !str_attr.getValue().empty(); |
404 | return false; |
405 | }; |
406 | return is_enabled("_xla_compile_id" ) || is_enabled("_tpu_replicate" ) || |
407 | is_enabled("_XlaMustCompile" ); |
408 | } |
409 | |
410 | // Hook for OpTrait::FunctionLike, called after verifying that the 'type' |
411 | // attribute is present and checks if it holds a function type. Ensures |
412 | // getType, getNumFuncArguments, and getNumFuncResults can be called safely |
413 | LogicalResult GraphFuncOp::verifyType() { |
414 | auto type = getFunctionTypeAttr().getValue(); |
415 | if (!type.isa<FunctionType>()) |
416 | return emitOpError("requires '" + getTypeAttrName() + |
417 | "' attribute of function type" ); |
418 | return success(); |
419 | } |
420 | |
421 | // Hook for OpTrait::FunctionLike, called after verifying the function |
422 | // type and the presence of the (potentially empty) function body. |
423 | LogicalResult GraphFuncOp::verifyBody() { |
424 | FunctionType type = getFunctionType(); |
425 | Block *body = SingleBlock::getBody(); |
426 | // Check that the body is terminated with a tfg.return. |
427 | if (getRegion().empty() || body->empty()) |
428 | return emitOpError() << "expects a non empty body" ; |
429 | |
430 | if (body->getNumArguments() != type.getNumInputs()) |
431 | return emitOpError() << "function type indicated " << type.getNumInputs() |
432 | << " args but block has " << body->getNumArguments(); |
433 | |
434 | for (auto &arg_types : |
435 | llvm::enumerate(llvm::zip(type.getInputs(), body->getArgumentTypes()))) { |
436 | Type signature_arg = std::get<0>(arg_types.value()); |
437 | Type block_arg = std::get<1>(arg_types.value()); |
438 | if (signature_arg != block_arg) |
439 | return emitOpError() << "type mismatch for arg #" << arg_types.index() |
440 | << ", signature defines " << signature_arg |
441 | << " block arg is " << block_arg; |
442 | } |
443 | |
444 | if (!isa<ReturnOp>(body->back())) |
445 | return emitOpError() |
446 | << "expects body to be terminated with a tfg.return, but got: " |
447 | << body->back().getName().getStringRef(); |
448 | |
449 | ReturnOp return_op = cast<ReturnOp>(body->getTerminator()); |
450 | |
451 | if (type.getNumResults() > return_op->getNumOperands()) |
452 | return emitOpError() << "expects " << type.getNumResults() |
453 | << " returned values but tfg.return has " |
454 | << return_op->getNumOperands() << " operands" ; |
455 | for (auto &indexed_type : llvm::enumerate(type.getResults())) { |
456 | Type expected_type = indexed_type.value(); |
457 | int res_num = indexed_type.index(); |
458 | Type actual_type = return_op->getOperand(res_num).getType(); |
459 | if (!tf_type::AreCastCompatible({expected_type, actual_type})) { |
460 | return emitOpError() << "type mismatch for returned value #" << res_num |
461 | << ", expected " << expected_type << " but got " |
462 | << actual_type; |
463 | } |
464 | } |
465 | Type control_type = getDialect()->getControlType(); |
466 | for (auto &indexed_type : llvm::enumerate(llvm::drop_begin( |
467 | return_op->getOperandTypes(), type.getNumResults()))) { |
468 | Type actual_type = indexed_type.value(); |
469 | if (actual_type != control_type) { |
470 | return emitOpError() << "returned value #" << indexed_type.index() |
471 | << " overflow the expected " << type.getNumResults() |
472 | << " returned value for function " << getName() |
473 | << ", expected a ControlType but got " |
474 | << actual_type; |
475 | } |
476 | } |
477 | |
478 | // Check all ops in the body. |
479 | if (!all_of(*SingleBlock::getBody(), VerifyGenericTFGOperation)) |
480 | return failure(); |
481 | |
482 | return success(); |
483 | } |
484 | |
485 | LogicalResult GraphFuncOp::canonicalize(GraphFuncOp func_op, |
486 | PatternRewriter &rewriter) { |
487 | // Prune function body: the body is a graph where feeds/fetches a materialized |
488 | // with function arguments and returned values. As such any operation not |
489 | // reachable from the "fetches" can be pruned. The return statement also has |
490 | // control input so that side-effecting operations without results (print for |
491 | // example) aren't pruned. |
492 | bool changed = true; |
493 | while (changed) { |
494 | changed = false; |
495 | for (Operation &op : llvm::make_early_inc_range( |
496 | llvm::reverse(*func_op.SingleBlock::getBody()))) { |
497 | if (isa<ReturnOp>(op)) continue; |
498 | if (op.getUses().empty()) { |
499 | rewriter.eraseOp(&op); |
500 | changed = true; |
501 | } |
502 | } |
503 | } |
504 | return failure(); |
505 | } |
506 | |
507 | LogicalResult GraphFuncOp::verify() { |
508 | GraphFuncOp func_op = *this; |
509 | if (func_op.getNumArguments() % 2) |
510 | return func_op.emitOpError() << "expects an even number of arguments" ; |
511 | ArrayAttr args_attrs = func_op.getAllArgAttrs(); |
512 | if (args_attrs && args_attrs.size() != func_op.getNumArguments()) |
513 | return func_op.emitOpError() |
514 | << "expects argument attributes for each argument (" |
515 | << args_attrs.size() << " vs " << func_op.getNumArguments() << ")" ; |
516 | ArrayAttr res_attrs = func_op.getAllResultAttrs(); |
517 | if (res_attrs && res_attrs.size() != func_op.getNumResults()) |
518 | return func_op.emitOpError() |
519 | << "expects results attributes for each result (" << res_attrs.size() |
520 | << " vs " << func_op.getNumResults() << ")" ; |
521 | return success(); |
522 | } |
523 | |
524 | ParseResult GraphFuncOp::parse(OpAsmParser &parser, OperationState &result) { |
525 | SmallVector<OpAsmParser::UnresolvedOperand> entry_args; |
526 | SmallVector<Attribute> arg_attrs; |
527 | SmallVector<Attribute> result_attrs; |
528 | SmallVector<Type> arg_types; |
529 | SmallVector<Type> result_types; |
530 | auto &builder = parser.getBuilder(); |
531 | MLIRContext *context = builder.getContext(); |
532 | |
533 | // Parse visibility. |
534 | StringRef visibility; |
535 | if (!parser.parseOptionalKeyword(&visibility, |
536 | {"public" , "private" , "nested" })) { |
537 | StringAttr visibility_attr = parser.getBuilder().getStringAttr(visibility); |
538 | result.attributes.push_back(parser.getBuilder().getNamedAttr( |
539 | SymbolTable::getVisibilityAttrName(), visibility_attr)); |
540 | } |
541 | |
542 | if (succeeded(parser.parseOptionalKeyword("generic" ))) |
543 | result.addAttribute("generic" , builder.getUnitAttr()); |
544 | |
545 | // Parse the name as a symbol. |
546 | StringAttr name_attr; |
547 | if (parser.parseSymbolName(name_attr, SymbolTable::getSymbolAttrName(), |
548 | result.attributes)) |
549 | return failure(); |
550 | |
551 | // Parse the function signature. |
552 | // The difference with usual functions, is that for every single argument |
553 | // parsed, we create two block arguments: one for the expected value and one |
554 | // for the control dependency. |
555 | if (parser.parseLParen()) return failure(); |
556 | Type control_ty = ControlType::get(builder.getContext()); |
557 | std::list<std::string> control_operand_names; |
558 | |
559 | // Helper to parse a single argument and its attributes. |
560 | auto parse_argument = [&]() -> ParseResult { |
561 | // Parse argument name if present. |
562 | entry_args.emplace_back(); |
563 | arg_types.emplace_back(); |
564 | if (parser.parseOperand(entry_args.back(), /*allowResultNumber=*/false) || |
565 | parser.parseColonType(arg_types.back())) |
566 | return failure(); |
567 | |
568 | // Parse any argument attributes. |
569 | NamedAttrList attrs; |
570 | if (parser.parseOptionalAttrDict(attrs)) return failure(); |
571 | arg_attrs.push_back(attrs.getDictionary(context)); |
572 | |
573 | // Define the control input: it's not printed but is added as block |
574 | // argument. Note the name computed here (suffixed ".ctl") is coupled to the |
575 | // implementation of: |
576 | // TFGraphOpAsmInterface::getAsmBlockArgumentNames() |
577 | // at the top of this file. |
578 | OpAsmParser::UnresolvedOperand control_operand = entry_args.back(); |
579 | control_operand_names.push_back((control_operand.name + ".ctl" ).str()); |
580 | control_operand.name = control_operand_names.back(); |
581 | entry_args.push_back(control_operand); |
582 | arg_types.push_back(control_ty); |
583 | arg_attrs.push_back(DictionaryAttr::get(context)); |
584 | return success(); |
585 | }; |
586 | |
587 | // Parse the function arguments and their attributes. |
588 | if (failed(parser.parseOptionalRParen())) { |
589 | do { |
590 | if (parse_argument()) return failure(); |
591 | } while (succeeded(parser.parseOptionalComma())); |
592 | if (parser.parseRParen()) return failure(); |
593 | } |
594 | |
595 | // Parse the result types and their attributes. |
596 | if (succeeded(parser.parseOptionalArrow())) { |
597 | if (failed(parser.parseLParen())) return failure(); |
598 | if (failed(parser.parseOptionalRParen())) { |
599 | // Parse individual function results. |
600 | do { |
601 | result_types.emplace_back(); |
602 | NamedAttrList result_attr; |
603 | if (parser.parseType(result_types.back()) || |
604 | parser.parseOptionalAttrDict(result_attr)) { |
605 | return failure(); |
606 | } |
607 | result_attrs.push_back(builder.getDictionaryAttr(result_attr)); |
608 | } while (succeeded(parser.parseOptionalComma())); |
609 | if (parser.parseRParen()) return failure(); |
610 | } |
611 | } |
612 | |
613 | auto type = builder.getFunctionType(arg_types, result_types); |
614 | result.addAttribute(GraphFuncOp::getTypeAttrName(), TypeAttr::get(type)); |
615 | |
616 | // If function attributes are present, parse them. |
617 | NamedAttrList parsed_attributes; |
618 | if (parser.parseOptionalAttrDictWithKeyword(parsed_attributes)) |
619 | return failure(); |
620 | result.attributes.append(parsed_attributes); |
621 | |
622 | // Add the attributes to the function arguments. |
623 | assert(arg_attrs.size() == arg_types.size()); |
624 | assert(result_attrs.size() == result_types.size()); |
625 | result.attributes.append( |
626 | builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), |
627 | builder.getArrayAttr(arg_attrs))); |
628 | result.attributes.append( |
629 | builder.getNamedAttr(FunctionOpInterface::getResultDictAttrName(), |
630 | builder.getArrayAttr(result_attrs))); |
631 | |
632 | // Parse the function body. |
633 | auto *body = result.addRegion(); |
634 | llvm::SMLoc loc = parser.getCurrentLocation(); |
635 | SmallVector<OpAsmParser::Argument> args; |
636 | if (entry_args.size()) { |
637 | for (auto argAndType : llvm::zip(entry_args, arg_types)) { |
638 | auto &arg = args.emplace_back(); |
639 | arg.ssaName = std::get<0>(argAndType); |
640 | arg.type = std::get<1>(argAndType); |
641 | } |
642 | } |
643 | |
644 | if (failed(parser.parseRegion(*body, args, /*enableNameShadowing=*/false))) |
645 | return failure(); |
646 | |
647 | // Function body was parsed, make sure it's not empty. |
648 | if (body->empty()) |
649 | return parser.emitError(loc, "expected non-empty function body" ); |
650 | |
651 | return success(); |
652 | } |
653 | |
654 | void GraphFuncOp::print(OpAsmPrinter &p) { |
655 | // Print the operation and the function name. |
656 | Operation *op = *this; |
657 | p << " " ; |
658 | int argIndentSize = op->getName().getStringRef().size() + 3; |
659 | StringRef visibility_attr_name = SymbolTable::getVisibilityAttrName(); |
660 | if (auto visibility = op->getAttrOfType<StringAttr>(visibility_attr_name)) { |
661 | p << visibility.getValue() << ' '; |
662 | argIndentSize += visibility.getValue().size() + 1; |
663 | } |
664 | if (getGeneric()) p << "generic " ; |
665 | auto funcName = |
666 | op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()) |
667 | .getValue(); |
668 | p.printSymbolName(funcName); |
669 | argIndentSize += funcName.size(); |
670 | std::string indent(argIndentSize, ' '); |
671 | FunctionType fnType = getFunctionType(); |
672 | ArrayRef<Type> arg_types = fnType.getInputs(); |
673 | ArrayRef<Type> result_types = fnType.getResults(); |
674 | assert((arg_types.size() % 2) == 0); |
675 | // Print operand list with attributes. |
676 | p << '('; |
677 | ArrayAttr args_attr = getAllArgAttrs(); |
678 | for (unsigned i = 0, e = arg_types.size(); i < e; i += 2) { |
679 | // Args come by pair: input+control. |
680 | p.printOperand(getArgument(i)); |
681 | p << ": " ; |
682 | p.printType(arg_types[i]); |
683 | if (auto arg_attrs = args_attr[i].dyn_cast<DictionaryAttr>()) |
684 | p.printOptionalAttrDict(arg_attrs.getValue()); |
685 | if (i != e - 2) { |
686 | p << ", " ; |
687 | p.printNewline(); |
688 | p << indent; |
689 | } |
690 | } |
691 | p << ')'; |
692 | |
693 | // Print result types, if any. |
694 | if (!result_types.empty()) { |
695 | p.printNewline(); |
696 | p.getStream() << " -> (" ; |
697 | indent = std::string(9, ' '); |
698 | ArrayAttr results_attr = getAllResultAttrs(); |
699 | for (int i = 0, e = result_types.size(); i < e; ++i) { |
700 | p.printType(result_types[i]); |
701 | if (auto result_attrs = results_attr[i].dyn_cast<DictionaryAttr>()) |
702 | p.printOptionalAttrDict(result_attrs.getValue()); |
703 | if (i != e - 1) { |
704 | p << ", " ; |
705 | p.printNewline(); |
706 | p << indent; |
707 | } |
708 | } |
709 | p << ")" ; |
710 | } |
711 | // Print attributes. |
712 | if (!op->getAttrs().empty()) { |
713 | p.printNewline(); |
714 | function_interface_impl::printFunctionAttributes( |
715 | p, *this, fnType.getNumInputs(), fnType.getNumResults(), |
716 | {"generic" , SymbolTable::getVisibilityAttrName()}); |
717 | } |
718 | // Print body. |
719 | p << ' '; |
720 | p.printRegion(getBody(), /*printEntryBlockArgs=*/false); |
721 | } |
722 | |
723 | GraphFuncOp GraphFuncOp::getCalledFunction(Operation *op, |
724 | SymbolTable &symbol_table) { |
725 | // Check if a node does indirect function call via PartitionedCallOp. |
726 | // TODO(aminim): consider replacing with isa<...> when possible. |
727 | if (op->getName().getStringRef() == "tfg.PartitionCall" || |
728 | op->getName().getStringRef() == "tfg.StatefulPartitionedCall" ) { |
729 | auto func_attr = op->getAttrOfType<FuncAttr>("f" ); |
730 | if (!func_attr) return {}; |
731 | GraphFuncOp callee = symbol_table.lookup<GraphFuncOp>( |
732 | func_attr.getName().getLeafReference()); |
733 | if (callee) return callee; |
734 | } |
735 | return symbol_table.lookup<GraphFuncOp>(op->getName().stripDialect()); |
736 | } |
737 | |
738 | BlockArgument GraphFuncOp::getDataValueOf(BlockArgument ctl) { |
739 | return ctl.getOwner()->getArgument(ctl.getArgNumber() - 1); |
740 | } |
741 | |
742 | BlockArgument GraphFuncOp::getControlTokenOf(BlockArgument data) { |
743 | return data.getOwner()->getArgument(data.getArgNumber() + 1); |
744 | } |
745 | |
746 | BlockArgument GraphFuncOp::getDataValue(Region ®ion, unsigned idx) { |
747 | return region.getArgument(idx * 2); |
748 | } |
749 | |
750 | // This is naming block arguments for GraphFuncOp, we rely on the arg attributes |
751 | // for computing the names. |
752 | void GraphFuncOp::getAsmBlockArgumentNames(Region ®ion, |
753 | OpAsmSetValueNameFn set_name_fn) { |
754 | ArrayRef<BlockArgument> args = SingleBlock::getBody()->getArguments(); |
755 | ControlType control_ty = ControlType::get(getContext()); |
756 | // Sanity checking: this is verified by the op but this may be called before |
757 | // the verifier or in some diagnostic/debug context, let's not crash. |
758 | // We expect the function block operands to come as pair: tensor+control. |
759 | if (args.size() % 2) return; |
760 | for (unsigned i = 0, e = args.size(); i < e; i += 2) |
761 | if (args[i].getType() == control_ty || args[i + 1].getType() != control_ty) |
762 | return; |
763 | |
764 | // Name the values based on the `tfg.name` arg attribute retrieved from the |
765 | // func_op. |
766 | ArrayAttr args_attr = getAllArgAttrs(); |
767 | if (!args_attr || args_attr.size() != args.size()) return; |
768 | for (int arg_num = 0, e = args.size(); arg_num < e; arg_num += 2) { |
769 | DictionaryAttr arg_attrs = args_attr[arg_num].dyn_cast<DictionaryAttr>(); |
770 | if (!arg_attrs) continue; |
771 | if (auto strAttr = arg_attrs.getAs<StringAttr>("tfg.name" )) { |
772 | set_name_fn(args[arg_num], strAttr.getValue()); |
773 | set_name_fn(args[arg_num + 1], (strAttr.getValue() + ".ctl" ).str()); |
774 | } |
775 | } |
776 | } |
777 | |
778 | //===----------------------------------------------------------------------===// |
779 | // ReturnOp |
780 | //===----------------------------------------------------------------------===// |
781 | |
782 | LogicalResult ReturnOp::verify() { |
783 | ReturnOp op = *this; |
784 | // If the control result attributes are present, there must be the same number |
785 | // of entries as control results. |
786 | if (op.getControlRetAttrs().size() != TFOp(op).getControlOperands().size()) { |
787 | return op.emitOpError( |
788 | "expected as many control result attributes as there are control " |
789 | "operands" ); |
790 | } |
791 | return success(); |
792 | } |
793 | |
794 | ParseResult ReturnOp::parse(OpAsmParser &parser, OperationState &result) { |
795 | // ReturnOp has the same assembly format as generic TFG ops except that the |
796 | // control result attributes are embedded with the control operands: |
797 | // [%ctl {tfg.name = "foo"}, %ctl_0 {tfg.name = "bar"}] |
798 | SmallVector<OpAsmParser::UnresolvedOperand> operands; |
799 | if (parser.parseOperandList(operands, AsmParser::Delimiter::OptionalParen)) |
800 | return failure(); |
801 | |
802 | SmallVector<Attribute> control_ret_attrs; |
803 | if (succeeded(parser.parseOptionalLSquare())) { |
804 | OpAsmParser::UnresolvedOperand operand; |
805 | do { |
806 | NamedAttrList attrs; |
807 | OptionalParseResult parse_result = parser.parseOptionalOperand(operand); |
808 | if (!parse_result.hasValue()) break; |
809 | if (failed(parse_result.getValue())) return failure(); |
810 | if (parser.parseOptionalAttrDict(attrs)) return failure(); |
811 | control_ret_attrs.push_back(attrs.getDictionary(result.getContext())); |
812 | operands.push_back(std::move(operand)); |
813 | } while (succeeded(parser.parseOptionalComma())); |
814 | if (parser.parseRSquare()) return failure(); |
815 | } |
816 | |
817 | if (ParseKeywordAttributes(parser, result)) return failure(); |
818 | result.addAttribute(ReturnOp::control_ret_attrsAttrName(result.name), |
819 | ArrayAttr::get(result.getContext(), control_ret_attrs)); |
820 | |
821 | SmallVector<Type> types; |
822 | if (parser.parseOptionalColonTypeList(types)) return failure(); |
823 | types.resize(operands.size(), ControlType::get(result.getContext())); |
824 | if (parser.resolveOperands(operands, types, parser.getCurrentLocation(), |
825 | result.operands)) |
826 | return failure(); |
827 | return success(); |
828 | } |
829 | |
830 | void ReturnOp::print(OpAsmPrinter &printer) { |
831 | TFOp tfg_op(*this); |
832 | OperandRange data = tfg_op.getNonControlOperands(); |
833 | if (!data.empty()) printer << '(' << data << ')'; |
834 | |
835 | OperandRange ctls = tfg_op.getControlOperands(); |
836 | if (!ctls.empty()) { |
837 | printer << " [" ; |
838 | llvm::interleave( |
839 | llvm::zip(ctls, getControlRetAttrs().getAsRange<DictionaryAttr>()), |
840 | printer, |
841 | [&](auto it) { |
842 | printer << std::get<0>(it); |
843 | if (!std::get<1>(it).empty()) printer << ' ' << std::get<1>(it); |
844 | }, |
845 | ", " ); |
846 | printer << ']'; |
847 | } |
848 | |
849 | PrintKeywordAttributes(*this, printer, {"control_ret_attrs" }); |
850 | |
851 | if (!data.empty()) printer << " : " << data.getTypes(); |
852 | } |
853 | |
854 | void ReturnOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
855 | ValueRange operands, ValueRange control_operands) { |
856 | odsState.addOperands(operands); |
857 | odsState.addOperands(control_operands); |
858 | // Populate `control_ret_attrs` with empty dictionaries. |
859 | odsState.addAttribute( |
860 | ReturnOp::control_ret_attrsAttrName(odsState.name), |
861 | odsBuilder.getArrayAttr(SmallVector<Attribute>( |
862 | control_operands.size(), odsBuilder.getDictionaryAttr({})))); |
863 | } |
864 | |
865 | //===----------------------------------------------------------------------===// |
866 | // Concrete Ops |
867 | //===----------------------------------------------------------------------===// |
868 | |
869 | // The ODS definitions of TFG ops can be autogenerated TODO(jeffniu) as well as |
870 | // parts of their verifiers. These hand-written verifiers focus on verifying the |
871 | // ops' operand and result types with respect to their functions' types, the |
872 | // logic for which is slightly different between operations. |
873 | |
874 | // Verify that all control operands follow non-control operands, and return the |
875 | // subrange of non-control operands. |
876 | static FailureOr<TypeRange> VerifyOperands(Operation *op) { |
877 | ControlType control_ty = |
878 | cast<TFGraphDialect>(op->getDialect())->getControlType(); |
879 | Operation::operand_type_iterator it = |
880 | llvm::find(op->getOperandTypes(), control_ty); |
881 | if (!std::all_of(it, op->operand_type_end(), |
882 | [&](Type type) { return type == control_ty; })) { |
883 | return op->emitOpError( |
884 | "not all control tokens come after non-control operands" ); |
885 | } |
886 | return {Operation::operand_type_range(op->operand_type_begin(), it)}; |
887 | } |
888 | |
889 | // Verify that the last result of an operation is the only control result, and |
890 | // return a subrange of the non-control results. |
891 | static FailureOr<TypeRange> VerifyResults(Operation *op) { |
892 | ControlType control_ty = |
893 | cast<TFGraphDialect>(op->getDialect())->getControlType(); |
894 | Operation::result_type_iterator it = |
895 | llvm::find(op->getResultTypes(), control_ty); |
896 | if (it == op->result_type_end()) |
897 | return op->emitOpError("does not define a control result" ); |
898 | if (it != std::prev(op->result_type_end())) { |
899 | return op->emitOpError( |
900 | "must have a control token result as and only as its last result" ); |
901 | } |
902 | return {Operation::result_type_range(op->result_type_begin(), it)}; |
903 | } |
904 | |
905 | // Verify that the signature of the function matches the operation's operands |
906 | // and results. |
907 | static LogicalResult VerifySignature(GraphFuncOp func, Operation *op, |
908 | TypeRange operands, TypeRange results, |
909 | const Twine &func_name) { |
910 | auto attach_func = [&](InFlightDiagnostic diag) -> LogicalResult { |
911 | return diag.attachNote(func.getLoc()).appendOp(*func, OpPrintingFlags()) |
912 | << "\nsee referenced function" ; |
913 | }; |
914 | |
915 | ArrayRef<Type> arguments = func.getFunctionType().getInputs(); |
916 | ArrayRef<Type> returns = func.getFunctionType().getResults(); |
917 | if (operands.size() * 2 != arguments.size()) { |
918 | return attach_func(op->emitOpError(func_name) |
919 | << " function has " << arguments.size() / 2 |
920 | << " arguments but was provided " << operands.size()); |
921 | } |
922 | if (results.size() != returns.size()) { |
923 | return attach_func(op->emitOpError(func_name) |
924 | << " function has " << returns.size() |
925 | << " results but expected " << results.size()); |
926 | } |
927 | |
928 | if (func.getGeneric()) return success(); |
929 | |
930 | for (auto &it : llvm::enumerate(operands)) { |
931 | Type arg_type = arguments[it.index() * 2]; |
932 | Type op_type = it.value(); |
933 | if (!tf_type::HasCompatibleElementTypes(arg_type, op_type)) { |
934 | return attach_func( |
935 | op->emitOpError(func_name) |
936 | << " function argument #" << it.index() << " type " << arg_type |
937 | << " is not compatible with corresponding operand type: " << op_type); |
938 | } |
939 | } |
940 | for (auto &it : llvm::enumerate(results)) { |
941 | Type ret_type = returns[it.index()]; |
942 | Type res_type = it.value(); |
943 | if (!tf_type::HasCompatibleElementTypes(ret_type, res_type)) { |
944 | return attach_func( |
945 | op->emitOpError(func_name) |
946 | << " function result #" << it.index() << " type " << ret_type |
947 | << " is not compatible with corresponding result type: " << res_type); |
948 | } |
949 | } |
950 | return success(); |
951 | } |
952 | |
953 | // This function verifies that the types of `values`, which are either operands |
954 | // or results of `op`, match the types specified in `types`, which is expected |
955 | // to be an array of type attributes. |
956 | static LogicalResult VerifyTypeArray(Operation *op, ValueRange values, |
957 | ArrayAttr types, StringRef kind) { |
958 | // Don't verify if the types are not present. |
959 | if (!types) return success(); |
960 | if (values.size() != types.size()) { |
961 | return op->emitOpError("has " ) << values.size() << " " << kind << "s but " |
962 | << types.size() << " " << kind << " types" ; |
963 | } |
964 | for (auto it : |
965 | llvm::zip(llvm::enumerate(values), types.getAsRange<TypeAttr>())) { |
966 | Type type = std::get<0>(it).value().getType(); |
967 | Type dtype = std::get<1>(it).getValue(); |
968 | if (!tf_type::HasCompatibleElementTypes(type, |
969 | UnrankedTensorType::get(dtype))) { |
970 | return op->emitOpError(kind) |
971 | << " #" << std::get<0>(it).index() |
972 | << " is incompatible with dtype " << dtype << ", got: " << type; |
973 | } |
974 | } |
975 | return success(); |
976 | } |
977 | |
978 | namespace detail { |
979 | // Check if the op type has `T`. |
980 | template <typename OpT> |
981 | using has_T = decltype(std::declval<OpT>().T()); |
982 | template <typename OpT> |
983 | using detect_has_T = llvm::is_detected<has_T, OpT>; |
984 | |
985 | // Get the input and output type arrays. If the op has a single type array, |
986 | // use it for both input and output. Otherwise, return separate type arrays. |
987 | template <typename OpT, bool = detect_has_T<OpT>::value> |
988 | struct GetTypeArray { |
989 | static ArrayAttr getInputTypes(OpT op) { return op.getTinAttr(); } |
990 | static ArrayAttr getOutputTypes(OpT op) { return op.getToutAttr(); } |
991 | }; |
992 | template <typename OpT> |
993 | struct GetTypeArray<OpT, true> { |
994 | static ArrayAttr getInputTypes(OpT op) { return op.getTAttr(); } |
995 | static ArrayAttr getOutputTypes(OpT op) { return op.getTAttr(); } |
996 | }; |
997 | } // namespace detail |
998 | |
999 | // Verify a functional op's inputs and outputs against its data type arrays. For |
1000 | // loop ops, this also checks that the number of inputs and outputs match. This |
1001 | // is guaranteed to be valid on import but may be violated by a transformation. |
1002 | template <typename OpT> |
1003 | static LogicalResult VerifyTypeArrayAttributes(OpT op) { |
1004 | using GetTypeArray = typename detail::GetTypeArray<OpT>; |
1005 | ValueRange args = |
1006 | SplitDataAndControlValues(op.getArgs(), ControlType::get(op.getContext())) |
1007 | .first; |
1008 | return success( |
1009 | succeeded(VerifyTypeArray(op, args, GetTypeArray::getInputTypes(op), |
1010 | "argument" )) && |
1011 | succeeded(VerifyTypeArray(op, op.getOuts(), |
1012 | GetTypeArray::getOutputTypes(op), "result" ))); |
1013 | } |
1014 | |
1015 | //===----------------------------------------------------------------------===// |
1016 | // If-Like Ops |
1017 | |
1018 | template <typename IfLikeOp> |
1019 | static LogicalResult VerifyIfLikeOp(IfLikeOp op, |
1020 | SymbolTableCollection &symbol_table) { |
1021 | if (failed(op.verifyInvariants())) return failure(); |
1022 | FailureOr<TypeRange> ins = VerifyOperands(op); |
1023 | if (failed(ins)) return failure(); |
1024 | FailureOr<TypeRange> outs = VerifyResults(op); |
1025 | if (failed(outs)) return failure(); |
1026 | |
1027 | // The first operand is the condition and is not passed to the functions. |
1028 | TypeRange func_args = ins->drop_front(); |
1029 | |
1030 | auto then_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>( |
1031 | op, op.getThenBranch().getName()); |
1032 | if (then_func && |
1033 | failed(VerifySignature(then_func, op, func_args, *outs, "then" ))) |
1034 | return failure(); |
1035 | |
1036 | auto else_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>( |
1037 | op, op.getElseBranch().getName()); |
1038 | if (else_func && |
1039 | failed(VerifySignature(else_func, op, func_args, *outs, "else" ))) |
1040 | return failure(); |
1041 | |
1042 | return VerifyTypeArrayAttributes(op); |
1043 | } |
1044 | |
1045 | //===----------------------------------------------------------------------===// |
1046 | // Case-Like Ops |
1047 | |
1048 | template <typename CaseLikeOp> |
1049 | static LogicalResult VerifyCaseLikeOp(CaseLikeOp op, |
1050 | SymbolTableCollection &symbol_table) { |
1051 | if (failed(op.verifyInvariants())) return failure(); |
1052 | FailureOr<TypeRange> ins = VerifyOperands(op); |
1053 | if (failed(ins)) return failure(); |
1054 | FailureOr<TypeRange> outs = VerifyResults(op); |
1055 | if (failed(outs)) return failure(); |
1056 | |
1057 | // The first operand is the branch index and is not passed to the functions. |
1058 | TypeRange func_args = ins->drop_front(); |
1059 | |
1060 | for (auto &it : llvm::enumerate(op.getBranches())) { |
1061 | SymbolRefAttr func_name = it.value().template cast<FuncAttr>().getName(); |
1062 | auto func = |
1063 | symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(op, func_name); |
1064 | if (func && failed(VerifySignature(func, op, func_args, *outs, |
1065 | "branch #" + Twine(it.index())))) |
1066 | return failure(); |
1067 | } |
1068 | |
1069 | return VerifyTypeArrayAttributes(op); |
1070 | } |
1071 | |
1072 | //===----------------------------------------------------------------------===// |
1073 | // While-Like Ops |
1074 | |
1075 | template <typename WhileLikeOp> |
1076 | static LogicalResult VerifyWhileLikeOp(WhileLikeOp op, |
1077 | SymbolTableCollection &symbol_table) { |
1078 | if (failed(op.verifyInvariants())) return failure(); |
1079 | FailureOr<TypeRange> ins = VerifyOperands(op); |
1080 | if (failed(ins)) return failure(); |
1081 | FailureOr<TypeRange> outs = VerifyResults(op); |
1082 | if (failed(outs)) return failure(); |
1083 | |
1084 | SymbolRefAttr body_name = op.getBody().getName(); |
1085 | |
1086 | auto cond_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>( |
1087 | op, op.getCond().getName()); |
1088 | auto i1_type = UnrankedTensorType::get(Builder(op.getContext()).getI1Type()); |
1089 | if (cond_func && |
1090 | failed(VerifySignature(cond_func, op, *ins, i1_type, "cond" ))) |
1091 | return failure(); |
1092 | |
1093 | auto body_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>( |
1094 | op, op.getBody().getName()); |
1095 | if (body_func && failed(VerifySignature(body_func, op, *ins, *outs, "body" ))) |
1096 | return failure(); |
1097 | |
1098 | return VerifyTypeArrayAttributes(op); |
1099 | } |
1100 | |
1101 | //===----------------------------------------------------------------------===// |
1102 | // ForOp |
1103 | |
1104 | LogicalResult ForOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1105 | if (failed(verifyInvariants())) return failure(); |
1106 | FailureOr<TypeRange> ins = VerifyOperands(*this); |
1107 | if (failed(ins)) return failure(); |
1108 | FailureOr<TypeRange> outs = VerifyResults(*this); |
1109 | if (failed(outs)) return failure(); |
1110 | |
1111 | auto body_func = symbolTable.lookupNearestSymbolFrom<GraphFuncOp>( |
1112 | *this, getBody().getName()); |
1113 | // The first three arguments are the for-loop indices, but the current loop |
1114 | // index is passed in. |
1115 | TypeRange func_args = llvm::drop_begin(*ins, /*N=*/2); |
1116 | if (body_func && |
1117 | failed(VerifySignature(body_func, *this, func_args, *outs, "body" ))) |
1118 | return failure(); |
1119 | |
1120 | return VerifyTypeArrayAttributes(*this); |
1121 | } |
1122 | |
1123 | //===----------------------------------------------------------------------===// |
1124 | // Region Ops and Terminators |
1125 | //===----------------------------------------------------------------------===// |
1126 | |
1127 | // If a region op has preserved attributes, verify that they match the number of |
1128 | // results and block arguments. |
1129 | static LogicalResult VerifyPreservedAttrs(Operation *op, |
1130 | ArrayRef<Attribute> preserved_attrs) { |
1131 | assert(op->getNumRegions() == preserved_attrs.size()); |
1132 | for (auto it : llvm::zip(preserved_attrs, op->getRegions())) { |
1133 | // Preserved attributes for a particular region may not exist. |
1134 | auto attrs = std::get<0>(it).dyn_cast_or_null<RegionAttr>(); |
1135 | if (!attrs) continue; |
1136 | Region ®ion = std::get<1>(it); |
1137 | |
1138 | const auto emit_region_error = [&](StringRef msg) { |
1139 | return op->emitOpError("region #" ) |
1140 | << region.getRegionNumber() << " " << msg; |
1141 | }; |
1142 | |
1143 | unsigned num_args = GetLoopRegionDataArgs(region).size(); |
1144 | if (num_args != attrs.getArgAttrs().size()) { |
1145 | return emit_region_error("has " ) |
1146 | << num_args << " argument(s) but preserved attributes has " |
1147 | << attrs.getArgAttrs().size(); |
1148 | } |
1149 | |
1150 | // All regions are terminated by either a YieldOp or a ConditionOp. In the |
1151 | // latter case, the function will only have one result. |
1152 | unsigned num_rets; |
1153 | Operation *terminator = region.front().getTerminator(); |
1154 | if (isa<ConditionOp>(terminator)) { |
1155 | num_rets = 1; |
1156 | } else { |
1157 | num_rets = cast<RegionBranchTerminatorOpInterface>(terminator) |
1158 | .getMutableSuccessorOperands(region.getRegionNumber()) |
1159 | .size(); |
1160 | } |
1161 | if (num_rets != attrs.getResAttrs().size()) { |
1162 | return emit_region_error("has " ) |
1163 | << num_rets << " result(s) but preserved attributes has " |
1164 | << attrs.getResAttrs().size(); |
1165 | } |
1166 | } |
1167 | return success(); |
1168 | } |
1169 | |
1170 | //===----------------------------------------------------------------------===// |
1171 | // YieldOp |
1172 | |
1173 | MutableOperandRange YieldOp::getMutableSuccessorOperands( |
1174 | Optional<unsigned> index) { |
1175 | // Get the subrange of non-control operands. |
1176 | return getArgsMutable(); |
1177 | } |
1178 | |
1179 | static bool TerminatedByYield(Block &block) { |
1180 | return isa<YieldOp>(block.getTerminator()); |
1181 | } |
1182 | |
1183 | //===----------------------------------------------------------------------===// |
1184 | // IfLikeRegionOp |
1185 | |
1186 | // Verify an if-like region op. |
1187 | template <typename IfLikeRegionOp> |
1188 | static LogicalResult VerifyIfLikeRegionOp(IfLikeRegionOp op) { |
1189 | // Verify terminators. |
1190 | if (!TerminatedByYield(op.then_block())) |
1191 | return op.emitOpError("then region must be terminated by a 'tfg.yield'" ); |
1192 | if (!TerminatedByYield(op.else_block())) |
1193 | return op.emitOpError("else region must be terminated by a 'tfg.yield'" ); |
1194 | return VerifyPreservedAttrs( |
1195 | op, {op.getThenRegionAttrsAttr(), op.getElseRegionAttrsAttr()}); |
1196 | } |
1197 | |
1198 | // Given an potentially null attribute that would represent a constant value, |
1199 | // try to narrow it to a statically known condition. |
1200 | // TODO(jeffniu): Incorporate the other cases of `tf.ToBool`. |
1201 | static Optional<bool> GetStaticallyKnownBranch(Attribute cond_attr) { |
1202 | // Only handle the case of a scalar tensor of i1. |
1203 | auto cond = cond_attr.dyn_cast_or_null<ElementsAttr>(); |
1204 | if (cond && cond.getNumElements() == 1 && |
1205 | cond.getElementType().isSignlessInteger(1)) |
1206 | return cond.getSplatValue<bool>(); |
1207 | return {}; |
1208 | } |
1209 | |
1210 | // Get the successor of the regions of an if-like op. |
1211 | template <typename IfLikeRegionOp> |
1212 | void GetIfLikeRegionOpSuccessorRegions( |
1213 | IfLikeRegionOp op, Optional<unsigned> index, ArrayRef<Attribute> operands, |
1214 | SmallVectorImpl<RegionSuccessor> ®ions) { |
1215 | assert(index.has_value() || |
1216 | !operands.empty() && "if-like op expected at least 1 operand" ); |
1217 | // Both regions branch back to the parent op. |
1218 | if (index.has_value()) { |
1219 | // Ignore the control token. |
1220 | regions.emplace_back( |
1221 | ResultRange(op->result_begin(), std::prev(op->result_end()))); |
1222 | } else if (auto cond = GetStaticallyKnownBranch(operands[0])) { |
1223 | // Add only 1 possible successor if the condition is known. |
1224 | Region ®ion = *cond ? op.getThenRegion() : op.getElseRegion(); |
1225 | regions.emplace_back(®ion, GetLoopRegionDataArgs(region)); |
1226 | } else { |
1227 | // Unknown successor. |
1228 | regions.emplace_back(&op.getThenRegion(), |
1229 | GetLoopRegionDataArgs(op.getThenRegion())); |
1230 | regions.emplace_back(&op.getElseRegion(), |
1231 | GetLoopRegionDataArgs(op.getElseRegion())); |
1232 | } |
1233 | } |
1234 | |
1235 | //===----------------------------------------------------------------------===// |
1236 | // CaseLikeRegionOp |
1237 | |
1238 | // Verify a case-like region op. |
1239 | template <typename CaseLikeRegionOp> |
1240 | static LogicalResult VerifyCaseLikeRegionOp(CaseLikeRegionOp op) { |
1241 | for (auto &it : llvm::enumerate(op.getBranches())) { |
1242 | if (!TerminatedByYield(it.value().front())) { |
1243 | return op.emitOpError("branch region #" ) |
1244 | << it.index() << " is not terminated by a 'tfg.yield' op" ; |
1245 | } |
1246 | } |
1247 | |
1248 | if (op.getBranchAttrs() && |
1249 | op.getBranches().size() != op.getBranchAttrs()->size()) { |
1250 | return op.emitOpError("has " ) |
1251 | << op.getBranches().size() << " regions but " |
1252 | << op.getBranchAttrs()->size() << " branch function attributes" ; |
1253 | } |
1254 | if (auto region_attrs = op.getRegionAttrsAttr()) { |
1255 | if (region_attrs.size() != op.getNumRegions()) { |
1256 | return op.emitOpError("expected " ) |
1257 | << op.getNumRegions() << " region attribute(s) but got " |
1258 | << region_attrs.size(); |
1259 | } |
1260 | if (failed(VerifyPreservedAttrs(op, region_attrs.getValue()))) |
1261 | return failure(); |
1262 | } |
1263 | return success(); |
1264 | } |
1265 | |
1266 | // Given a potentially null attribute that would represent a constant value, |
1267 | // try to narrow it to a statically known branch index. |
1268 | static Optional<unsigned> GetStaticallyKnownCaseBranch(Attribute branch_attr) { |
1269 | auto branch = branch_attr.dyn_cast_or_null<ElementsAttr>(); |
1270 | if (branch && branch.getNumElements() == 1 && |
1271 | branch.getElementType().isSignlessInteger(32)) |
1272 | return branch.getSplatValue<unsigned>(); |
1273 | return {}; |
1274 | } |
1275 | |
1276 | // Get the successor of the regions of a case-like op. |
1277 | template <typename CaseLikeRegionOp> |
1278 | void GetCaseLikeRegionOpSuccessorRegions( |
1279 | CaseLikeRegionOp op, Optional<unsigned> index, ArrayRef<Attribute> operands, |
1280 | SmallVectorImpl<RegionSuccessor> ®ions) { |
1281 | assert(index.has_value() || |
1282 | !operands.empty() && "case-like op expected at least 1 operand" ); |
1283 | // All branch regions branch back to the parent op. |
1284 | if (index.has_value()) { |
1285 | // Ignore the control token. |
1286 | regions.emplace_back( |
1287 | ResultRange(op->result_begin(), std::prev(op->result_end()))); |
1288 | } else if (auto branch_index = GetStaticallyKnownCaseBranch(operands[0])) { |
1289 | // Add only 1 possible successor if the condition is known. |
1290 | Region ®ion = op.getBranches()[*branch_index]; |
1291 | regions.emplace_back(®ion, GetLoopRegionDataArgs(region)); |
1292 | } else { |
1293 | // Unknown successor. Add all of them. |
1294 | for (Region &branch : op.getBranches()) |
1295 | regions.emplace_back(&branch, GetLoopRegionDataArgs(branch)); |
1296 | } |
1297 | } |
1298 | |
1299 | //===----------------------------------------------------------------------===// |
1300 | // ConditionOp |
1301 | |
1302 | MutableOperandRange ConditionOp::getMutableSuccessorOperands( |
1303 | Optional<unsigned> index) { |
1304 | // Get the subrange of non-control operands that are forwarded to the |
1305 | // successor region. |
1306 | return getArgsMutable(); |
1307 | } |
1308 | |
1309 | //===----------------------------------------------------------------------===// |
1310 | // WhileLikeRegionOp |
1311 | |
1312 | // Verify that the loop regions of a region-based loop op have N control tokens |
1313 | // immediately following N data values in their entry block arguments. |
1314 | // `RegionBranchOpInterface` will verify the number of arguments and their |
1315 | // types. |
1316 | static LogicalResult VerifyLoopRegionArgs(Operation *op, Region ®ion) { |
1317 | const auto arg_error = [&](BlockArgument arg) { |
1318 | return op->emitOpError("region #" ) |
1319 | << region.getRegionNumber() << " argument #" << arg.getArgNumber() |
1320 | << " " ; |
1321 | }; |
1322 | |
1323 | // The interface trait verifies the number of data and control arguments. If |
1324 | // the first half of the arguments are not control tokens, then we know for |
1325 | // sure that the second half is only control tokens. |
1326 | for (BlockArgument data : GetLoopRegionDataArgs(region)) |
1327 | if (data.getType().isa<ControlType>()) |
1328 | return arg_error(data) << "should not be a control token" ; |
1329 | return success(); |
1330 | } |
1331 | |
1332 | // Verify a while-like region op. |
1333 | template <typename WhileLikeRegionOp> |
1334 | static LogicalResult VerifyWhileLikeRegionOp(WhileLikeRegionOp op) { |
1335 | // Verify terminators. |
1336 | if (!isa<ConditionOp>(op.cond_block().getTerminator())) { |
1337 | return op.emitOpError( |
1338 | "condition region must be terminated by a 'tfg.condition' op" ); |
1339 | } |
1340 | if (!TerminatedByYield(op.body_block())) |
1341 | op.emitOpError("body region must be terminated by a 'tfg.yield' op" ); |
1342 | |
1343 | if (failed(VerifyLoopRegionArgs(op, op.getCondRegion())) || |
1344 | failed(VerifyLoopRegionArgs(op, op.getBodyRegion()))) |
1345 | return failure(); |
1346 | if (failed(VerifyPreservedAttrs( |
1347 | op, {op.getCondRegionAttrsAttr(), op.getBodyRegionAttrsAttr()}))) |
1348 | return failure(); |
1349 | |
1350 | return success(); |
1351 | } |
1352 | |
1353 | template <typename WhileLikeRegionOp> |
1354 | static void GetWhileLikeRegionOpSuccessorRegions( |
1355 | WhileLikeRegionOp op, Optional<unsigned> index, |
1356 | ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> ®ions) { |
1357 | // The parent op and the body region always branch to the condion region. |
1358 | if (!index || *index == 1) { |
1359 | regions.emplace_back(&op.getCondRegion(), |
1360 | GetLoopRegionDataArgs(op.getCondRegion())); |
1361 | return; |
1362 | } |
1363 | assert(*index == 0 && "invalid region index" ); |
1364 | // The condition regions branches to the loop body or back to the parent. |
1365 | // Try to narrow the condition value to a constant. |
1366 | auto condition = |
1367 | cast<ConditionOp>(op.getCondRegion().front().getTerminator()); |
1368 | Attribute cond_attr; |
1369 | matchPattern(condition.getCond(), m_Constant(&cond_attr)); |
1370 | Optional<bool> cond = GetStaticallyKnownBranch(cond_attr); |
1371 | if (!cond || *cond) { |
1372 | regions.emplace_back(&op.getBodyRegion(), |
1373 | GetLoopRegionDataArgs(op.getBodyRegion())); |
1374 | } |
1375 | if (!cond || !*cond) { |
1376 | // Drop the control token. |
1377 | regions.emplace_back(op.getResults().drop_back()); |
1378 | } |
1379 | } |
1380 | |
1381 | //===----------------------------------------------------------------------===// |
1382 | // ForRegionOp |
1383 | |
1384 | LogicalResult ForRegionOp::verify() { |
1385 | if (!TerminatedByYield(body_block())) { |
1386 | return emitOpError("body region must be terminated by a 'tfg.yield' op" ); |
1387 | } |
1388 | |
1389 | Block::BlockArgListType args = body_block().getArguments(); |
1390 | if (args.empty()) { |
1391 | return emitOpError( |
1392 | "expected the body block to have at least have the loop index as an " |
1393 | "argument" ); |
1394 | } |
1395 | auto index = args.front().getType().dyn_cast<TensorType>(); |
1396 | if (!index || !index.getElementType().isSignlessInteger(32)) { |
1397 | return emitOpError( |
1398 | "expected first body block argument to be an i32 tensor" ); |
1399 | } |
1400 | |
1401 | if (failed(VerifyLoopRegionArgs(*this, getBodyRegion()))) return failure(); |
1402 | return VerifyPreservedAttrs(*this, {getRegionAttrsAttr()}); |
1403 | } |
1404 | |
1405 | OperandRange ForRegionOp::getSuccessorEntryOperands(Optional<unsigned> index) { |
1406 | return getInit(); |
1407 | } |
1408 | |
1409 | void ForRegionOp::getSuccessorRegions( |
1410 | Optional<unsigned> index, ArrayRef<Attribute> operands, |
1411 | SmallVectorImpl<RegionSuccessor> ®ions) { |
1412 | // Both the parent op and the body region branch to the body. Ignore the loop |
1413 | // index block argument, as it is not modified by the loop body itself. |
1414 | regions.emplace_back(&getBodyRegion(), |
1415 | GetLoopRegionDataArgs(getBodyRegion()).drop_front()); |
1416 | if (!index) return; |
1417 | // The body might branch back to the parent. Drop the control token. |
1418 | regions.emplace_back((*this)->getResults().drop_back()); |
1419 | } |
1420 | |
1421 | BlockArgument ForRegionOp::getDataValueOf(BlockArgument ctl) { |
1422 | return GetLoopRegionDataOf(ctl); |
1423 | } |
1424 | BlockArgument ForRegionOp::getControlTokenOf(BlockArgument data) { |
1425 | return GetLoopRegionControlOf(data); |
1426 | } |
1427 | BlockArgument ForRegionOp::getDataValue(Region ®ion, unsigned idx) { |
1428 | return GetLoopRegionDataArgs(region)[idx]; |
1429 | } |
1430 | BlockArgument ForRegionOp::getControlToken(Region ®ion, unsigned idx) { |
1431 | return GetLoopRegionControlTokens(region)[idx]; |
1432 | } |
1433 | |
1434 | //===----------------------------------------------------------------------===// |
1435 | // Function Table |
1436 | //===----------------------------------------------------------------------===// |
1437 | |
1438 | FunctionTable::FunctionTable(ModuleOp module) { |
1439 | // Collect function names (to be used for disambiguating legacy call |
1440 | // behavior). |
1441 | for (auto &op : module.getOps()) { |
1442 | if (auto func = dyn_cast<GraphFuncOp>(op)) functions.insert(func.getName()); |
1443 | } |
1444 | } |
1445 | |
1446 | bool FunctionTable::MayBeCall(Operation *op) const { |
1447 | if (IsLegacyCall(op)) return true; |
1448 | // The operation might be a call if it references a symbol. |
1449 | bool references_symbol = false; |
1450 | op->getAttrDictionary().walkSubAttrs( |
1451 | [&](Attribute attr) { references_symbol |= attr.isa<SymbolRefAttr>(); }); |
1452 | return references_symbol; |
1453 | } |
1454 | |
1455 | bool FunctionTable::IsLegacyCall(Operation *op) const { |
1456 | // If the operation name refers to a function in the module, then it is |
1457 | // guaranteed to be a legacy call. Otherwise, it is not. |
1458 | return functions.count(op->getName().stripDialect()); |
1459 | } |
1460 | |
1461 | } // namespace tfg |
1462 | } // namespace mlir |
1463 | |
1464 | //===----------------------------------------------------------------------===// |
1465 | // ODS Definitions |
1466 | //===----------------------------------------------------------------------===// |
1467 | |
1468 | #define GET_OP_CLASSES |
1469 | #include "tensorflow/core/ir/ops.cc.inc" |
1470 | #define GET_ATTRDEF_CLASSES |
1471 | #include "tensorflow/core/ir/attributes.cc.inc" |
1472 | |