godot/thirdparty/embree/kernels/bvh/node_intersector1.h

// Copyright 2009-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "node_intersector.h"

#if defined(__AVX2__)
#define __FMA_X4__
#endif

#if defined(__aarch64__)
#define __FMA_X4__
#endif


namespace embree
{
  namespace isa
  {
    //////////////////////////////////////////////////////////////////////////////////////
    // Ray structure used in single-ray traversal
    //////////////////////////////////////////////////////////////////////////////////////

    template<int N, bool robust>
      struct TravRayBase;
      
    /* Base (without tnear and tfar) */
    TravRayBase<N, false>;

    /* Base (without tnear and tfar) */
    TravRayBase<N, true>;

    /* Full (with tnear and tfar) */
    template<int N, bool robust>
      struct TravRay : TravRayBase<N,robust>
    {};
    
    //////////////////////////////////////////////////////////////////////////////////////
    // Point Query structure used in single-ray traversal
    //////////////////////////////////////////////////////////////////////////////////////

    template<int N>
    struct TravPointQuery
    {};
    
    //////////////////////////////////////////////////////////////////////////////////////
    // point query
    //////////////////////////////////////////////////////////////////////////////////////

    template<int N>
    __forceinline size_t pointQuerySphereDistAndMask(
      const TravPointQuery<N>& query, vfloat<N>& dist, vfloat<N> const& minX, vfloat<N> const& maxX, 
      vfloat<N> const& minY, vfloat<N> const& maxY, vfloat<N> const& minZ, vfloat<N> const& maxZ)
    {}

    template<int N>
    __forceinline size_t pointQueryNodeSphere(const typename BVHN<N>::AABBNode* node, const TravPointQuery<N>& query, vfloat<N>& dist)
    {}
    
    template<int N>
    __forceinline size_t pointQueryNodeSphere(const typename BVHN<N>::AABBNodeMB* node, const TravPointQuery<N>& query, const float time, vfloat<N>& dist)
    {}
    
    template<int N>
      __forceinline size_t pointQueryNodeSphereMB4D(const typename BVHN<N>::NodeRef ref, const TravPointQuery<N>& query, const float time, vfloat<N>& dist)
    {}
    
    template<int N>
    __forceinline size_t pointQueryNodeSphere(const typename BVHN<N>::QuantizedBaseNode* node, const TravPointQuery<N>& query, vfloat<N>& dist)
    {}
    
    template<int N>
    __forceinline size_t pointQueryNodeSphere(const typename BVHN<N>::QuantizedBaseNodeMB* node, const TravPointQuery<N>& query, const float time, vfloat<N>& dist)
    {}
    
    template<int N>
    __forceinline size_t pointQueryNodeSphere(const typename BVHN<N>::OBBNode* node, const TravPointQuery<N>& query, vfloat<N>& dist)
    {}
    
    template<int N>
    __forceinline size_t pointQueryNodeSphere(const typename BVHN<N>::OBBNodeMB* node, const TravPointQuery<N>& query, const float time, vfloat<N>& dist)
    {}

    template<int N>
    __forceinline size_t pointQueryAABBDistAndMask(
      const TravPointQuery<N>& query, vfloat<N>& dist, vfloat<N> const& minX, vfloat<N> const& maxX, 
      vfloat<N> const& minY, vfloat<N> const& maxY, vfloat<N> const& minZ, vfloat<N> const& maxZ)
    {}

    template<int N>
    __forceinline size_t pointQueryNodeAABB(const typename BVHN<N>::AABBNode* node, const TravPointQuery<N>& query, vfloat<N>& dist)
    {}
    
    template<int N>
    __forceinline size_t pointQueryNodeAABB(const typename BVHN<N>::AABBNodeMB* node, const TravPointQuery<N>& query, const float time, vfloat<N>& dist)
    {}
    
    template<int N>
      __forceinline size_t pointQueryNodeAABBMB4D(const typename BVHN<N>::NodeRef ref, const TravPointQuery<N>& query, const float time, vfloat<N>& dist)
    {}
    
    template<int N>
    __forceinline size_t pointQueryNodeAABB(const typename BVHN<N>::QuantizedBaseNode* node, const TravPointQuery<N>& query, vfloat<N>& dist)
    {}
    
    template<int N>
    __forceinline size_t pointQueryNodeAABB(const typename BVHN<N>::QuantizedBaseNodeMB* node, const TravPointQuery<N>& query, const float time, vfloat<N>& dist)
    {}
    
    template<int N>
    __forceinline size_t pointQueryNodeAABB(const typename BVHN<N>::OBBNode* node, const TravPointQuery<N>& query, vfloat<N>& dist)
    {}
    
    template<int N>
    __forceinline size_t pointQueryNodeAABB(const typename BVHN<N>::OBBNodeMB* node, const TravPointQuery<N>& query, const float time, vfloat<N>& dist)
    {}

    //////////////////////////////////////////////////////////////////////////////////////
    // Fast AABBNode intersection
    //////////////////////////////////////////////////////////////////////////////////////

    template<int N, bool robust>
      __forceinline size_t intersectNode(const typename BVHN<N>::AABBNode* node, const TravRay<N,robust>& ray, vfloat<N>& dist);

    template<>
      __forceinline size_t intersectNode<4>(const typename BVH4::AABBNode* node, const TravRay<4,false>& ray, vfloat4& dist)
    {}

#if defined(__AVX__)

    template<>
      __forceinline size_t intersectNode<8>(const typename BVH8::AABBNode* node, const TravRay<8,false>& ray, vfloat8& dist)
    {
#if defined(__AVX2__)
#if defined(__aarch64__)
      const vfloat8 tNearX = madd(vfloat8::load((float*)((const char*)&node->lower_x+ray.nearX)), ray.rdir.x, ray.neg_org_rdir.x);
      const vfloat8 tNearY = madd(vfloat8::load((float*)((const char*)&node->lower_x+ray.nearY)), ray.rdir.y, ray.neg_org_rdir.y);
      const vfloat8 tNearZ = madd(vfloat8::load((float*)((const char*)&node->lower_x+ray.nearZ)), ray.rdir.z, ray.neg_org_rdir.z);
      const vfloat8 tFarX  = madd(vfloat8::load((float*)((const char*)&node->lower_x+ray.farX )), ray.rdir.x, ray.neg_org_rdir.x);
      const vfloat8 tFarY  = madd(vfloat8::load((float*)((const char*)&node->lower_x+ray.farY )), ray.rdir.y, ray.neg_org_rdir.y);
      const vfloat8 tFarZ  = madd(vfloat8::load((float*)((const char*)&node->lower_x+ray.farZ )), ray.rdir.z, ray.neg_org_rdir.z);
#else
      const vfloat8 tNearX = msub(vfloat8::load((float*)((const char*)&node->lower_x+ray.nearX)), ray.rdir.x, ray.org_rdir.x);
      const vfloat8 tNearY = msub(vfloat8::load((float*)((const char*)&node->lower_x+ray.nearY)), ray.rdir.y, ray.org_rdir.y);
      const vfloat8 tNearZ = msub(vfloat8::load((float*)((const char*)&node->lower_x+ray.nearZ)), ray.rdir.z, ray.org_rdir.z);
      const vfloat8 tFarX  = msub(vfloat8::load((float*)((const char*)&node->lower_x+ray.farX )), ray.rdir.x, ray.org_rdir.x);
      const vfloat8 tFarY  = msub(vfloat8::load((float*)((const char*)&node->lower_x+ray.farY )), ray.rdir.y, ray.org_rdir.y);
      const vfloat8 tFarZ  = msub(vfloat8::load((float*)((const char*)&node->lower_x+ray.farZ )), ray.rdir.z, ray.org_rdir.z);
#endif

#else
      const vfloat8 tNearX = (vfloat8::load((float*)((const char*)&node->lower_x+ray.nearX)) - ray.org.x) * ray.rdir.x;
      const vfloat8 tNearY = (vfloat8::load((float*)((const char*)&node->lower_x+ray.nearY)) - ray.org.y) * ray.rdir.y;
      const vfloat8 tNearZ = (vfloat8::load((float*)((const char*)&node->lower_x+ray.nearZ)) - ray.org.z) * ray.rdir.z;
      const vfloat8 tFarX  = (vfloat8::load((float*)((const char*)&node->lower_x+ray.farX )) - ray.org.x) * ray.rdir.x;
      const vfloat8 tFarY  = (vfloat8::load((float*)((const char*)&node->lower_x+ray.farY )) - ray.org.y) * ray.rdir.y;
      const vfloat8 tFarZ  = (vfloat8::load((float*)((const char*)&node->lower_x+ray.farZ )) - ray.org.z) * ray.rdir.z;
#endif
      
#if defined(__AVX2__) && !defined(__AVX512F__) // HSW
      const vfloat8 tNear = maxi(tNearX,tNearY,tNearZ,ray.tnear);
      const vfloat8 tFar  = mini(tFarX ,tFarY ,tFarZ ,ray.tfar);
      const vbool8 vmask = asInt(tNear) > asInt(tFar);
      const size_t mask = movemask(vmask) ^ ((1<<8)-1);
#elif defined(__AVX512F__) // SKX
      const vfloat8 tNear = maxi(tNearX,tNearY,tNearZ,ray.tnear);
      const vfloat8 tFar  = mini(tFarX ,tFarY ,tFarZ ,ray.tfar);
      const vbool8 vmask = asInt(tNear) <= asInt(tFar);
      const size_t mask = movemask(vmask);
#else
      const vfloat8 tNear = max(tNearX,tNearY,tNearZ,ray.tnear);
      const vfloat8 tFar  = min(tFarX ,tFarY ,tFarZ ,ray.tfar);
      const vbool8 vmask = tNear <= tFar;
      const size_t mask = movemask(vmask);
#endif
      dist = tNear;
      return mask;
    }

#endif

    //////////////////////////////////////////////////////////////////////////////////////
    // Robust AABBNode intersection
    //////////////////////////////////////////////////////////////////////////////////////

    template<int N>
      __forceinline size_t intersectNodeRobust(const typename BVHN<N>::AABBNode* node, const TravRay<N,true>& ray, vfloat<N>& dist)
    {}

    //////////////////////////////////////////////////////////////////////////////////////
    // Fast AABBNodeMB intersection
    //////////////////////////////////////////////////////////////////////////////////////

    template<int N>
      __forceinline size_t intersectNode(const typename BVHN<N>::AABBNodeMB* node, const TravRay<N,false>& ray, const float time, vfloat<N>& dist)
    {}

    //////////////////////////////////////////////////////////////////////////////////////
    // Robust AABBNodeMB intersection
    //////////////////////////////////////////////////////////////////////////////////////

    template<int N>
      __forceinline size_t intersectNodeRobust(const typename BVHN<N>::AABBNodeMB* node, const TravRay<N,true>& ray, const float time, vfloat<N>& dist)
    {}
    
    //////////////////////////////////////////////////////////////////////////////////////
    // Fast AABBNodeMB4D intersection
    //////////////////////////////////////////////////////////////////////////////////////

    template<int N>
      __forceinline size_t intersectNodeMB4D(const typename BVHN<N>::NodeRef ref, const TravRay<N,false>& ray, const float time, vfloat<N>& dist)
    {}

    //////////////////////////////////////////////////////////////////////////////////////
    // Robust AABBNodeMB4D intersection
    //////////////////////////////////////////////////////////////////////////////////////

    template<int N>
      __forceinline size_t intersectNodeMB4DRobust(const typename BVHN<N>::NodeRef ref, const TravRay<N,true>& ray, const float time, vfloat<N>& dist)
    {}

    //////////////////////////////////////////////////////////////////////////////////////
    // Fast QuantizedBaseNode intersection
    //////////////////////////////////////////////////////////////////////////////////////

    template<int N, bool robust>
      __forceinline size_t intersectNode(const typename BVHN<N>::QuantizedBaseNode* node, const TravRay<N,robust>& ray, vfloat<N>& dist);

    template<>
      __forceinline size_t intersectNode<4>(const typename BVH4::QuantizedBaseNode* node, const TravRay<4,false>& ray, vfloat4& dist)
    {}

    template<>
      __forceinline size_t intersectNode<4>(const typename BVH4::QuantizedBaseNode* node, const TravRay<4,true>& ray, vfloat4& dist)
    {}


#if defined(__AVX__)

    template<>
      __forceinline size_t intersectNode<8>(const typename BVH8::QuantizedBaseNode* node, const TravRay<8,false>& ray, vfloat8& dist)
    {
      const size_t mvalid  = movemask(node->validMask());
      const vfloat8 start_x(node->start.x);
      const vfloat8 scale_x(node->scale.x);
      const vfloat8 lower_x = madd(node->dequantize<8>(ray.nearX >> 2),scale_x,start_x);
      const vfloat8 upper_x = madd(node->dequantize<8>(ray.farX  >> 2),scale_x,start_x);
      const vfloat8 start_y(node->start.y);
      const vfloat8 scale_y(node->scale.y);
      const vfloat8 lower_y = madd(node->dequantize<8>(ray.nearY >> 2),scale_y,start_y);
      const vfloat8 upper_y = madd(node->dequantize<8>(ray.farY  >> 2),scale_y,start_y);
      const vfloat8 start_z(node->start.z);
      const vfloat8 scale_z(node->scale.z);
      const vfloat8 lower_z = madd(node->dequantize<8>(ray.nearZ >> 2),scale_z,start_z);
      const vfloat8 upper_z = madd(node->dequantize<8>(ray.farZ  >> 2),scale_z,start_z);

#if defined(__AVX2__)
#if defined(__aarch64__)
      const vfloat8 tNearX = madd(lower_x, ray.rdir.x, ray.neg_org_rdir.x);
      const vfloat8 tNearY = madd(lower_y, ray.rdir.y, ray.neg_org_rdir.y);
      const vfloat8 tNearZ = madd(lower_z, ray.rdir.z, ray.neg_org_rdir.z);
      const vfloat8 tFarX  = madd(upper_x, ray.rdir.x, ray.neg_org_rdir.x);
      const vfloat8 tFarY  = madd(upper_y, ray.rdir.y, ray.neg_org_rdir.y);
      const vfloat8 tFarZ  = madd(upper_z, ray.rdir.z, ray.neg_org_rdir.z);
#else
      const vfloat8 tNearX = msub(lower_x, ray.rdir.x, ray.org_rdir.x);
      const vfloat8 tNearY = msub(lower_y, ray.rdir.y, ray.org_rdir.y);
      const vfloat8 tNearZ = msub(lower_z, ray.rdir.z, ray.org_rdir.z);
      const vfloat8 tFarX  = msub(upper_x, ray.rdir.x, ray.org_rdir.x);
      const vfloat8 tFarY  = msub(upper_y, ray.rdir.y, ray.org_rdir.y);
      const vfloat8 tFarZ  = msub(upper_z, ray.rdir.z, ray.org_rdir.z);
#endif
#else
      const vfloat8 tNearX = (lower_x - ray.org.x) * ray.rdir.x;
      const vfloat8 tNearY = (lower_y - ray.org.y) * ray.rdir.y;
      const vfloat8 tNearZ = (lower_z - ray.org.z) * ray.rdir.z;
      const vfloat8 tFarX  = (upper_x - ray.org.x) * ray.rdir.x;
      const vfloat8 tFarY  = (upper_y - ray.org.y) * ray.rdir.y;
      const vfloat8 tFarZ  = (upper_z - ray.org.z) * ray.rdir.z;
#endif
      
#if defined(__AVX2__) && !defined(__AVX512F__) // HSW
      const vfloat8 tNear = maxi(tNearX,tNearY,tNearZ,ray.tnear);
      const vfloat8 tFar  = mini(tFarX ,tFarY ,tFarZ ,ray.tfar);
      const vbool8 vmask = asInt(tNear) > asInt(tFar);
      const size_t mask = movemask(vmask) ^ ((1<<8)-1);
#elif defined(__AVX512F__) // SKX
      const vfloat8 tNear = maxi(tNearX,tNearY,tNearZ,ray.tnear);
      const vfloat8 tFar  = mini(tFarX ,tFarY ,tFarZ ,ray.tfar);
      const vbool8 vmask = asInt(tNear) <= asInt(tFar);
      const size_t mask = movemask(vmask);
#else
      const vfloat8 tNear = max(tNearX,tNearY,tNearZ,ray.tnear);
      const vfloat8 tFar  = min(tFarX ,tFarY ,tFarZ ,ray.tfar);
      const vbool8 vmask = tNear <= tFar;
      const size_t mask = movemask(vmask);
#endif
      dist = tNear;
      return mask & mvalid;
    }

