#include "btMiniSDF.h"

//
//Based on code from DiscreGrid, https://github.com/InteractiveComputerGraphics/Discregrid
//example:
//GenerateSDF.exe -r "32 32 32" -d "-1.6 -1.6 -.6 1.6 1.6 .6" concave_box.obj
//The MIT License (MIT)
//
//Copyright (c) 2017 Dan Koschier
//

#include <limits.h>
#include <string.h> //memcpy

struct btSdfDataStream
{
	const char* m_data;
	int m_size;
	
	int m_currentOffset;
	
	btSdfDataStream(const char* data, int size)
		:m_data(data),
		m_size(size),
		m_currentOffset(0)
	{
		
	}

	template<class T> bool read(T& val)
	{
		int bytes = sizeof(T);
		if (m_currentOffset+bytes<=m_size)
		{
			char* dest = (char*)&val;
			memcpy(dest,&m_data[m_currentOffset],bytes);
			m_currentOffset+=bytes;
			return true;
		}
		btAssert(0);
		return false;
	}
};


bool btMiniSDF::load(const char* data, int size)
{
		int fileSize = -1;

		btSdfDataStream ds(data,size);
		{
			double buf[6];
			ds.read(buf);
			m_domain.m_min[0] = buf[0];
			m_domain.m_min[1] = buf[1];
			m_domain.m_min[2] = buf[2];
			m_domain.m_min[3] = 0;
			m_domain.m_max[0] = buf[3];
			m_domain.m_max[1] = buf[4];
			m_domain.m_max[2] = buf[5];
			m_domain.m_max[3] = 0;
		}
		{
			unsigned int buf2[3];
			ds.read(buf2);
			m_resolution[0] = buf2[0];
			m_resolution[1] = buf2[1];
			m_resolution[2] = buf2[2];
		}
		{
			double buf[3];
			ds.read(buf);
			m_cell_size[0] = buf[0];
			m_cell_size[1] = buf[1];
			m_cell_size[2] = buf[2];
		}
		{
			double buf[3];
			ds.read(buf);
			m_inv_cell_size[0] = buf[0];
			m_inv_cell_size[1] = buf[1];
			m_inv_cell_size[2] = buf[2];
		}
		{
			unsigned long long int cells;
			ds.read(cells);
			m_n_cells = cells;
		}
		{
			unsigned long long int fields;
			ds.read(fields);
			m_n_fields = fields;
		}

		
		unsigned long long int nodes0;
		std::size_t n_nodes0;
		ds.read(nodes0);
		n_nodes0 = nodes0;
		if (n_nodes0 > 1024 * 1024 * 1024)
		{
			return m_isValid;
		}
		m_nodes.resize(n_nodes0);
		for (unsigned int i=0;i<n_nodes0;i++)
		{
			unsigned long long int n_nodes1;
			ds.read(n_nodes1);
			btAlignedObjectArray<double>& nodes = m_nodes[i];
			nodes.resize(n_nodes1);
			for ( int j=0;j<nodes.size();j++)
			{
				double& node = nodes[j];
				ds.read(node);
			}
		}

		unsigned long long int n_cells0;
		ds.read(n_cells0);
		m_cells.resize(n_cells0);
		for (int i=0;i<n_cells0;i++)
		{
			unsigned long long int n_cells1;
			btAlignedObjectArray<btCell32 >& cells = m_cells[i];
			ds.read(n_cells1);
			cells.resize(n_cells1);
			for (int j=0;j<n_cells1;j++)
			{
				btCell32& cell = cells[j];
				ds.read(cell);
			}
		}

		{
			unsigned long long int n_cell_maps0;
			ds.read(n_cell_maps0);
		
			m_cell_map.resize(n_cell_maps0);
			for (int i=0;i<n_cell_maps0;i++)
			{
				unsigned long long int n_cell_maps1;
				btAlignedObjectArray<unsigned int>& cell_maps = m_cell_map[i];
				ds.read(n_cell_maps1);
				cell_maps.resize(n_cell_maps1);
				for (int j=0;j<n_cell_maps1;j++)
				{
					unsigned int& cell_map = cell_maps[j];
					ds.read(cell_map);
				}
			}
		}

		m_isValid = (ds.m_currentOffset == ds.m_size);
		return m_isValid;
}


unsigned int btMiniSDF::multiToSingleIndex(btMultiIndex const & ijk) const
{
	return m_resolution[1] * m_resolution[0] * ijk.ijk[2] + m_resolution[0] * ijk.ijk[1] + ijk.ijk[0];
}
	
btAlignedBox3d
btMiniSDF::subdomain(btMultiIndex const& ijk) const
{
	btAssert(m_isValid);
	btVector3 tmp;
	tmp.m_floats[0] = m_cell_size[0]*(double)ijk.ijk[0];
	tmp.m_floats[1] = m_cell_size[1]*(double)ijk.ijk[1];
	tmp.m_floats[2] = m_cell_size[2]*(double)ijk.ijk[2];
	

	btVector3 origin = m_domain.min() + tmp;

	btAlignedBox3d box = btAlignedBox3d (origin, origin + m_cell_size);
	return box;
}

btMultiIndex
btMiniSDF::singleToMultiIndex(unsigned int l) const
{
	btAssert(m_isValid);
	unsigned int n01 = m_resolution[0] * m_resolution[1];
	unsigned int  k = l / n01;
	unsigned int  temp = l % n01;
	unsigned int j = temp / m_resolution[0];
	unsigned int i = temp % m_resolution[0];
	btMultiIndex mi;
	mi.ijk[0] = i;
	mi.ijk[1] = j;
	mi.ijk[2] = k;
	return mi;
}

btAlignedBox3d
btMiniSDF::subdomain(unsigned int l) const
{
	btAssert(m_isValid);
	return subdomain(singleToMultiIndex(l));
}


btShapeMatrix
btMiniSDF::shape_function_(btVector3 const& xi, btShapeGradients* gradient) const
{
	btAssert(m_isValid);
	btShapeMatrix res;
	
	btScalar x = xi[0];
	btScalar y = xi[1];
	btScalar z = xi[2];

	btScalar  x2 = x*x;
	btScalar  y2 = y*y;
	btScalar  z2 = z*z;

	btScalar  _1mx = 1.0 - x;
	btScalar  _1my = 1.0 - y;
	btScalar  _1mz = 1.0 - z;

	btScalar  _1px = 1.0 + x;
	btScalar  _1py = 1.0 + y;
	btScalar  _1pz = 1.0 + z;

	btScalar  _1m3x = 1.0 - 3.0 * x;
	btScalar  _1m3y = 1.0 - 3.0 * y;
	btScalar  _1m3z = 1.0 - 3.0 * z;

	btScalar  _1p3x = 1.0 + 3.0 * x;
	btScalar  _1p3y = 1.0 + 3.0 * y;
	btScalar  _1p3z = 1.0 + 3.0 * z;

	btScalar  _1mxt1my = _1mx * _1my;
	btScalar  _1mxt1py = _1mx * _1py;
	btScalar  _1pxt1my = _1px * _1my;
	btScalar  _1pxt1py = _1px * _1py;

	btScalar  _1mxt1mz = _1mx * _1mz;
	btScalar  _1mxt1pz = _1mx * _1pz;
	btScalar  _1pxt1mz = _1px * _1mz;
	btScalar  _1pxt1pz = _1px * _1pz;

	btScalar  _1myt1mz = _1my * _1mz;
	btScalar  _1myt1pz = _1my * _1pz;
	btScalar  _1pyt1mz = _1py * _1mz;
	btScalar  _1pyt1pz = _1py * _1pz;

	btScalar  _1mx2 = 1.0 - x2;
	btScalar  _1my2 = 1.0 - y2;
	btScalar  _1mz2 = 1.0 - z2;


	// Corner nodes.
	btScalar  fac = 1.0 / 64.0 * (9.0 * (x2 + y2 + z2) - 19.0);
	res[0] = fac * _1mxt1my * _1mz;
	res[1] = fac * _1pxt1my * _1mz;
	res[2] = fac * _1mxt1py * _1mz;
	res[3] = fac * _1pxt1py * _1mz;
	res[4] = fac * _1mxt1my * _1pz;
	res[5] = fac * _1pxt1my * _1pz;
	res[6] = fac * _1mxt1py * _1pz;
	res[7] = fac * _1pxt1py * _1pz;

	// Edge nodes.

	fac = 9.0 / 64.0 * _1mx2;
	btScalar  fact1m3x = fac * _1m3x;
	btScalar  fact1p3x = fac * _1p3x;
	res[ 8] = fact1m3x * _1myt1mz;
	res[ 9] = fact1p3x * _1myt1mz;
	res[10] = fact1m3x * _1myt1pz;
	res[11] = fact1p3x * _1myt1pz;
	res[12] = fact1m3x * _1pyt1mz;
	res[13] = fact1p3x * _1pyt1mz;
	res[14] = fact1m3x * _1pyt1pz;
	res[15] = fact1p3x * _1pyt1pz;

	fac = 9.0 / 64.0 * _1my2;
	btScalar  fact1m3y = fac * _1m3y;
	btScalar  fact1p3y = fac * _1p3y;
	res[16] = fact1m3y * _1mxt1mz;
	res[17] = fact1p3y * _1mxt1mz;
	res[18] = fact1m3y * _1pxt1mz;
	res[19] = fact1p3y * _1pxt1mz;
	res[20] = fact1m3y * _1mxt1pz;
	res[21] = fact1p3y * _1mxt1pz;
	res[22] = fact1m3y * _1pxt1pz;
	res[23] = fact1p3y * _1pxt1pz;

	fac = 9.0 / 64.0 * _1mz2;
	btScalar  fact1m3z = fac * _1m3z;
	btScalar  fact1p3z = fac * _1p3z;
	res[24] = fact1m3z * _1mxt1my;
	res[25] = fact1p3z * _1mxt1my;
	res[26] = fact1m3z * _1mxt1py;
	res[27] = fact1p3z * _1mxt1py;
	res[28] = fact1m3z * _1pxt1my;
	res[29] = fact1p3z * _1pxt1my;
	res[30] = fact1m3z * _1pxt1py;
	res[31] = fact1p3z * _1pxt1py;

	if (gradient)
	{
		btShapeGradients& dN = *gradient;

		btScalar _9t3x2py2pz2m19 = 9.0 * (3.0 * x2 + y2 + z2) - 19.0;
		btScalar _9tx2p3y2pz2m19 = 9.0 * (x2 + 3.0 * y2 + z2) - 19.0;
		btScalar _9tx2py2p3z2m19 = 9.0 * (x2 + y2 + 3.0 * z2) - 19.0;
		btScalar _18x = 18.0 * x;
		btScalar _18y = 18.0 * y;
		btScalar _18z = 18.0 * z;
		
		btScalar _3m9x2 = 3.0 - 9.0 * x2;
		btScalar _3m9y2 = 3.0 - 9.0 * y2;
		btScalar _3m9z2 = 3.0 - 9.0 * z2;

		btScalar _2x = 2.0 * x;
		btScalar _2y = 2.0 * y;
		btScalar _2z = 2.0 * z;

		btScalar _18xm9t3x2py2pz2m19 = _18x - _9t3x2py2pz2m19;
		btScalar _18xp9t3x2py2pz2m19 = _18x + _9t3x2py2pz2m19;
		btScalar _18ym9tx2p3y2pz2m19 = _18y - _9tx2p3y2pz2m19;
		btScalar _18yp9tx2p3y2pz2m19 = _18y + _9tx2p3y2pz2m19;
		btScalar _18zm9tx2py2p3z2m19 = _18z - _9tx2py2p3z2m19;
		btScalar _18zp9tx2py2p3z2m19 = _18z + _9tx2py2p3z2m19;

		dN(0,0) =_18xm9t3x2py2pz2m19 * _1myt1mz;
		dN(0,1) =_1mxt1mz * _18ym9tx2p3y2pz2m19;
		dN(0,2) =_1mxt1my * _18zm9tx2py2p3z2m19;
		dN(1,0) =_18xp9t3x2py2pz2m19 * _1myt1mz;
		dN(1,1) =_1pxt1mz * _18ym9tx2p3y2pz2m19;
		dN(1,2) =_1pxt1my * _18zm9tx2py2p3z2m19;
		dN(2,0) =_18xm9t3x2py2pz2m19 * _1pyt1mz;
		dN(2,1) =_1mxt1mz * _18yp9tx2p3y2pz2m19;
		dN(2,2) =_1mxt1py * _18zm9tx2py2p3z2m19;
		dN(3,0) =_18xp9t3x2py2pz2m19 * _1pyt1mz;
		dN(3,1) =_1pxt1mz * _18yp9tx2p3y2pz2m19;
		dN(3,2) =_1pxt1py * _18zm9tx2py2p3z2m19;
		dN(4,0) =_18xm9t3x2py2pz2m19 * _1myt1pz;
		dN(4,1) =_1mxt1pz * _18ym9tx2p3y2pz2m19;
		dN(4,2) =_1mxt1my * _18zp9tx2py2p3z2m19;
		dN(5,0) =_18xp9t3x2py2pz2m19 * _1myt1pz;
		dN(5,1) =_1pxt1pz * _18ym9tx2p3y2pz2m19;
		dN(5,2) =_1pxt1my * _18zp9tx2py2p3z2m19;
		dN(6,0) =_18xm9t3x2py2pz2m19 * _1pyt1pz;
		dN(6,1) =_1mxt1pz * _18yp9tx2p3y2pz2m19;
		dN(6,2) =_1mxt1py * _18zp9tx2py2p3z2m19;
		dN(7,0) =_18xp9t3x2py2pz2m19 * _1pyt1pz;
		dN(7,1) =_1pxt1pz * _18yp9tx2p3y2pz2m19;
		dN(7,2) =_1pxt1py * _18zp9tx2py2p3z2m19;

		dN.topRowsDivide(8, 64.0);

		btScalar _m3m9x2m2x = -_3m9x2 - _2x;
		btScalar _p3m9x2m2x =  _3m9x2 - _2x;
		btScalar _1mx2t1m3x = _1mx2 * _1m3x;
		btScalar _1mx2t1p3x = _1mx2 * _1p3x;
		dN( 8,0) =    _m3m9x2m2x * _1myt1mz,
		dN( 8,1) =   -_1mx2t1m3x * _1mz,
		dN( 8,2) =   -_1mx2t1m3x * _1my;
		dN( 9,0) =    _p3m9x2m2x * _1myt1mz,
		dN( 9,1) =   -_1mx2t1p3x * _1mz,
		dN( 9,2) =   -_1mx2t1p3x * _1my;
		dN(10,0) =    _m3m9x2m2x * _1myt1pz,
		dN(10,1) =   -_1mx2t1m3x * _1pz,
		dN(10,2) =   _1mx2t1m3x * _1my;
		dN(11,0) =    _p3m9x2m2x * _1myt1pz,
		dN(11,1) =   -_1mx2t1p3x * _1pz,
		dN(11,2) =   _1mx2t1p3x * _1my;
		dN(12,0) =    _m3m9x2m2x * _1pyt1mz,
		dN(12,1) =   _1mx2t1m3x * _1mz,
		dN(12,2) =   -_1mx2t1m3x * _1py;
		dN(13,0) =    _p3m9x2m2x * _1pyt1mz,
		dN(13,1) =   _1mx2t1p3x * _1mz,
		dN(13,2) =   -_1mx2t1p3x * _1py;
		dN(14,0) =    _m3m9x2m2x * _1pyt1pz,
		dN(14,1) =   _1mx2t1m3x * _1pz,
		dN(14,2) =   _1mx2t1m3x * _1py;
		dN(15,0) =    _p3m9x2m2x * _1pyt1pz,
		dN(15,1) =   _1mx2t1p3x * _1pz,
		dN(15,2) =   _1mx2t1p3x * _1py;

		btScalar _m3m9y2m2y = -_3m9y2 - _2y;
		btScalar _p3m9y2m2y =  _3m9y2 - _2y;
		btScalar _1my2t1m3y = _1my2 * _1m3y;
		btScalar _1my2t1p3y = _1my2 * _1p3y;
		dN(16,0) =    -_1my2t1m3y * _1mz,
		dN(16,1) =   _m3m9y2m2y * _1mxt1mz,
		dN(16,2) =   -_1my2t1m3y * _1mx;
		dN(17,0) =    -_1my2t1p3y * _1mz,
		dN(17,1) =   _p3m9y2m2y * _1mxt1mz,
		dN(17,2) =   -_1my2t1p3y * _1mx;
		dN(18,0) =    _1my2t1m3y * _1mz,
		dN(18,1) =   _m3m9y2m2y * _1pxt1mz,
		dN(18,2) =   -_1my2t1m3y * _1px;
		dN(19,0) =    _1my2t1p3y * _1mz,
		dN(19,1) =   _p3m9y2m2y * _1pxt1mz,
		dN(19,2) =   -_1my2t1p3y * _1px;
		dN(20,0) =    -_1my2t1m3y * _1pz,
		dN(20,1) =   _m3m9y2m2y * _1mxt1pz,
		dN(20,2) =   _1my2t1m3y * _1mx;
		dN(21,0) =    -_1my2t1p3y * _1pz,
		dN(21,1) =   _p3m9y2m2y * _1mxt1pz,
		dN(21,2) =   _1my2t1p3y * _1mx;
		dN(22,0) =    _1my2t1m3y * _1pz,
		dN(22,1) =   _m3m9y2m2y * _1pxt1pz,
		dN(22,2) =   _1my2t1m3y * _1px;
		dN(23,0) =    _1my2t1p3y * _1pz,
		dN(23,1) =   _p3m9y2m2y * _1pxt1pz,
		dN(23,2) =   _1my2t1p3y * _1px;


		btScalar _m3m9z2m2z = -_3m9z2 - _2z;
		btScalar _p3m9z2m2z =  _3m9z2 - _2z;
		btScalar _1mz2t1m3z = _1mz2 * _1m3z;
		btScalar _1mz2t1p3z = _1mz2 * _1p3z;
		dN(24,0) =    -_1mz2t1m3z * _1my,
		dN(24,1) =  -_1mz2t1m3z * _1mx,
		dN(24,2) =   _m3m9z2m2z * _1mxt1my;
		dN(25,0) =    -_1mz2t1p3z * _1my,
		dN(25,1) =   -_1mz2t1p3z * _1mx,
		dN(25,2) =   _p3m9z2m2z * _1mxt1my;
		dN(26,0) =    -_1mz2t1m3z * _1py,
		dN(26,1) =   _1mz2t1m3z * _1mx,
		dN(26,2) =   _m3m9z2m2z * _1mxt1py;
		dN(27,0) =    -_1mz2t1p3z * _1py,
		dN(27,1) =   _1mz2t1p3z * _1mx,
		dN(27,2) =   _p3m9z2m2z * _1mxt1py;
		dN(28,0) =    _1mz2t1m3z * _1my,
		dN(28,1) =   -_1mz2t1m3z * _1px,
		dN(28,2) =   _m3m9z2m2z * _1pxt1my;
		dN(29,0) =    _1mz2t1p3z * _1my,
		dN(29,1) =   -_1mz2t1p3z * _1px,
		dN(29,2) =   _p3m9z2m2z * _1pxt1my;
		dN(30,0) =    _1mz2t1m3z * _1py,
		dN(30,1) =   _1mz2t1m3z * _1px,
		dN(30,2) =   _m3m9z2m2z * _1pxt1py;
		dN(31,0) =    _1mz2t1p3z * _1py,
		dN(31,1) =   _1mz2t1p3z * _1px,
		dN(31,2) =   _p3m9z2m2z * _1pxt1py;

		dN.bottomRowsMul(32u - 8u, 9.0 / 64.0);

	}

	return res;
}



bool btMiniSDF::interpolate(unsigned int field_id, double& dist, btVector3 const& x,
	btVector3* gradient) const
{
	btAssert(m_isValid);
	if (!m_isValid)
		return false;

	if (!m_domain.contains(x))
		return false;

	btVector3 tmpmi = ((x - m_domain.min())*(m_inv_cell_size));//.cast<unsigned int>().eval();
	unsigned int mi[3] = {(unsigned int )tmpmi[0],(unsigned int )tmpmi[1],(unsigned int )tmpmi[2]};
	if (mi[0] >= m_resolution[0])
		mi[0] = m_resolution[0]-1;
	if (mi[1] >= m_resolution[1])
		mi[1] = m_resolution[1]-1;
	if (mi[2] >= m_resolution[2])
		mi[2] = m_resolution[2]-1;
	btMultiIndex mui; 
	mui.ijk[0] = mi[0];
	mui.ijk[1] = mi[1];
	mui.ijk[2] = mi[2];
	int i = multiToSingleIndex(mui);
	unsigned int i_ = m_cell_map[field_id][i];
	if (i_ == UINT_MAX)
		return false;

	btAlignedBox3d sd = subdomain(i);
	i = i_;
	btVector3 d = sd.m_max-sd.m_min;//.diagonal().eval();

	btVector3 denom = (sd.max() - sd.min());
	btVector3 c0 = btVector3(2.0,2.0,2.0)/denom;
	btVector3 c1 = (sd.max() + sd.min())/denom;
	btVector3 xi = (c0*x - c1);

	btCell32 const& cell = m_cells[field_id][i];
	if (!gradient)
	{
		//auto phi = m_coefficients[field_id][i].dot(shape_function_(xi, 0));
		double phi = 0.0;
		btShapeMatrix N = shape_function_(xi, 0);
		for (unsigned int j = 0u; j < 32u; ++j)
		{
			unsigned int v = cell.m_cells[j];
			double c = m_nodes[field_id][v];
			if (c == DBL_MAX)
			{
				return false;;
			}
			phi += c * N[j];
		}

		dist = phi;
		return true;
	}

	btShapeGradients dN;
	btShapeMatrix N = shape_function_(xi, &dN);

	double phi = 0.0;
	gradient->setZero();
	for (unsigned int j = 0u; j < 32u; ++j)
	{
		unsigned int v = cell.m_cells[j];
		double c = m_nodes[field_id][v];
		if (c == DBL_MAX)
		{
			gradient->setZero();
			return false;
		}
		phi += c * N[j];
		(*gradient)[0] += c * dN(j, 0);
		(*gradient)[1] += c * dN(j, 1);
		(*gradient)[2] += c * dN(j, 2);
	}
	(*gradient) *= c0;
	dist = phi;
	return true;
}