14 #include "Float16.cuh" 
   22 namespace faiss { 
namespace gpu {
 
   28   static inline __device__ T add(T a, T b) {
 
   32   static inline __device__ T sub(T a, T b) {
 
   36   static inline __device__ T mul(T a, T b) {
 
   40   static inline __device__ T neg(T v) {
 
   49   static inline __device__ 
bool lt(T a, T b) {
 
   53   static inline __device__ 
bool gt(T a, T b) {
 
   57   static inline __device__ 
bool eq(T a, T b) {
 
   61   static inline __device__ T zero() {
 
   68   typedef float ScalarType;
 
   70   static inline __device__ float2 add(float2 a, float2 b) {
 
   77   static inline __device__ float2 sub(float2 a, float2 b) {
 
   84   static inline __device__ float2 add(float2 a, 
float b) {
 
   91   static inline __device__ float2 sub(float2 a, 
float b) {
 
   98   static inline __device__ float2 mul(float2 a, float2 b) {
 
  105   static inline __device__ float2 mul(float2 a, 
float b) {
 
  112   static inline __device__ float2 neg(float2 v) {
 
  128   static inline __device__ float2 zero() {
 
  138   typedef float ScalarType;
 
  140   static inline __device__ float4 add(float4 a, float4 b) {
 
  149   static inline __device__ float4 sub(float4 a, float4 b) {
 
  158   static inline __device__ float4 add(float4 a, 
float b) {
 
  167   static inline __device__ float4 sub(float4 a, 
float b) {
 
  176   static inline __device__ float4 mul(float4 a, float4 b) {
 
  185   static inline __device__ float4 mul(float4 a, 
float b) {
 
  194   static inline __device__ float4 neg(float4 v) {
 
  204     return v.x + v.y + v.z + v.w;
 
  212   static inline __device__ float4 zero() {
 
  222 #ifdef FAISS_USE_FLOAT16 
  226   typedef half ScalarType;
 
  228   static inline __device__ half add(half a, half b) {
 
  229 #ifdef FAISS_USE_FULL_FLOAT16 
  232     return __float2half(__half2float(a) + __half2float(b));
 
  236   static inline __device__ half sub(half a, half b) {
 
  237 #ifdef FAISS_USE_FULL_FLOAT16 
  240     return __float2half(__half2float(a) - __half2float(b));
 
  244   static inline __device__ half mul(half a, half b) {
 
  245 #ifdef FAISS_USE_FULL_FLOAT16 
  248     return __float2half(__half2float(a) * __half2float(b));
 
  252   static inline __device__ half neg(half v) {
 
  253 #ifdef FAISS_USE_FULL_FLOAT16 
  256     return __float2half(-__half2float(v));
 
  260   static inline __device__ half 
reduceAdd(half v) {
 
  264   static inline __device__ 
bool lt(half a, half b) {
 
  265 #ifdef FAISS_USE_FULL_FLOAT16 
  268     return __half2float(a) < __half2float(b);
 
  272   static inline __device__ 
bool gt(half a, half b) {
 
  273 #ifdef FAISS_USE_FULL_FLOAT16 
  276     return __half2float(a) > __half2float(b);
 
  280   static inline __device__ 
bool eq(half a, half b) {
 
  281 #ifdef FAISS_USE_FULL_FLOAT16 
  284     return __half2float(a) == __half2float(b);
 
  288   static inline __device__ half zero() {
 
  297   typedef half ScalarType;
 
  299   static inline __device__ half2 add(half2 a, half2 b) {
 
  300 #ifdef FAISS_USE_FULL_FLOAT16 
  301     return __hadd2(a, b);
 
  303   float2 af = __half22float2(a);
 
  304   float2 bf = __half22float2(b);
 
  309   return __float22half2_rn(af);
 
  313   static inline __device__ half2 sub(half2 a, half2 b) {
 
  314 #ifdef FAISS_USE_FULL_FLOAT16 
  315     return __hsub2(a, b);
 
  317   float2 af = __half22float2(a);
 
  318   float2 bf = __half22float2(b);
 
  323   return __float22half2_rn(af);
 
  327   static inline __device__ half2 add(half2 a, half b) {
 
  328 #ifdef FAISS_USE_FULL_FLOAT16 
  329     half2 b2 = __half2half2(b);
 
  330     return __hadd2(a, b2);
 
  332   float2 af = __half22float2(a);
 
  333   float bf = __half2float(b);
 
  338   return __float22half2_rn(af);
 
  342   static inline __device__ half2 sub(half2 a, half b) {
 
  343 #ifdef FAISS_USE_FULL_FLOAT16 
  344     half2 b2 = __half2half2(b);
 
  345     return __hsub2(a, b2);
 
  347   float2 af = __half22float2(a);
 
  348   float bf = __half2float(b);
 
  353   return __float22half2_rn(af);
 
  357   static inline __device__ half2 mul(half2 a, half2 b) {
 
  358 #ifdef FAISS_USE_FULL_FLOAT16 
  359     return __hmul2(a, b);
 
  361   float2 af = __half22float2(a);
 
  362   float2 bf = __half22float2(b);
 
  367   return __float22half2_rn(af);
 
  371   static inline __device__ half2 mul(half2 a, half b) {
 
  372 #ifdef FAISS_USE_FULL_FLOAT16 
  373     half2 b2 = __half2half2(b);
 
  374     return __hmul2(a, b2);
 
  376   float2 af = __half22float2(a);
 
  377   float bf = __half2float(b);
 
  382   return __float22half2_rn(af);
 
  386   static inline __device__ half2 neg(half2 v) {
 
  387 #ifdef FAISS_USE_FULL_FLOAT16 
  390   float2 vf = __half22float2(v);
 
  394   return __float22half2_rn(vf);
 
  398   static inline __device__ half 
reduceAdd(half2 v) {
 
  399 #ifdef FAISS_USE_FULL_FLOAT16 
  400   half hv = __high2half(v);
 
  401   half lv = __low2half(v);
 
  403   return __hadd(hv, lv);
 
  405   float2 vf = __half22float2(v);
 
  408   return __float2half(vf.x);
 
  417   static inline __device__ half2 zero() {
 
  418     return __half2half2(Math<half>::zero());
 
  424   typedef half ScalarType;
 
  426   static inline __device__ Half4 add(Half4 a, Half4 b) {
 
  428     h.a = Math<half2>::add(a.a, b.a);
 
  429     h.b = Math<half2>::add(a.b, b.b);
 
  433   static inline __device__ Half4 sub(Half4 a, Half4 b) {
 
  435     h.a = Math<half2>::sub(a.a, b.a);
 
  436     h.b = Math<half2>::sub(a.b, b.b);
 
  440   static inline __device__ Half4 add(Half4 a, half b) {
 
  442     h.a = Math<half2>::add(a.a, b);
 
  443     h.b = Math<half2>::add(a.b, b);
 
  447   static inline __device__ Half4 sub(Half4 a, half b) {
 
  449     h.a = Math<half2>::sub(a.a, b);
 
  450     h.b = Math<half2>::sub(a.b, b);
 
  454   static inline __device__ Half4 mul(Half4 a, Half4 b) {
 
  456     h.a = Math<half2>::mul(a.a, b.a);
 
  457     h.b = Math<half2>::mul(a.b, b.b);
 
  461   static inline __device__ Half4 mul(Half4 a, half b) {
 
  463     h.a = Math<half2>::mul(a.a, b);
 
  464     h.b = Math<half2>::mul(a.b, b);
 
  468   static inline __device__ Half4 neg(Half4 v) {
 
  470     h.a = Math<half2>::neg(v.a);
 
  471     h.b = Math<half2>::neg(v.b);
 
  475   static inline __device__ half 
reduceAdd(Half4 v) {
 
  478     return Math<half>::add(hx, hy);
 
  486   static inline __device__ Half4 zero() {
 
  488     h.a = Math<half2>::zero();
 
  489     h.b = Math<half2>::zero();
 
  496   typedef half ScalarType;
 
  498   static inline __device__ Half8 add(Half8 a, Half8 b) {
 
  500     h.a = Math<Half4>::add(a.a, b.a);
 
  501     h.b = Math<Half4>::add(a.b, b.b);
 
  505   static inline __device__ Half8 sub(Half8 a, Half8 b) {
 
  507     h.a = Math<Half4>::sub(a.a, b.a);
 
  508     h.b = Math<Half4>::sub(a.b, b.b);
 
  512   static inline __device__ Half8 add(Half8 a, half b) {
 
  514     h.a = Math<Half4>::add(a.a, b);
 
  515     h.b = Math<Half4>::add(a.b, b);
 
  519   static inline __device__ Half8 sub(Half8 a, half b) {
 
  521     h.a = Math<Half4>::sub(a.a, b);
 
  522     h.b = Math<Half4>::sub(a.b, b);
 
  526   static inline __device__ Half8 mul(Half8 a, Half8 b) {
 
  528     h.a = Math<Half4>::mul(a.a, b.a);
 
  529     h.b = Math<Half4>::mul(a.b, b.b);
 
  533   static inline __device__ Half8 mul(Half8 a, half b) {
 
  535     h.a = Math<Half4>::mul(a.a, b);
 
  536     h.b = Math<Half4>::mul(a.b, b);
 
  540   static inline __device__ Half8 neg(Half8 v) {
 
  542     h.a = Math<Half4>::neg(v.a);
 
  543     h.b = Math<Half4>::neg(v.b);
 
  547   static inline __device__ half 
reduceAdd(Half8 v) {
 
  550     return Math<half>::add(hx, hy);
 
  558   static inline __device__ Half8 zero() {
 
  560     h.a = Math<Half4>::zero();
 
  561     h.b = Math<Half4>::zero();
 
  566 #endif // FAISS_USE_FLOAT16 
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i) 
static __device__ float reduceAdd(float2 v)
For a vector type, this is a horizontal add, returning sum(v_i) 
static __device__ float reduceAdd(float4 v)
For a vector type, this is a horizontal add, returning sum(v_i)