    template<>
      __forceinline size_t intersectNode<8>(const typename BVH8::QuantizedBaseNode* node, const TravRay<8,true>& ray, vfloat8& dist)
    {
      const size_t mvalid  = movemask(node->validMask());
      const vfloat8 start_x(node->start.x);
      const vfloat8 scale_x(node->scale.x);
      const vfloat8 lower_x = madd(node->dequantize<8>(ray.nearX >> 2),scale_x,start_x);
      const vfloat8 upper_x = madd(node->dequantize<8>(ray.farX  >> 2),scale_x,start_x);
      const vfloat8 start_y(node->start.y);
      const vfloat8 scale_y(node->scale.y);
      const vfloat8 lower_y = madd(node->dequantize<8>(ray.nearY >> 2),scale_y,start_y);
      const vfloat8 upper_y = madd(node->dequantize<8>(ray.farY  >> 2),scale_y,start_y);
      const vfloat8 start_z(node->start.z);
      const vfloat8 scale_z(node->scale.z);
      const vfloat8 lower_z = madd(node->dequantize<8>(ray.nearZ >> 2),scale_z,start_z);
      const vfloat8 upper_z = madd(node->dequantize<8>(ray.farZ  >> 2),scale_z,start_z);

      const vfloat8 tNearX = (lower_x - ray.org.x) * ray.rdir_near.x;
      const vfloat8 tNearY = (lower_y - ray.org.y) * ray.rdir_near.y;
      const vfloat8 tNearZ = (lower_z - ray.org.z) * ray.rdir_near.z;
      const vfloat8 tFarX  = (upper_x - ray.org.x) * ray.rdir_far.x;
      const vfloat8 tFarY  = (upper_y - ray.org.y) * ray.rdir_far.y;
      const vfloat8 tFarZ  = (upper_z - ray.org.z) * ray.rdir_far.z;
      
      const vfloat8 tNear = max(tNearX,tNearY,tNearZ,ray.tnear);
      const vfloat8 tFar  = min(tFarX ,tFarY ,tFarZ ,ray.tfar);
      const vbool8 vmask = tNear <= tFar;
      const size_t mask = movemask(vmask);

      dist = tNear;
      return mask & mvalid;
    }


#endif

    template<int N>
      __forceinline size_t intersectNode(const typename BVHN<N>::QuantizedBaseNodeMB* node, const TravRay<N,false>& ray, const float time, vfloat<N>& dist)
    {}

    template<int N>
      __forceinline size_t intersectNode(const typename BVHN<N>::QuantizedBaseNodeMB* node, const TravRay<N,true>& ray, const float time, vfloat<N>& dist)
    {}

    //////////////////////////////////////////////////////////////////////////////////////
    // Fast OBBNode intersection
    //////////////////////////////////////////////////////////////////////////////////////

