# RUN: %PYTHON %s 2>&1 | FileCheck %s
from mlir.dialects import arith, func, pdl
from mlir.dialects.builtin import module
from mlir.ir import *
from mlir.rewrite import *
def construct_and_print_in_module(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
module = f(module)
if module is not None:
print(module)
return f
# CHECK-LABEL: TEST: test_add_to_mul
# CHECK: arith.muli
@construct_and_print_in_module
def test_add_to_mul(module_):
index_type = IndexType.get()
# Create a test case.
@module(sym_name="ir")
def ir():
@func.func(index_type, index_type)
def add_func(a, b):
return arith.addi(a, b)
# Create a rewrite from add to mul. This will match
# - operation name is arith.addi
# - operands are index types.
# - there are two operands.
with Location.unknown():
m = Module.create()
with InsertionPoint(m.body):
# Change all arith.addi with index types to arith.muli.
@pdl.pattern(benefit=1, sym_name="addi_to_mul")
def pat():
# Match arith.addi with index types.
index_type = pdl.TypeOp(IndexType.get())
operand0 = pdl.OperandOp(index_type)
operand1 = pdl.OperandOp(index_type)
op0 = pdl.OperationOp(
name="arith.addi", args=[operand0, operand1], types=[index_type]
)
# Replace the matched op with arith.muli.
@pdl.rewrite()
def rew():
newOp = pdl.OperationOp(
name="arith.muli", args=[operand0, operand1], types=[index_type]
)
pdl.ReplaceOp(op0, with_op=newOp)
# Create a PDL module from module and freeze it. At this point the ownership
# of the module is transferred to the PDL module. This ownership transfer is
# not yet captured Python side/has sharp edges. So best to construct the
# module and PDL module in same scope.
# FIXME: This should be made more robust.
frozen = PDLModule(m).freeze()
# Could apply frozen pattern set multiple times.
apply_patterns_and_fold_greedily(module_, frozen)
return module_