llvm/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp

//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Identification:
// This step is responsible for finding the patterns that can be lowered to
// complex instructions, and building a graph to represent the complex
// structures. Starting from the "Converging Shuffle" (a shuffle that
// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
// operands are evaluated and identified as "Composite Nodes" (collections of
// instructions that can potentially be lowered to a single complex
// instruction). This is performed by checking the real and imaginary components
// and tracking the data flow for each component while following the operand
// pairs. Validity of each node is expected to be done upon creation, and any
// validation errors should halt traversal and prevent further graph
// construction.
// Instead of relying on Shuffle operations, vector interleaving and
// deinterleaving can be represented by vector.interleave2 and
// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
// these intrinsics, whereas, fixed-width vectors are recognized for both
// shufflevector instruction and intrinsics.
//
// Replacement:
// This step traverses the graph built up by identification, delegating to the
// target to validate and generate the correct intrinsics, and plumbs them
// together connecting each end of the new intrinsics graph to the existing
// use-def chain. This step is assumed to finish successfully, as all
// information is expected to be correct by this point.
//
//
// Internal data structure:
// ComplexDeinterleavingGraph:
// Keeps references to all the valid CompositeNodes formed as part of the
// transformation, and every Instruction contained within said nodes. It also
// holds onto a reference to the root Instruction, and the root node that should
// replace it.
//
// ComplexDeinterleavingCompositeNode:
// A CompositeNode represents a single transformation point; each node should
// transform into a single complex instruction (ignoring vector splitting, which
// would generate more instructions per node). They are identified in a
// depth-first manner, traversing and identifying the operands of each
// instruction in the order they appear in the IR.
// Each node maintains a reference  to its Real and Imaginary instructions,
// as well as any additional instructions that make up the identified operation
// (Internal instructions should only have uses within their containing node).
// A Node also contains the rotation and operation type that it represents.
// Operands contains pointers to other CompositeNodes, acting as the edges in
// the graph. ReplacementValue is the transformed Value* that has been emitted
// to the IR.
//
// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
// ReplacementValue fields of that Node are relevant, where the ReplacementValue
// should be pre-populated.
//
//===----------------------------------------------------------------------===//

#include "llvm/CodeGen/ComplexDeinterleavingPass.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/Local.h"
#include <algorithm>

usingnamespacellvm;
usingnamespacePatternMatch;

#define DEBUG_TYPE

STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");

static cl::opt<bool> ComplexDeinterleavingEnabled(
    "enable-complex-deinterleaving",
    cl::desc("Enable generation of complex instructions"), cl::init(true),
    cl::Hidden);

/// Checks the given mask, and determines whether said mask is interleaving.
///
/// To be interleaving, a mask must alternate between `i` and `i + (Length /
/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
/// 4x vector interleaving mask would be <0, 2, 1, 3>).
static bool isInterleavingMask(ArrayRef<int> Mask);

/// Checks the given mask, and determines whether said mask is deinterleaving.
///
/// To be deinterleaving, a mask must increment in steps of 2, and either start
/// with 0 or 1.
/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
/// <1, 3, 5, 7>).
static bool isDeinterleavingMask(ArrayRef<int> Mask);

/// Returns true if the operation is a negation of V, and it works for both
/// integers and floats.
static bool isNeg(Value *V);

/// Returns the operand for negation operation.
static Value *getNegOperand(Value *V);

namespace {

class ComplexDeinterleavingLegacyPass : public FunctionPass {};

class ComplexDeinterleavingGraph;
struct ComplexDeinterleavingCompositeNode {};

class ComplexDeinterleavingGraph {};

class ComplexDeinterleaving {};

} // namespace

char ComplexDeinterleavingLegacyPass::ID =;

INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
                      "Complex Deinterleaving", false, false)
INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
                    "Complex Deinterleaving", false, false)

PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
                                                 FunctionAnalysisManager &AM) {}

FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {}

bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {}

bool ComplexDeinterleaving::runOnFunction(Function &F) {}

static bool isInterleavingMask(ArrayRef<int> Mask) {}

static bool isDeinterleavingMask(ArrayRef<int> Mask) {}

bool isNeg(Value *V) {}

Value *getNegOperand(Value *V) {}

bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
    Instruction *Real, Instruction *Imag,
    std::pair<Value *, Value *> &PartialMatch) {}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
                                               Instruction *Imag) {}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {}

static bool isInstructionPairAdd(Instruction *A, Instruction *B) {}

static bool isInstructionPairMul(Instruction *A, Instruction *B) {}

static bool isInstructionPotentiallySymmetric(Instruction *I) {}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
                                                       Instruction *Imag) {}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
                                                 Instruction *Imag) {}

bool ComplexDeinterleavingGraph::collectPartialMuls(
    const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
    std::vector<PartialMulCandidate> &PartialMulCandidates) {}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyMultiplications(
    std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
    NodePtr Accumulator = nullptr) {}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyAdditions(
    std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
    std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::extractPositiveAddend(
    std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {}

bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {}

bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {}

void ComplexDeinterleavingGraph::identifyReductionNodes() {}

bool ComplexDeinterleavingGraph::checkNodes() {}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
                                                 Instruction *Imag) {}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
                                            Instruction *Imag) {}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
                                               Instruction *Imag) {}

static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
                                   std::optional<FastMathFlags> Flags,
                                   Value *InputA, Value *InputB) {}

Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
                                               RawNodePtr Node) {}

void ComplexDeinterleavingGraph::processReductionOperation(
    Value *OperationReplacement, RawNodePtr Node) {}

void ComplexDeinterleavingGraph::replaceNodes() {}