Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
NoTypeTensor.cuh
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #pragma once
13 
14 #include "../../FaissAssert.h"
15 #include "Tensor.cuh"
16 #include <initializer_list>
17 
18 namespace faiss { namespace gpu {
19 
20 template <int Dim, bool Contig = false, typename IndexT = int>
21 class NoTypeTensor {
22  public:
23  NoTypeTensor()
24  : mem_(nullptr),
25  typeSize_(0) {
26  }
27 
28  template <typename T>
30  : mem_(t.data()),
31  typeSize_(sizeof(T)) {
32  for (int i = 0; i < Dim; ++i) {
33  size_[i] = t.getSize(i);
34  stride_[i] = t.getStride(i);
35  }
36  }
37 
38  NoTypeTensor(void* mem, int typeSize, std::initializer_list<IndexT> sizes)
39  : mem_(mem),
40  typeSize_(typeSize) {
41 
42  int i = 0;
43  for (auto s : sizes) {
44  size_[i++] = s;
45  }
46 
47  stride_[Dim - 1] = (IndexT) 1;
48  for (int j = Dim - 2; j >= 0; --j) {
49  stride_[j] = stride_[j + 1] * size_[j + 1];
50  }
51  }
52 
53  NoTypeTensor(void* mem, int typeSize, int sizes[Dim])
54  : mem_(mem),
55  typeSize_(typeSize) {
56  for (int i = 0; i < Dim; ++i) {
57  size_[i] = sizes[i];
58  }
59 
60  stride_[Dim - 1] = (IndexT) 1;
61  for (int i = Dim - 2; i >= 0; --i) {
62  stride_[i] = stride_[i + 1] * sizes[i + 1];
63  }
64  }
65 
66  NoTypeTensor(void* mem, int typeSize,
67  IndexT sizes[Dim], IndexT strides[Dim])
68  : mem_(mem),
69  typeSize_(typeSize) {
70  for (int i = 0; i < Dim; ++i) {
71  size_[i] = sizes[i];
72  stride_[i] = strides[i];
73  }
74  }
75 
76  int getTypeSize() const {
77  return typeSize_;
78  }
79 
80  IndexT getSize(int dim) const {
81  FAISS_ASSERT(dim < Dim);
82  return size_[dim];
83  }
84 
85  IndexT getStride(int dim) const {
86  FAISS_ASSERT(dim < Dim);
87  return stride_[dim];
88  }
89 
90  template <typename T>
92  FAISS_ASSERT(sizeof(T) == typeSize_);
93 
94  return Tensor<T, Dim, Contig, IndexT>((T*) mem_, size_, stride_);
95  }
96 
97  NoTypeTensor<Dim, Contig, IndexT> narrowOutermost(IndexT start, IndexT size) {
98  char* newPtr = (char*) mem_;
99 
100  if (start > 0) {
101  newPtr += typeSize_ * start * stride_[0];
102  }
103 
104  IndexT newSize[Dim];
105  for (int i = 0; i < Dim; ++i) {
106  if (i == 0) {
107  assert(start + size <= size_[0]);
108  newSize[i] = size;
109  } else {
110  newSize[i] = size_[i];
111  }
112  }
113 
115  newPtr, typeSize_, newSize, stride_);
116  }
117 
118  private:
119  void* mem_;
120  int typeSize_;
121  IndexT size_[Dim];
122  IndexT stride_[Dim];
123 };
124 
125 } } // namespace
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
Definition: Tensor.cuh:162
__host__ __device__ IndexT getStride(int i) const
Definition: Tensor.cuh:216
Our tensor type.
Definition: Tensor.cuh:31
__host__ __device__ IndexT getSize(int i) const
Definition: Tensor.cuh:210