Commit 80d36f43 authored by Davis King's avatar Davis King

Fleshed out the AVX SIMD support

parent 4fec4476
......@@ -6,6 +6,7 @@
#include "simd/simd4f.h"
#include "simd/simd4i.h"
#include "simd/simd8f.h"
#include "simd/simd8i.h"
#endif // DLIB_SIMd_H__
......@@ -5,6 +5,7 @@
#include "simd_check.h"
#include "simd4f.h"
#include "simd8i.h"
namespace dlib
......@@ -24,6 +25,7 @@ namespace dlib
inline simd8f(float r0, float r1, float r2, float r3, float r4, float r5, float r6, float r7)
{ x = _mm256_setr_ps(r0,r1,r2,r3,r4,r5,r6,r7); }
simd8f(const simd8i& val):x(_mm256_cvtepi32_ps(val)) {}
simd8f(const __m256& val):x(val) {}
simd8f& operator=(const __m256& val)
{
......@@ -32,6 +34,9 @@ namespace dlib
}
inline operator __m256() const { return x; }
// truncate to 32bit integers
operator __m256i() const { return _mm256_cvttps_epi32(x); }
void load_aligned(const type* ptr) { x = _mm256_load_ps(ptr); }
void store_aligned(type* ptr) const { _mm256_store_ps(ptr, x); }
void load(const type* ptr) { x = _mm256_loadu_ps(ptr); }
......@@ -51,6 +56,33 @@ namespace dlib
private:
__m256 x;
};
class simd8f_bool
{
public:
typedef float type;
simd8f_bool() {}
simd8f_bool(const __m256& val):x(val) {}
simd8f_bool(const simd4f_bool& low, const simd4f_bool& high)
{
x = _mm256_insertf128_ps(_mm256_castps128_ps256(low),high,1);
}
simd8f_bool& operator=(const __m256& val)
{
x = val;
return *this;
}
operator __m256() const { return x; }
private:
__m256 x;
};
#else
class simd8f
{
......@@ -62,6 +94,16 @@ namespace dlib
simd8f(float f) :_low(f),_high(f) {}
simd8f(float r0, float r1, float r2, float r3, float r4, float r5, float r6, float r7) :
_low(r0,r1,r2,r3), _high(r4,r5,r6,r7) {}
simd8f(const simd8i& val) : _low(val.low()), _high(val.high()) { }
// truncate to 32bit integers
operator simd8i::rawarray() const
{
simd8i::rawarray temp;
temp.low = _low;
temp.high = _high;
return temp;
}
void load_aligned(const type* ptr) { _low.load_aligned(ptr); _high.load_aligned(ptr+4); }
void store_aligned(type* ptr) const { _low.store_aligned(ptr); _high.store_aligned(ptr+4); }
......@@ -83,6 +125,21 @@ namespace dlib
private:
simd4f _low, _high;
};
class simd8f_bool
{
public:
typedef float type;
simd8f_bool() {}
simd8f_bool(const simd4f_bool& low_, const simd4f_bool& high_): _low(low_),_high(high_){}
simd4f_bool low() const { return _low; }
simd4f_bool high() const { return _high; }
private:
simd4f_bool _low,_high;
};
#endif
// ----------------------------------------------------------------------------------------
......@@ -110,6 +167,20 @@ namespace dlib
inline simd8f& operator+= (simd8f& lhs, const simd8f& rhs)
{ return lhs = lhs + rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8f operator- (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_sub_ps(lhs, rhs);
#else
return simd8f(lhs.low()-rhs.low(),
lhs.high()-rhs.high());
#endif
}
inline simd8f& operator-= (simd8f& lhs, const simd8f& rhs)
{ return lhs = lhs - rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8f operator* (const simd8f& lhs, const simd8f& rhs)
......@@ -124,6 +195,130 @@ namespace dlib
inline simd8f& operator*= (simd8f& lhs, const simd8f& rhs)
{ return lhs = lhs * rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8f operator/ (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_div_ps(lhs, rhs);
#else
return simd8f(lhs.low()/rhs.low(),
lhs.high()/rhs.high());
#endif
}
inline simd8f& operator/= (simd8f& lhs, const simd8f& rhs)
{ return lhs = lhs / rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator== (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_cmp_ps(lhs, rhs, 0);
#else
return simd8f_bool(lhs.low() ==rhs.low(),
lhs.high()==rhs.high());
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator!= (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_cmp_ps(lhs, rhs, 4);
#else
return simd8f_bool(lhs.low() !=rhs.low(),
lhs.high()!=rhs.high());
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator< (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_cmp_ps(lhs, rhs, 1);
#else
return simd8f_bool(lhs.low() <rhs.low(),
lhs.high()<rhs.high());
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator> (const simd8f& lhs, const simd8f& rhs)
{
return rhs < lhs;
}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator<= (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_cmp_ps(lhs, rhs, 2);
#else
return simd8f_bool(lhs.low() <=rhs.low(),
lhs.high()<=rhs.high());
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator>= (const simd8f& lhs, const simd8f& rhs)
{
return rhs <= lhs;
}
// ----------------------------------------------------------------------------------------
inline simd8f min (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_min_ps(lhs, rhs);
#else
return simd8f(min(lhs.low(), rhs.low()),
min(lhs.high(),rhs.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f max (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_max_ps(lhs, rhs);
#else
return simd8f(max(lhs.low(), rhs.low()),
max(lhs.high(),rhs.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f reciprocal (const simd8f& item)
{
#ifdef DLIB_HAVE_AVX
return _mm256_rcp_ps(item);
#else
return simd8f(reciprocal(item.low()),
reciprocal(item.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f reciprocal_sqrt (const simd8f& item)
{
#ifdef DLIB_HAVE_AVX
return _mm256_rsqrt_ps(item);
#else
return simd8f(reciprocal_sqrt(item.low()),
reciprocal_sqrt(item.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline float sum(const simd8f& item)
......@@ -144,6 +339,55 @@ namespace dlib
return sum(lhs*rhs);
}
// ----------------------------------------------------------------------------------------
inline simd8f sqrt(const simd8f& item)
{
#ifdef DLIB_HAVE_AVX
return _mm256_sqrt_ps(item);
#else
return simd8f(sqrt(item.low()),
sqrt(item.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f ceil(const simd8f& item)
{
#ifdef DLIB_HAVE_AVX
return _mm256_ceil_ps(item);
#else
return simd8f(ceil(item.low()),
ceil(item.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f floor(const simd8f& item)
{
#ifdef DLIB_HAVE_AVX
return _mm256_floor_ps(item);
#else
return simd8f(floor(item.low()),
floor(item.high()));
#endif
}
// ----------------------------------------------------------------------------------------
// perform cmp ? a : b
inline simd8f select(const simd8f_bool& cmp, const simd8f& a, const simd8f& b)
{
#ifdef DLIB_HAVE_AVX
return _mm256_blendv_ps(b,a,cmp);
#else
return simd8f(select(cmp.low(), a.low(), b.low()),
select(cmp.high(), a.high(), b.high()));
#endif
}
// ----------------------------------------------------------------------------------------
}
......
This diff is collapsed.
......@@ -3,6 +3,7 @@
#ifndef DLIB_SIMd_CHECK_H__
#define DLIB_SIMd_CHECK_H__
//#define DLIB_DO_NOT_USE_SIMD
// figure out which SIMD instructions we can use.
#ifndef DLIB_DO_NOT_USE_SIMD
......@@ -27,29 +28,38 @@
#ifdef __AVX__
#define DLIB_HAVE_AVX
#endif
#ifdef __AVX2__
#define DLIB_HAVE_AVX2
#endif
#endif
#endif
// ----------------------------------------------------------------------------------------
#ifdef DLIB_HAVE_SSE2
#include <xmmintrin.h>
#include <emmintrin.h>
#include <mmintrin.h>
#endif
#ifdef DLIB_HAVE_SSE3
#include <pmmintrin.h> // SSE3
#include <tmmintrin.h>
#endif
#ifdef DLIB_HAVE_SSE41
#include <smmintrin.h> // SSE4
#endif
#ifdef DLIB_HAVE_AVX
#include <immintrin.h> // AVX
#ifdef __GNUC__
#include <x86intrin.h>
#else
#ifdef DLIB_HAVE_SSE2
#include <xmmintrin.h>
#include <emmintrin.h>
#include <mmintrin.h>
#endif
#ifdef DLIB_HAVE_SSE3
#include <pmmintrin.h> // SSE3
#include <tmmintrin.h>
#endif
#ifdef DLIB_HAVE_SSE41
#include <smmintrin.h> // SSE4
#endif
#ifdef DLIB_HAVE_AVX
#include <immintrin.h> // AVX
#endif
#ifdef DLIB_HAVE_AVX2
#include <avx2intrin.h>
#endif
#endif
#endif // DLIB_SIMd_CHECK_H__
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment