godot/thirdparty/embree/kernels/builders/heuristic_strand_array.h

// Copyright 2009-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "priminfo.h"
#include "../../common/algorithms/parallel_reduce.h"
#include "../../common/algorithms/parallel_partition.h"

namespace embree
{
  namespace isa
  { 
    /*! Performs standard object binning */
    struct HeuristicStrandSplit
    {
      typedef range<size_t> Set;
  
      static const size_t PARALLEL_THRESHOLD = 10000;
      static const size_t PARALLEL_FIND_BLOCK_SIZE = 4096;
      static const size_t PARALLEL_PARTITION_BLOCK_SIZE = 64;

      /*! stores all information to perform some split */
      struct Split
      {    
	/*! construct an invalid split by default */
	__forceinline Split()
	  : sah(inf), axis0(zero), axis1(zero) {}
	
	/*! constructs specified split */
	__forceinline Split(const float sah, const Vec3fa& axis0, const Vec3fa& axis1)
	  : sah(sah), axis0(axis0), axis1(axis1) {}
	
	/*! calculates standard surface area heuristic for the split */
	__forceinline float splitSAH() const { return sah; }

        /*! test if this split is valid */
        __forceinline bool valid() const { return sah != float(inf); }
		
      public:
	float sah;             //!< SAH cost of the split
	Vec3fa axis0, axis1;   //!< axis the two strands are aligned into
      };

      __forceinline HeuristicStrandSplit () // FIXME: required?
        : scene(nullptr), prims(nullptr) {}
      
      /*! remember prim array */
      __forceinline HeuristicStrandSplit (Scene* scene, PrimRef* prims)
        : scene(scene), prims(prims) {}
      
      __forceinline const Vec3fa direction(const PrimRef& prim) {
        return scene->get(prim.geomID())->computeDirection(prim.primID());
      }
      
      __forceinline const BBox3fa bounds(const PrimRef& prim) {
        return scene->get(prim.geomID())->vbounds(prim.primID());
      }

      __forceinline const BBox3fa bounds(const LinearSpace3fa& space, const PrimRef& prim) {
        return scene->get(prim.geomID())->vbounds(space,prim.primID());
      }

      /*! finds the best split */
      const Split find(const range<size_t>& set, size_t logBlockSize)
      {
        Vec3fa axis0(0,0,1);
        uint64_t bestGeomPrimID = -1;

        /* curve with minimum ID determines first axis */
        for (size_t i=set.begin(); i<set.end(); i++)
        {
          const uint64_t geomprimID = prims[i].ID64();
          if (geomprimID >= bestGeomPrimID) continue;
          const Vec3fa axis = direction(prims[i]);
          if (sqr_length(axis) > 1E-18f) {
            axis0 = normalize(axis);
            bestGeomPrimID = geomprimID;
          }
        }
      
        /* find 2nd axis that is most misaligned with first axis and has minimum ID */
        float bestCos = 1.0f;
        Vec3fa axis1 = axis0;
        bestGeomPrimID = -1;
        for (size_t i=set.begin(); i<set.end(); i++) 
        {
          const uint64_t geomprimID = prims[i].ID64();
          Vec3fa axisi = direction(prims[i]);
          float leni = length(axisi);
          if (leni == 0.0f) continue;
          axisi /= leni;
          float cos = abs(dot(axisi,axis0));
          if ((cos == bestCos && (geomprimID < bestGeomPrimID)) || cos < bestCos) {
            bestCos = cos; axis1 = axisi;
            bestGeomPrimID = geomprimID;
          }
        }
      
        /* partition the two strands */
        size_t lnum = 0, rnum = 0;
        BBox3fa lbounds = empty, rbounds = empty;
        const LinearSpace3fa space0 = frame(axis0).transposed();
        const LinearSpace3fa space1 = frame(axis1).transposed();
        
        for (size_t i=set.begin(); i<set.end(); i++)
        {
          PrimRef& prim = prims[i];
          const Vec3fa axisi = normalize(direction(prim));
          const float cos0 = abs(dot(axisi,axis0));
          const float cos1 = abs(dot(axisi,axis1));
          
          if (cos0 > cos1) { lnum++; lbounds.extend(bounds(space0,prim)); }
          else             { rnum++; rbounds.extend(bounds(space1,prim)); }
        }
      
        /*! return an invalid split if we do not partition */
        if (lnum == 0 || rnum == 0) 
          return Split(inf,axis0,axis1);
      
        /*! calculate sah for the split */
        const size_t lblocks = (lnum+(1ull<<logBlockSize)-1ull) >> logBlockSize;
        const size_t rblocks = (rnum+(1ull<<logBlockSize)-1ull) >> logBlockSize;
        const float sah = madd(float(lblocks),halfArea(lbounds),float(rblocks)*halfArea(rbounds));
        return Split(sah,axis0,axis1);
      }

      /*! array partitioning */
      void split(const Split& split, const PrimInfoRange& set, PrimInfoRange& lset, PrimInfoRange& rset) 
      {
        if (!split.valid()) {
          deterministic_order(set);
          return splitFallback(set,lset,rset);
        }
        
        const size_t begin = set.begin();
        const size_t end   = set.end();
        CentGeomBBox3fa local_left(empty);
        CentGeomBBox3fa local_right(empty);

        auto primOnLeftSide = [&] (const PrimRef& prim) -> bool { 
          const Vec3fa axisi = normalize(direction(prim));
          const float cos0 = abs(dot(axisi,split.axis0));
          const float cos1 = abs(dot(axisi,split.axis1));
          return cos0 > cos1;
        };

        auto mergePrimBounds = [this] (CentGeomBBox3fa& pinfo,const PrimRef& ref) { 
          pinfo.extend(bounds(ref)); 
        };
        
        size_t center = serial_partitioning(prims,begin,end,local_left,local_right,primOnLeftSide,mergePrimBounds);
        
        new (&lset) PrimInfoRange(begin,center,local_left);
        new (&rset) PrimInfoRange(center,end,local_right);
        assert(area(lset.geomBounds) >= 0.0f);
        assert(area(rset.geomBounds) >= 0.0f);
      }

      void deterministic_order(const Set& set) 
      {
        /* required as parallel partition destroys original primitive order */
        std::sort(&prims[set.begin()],&prims[set.end()]);
      }
      
      void splitFallback(const Set& set, PrimInfoRange& lset, PrimInfoRange& rset)
      {
        const size_t begin = set.begin();
        const size_t end   = set.end();
        const size_t center = (begin + end)/2;
        
        CentGeomBBox3fa left(empty);
        for (size_t i=begin; i<center; i++)
          left.extend(bounds(prims[i]));
        new (&lset) PrimInfoRange(begin,center,left);
        
        CentGeomBBox3fa right(empty);
        for (size_t i=center; i<end; i++)
          right.extend(bounds(prims[i]));	
        new (&rset) PrimInfoRange(center,end,right);
      }
      
    private:
      Scene* const scene;
      PrimRef* const prims;
    };
  }
}