    template<int N, bool robust>
      __forceinline size_t intersectNode(const typename BVHN<N>::OBBNode* node, const TravRay<N,robust>& ray, vfloat<N>& dist)
    {}

    //////////////////////////////////////////////////////////////////////////////////////
    // Fast OBBNodeMB intersection
    //////////////////////////////////////////////////////////////////////////////////////

    template<int N, bool robust>
      __forceinline size_t intersectNode(const typename BVHN<N>::OBBNodeMB* node, const TravRay<N,robust>& ray, const float time, vfloat<N>& dist)
    {}
    
    //////////////////////////////////////////////////////////////////////////////////////
    // Node intersectors used in point query raversal
    //////////////////////////////////////////////////////////////////////////////////////
    
    /*! Computes traversal information for N nodes with 1 point query */
    template<int N, int types>
    struct BVHNNodePointQuerySphere1;

    BVHNNodePointQuerySphere1<N, BVH_AN1>;

    BVHNNodePointQuerySphere1<N, BVH_AN2>;

    BVHNNodePointQuerySphere1<N, BVH_AN2_AN4D>;

    BVHNNodePointQuerySphere1<N, BVH_AN1_UN1>;
    
    BVHNNodePointQuerySphere1<N, BVH_AN2_UN2>;

    BVHNNodePointQuerySphere1<N, BVH_AN2_AN4D_UN2>;

    BVHNNodePointQuerySphere1<N, BVH_QN1>;
    
    template<int N>
    struct BVHNQuantizedBaseNodePointQuerySphere1
    {};

    /*! Computes traversal information for N nodes with 1 point query */
    template<int N, int types>
    struct BVHNNodePointQueryAABB1;

    BVHNNodePointQueryAABB1<N, BVH_AN1>;

    BVHNNodePointQueryAABB1<N, BVH_AN2>;

    BVHNNodePointQueryAABB1<N, BVH_AN2_AN4D>;

    BVHNNodePointQueryAABB1<N, BVH_AN1_UN1>;
    
    BVHNNodePointQueryAABB1<N, BVH_AN2_UN2>;

    BVHNNodePointQueryAABB1<N, BVH_AN2_AN4D_UN2>;

    BVHNNodePointQueryAABB1<N, BVH_QN1>;
    
    template<int N>
    struct BVHNQuantizedBaseNodePointQueryAABB1
    {};

    
    //////////////////////////////////////////////////////////////////////////////////////
    // Node intersectors used in ray traversal
    //////////////////////////////////////////////////////////////////////////////////////

    /*! Intersects N nodes with 1 ray */
    template<int N, int types, bool robust>
    struct BVHNNodeIntersector1;

    BVHNNodeIntersector1<N, BVH_AN1, false>;

    BVHNNodeIntersector1<N, BVH_AN1, true>;

    BVHNNodeIntersector1<N, BVH_AN2, false>;

    BVHNNodeIntersector1<N, BVH_AN2, true>;

    BVHNNodeIntersector1<N, BVH_AN2_AN4D, false>;

    BVHNNodeIntersector1<N, BVH_AN2_AN4D, true>;

    BVHNNodeIntersector1<N, BVH_AN1_UN1, false>;

    BVHNNodeIntersector1<N, BVH_AN1_UN1, true>;

    BVHNNodeIntersector1<N, BVH_AN2_UN2, false>;

    BVHNNodeIntersector1<N, BVH_AN2_UN2, true>;

    BVHNNodeIntersector1<N, BVH_AN2_AN4D_UN2, false>;

    BVHNNodeIntersector1<N, BVH_AN2_AN4D_UN2, true>;

    BVHNNodeIntersector1<N, BVH_QN1, false>;

    BVHNNodeIntersector1<N, BVH_QN1, true>;

    /*! Intersects N nodes with K rays */
    template<int N, bool robust>
      struct BVHNQuantizedBaseNodeIntersector1;

    BVHNQuantizedBaseNodeIntersector1<N, false>;

    BVHNQuantizedBaseNodeIntersector1<N, true>;


  }
}