Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
NoTypeTensor.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #pragma once
12 
13 #include "../../FaissAssert.h"
14 #include "Tensor.cuh"
15 #include <initializer_list>
16 
17 namespace faiss { namespace gpu {
18 
19 template <int Dim, bool InnerContig = false, typename IndexT = int>
20 class NoTypeTensor {
21  public:
22  NoTypeTensor()
23  : mem_(nullptr),
24  typeSize_(0) {
25  }
26 
27  template <typename T>
29  : mem_(t.data()),
30  typeSize_(sizeof(T)) {
31  for (int i = 0; i < Dim; ++i) {
32  size_[i] = t.getSize(i);
33  stride_[i] = t.getStride(i);
34  }
35  }
36 
37  NoTypeTensor(void* mem, int typeSize, std::initializer_list<IndexT> sizes)
38  : mem_(mem),
39  typeSize_(typeSize) {
40 
41  int i = 0;
42  for (auto s : sizes) {
43  size_[i++] = s;
44  }
45 
46  stride_[Dim - 1] = (IndexT) 1;
47  for (int j = Dim - 2; j >= 0; --j) {
48  stride_[j] = stride_[j + 1] * size_[j + 1];
49  }
50  }
51 
52  NoTypeTensor(void* mem, int typeSize, int sizes[Dim])
53  : mem_(mem),
54  typeSize_(typeSize) {
55  for (int i = 0; i < Dim; ++i) {
56  size_[i] = sizes[i];
57  }
58 
59  stride_[Dim - 1] = (IndexT) 1;
60  for (int i = Dim - 2; i >= 0; --i) {
61  stride_[i] = stride_[i + 1] * sizes[i + 1];
62  }
63  }
64 
65  NoTypeTensor(void* mem, int typeSize,
66  IndexT sizes[Dim], IndexT strides[Dim])
67  : mem_(mem),
68  typeSize_(typeSize) {
69  for (int i = 0; i < Dim; ++i) {
70  size_[i] = sizes[i];
71  stride_[i] = strides[i];
72  }
73  }
74 
75  int getTypeSize() const {
76  return typeSize_;
77  }
78 
79  IndexT getSize(int dim) const {
80  FAISS_ASSERT(dim < Dim);
81  return size_[dim];
82  }
83 
84  IndexT getStride(int dim) const {
85  FAISS_ASSERT(dim < Dim);
86  return stride_[dim];
87  }
88 
89  template <typename T>
91  FAISS_ASSERT(sizeof(T) == typeSize_);
92 
93  return Tensor<T, Dim, InnerContig, IndexT>((T*) mem_, size_, stride_);
94  }
95 
96  NoTypeTensor<Dim, InnerContig, IndexT> narrowOutermost(IndexT start,
97  IndexT size) {
98  char* newPtr = (char*) mem_;
99 
100  if (start > 0) {
101  newPtr += typeSize_ * start * stride_[0];
102  }
103 
104  IndexT newSize[Dim];
105  for (int i = 0; i < Dim; ++i) {
106  if (i == 0) {
107  assert(start + size <= size_[0]);
108  newSize[i] = size;
109  } else {
110  newSize[i] = size_[i];
111  }
112  }
113 
115  newPtr, typeSize_, newSize, stride_);
116  }
117 
118  private:
119  void* mem_;
120  int typeSize_;
121  IndexT size_[Dim];
122  IndexT stride_[Dim];
123 };
124 
125 } } // namespace
__host__ __device__ IndexT getSize(int i) const
Definition: Tensor.cuh:226
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
Definition: Tensor.cuh:178
Our tensor type.
Definition: Tensor.cuh:30
__host__ __device__ IndexT getStride(int i) const
Definition: Tensor.cuh:232