Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
NoTypeTensor.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #pragma once
11 
12 #include "../../FaissAssert.h"
13 #include "Tensor.cuh"
14 #include <initializer_list>
15 
16 namespace faiss { namespace gpu {
17 
18 template <int Dim, bool InnerContig = false, typename IndexT = int>
19 class NoTypeTensor {
20  public:
21  NoTypeTensor()
22  : mem_(nullptr),
23  typeSize_(0) {
24  }
25 
26  template <typename T>
28  : mem_(t.data()),
29  typeSize_(sizeof(T)) {
30  for (int i = 0; i < Dim; ++i) {
31  size_[i] = t.getSize(i);
32  stride_[i] = t.getStride(i);
33  }
34  }
35 
36  NoTypeTensor(void* mem, int typeSize, std::initializer_list<IndexT> sizes)
37  : mem_(mem),
38  typeSize_(typeSize) {
39 
40  int i = 0;
41  for (auto s : sizes) {
42  size_[i++] = s;
43  }
44 
45  stride_[Dim - 1] = (IndexT) 1;
46  for (int j = Dim - 2; j >= 0; --j) {
47  stride_[j] = stride_[j + 1] * size_[j + 1];
48  }
49  }
50 
51  NoTypeTensor(void* mem, int typeSize, int sizes[Dim])
52  : mem_(mem),
53  typeSize_(typeSize) {
54  for (int i = 0; i < Dim; ++i) {
55  size_[i] = sizes[i];
56  }
57 
58  stride_[Dim - 1] = (IndexT) 1;
59  for (int i = Dim - 2; i >= 0; --i) {
60  stride_[i] = stride_[i + 1] * sizes[i + 1];
61  }
62  }
63 
64  NoTypeTensor(void* mem, int typeSize,
65  IndexT sizes[Dim], IndexT strides[Dim])
66  : mem_(mem),
67  typeSize_(typeSize) {
68  for (int i = 0; i < Dim; ++i) {
69  size_[i] = sizes[i];
70  stride_[i] = strides[i];
71  }
72  }
73 
74  int getTypeSize() const {
75  return typeSize_;
76  }
77 
78  IndexT getSize(int dim) const {
79  FAISS_ASSERT(dim < Dim);
80  return size_[dim];
81  }
82 
83  IndexT getStride(int dim) const {
84  FAISS_ASSERT(dim < Dim);
85  return stride_[dim];
86  }
87 
88  template <typename T>
90  FAISS_ASSERT(sizeof(T) == typeSize_);
91 
92  return Tensor<T, Dim, InnerContig, IndexT>((T*) mem_, size_, stride_);
93  }
94 
95  NoTypeTensor<Dim, InnerContig, IndexT> narrowOutermost(IndexT start,
96  IndexT size) {
97  char* newPtr = (char*) mem_;
98 
99  if (start > 0) {
100  newPtr += typeSize_ * start * stride_[0];
101  }
102 
103  IndexT newSize[Dim];
104  for (int i = 0; i < Dim; ++i) {
105  if (i == 0) {
106  assert(start + size <= size_[0]);
107  newSize[i] = size;
108  } else {
109  newSize[i] = size_[i];
110  }
111  }
112 
114  newPtr, typeSize_, newSize, stride_);
115  }
116 
117  private:
118  void* mem_;
119  int typeSize_;
120  IndexT size_[Dim];
121  IndexT stride_[Dim];
122 };
123 
124 } } // namespace
__host__ __device__ IndexT getSize(int i) const
Definition: Tensor.cuh:223
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
Definition: Tensor.cuh:175
Our tensor type.
Definition: Tensor.cuh:29
__host__ __device__ IndexT getStride(int i) const
Definition: Tensor.cuh:229