llvm/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td

//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops 
// automatically. It is used by NVVM to LLVM pass. 
//
//===----------------------------------------------------------------------===//

#ifndef BASICPTXBUILDER_OP_INTERFACE
#define BASICPTXBUILDER_OP_INTERFACE

include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"

//===----------------------------------------------------------------------===//
// Basic PTX Builder Interface
//===----------------------------------------------------------------------===//

def PtxPredicate : Optional<I1>;

def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
  let description = [{
    This interface is used to generate inline assembly with PTX for basic 
    operations. It's utilized in the `convert-nvvm-to-llvm pass` to lower 
    NVVM Ops that implement this interface to PTX (parallel thread execution) 
    using inline assembly Ops. Interface methods play a crucial role in this 
    lowering process.

    Here's an example of an Op with the `BasicPtxBuilderOpInterface`:    
    ```tablegen
      def NVVM_SpecialOp : NVVM_Op<"special.op",
          [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
        Results<(outs LLVM_Type:$res)>,
        Arguments<(ins LLVM_i64ptr_any:$op1, I32:$op2)> {
        ...
        let extraClassDefinition = [{
          std::string $cppClass::getPtx() { 
            return std::string("special.op %0, %1, %2;"); 
          }
     } ];
    ```

    In the above NVVM Op example:
    ```mlir
      %0 = nvvm.special.op %1, %2 : !llvm.ptr, i32 -> i32
    ```

    The `convert-nvvm-to-llvm` pass generates the inline assembly like below. 
    The order of arguments is retained, and the read and write modifiers are 
    set based on the input and result types:
    ```mlir
      %0 = llvm.inline_asm 
                has_side_effects 
                asm_dialect = 
                att "special.op %0, %1, %2;", "=r,l,r" %arg0, %arg1 
                : (!llvm.ptr, i32) -> i32
    ```
  }];
  let cppNamespace = "::mlir::NVVM";
  let methods = [
    InterfaceMethod<
        /*desc=*/[{
          Optional function for setting a predicate, which 
          always returns a `PtxPredicate` value of type i1. If no predicate is 
          provided, the instruction is unguarded; otherwise, it's guarded by the 
          predicate value. The `PtxPredicate` value must always be the last argument. 
          The provided PTX code by `getPtx` should not include the predicate usage.
          The interface automatically handles predicate usage in the generated
          PTX code when necessary.
        }],
        /*retType=*/"std::optional<::mlir::Value>",
        /*methodName=*/"getPredicate",
        /*args=*/(ins),
        /*methodBody=*/"",
        /*defaultImplementation=*/"return {};"
      >,
    InterfaceMethod<
        /*desc=*/[{ Returns PTX assembly with operand number. }],
        /*retType=*/"std::string",
        /*methodName=*/"getPtx"
      >,
    InterfaceMethod<
        /*desc=*/[{
          This function indicates whether the operation is supported by LLVM 
          intrinsics. It's particularly useful for operations that have 
          specific cases with LLVM intrinsic support.
        }],
        /*retType=*/"bool",
        /*methodName=*/"hasIntrinsic",
        /*args=*/(ins),
        /*methodBody=*/"",
        /*defaultImplementation=*/"return false;"
      >,
    InterfaceMethod<
        /*desc=*/[{Return whether the operation has memory side effects.}],
        /*retType=*/"bool",
        /*methodName=*/"hasSideEffect",
        /*args=*/(ins),
        /*methodBody=*/"",
        /*defaultImplementation=*/"return true;"
      >,
    
    InterfaceMethod<
        /*desc=*/[{Helper function to generate i32 constant value.}],
        /*retType=*/"::mlir::Value",
        /*methodName=*/"makeConstantI32",
        /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "int" : $val),
        /*methodBody=*/"",
        /*defaultImpl=*/ [{
            mlir::Operation* op = $_op;
            return rewriter.create<LLVM::ConstantOp>(
              op->getLoc(), rewriter.getIntegerType(32), val);
        }]
     >,
     InterfaceMethod<
         /*desc=*/[{ 
            This function supplies the necessary arguments for passing PTX code,
            following this order:
             1) Adds results 
             2) Adds operands 
             3) Adds attributes             
          }],
         /*retType=*/"void",
         /*methodName=*/"getAsmValues",
         /*args=*/(ins "::mlir::RewriterBase &":$rewriter, 
         "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
         /*methodBody=*/"",
         /*defaultImpl=*/ [{         
           mlir::Operation* op = $_op;
           
           // Step 1. Add results
           for (auto val : op->getResults()) 
            asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});

           // Step 2. Add operands
           for (auto val : op->getOperands()) 
            asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
           
           // Step 3. Add attributes
           for (auto attr : op->getAttrs()) {
            if (auto intAttr = dyn_cast<mlir::IntegerAttr>(attr.getValue())) {
             ::mlir::Value val = makeConstantI32(rewriter, intAttr.getInt());
             asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
             }
           }
         }]
       >
  ];
}

#endif // BASICPTXBUILDER_OP_INTERFACE