Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
NoTypeTensor.cuh
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #pragma once
10 
11 #include "../../FaissAssert.h"
12 #include "Tensor.cuh"
13 #include <initializer_list>
14 
15 namespace faiss { namespace gpu {
16 
17 template <int Dim, bool InnerContig = false, typename IndexT = int>
18 class NoTypeTensor {
19  public:
20  NoTypeTensor()
21  : mem_(nullptr),
22  typeSize_(0) {
23  }
24 
25  template <typename T>
27  : mem_(t.data()),
28  typeSize_(sizeof(T)) {
29  for (int i = 0; i < Dim; ++i) {
30  size_[i] = t.getSize(i);
31  stride_[i] = t.getStride(i);
32  }
33  }
34 
35  NoTypeTensor(void* mem, int typeSize, std::initializer_list<IndexT> sizes)
36  : mem_(mem),
37  typeSize_(typeSize) {
38 
39  int i = 0;
40  for (auto s : sizes) {
41  size_[i++] = s;
42  }
43 
44  stride_[Dim - 1] = (IndexT) 1;
45  for (int j = Dim - 2; j >= 0; --j) {
46  stride_[j] = stride_[j + 1] * size_[j + 1];
47  }
48  }
49 
50  NoTypeTensor(void* mem, int typeSize, int sizes[Dim])
51  : mem_(mem),
52  typeSize_(typeSize) {
53  for (int i = 0; i < Dim; ++i) {
54  size_[i] = sizes[i];
55  }
56 
57  stride_[Dim - 1] = (IndexT) 1;
58  for (int i = Dim - 2; i >= 0; --i) {
59  stride_[i] = stride_[i + 1] * sizes[i + 1];
60  }
61  }
62 
63  NoTypeTensor(void* mem, int typeSize,
64  IndexT sizes[Dim], IndexT strides[Dim])
65  : mem_(mem),
66  typeSize_(typeSize) {
67  for (int i = 0; i < Dim; ++i) {
68  size_[i] = sizes[i];
69  stride_[i] = strides[i];
70  }
71  }
72 
73  int getTypeSize() const {
74  return typeSize_;
75  }
76 
77  IndexT getSize(int dim) const {
78  FAISS_ASSERT(dim < Dim);
79  return size_[dim];
80  }
81 
82  IndexT getStride(int dim) const {
83  FAISS_ASSERT(dim < Dim);
84  return stride_[dim];
85  }
86 
87  template <typename T>
89  FAISS_ASSERT(sizeof(T) == typeSize_);
90 
91  return Tensor<T, Dim, InnerContig, IndexT>((T*) mem_, size_, stride_);
92  }
93 
94  NoTypeTensor<Dim, InnerContig, IndexT> narrowOutermost(IndexT start,
95  IndexT size) {
96  char* newPtr = (char*) mem_;
97 
98  if (start > 0) {
99  newPtr += typeSize_ * start * stride_[0];
100  }
101 
102  IndexT newSize[Dim];
103  for (int i = 0; i < Dim; ++i) {
104  if (i == 0) {
105  assert(start + size <= size_[0]);
106  newSize[i] = size;
107  } else {
108  newSize[i] = size_[i];
109  }
110  }
111 
113  newPtr, typeSize_, newSize, stride_);
114  }
115 
116  private:
117  void* mem_;
118  int typeSize_;
119  IndexT size_[Dim];
120  IndexT stride_[Dim];
121 };
122 
123 } } // namespace
__host__ __device__ IndexT getSize(int i) const
Definition: Tensor.cuh:222
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
Definition: Tensor.cuh:174
Our tensor type.
Definition: Tensor.cuh:28
__host__ __device__ IndexT getStride(int i) const
Definition: Tensor.cuh:228