Commit 80d36f43 authored by Davis King's avatar Davis King

Fleshed out the AVX SIMD support

parent 4fec4476
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "simd/simd4f.h" #include "simd/simd4f.h"
#include "simd/simd4i.h" #include "simd/simd4i.h"
#include "simd/simd8f.h" #include "simd/simd8f.h"
#include "simd/simd8i.h"
#endif // DLIB_SIMd_H__ #endif // DLIB_SIMd_H__
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "simd_check.h" #include "simd_check.h"
#include "simd4f.h" #include "simd4f.h"
#include "simd8i.h"
namespace dlib namespace dlib
...@@ -24,6 +25,7 @@ namespace dlib ...@@ -24,6 +25,7 @@ namespace dlib
inline simd8f(float r0, float r1, float r2, float r3, float r4, float r5, float r6, float r7) inline simd8f(float r0, float r1, float r2, float r3, float r4, float r5, float r6, float r7)
{ x = _mm256_setr_ps(r0,r1,r2,r3,r4,r5,r6,r7); } { x = _mm256_setr_ps(r0,r1,r2,r3,r4,r5,r6,r7); }
simd8f(const simd8i& val):x(_mm256_cvtepi32_ps(val)) {}
simd8f(const __m256& val):x(val) {} simd8f(const __m256& val):x(val) {}
simd8f& operator=(const __m256& val) simd8f& operator=(const __m256& val)
{ {
...@@ -32,6 +34,9 @@ namespace dlib ...@@ -32,6 +34,9 @@ namespace dlib
} }
inline operator __m256() const { return x; } inline operator __m256() const { return x; }
// truncate to 32bit integers
operator __m256i() const { return _mm256_cvttps_epi32(x); }
void load_aligned(const type* ptr) { x = _mm256_load_ps(ptr); } void load_aligned(const type* ptr) { x = _mm256_load_ps(ptr); }
void store_aligned(type* ptr) const { _mm256_store_ps(ptr, x); } void store_aligned(type* ptr) const { _mm256_store_ps(ptr, x); }
void load(const type* ptr) { x = _mm256_loadu_ps(ptr); } void load(const type* ptr) { x = _mm256_loadu_ps(ptr); }
...@@ -51,6 +56,33 @@ namespace dlib ...@@ -51,6 +56,33 @@ namespace dlib
private: private:
__m256 x; __m256 x;
}; };
class simd8f_bool
{
public:
typedef float type;
simd8f_bool() {}
simd8f_bool(const __m256& val):x(val) {}
simd8f_bool(const simd4f_bool& low, const simd4f_bool& high)
{
x = _mm256_insertf128_ps(_mm256_castps128_ps256(low),high,1);
}
simd8f_bool& operator=(const __m256& val)
{
x = val;
return *this;
}
operator __m256() const { return x; }
private:
__m256 x;
};
#else #else
class simd8f class simd8f
{ {
...@@ -62,6 +94,16 @@ namespace dlib ...@@ -62,6 +94,16 @@ namespace dlib
simd8f(float f) :_low(f),_high(f) {} simd8f(float f) :_low(f),_high(f) {}
simd8f(float r0, float r1, float r2, float r3, float r4, float r5, float r6, float r7) : simd8f(float r0, float r1, float r2, float r3, float r4, float r5, float r6, float r7) :
_low(r0,r1,r2,r3), _high(r4,r5,r6,r7) {} _low(r0,r1,r2,r3), _high(r4,r5,r6,r7) {}
simd8f(const simd8i& val) : _low(val.low()), _high(val.high()) { }
// truncate to 32bit integers
operator simd8i::rawarray() const
{
simd8i::rawarray temp;
temp.low = _low;
temp.high = _high;
return temp;
}
void load_aligned(const type* ptr) { _low.load_aligned(ptr); _high.load_aligned(ptr+4); } void load_aligned(const type* ptr) { _low.load_aligned(ptr); _high.load_aligned(ptr+4); }
void store_aligned(type* ptr) const { _low.store_aligned(ptr); _high.store_aligned(ptr+4); } void store_aligned(type* ptr) const { _low.store_aligned(ptr); _high.store_aligned(ptr+4); }
...@@ -83,6 +125,21 @@ namespace dlib ...@@ -83,6 +125,21 @@ namespace dlib
private: private:
simd4f _low, _high; simd4f _low, _high;
}; };
class simd8f_bool
{
public:
typedef float type;
simd8f_bool() {}
simd8f_bool(const simd4f_bool& low_, const simd4f_bool& high_): _low(low_),_high(high_){}
simd4f_bool low() const { return _low; }
simd4f_bool high() const { return _high; }
private:
simd4f_bool _low,_high;
};
#endif #endif
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -110,6 +167,20 @@ namespace dlib ...@@ -110,6 +167,20 @@ namespace dlib
inline simd8f& operator+= (simd8f& lhs, const simd8f& rhs) inline simd8f& operator+= (simd8f& lhs, const simd8f& rhs)
{ return lhs = lhs + rhs; return lhs;} { return lhs = lhs + rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8f operator- (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_sub_ps(lhs, rhs);
#else
return simd8f(lhs.low()-rhs.low(),
lhs.high()-rhs.high());
#endif
}
inline simd8f& operator-= (simd8f& lhs, const simd8f& rhs)
{ return lhs = lhs - rhs; return lhs;}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
inline simd8f operator* (const simd8f& lhs, const simd8f& rhs) inline simd8f operator* (const simd8f& lhs, const simd8f& rhs)
...@@ -124,6 +195,130 @@ namespace dlib ...@@ -124,6 +195,130 @@ namespace dlib
inline simd8f& operator*= (simd8f& lhs, const simd8f& rhs) inline simd8f& operator*= (simd8f& lhs, const simd8f& rhs)
{ return lhs = lhs * rhs; return lhs;} { return lhs = lhs * rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8f operator/ (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_div_ps(lhs, rhs);
#else
return simd8f(lhs.low()/rhs.low(),
lhs.high()/rhs.high());
#endif
}
inline simd8f& operator/= (simd8f& lhs, const simd8f& rhs)
{ return lhs = lhs / rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator== (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_cmp_ps(lhs, rhs, 0);
#else
return simd8f_bool(lhs.low() ==rhs.low(),
lhs.high()==rhs.high());
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator!= (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_cmp_ps(lhs, rhs, 4);
#else
return simd8f_bool(lhs.low() !=rhs.low(),
lhs.high()!=rhs.high());
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator< (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_cmp_ps(lhs, rhs, 1);
#else
return simd8f_bool(lhs.low() <rhs.low(),
lhs.high()<rhs.high());
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator> (const simd8f& lhs, const simd8f& rhs)
{
return rhs < lhs;
}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator<= (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_cmp_ps(lhs, rhs, 2);
#else
return simd8f_bool(lhs.low() <=rhs.low(),
lhs.high()<=rhs.high());
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator>= (const simd8f& lhs, const simd8f& rhs)
{
return rhs <= lhs;
}
// ----------------------------------------------------------------------------------------
inline simd8f min (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_min_ps(lhs, rhs);
#else
return simd8f(min(lhs.low(), rhs.low()),
min(lhs.high(),rhs.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f max (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_max_ps(lhs, rhs);
#else
return simd8f(max(lhs.low(), rhs.low()),
max(lhs.high(),rhs.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f reciprocal (const simd8f& item)
{
#ifdef DLIB_HAVE_AVX
return _mm256_rcp_ps(item);
#else
return simd8f(reciprocal(item.low()),
reciprocal(item.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f reciprocal_sqrt (const simd8f& item)
{
#ifdef DLIB_HAVE_AVX
return _mm256_rsqrt_ps(item);
#else
return simd8f(reciprocal_sqrt(item.low()),
reciprocal_sqrt(item.high()));
#endif
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
inline float sum(const simd8f& item) inline float sum(const simd8f& item)
...@@ -144,6 +339,55 @@ namespace dlib ...@@ -144,6 +339,55 @@ namespace dlib
return sum(lhs*rhs); return sum(lhs*rhs);
} }
// ----------------------------------------------------------------------------------------
inline simd8f sqrt(const simd8f& item)
{
#ifdef DLIB_HAVE_AVX
return _mm256_sqrt_ps(item);
#else
return simd8f(sqrt(item.low()),
sqrt(item.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f ceil(const simd8f& item)
{
#ifdef DLIB_HAVE_AVX
return _mm256_ceil_ps(item);
#else
return simd8f(ceil(item.low()),
ceil(item.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f floor(const simd8f& item)
{
#ifdef DLIB_HAVE_AVX
return _mm256_floor_ps(item);
#else
return simd8f(floor(item.low()),
floor(item.high()));
#endif
}
// ----------------------------------------------------------------------------------------
// perform cmp ? a : b
inline simd8f select(const simd8f_bool& cmp, const simd8f& a, const simd8f& b)
{
#ifdef DLIB_HAVE_AVX
return _mm256_blendv_ps(b,a,cmp);
#else
return simd8f(select(cmp.low(), a.low(), b.low()),
select(cmp.high(), a.high(), b.high()));
#endif
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
This diff is collapsed.
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#ifndef DLIB_SIMd_CHECK_H__ #ifndef DLIB_SIMd_CHECK_H__
#define DLIB_SIMd_CHECK_H__ #define DLIB_SIMd_CHECK_H__
//#define DLIB_DO_NOT_USE_SIMD
// figure out which SIMD instructions we can use. // figure out which SIMD instructions we can use.
#ifndef DLIB_DO_NOT_USE_SIMD #ifndef DLIB_DO_NOT_USE_SIMD
...@@ -27,29 +28,38 @@ ...@@ -27,29 +28,38 @@
#ifdef __AVX__ #ifdef __AVX__
#define DLIB_HAVE_AVX #define DLIB_HAVE_AVX
#endif #endif
#ifdef __AVX2__
#define DLIB_HAVE_AVX2
#endif
#endif #endif
#endif #endif
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
#ifdef DLIB_HAVE_SSE2 #ifdef __GNUC__
#include <xmmintrin.h> #include <x86intrin.h>
#include <emmintrin.h> #else
#include <mmintrin.h> #ifdef DLIB_HAVE_SSE2
#endif #include <xmmintrin.h>
#ifdef DLIB_HAVE_SSE3 #include <emmintrin.h>
#include <pmmintrin.h> // SSE3 #include <mmintrin.h>
#include <tmmintrin.h> #endif
#endif #ifdef DLIB_HAVE_SSE3
#ifdef DLIB_HAVE_SSE41 #include <pmmintrin.h> // SSE3
#include <smmintrin.h> // SSE4 #include <tmmintrin.h>
#endif #endif
#ifdef DLIB_HAVE_AVX #ifdef DLIB_HAVE_SSE41
#include <immintrin.h> // AVX #include <smmintrin.h> // SSE4
#endif
#ifdef DLIB_HAVE_AVX
#include <immintrin.h> // AVX
#endif
#ifdef DLIB_HAVE_AVX2
#include <avx2intrin.h>
#endif
#endif #endif
#endif // DLIB_SIMd_CHECK_H__ #endif // DLIB_SIMd_CHECK_H__
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment