Commit 80d36f43 authored by Davis King's avatar Davis King

Fleshed out the AVX SIMD support

parent 4fec4476
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "simd/simd4f.h" #include "simd/simd4f.h"
#include "simd/simd4i.h" #include "simd/simd4i.h"
#include "simd/simd8f.h" #include "simd/simd8f.h"
#include "simd/simd8i.h"
#endif // DLIB_SIMd_H__ #endif // DLIB_SIMd_H__
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "simd_check.h" #include "simd_check.h"
#include "simd4f.h" #include "simd4f.h"
#include "simd8i.h"
namespace dlib namespace dlib
...@@ -24,6 +25,7 @@ namespace dlib ...@@ -24,6 +25,7 @@ namespace dlib
inline simd8f(float r0, float r1, float r2, float r3, float r4, float r5, float r6, float r7) inline simd8f(float r0, float r1, float r2, float r3, float r4, float r5, float r6, float r7)
{ x = _mm256_setr_ps(r0,r1,r2,r3,r4,r5,r6,r7); } { x = _mm256_setr_ps(r0,r1,r2,r3,r4,r5,r6,r7); }
simd8f(const simd8i& val):x(_mm256_cvtepi32_ps(val)) {}
simd8f(const __m256& val):x(val) {} simd8f(const __m256& val):x(val) {}
simd8f& operator=(const __m256& val) simd8f& operator=(const __m256& val)
{ {
...@@ -32,6 +34,9 @@ namespace dlib ...@@ -32,6 +34,9 @@ namespace dlib
} }
inline operator __m256() const { return x; } inline operator __m256() const { return x; }
// truncate to 32bit integers
operator __m256i() const { return _mm256_cvttps_epi32(x); }
void load_aligned(const type* ptr) { x = _mm256_load_ps(ptr); } void load_aligned(const type* ptr) { x = _mm256_load_ps(ptr); }
void store_aligned(type* ptr) const { _mm256_store_ps(ptr, x); } void store_aligned(type* ptr) const { _mm256_store_ps(ptr, x); }
void load(const type* ptr) { x = _mm256_loadu_ps(ptr); } void load(const type* ptr) { x = _mm256_loadu_ps(ptr); }
...@@ -51,6 +56,33 @@ namespace dlib ...@@ -51,6 +56,33 @@ namespace dlib
private: private:
__m256 x; __m256 x;
}; };
class simd8f_bool
{
public:
typedef float type;
simd8f_bool() {}
simd8f_bool(const __m256& val):x(val) {}
simd8f_bool(const simd4f_bool& low, const simd4f_bool& high)
{
x = _mm256_insertf128_ps(_mm256_castps128_ps256(low),high,1);
}
simd8f_bool& operator=(const __m256& val)
{
x = val;
return *this;
}
operator __m256() const { return x; }
private:
__m256 x;
};
#else #else
class simd8f class simd8f
{ {
...@@ -62,6 +94,16 @@ namespace dlib ...@@ -62,6 +94,16 @@ namespace dlib
simd8f(float f) :_low(f),_high(f) {} simd8f(float f) :_low(f),_high(f) {}
simd8f(float r0, float r1, float r2, float r3, float r4, float r5, float r6, float r7) : simd8f(float r0, float r1, float r2, float r3, float r4, float r5, float r6, float r7) :
_low(r0,r1,r2,r3), _high(r4,r5,r6,r7) {} _low(r0,r1,r2,r3), _high(r4,r5,r6,r7) {}
simd8f(const simd8i& val) : _low(val.low()), _high(val.high()) { }
// truncate to 32bit integers
operator simd8i::rawarray() const
{
simd8i::rawarray temp;
temp.low = _low;
temp.high = _high;
return temp;
}
void load_aligned(const type* ptr) { _low.load_aligned(ptr); _high.load_aligned(ptr+4); } void load_aligned(const type* ptr) { _low.load_aligned(ptr); _high.load_aligned(ptr+4); }
void store_aligned(type* ptr) const { _low.store_aligned(ptr); _high.store_aligned(ptr+4); } void store_aligned(type* ptr) const { _low.store_aligned(ptr); _high.store_aligned(ptr+4); }
...@@ -83,6 +125,21 @@ namespace dlib ...@@ -83,6 +125,21 @@ namespace dlib
private: private:
simd4f _low, _high; simd4f _low, _high;
}; };
class simd8f_bool
{
public:
typedef float type;
simd8f_bool() {}
simd8f_bool(const simd4f_bool& low_, const simd4f_bool& high_): _low(low_),_high(high_){}
simd4f_bool low() const { return _low; }
simd4f_bool high() const { return _high; }
private:
simd4f_bool _low,_high;
};
#endif #endif
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -110,6 +167,20 @@ namespace dlib ...@@ -110,6 +167,20 @@ namespace dlib
inline simd8f& operator+= (simd8f& lhs, const simd8f& rhs) inline simd8f& operator+= (simd8f& lhs, const simd8f& rhs)
{ return lhs = lhs + rhs; return lhs;} { return lhs = lhs + rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8f operator- (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_sub_ps(lhs, rhs);
#else
return simd8f(lhs.low()-rhs.low(),
lhs.high()-rhs.high());
#endif
}
inline simd8f& operator-= (simd8f& lhs, const simd8f& rhs)
{ return lhs = lhs - rhs; return lhs;}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
inline simd8f operator* (const simd8f& lhs, const simd8f& rhs) inline simd8f operator* (const simd8f& lhs, const simd8f& rhs)
...@@ -124,6 +195,130 @@ namespace dlib ...@@ -124,6 +195,130 @@ namespace dlib
inline simd8f& operator*= (simd8f& lhs, const simd8f& rhs) inline simd8f& operator*= (simd8f& lhs, const simd8f& rhs)
{ return lhs = lhs * rhs; return lhs;} { return lhs = lhs * rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8f operator/ (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_div_ps(lhs, rhs);
#else
return simd8f(lhs.low()/rhs.low(),
lhs.high()/rhs.high());
#endif
}
inline simd8f& operator/= (simd8f& lhs, const simd8f& rhs)
{ return lhs = lhs / rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator== (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_cmp_ps(lhs, rhs, 0);
#else
return simd8f_bool(lhs.low() ==rhs.low(),
lhs.high()==rhs.high());
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator!= (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_cmp_ps(lhs, rhs, 4);
#else
return simd8f_bool(lhs.low() !=rhs.low(),
lhs.high()!=rhs.high());
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator< (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_cmp_ps(lhs, rhs, 1);
#else
return simd8f_bool(lhs.low() <rhs.low(),
lhs.high()<rhs.high());
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator> (const simd8f& lhs, const simd8f& rhs)
{
return rhs < lhs;
}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator<= (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_cmp_ps(lhs, rhs, 2);
#else
return simd8f_bool(lhs.low() <=rhs.low(),
lhs.high()<=rhs.high());
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f_bool operator>= (const simd8f& lhs, const simd8f& rhs)
{
return rhs <= lhs;
}
// ----------------------------------------------------------------------------------------
inline simd8f min (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_min_ps(lhs, rhs);
#else
return simd8f(min(lhs.low(), rhs.low()),
min(lhs.high(),rhs.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f max (const simd8f& lhs, const simd8f& rhs)
{
#ifdef DLIB_HAVE_AVX
return _mm256_max_ps(lhs, rhs);
#else
return simd8f(max(lhs.low(), rhs.low()),
max(lhs.high(),rhs.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f reciprocal (const simd8f& item)
{
#ifdef DLIB_HAVE_AVX
return _mm256_rcp_ps(item);
#else
return simd8f(reciprocal(item.low()),
reciprocal(item.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f reciprocal_sqrt (const simd8f& item)
{
#ifdef DLIB_HAVE_AVX
return _mm256_rsqrt_ps(item);
#else
return simd8f(reciprocal_sqrt(item.low()),
reciprocal_sqrt(item.high()));
#endif
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
inline float sum(const simd8f& item) inline float sum(const simd8f& item)
...@@ -144,6 +339,55 @@ namespace dlib ...@@ -144,6 +339,55 @@ namespace dlib
return sum(lhs*rhs); return sum(lhs*rhs);
} }
// ----------------------------------------------------------------------------------------
inline simd8f sqrt(const simd8f& item)
{
#ifdef DLIB_HAVE_AVX
return _mm256_sqrt_ps(item);
#else
return simd8f(sqrt(item.low()),
sqrt(item.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f ceil(const simd8f& item)
{
#ifdef DLIB_HAVE_AVX
return _mm256_ceil_ps(item);
#else
return simd8f(ceil(item.low()),
ceil(item.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8f floor(const simd8f& item)
{
#ifdef DLIB_HAVE_AVX
return _mm256_floor_ps(item);
#else
return simd8f(floor(item.low()),
floor(item.high()));
#endif
}
// ----------------------------------------------------------------------------------------
// perform cmp ? a : b
inline simd8f select(const simd8f_bool& cmp, const simd8f& a, const simd8f& b)
{
#ifdef DLIB_HAVE_AVX
return _mm256_blendv_ps(b,a,cmp);
#else
return simd8f(select(cmp.low(), a.low(), b.low()),
select(cmp.high(), a.high(), b.high()));
#endif
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
// Copyright (C) 2013 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_sIMD8I_H__
#define DLIB_sIMD8I_H__
#include "simd_check.h"
#include "../uintn.h"
namespace dlib
{
#ifdef DLIB_HAVE_AVX
class simd8i
{
public:
typedef int32 type;
simd8i() {}
simd8i(int32 f) { x = _mm256_set1_epi32(f); }
simd8i(int32 r0, int32 r1, int32 r2, int32 r3,
int32 r4, int32 r5, int32 r6, int32 r7 )
{ x = _mm256_setr_epi32(r0,r1,r2,r3,r4,r5,r6,r7); }
simd8i(const __m256i& val):x(val) {}
simd8i(const simd4i& low, const simd4i& high)
{
x = _mm256_insertf128_si256(_mm256_castsi128_si256(low),high,1);
}
simd8i& operator=(const __m256i& val)
{
x = val;
return *this;
}
operator __m256i() const { return x; }
void load_aligned(const type* ptr) { x = _mm256_load_si256((const __m256i*)ptr); }
void store_aligned(type* ptr) const { _mm256_store_si256((__m256i*)ptr, x); }
void load(const type* ptr) { x = _mm256_loadu_si256((const __m256i*)ptr); }
void store(type* ptr) const { _mm256_storeu_si256((__m256i*)ptr, x); }
simd4i low() const { return _mm256_castsi256_si128(x); }
simd4i high() const { return _mm256_extractf128_si256(x,1); }
unsigned int size() const { return 4; }
int32 operator[](unsigned int idx) const
{
int32 temp[8];
store(temp);
return temp[idx];
}
private:
__m256i x;
};
#else
class simd8i
{
public:
typedef int32 type;
simd8i() {}
simd8i(const simd4i& low_, const simd4i& high_): _low(low_),_high(high_){}
simd8i(int32 f) :_low(f),_high(f) {}
simd8i(int32 r0, int32 r1, int32 r2, int32 r3, int32 r4, int32 r5, int32 r6, int32 r7) :
_low(r0,r1,r2,r3), _high(r4,r5,r6,r7) {}
struct rawarray
{
simd4i low, high;
};
simd8i(const rawarray& a)
{
_low = a.low;
_high = a.high;
}
void load_aligned(const type* ptr) { _low.load_aligned(ptr); _high.load_aligned(ptr+4); }
void store_aligned(type* ptr) const { _low.store_aligned(ptr); _high.store_aligned(ptr+4); }
void load(const type* ptr) { _low.load(ptr); _high.load(ptr+4); }
void store(type* ptr) const { _low.store(ptr); _high.store(ptr+4); }
unsigned int size() const { return 8; }
int32 operator[](unsigned int idx) const
{
if (idx < 4)
return _low[idx];
else
return _high[idx-4];
}
simd4i low() const { return _low; }
simd4i high() const { return _high; }
private:
simd4i _low, _high;
};
#endif
// ----------------------------------------------------------------------------------------
inline std::ostream& operator<<(std::ostream& out, const simd8i& item)
{
int32 temp[8];
item.store(temp);
out << "(" << temp[0] << ", " << temp[1] << ", " << temp[2] << ", " << temp[3] << ", "
<< temp[4] << ", " << temp[5] << ", " << temp[6] << ", " << temp[7] << ")";
return out;
}
// ----------------------------------------------------------------------------------------
inline simd8i operator+ (const simd8i& lhs, const simd8i& rhs)
{
#ifdef DLIB_HAVE_AVX2
return _mm256_add_epi32(lhs, rhs);
#else
return simd8i(lhs.low()+rhs.low(),
lhs.high()+rhs.high());
#endif
}
inline simd8i& operator+= (simd8i& lhs, const simd8i& rhs)
{ return lhs = lhs + rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8i operator- (const simd8i& lhs, const simd8i& rhs)
{
#ifdef DLIB_HAVE_AVX2
return _mm256_sub_epi32(lhs, rhs);
#else
return simd8i(lhs.low()-rhs.low(),
lhs.high()-rhs.high());
#endif
}
inline simd8i& operator-= (simd8i& lhs, const simd8i& rhs)
{ return lhs = lhs - rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8i operator* (const simd8i& lhs, const simd8i& rhs)
{
#ifdef DLIB_HAVE_AVX2
return _mm256_mullo_epi32(lhs, rhs);
#else
return simd8i(lhs.low()*rhs.low(),
lhs.high()*rhs.high());
#endif
}
inline simd8i& operator*= (simd8i& lhs, const simd8i& rhs)
{ return lhs = lhs * rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8i operator& (const simd8i& lhs, const simd8i& rhs)
{
#ifdef DLIB_HAVE_AVX2
return _mm256_and_si256(lhs, rhs);
#else
return simd8i(lhs.low()&rhs.low(),
lhs.high()&rhs.high());
#endif
}
inline simd8i& operator&= (simd8i& lhs, const simd8i& rhs)
{ return lhs = lhs & rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8i operator| (const simd8i& lhs, const simd8i& rhs)
{
#ifdef DLIB_HAVE_AVX2
return _mm256_or_si256(lhs, rhs);
#else
return simd8i(lhs.low()|rhs.low(),
lhs.high()|rhs.high());
#endif
}
inline simd8i& operator|= (simd8i& lhs, const simd8i& rhs)
{ return lhs = lhs | rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8i operator^ (const simd8i& lhs, const simd8i& rhs)
{
#ifdef DLIB_HAVE_AVX2
return _mm256_xor_si256(lhs, rhs);
#else
return simd8i(lhs.low()^rhs.low(),
lhs.high()^rhs.high());
#endif
}
inline simd8i& operator^= (simd8i& lhs, const simd8i& rhs)
{ return lhs = lhs ^ rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8i operator~ (const simd8i& lhs)
{
#ifdef DLIB_HAVE_AVX2
return _mm256_xor_si256(lhs, _mm256_set1_epi32(0xFFFFFFFF));
#else
return simd8i(~lhs.low(), ~lhs.high());
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8i operator<< (const simd8i& lhs, const int& rhs)
{
#ifdef DLIB_HAVE_AVX2
return _mm256_sll_epi32(lhs,_mm_cvtsi32_si128(rhs));
#else
return simd8i(lhs.low()<<rhs,
lhs.high()<<rhs);
#endif
}
inline simd8i& operator<<= (simd8i& lhs, const int& rhs)
{ return lhs = lhs << rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8i operator>> (const simd8i& lhs, const int& rhs)
{
#ifdef DLIB_HAVE_AVX2
return _mm256_sra_epi32(lhs,_mm_cvtsi32_si128(rhs));
#else
return simd8i(lhs.low()>>rhs,
lhs.high()>>rhs);
#endif
}
inline simd8i& operator>>= (simd8i& lhs, const int& rhs)
{ return lhs = lhs >> rhs; return lhs;}
// ----------------------------------------------------------------------------------------
inline simd8i operator== (const simd8i& lhs, const simd8i& rhs)
{
#ifdef DLIB_HAVE_AVX2
return _mm256_cmpeq_epi32(lhs, rhs);
#else
return simd8i(lhs.low()==rhs.low(),
lhs.high()==rhs.high());
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8i operator!= (const simd8i& lhs, const simd8i& rhs)
{
return ~(lhs==rhs);
}
// ----------------------------------------------------------------------------------------
inline simd8i operator> (const simd8i& lhs, const simd8i& rhs)
{
#ifdef DLIB_HAVE_AVX2
return _mm256_cmpgt_epi32(lhs, rhs);
#else
return simd8i(lhs.low()>rhs.low(),
lhs.high()>rhs.high());
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8i operator< (const simd8i& lhs, const simd8i& rhs)
{
return rhs > lhs;
}
// ----------------------------------------------------------------------------------------
inline simd8i operator<= (const simd8i& lhs, const simd8i& rhs)
{
return ~(lhs > rhs);
}
// ----------------------------------------------------------------------------------------
inline simd8i operator>= (const simd8i& lhs, const simd8i& rhs)
{
return rhs <= lhs;
}
// ----------------------------------------------------------------------------------------
inline simd8i min (const simd8i& lhs, const simd8i& rhs)
{
#ifdef DLIB_HAVE_AVX2
return _mm256_min_epi32(lhs, rhs);
#else
return simd8i(min(lhs.low(),rhs.low()),
min(lhs.high(),rhs.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline simd8i max (const simd8i& lhs, const simd8i& rhs)
{
#ifdef DLIB_HAVE_AVX2
return _mm256_max_epi32(lhs, rhs);
#else
return simd8i(max(lhs.low(),rhs.low()),
max(lhs.high(),rhs.high()));
#endif
}
// ----------------------------------------------------------------------------------------
inline int32 sum(const simd8i& item)
{
return sum(item.low()+item.high());
}
// ----------------------------------------------------------------------------------------
// perform cmp ? a : b
inline simd8i select(const simd8i& cmp, const simd8i& a, const simd8i& b)
{
#ifdef DLIB_HAVE_AVX2
return _mm256_blendv_epi8(b,a,cmp);
#else
return simd8i(select(cmp.low(), a.low(), b.low()),
select(cmp.high(), a.high(), b.high()));
#endif
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_sIMD8I_H__
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#ifndef DLIB_SIMd_CHECK_H__ #ifndef DLIB_SIMd_CHECK_H__
#define DLIB_SIMd_CHECK_H__ #define DLIB_SIMd_CHECK_H__
//#define DLIB_DO_NOT_USE_SIMD
// figure out which SIMD instructions we can use. // figure out which SIMD instructions we can use.
#ifndef DLIB_DO_NOT_USE_SIMD #ifndef DLIB_DO_NOT_USE_SIMD
...@@ -27,29 +28,38 @@ ...@@ -27,29 +28,38 @@
#ifdef __AVX__ #ifdef __AVX__
#define DLIB_HAVE_AVX #define DLIB_HAVE_AVX
#endif #endif
#ifdef __AVX2__
#define DLIB_HAVE_AVX2
#endif
#endif #endif
#endif #endif
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
#ifdef DLIB_HAVE_SSE2 #ifdef __GNUC__
#include <xmmintrin.h> #include <x86intrin.h>
#include <emmintrin.h> #else
#include <mmintrin.h> #ifdef DLIB_HAVE_SSE2
#endif #include <xmmintrin.h>
#ifdef DLIB_HAVE_SSE3 #include <emmintrin.h>
#include <pmmintrin.h> // SSE3 #include <mmintrin.h>
#include <tmmintrin.h> #endif
#endif #ifdef DLIB_HAVE_SSE3
#ifdef DLIB_HAVE_SSE41 #include <pmmintrin.h> // SSE3
#include <smmintrin.h> // SSE4 #include <tmmintrin.h>
#endif #endif
#ifdef DLIB_HAVE_AVX #ifdef DLIB_HAVE_SSE41
#include <immintrin.h> // AVX #include <smmintrin.h> // SSE4
#endif
#ifdef DLIB_HAVE_AVX
#include <immintrin.h> // AVX
#endif
#ifdef DLIB_HAVE_AVX2
#include <avx2intrin.h>
#endif
#endif #endif
#endif // DLIB_SIMd_CHECK_H__ #endif // DLIB_SIMd_CHECK_H__
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment