Open3D (C++ API)  0.16.1
AdvancedIndexing.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// The MIT License (MIT)
5//
6// Copyright (c) 2018-2021 www.open3d.org
7//
8// Permission is hereby granted, free of charge, to any person obtaining a copy
9// of this software and associated documentation files (the "Software"), to deal
10// in the Software without restriction, including without limitation the rights
11// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12// copies of the Software, and to permit persons to whom the Software is
13// furnished to do so, subject to the following conditions:
14//
15// The above copyright notice and this permission notice shall be included in
16// all copies or substantial portions of the Software.
17//
18// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24// IN THE SOFTWARE.
25// ----------------------------------------------------------------------------
26
27#pragma once
28
29#include <vector>
30
31#include "open3d/core/Indexer.h"
33#include "open3d/core/Tensor.h"
34
35namespace open3d {
36namespace core {
37
40public:
42 const std::vector<Tensor>& index_tensors)
43 : tensor_(tensor), index_tensors_(ExpandBoolTensors(index_tensors)) {
45 }
46
47 inline Tensor GetTensor() const { return tensor_; }
48
49 inline std::vector<Tensor> GetIndexTensors() const {
50 return index_tensors_;
51 }
52
53 inline SizeVector GetOutputShape() const { return output_shape_; }
54
55 inline SizeVector GetIndexedShape() const { return indexed_shape_; }
56
58
62 static bool IsIndexSplittedBySlice(
63 const std::vector<Tensor>& index_tensors);
64
67 static std::pair<Tensor, std::vector<Tensor>> ShuffleIndexedDimsToFront(
68 const Tensor& tensor, const std::vector<Tensor>& index_tensors);
69
72 static std::pair<std::vector<Tensor>, SizeVector>
73 ExpandToCommonShapeExceptZeroDim(const std::vector<Tensor>& index_tensors);
74
75 // Replace indexed dimensions with stride 0 and the size of the result
76 // tensor.
77 //
78 // The offset in these dimensions is computed by the kernel using
79 // the index tensor's values and the stride of the tensor. The new shape is
80 // not meaningful. It's used to make the shape compatible with the result
81 // tensor.
82 //
83 // Effectively, we throw away the tensor's shape and strides for the sole
84 // purpose of element-wise iteration for the Indexer. The tensor's original
85 // strides are stored in indexed_shape_ and indexed_strides_,
86 // which are passed to fancy indexing kernels.
87 static Tensor RestrideTensor(const Tensor& tensor,
88 int64_t dims_before,
89 int64_t dims_indexed,
90 SizeVector replacement_shape);
91
92 // Add dimensions of size 1 to an index tensor so that it can be broadcast
93 // to the result shape and iterated over element-wise like the result tensor
94 // and the restrided src.
95 static Tensor RestrideIndexTensor(const Tensor& index_tensor,
96 int64_t dims_before,
97 int64_t dims_after);
98
99protected:
101 void RunPreprocess();
102
104 static std::vector<Tensor> ExpandBoolTensors(
105 const std::vector<Tensor>& index_tensors);
106
110
112 std::vector<Tensor> index_tensors_;
113
116
120
124};
125
136public:
137 enum class AdvancedIndexerMode { SET, GET };
138
140 const Tensor& dst,
141 const std::vector<Tensor>& index_tensors,
142 const SizeVector& indexed_shape,
143 const SizeVector& indexed_strides,
145 : mode_(mode) {
146 if (indexed_shape.size() != indexed_strides.size()) {
148 "Internal error: indexed_shape's ndim {} does not equal to "
149 "indexd_strides' ndim {}",
150 indexed_shape.size(), indexed_strides.size());
151 }
152 num_indices_ = indexed_shape.size();
153
154 // Initialize Indexer
155 std::vector<Tensor> inputs;
156 inputs.push_back(src);
157 for (const Tensor& index_tensor : index_tensors) {
158 if (index_tensor.NumDims() != 0) {
159 inputs.push_back(index_tensor);
160 }
161 }
162 indexer_ = Indexer({inputs}, dst, DtypePolicy::NONE);
163
164 // Fill shape and strides
165 if (num_indices_ != static_cast<int64_t>(indexed_strides.size())) {
167 "Internal error: indexed_shape's ndim {} does not equal to "
168 "indexd_strides' ndim {}",
169 num_indices_, indexed_strides.size());
170 }
171 for (int64_t i = 0; i < num_indices_; ++i) {
172 indexed_shape_[i] = indexed_shape[i];
173 indexed_strides_[i] = indexed_strides[i];
174 }
175
176 // Check dtypes
177 if (src.GetDtype() != dst.GetDtype()) {
179 "src's dtype {} is not the same as dst's dtype {}.",
180 src.GetDtype().ToString(), dst.GetDtype().ToString());
181 }
183 }
184
185 inline OPEN3D_HOST_DEVICE char* GetInputPtr(int64_t workload_idx) const {
186 char* ptr = indexer_.GetInputPtr(0, workload_idx);
187 ptr += GetIndexedOffset(workload_idx) * element_byte_size_ *
189 return ptr;
190 }
191
192 inline OPEN3D_HOST_DEVICE char* GetOutputPtr(int64_t workload_idx) const {
193 char* ptr = indexer_.GetOutputPtr(workload_idx);
194 ptr += GetIndexedOffset(workload_idx) * element_byte_size_ *
196 return ptr;
197 }
198
199 inline OPEN3D_HOST_DEVICE int64_t
200 GetIndexedOffset(int64_t workload_idx) const {
201 int64_t offset = 0;
202 for (int64_t i = 0; i < num_indices_; ++i) {
203 int64_t index = *(reinterpret_cast<int64_t*>(
204 indexer_.GetInputPtr(i + 1, workload_idx)));
205 OPEN3D_ASSERT(index >= -indexed_shape_[i] &&
206 index < indexed_shape_[i] && "Index out of bounds.");
207 index += indexed_shape_[i] * (index < 0);
208 offset += index * indexed_strides_[i];
209 }
210 return offset;
211 }
212
213 int64_t NumWorkloads() const { return indexer_.NumWorkloads(); }
214
215protected:
220 int64_t indexed_shape_[MAX_DIMS];
221 int64_t indexed_strides_[MAX_DIMS];
222};
223
224} // namespace core
225} // namespace open3d
#define OPEN3D_HOST_DEVICE
Definition: CUDAUtils.h:63
#define LogError(...)
Definition: Logging.h:67
#define OPEN3D_ASSERT(...)
Definition: Macro.h:67
This class is based on PyTorch's aten/src/ATen/native/Indexing.cpp.
Definition: AdvancedIndexing.h:39
void RunPreprocess()
Preprocess tensor and index tensors.
Definition: AdvancedIndexing.cpp:129
static Tensor RestrideTensor(const Tensor &tensor, int64_t dims_before, int64_t dims_indexed, SizeVector replacement_shape)
Definition: AdvancedIndexing.cpp:104
static bool IsIndexSplittedBySlice(const std::vector< Tensor > &index_tensors)
Definition: AdvancedIndexing.cpp:36
static std::pair< Tensor, std::vector< Tensor > > ShuffleIndexedDimsToFront(const Tensor &tensor, const std::vector< Tensor > &index_tensors)
Definition: AdvancedIndexing.cpp:60
SizeVector output_shape_
Output shape.
Definition: AdvancedIndexing.h:115
SizeVector GetIndexedStrides() const
Definition: AdvancedIndexing.h:57
std::vector< Tensor > index_tensors_
The processed index tensors.
Definition: AdvancedIndexing.h:112
Tensor tensor_
Definition: AdvancedIndexing.h:109
static std::vector< Tensor > ExpandBoolTensors(const std::vector< Tensor > &index_tensors)
Expand boolean tensor to integer index.
Definition: AdvancedIndexing.cpp:249
static std::pair< std::vector< Tensor >, SizeVector > ExpandToCommonShapeExceptZeroDim(const std::vector< Tensor > &index_tensors)
Definition: AdvancedIndexing.cpp:82
AdvancedIndexPreprocessor(const Tensor &tensor, const std::vector< Tensor > &index_tensors)
Definition: AdvancedIndexing.h:41
static Tensor RestrideIndexTensor(const Tensor &index_tensor, int64_t dims_before, int64_t dims_after)
Definition: AdvancedIndexing.cpp:119
std::vector< Tensor > GetIndexTensors() const
Definition: AdvancedIndexing.h:49
SizeVector indexed_shape_
Definition: AdvancedIndexing.h:119
SizeVector GetIndexedShape() const
Definition: AdvancedIndexing.h:55
SizeVector GetOutputShape() const
Definition: AdvancedIndexing.h:53
Tensor GetTensor() const
Definition: AdvancedIndexing.h:47
SizeVector indexed_strides_
Definition: AdvancedIndexing.h:123
Definition: AdvancedIndexing.h:135
Indexer indexer_
Definition: AdvancedIndexing.h:216
AdvancedIndexer(const Tensor &src, const Tensor &dst, const std::vector< Tensor > &index_tensors, const SizeVector &indexed_shape, const SizeVector &indexed_strides, AdvancedIndexerMode mode)
Definition: AdvancedIndexing.h:139
int64_t element_byte_size_
Definition: AdvancedIndexing.h:219
int64_t NumWorkloads() const
Definition: AdvancedIndexing.h:213
int64_t num_indices_
Definition: AdvancedIndexing.h:218
OPEN3D_HOST_DEVICE char * GetOutputPtr(int64_t workload_idx) const
Definition: AdvancedIndexing.h:192
int64_t indexed_strides_[MAX_DIMS]
Definition: AdvancedIndexing.h:221
AdvancedIndexerMode mode_
Definition: AdvancedIndexing.h:217
int64_t indexed_shape_[MAX_DIMS]
Definition: AdvancedIndexing.h:220
OPEN3D_HOST_DEVICE int64_t GetIndexedOffset(int64_t workload_idx) const
Definition: AdvancedIndexing.h:200
AdvancedIndexerMode
Definition: AdvancedIndexing.h:137
OPEN3D_HOST_DEVICE char * GetInputPtr(int64_t workload_idx) const
Definition: AdvancedIndexing.h:185
std::string ToString() const
Definition: Dtype.h:83
int64_t ByteSize() const
Definition: Dtype.h:77
Definition: Indexer.h:280
OPEN3D_HOST_DEVICE char * GetOutputPtr(int64_t workload_idx) const
Definition: Indexer.h:456
int64_t NumWorkloads() const
Definition: Indexer.cpp:425
OPEN3D_HOST_DEVICE char * GetInputPtr(int64_t input_idx, int64_t workload_idx) const
Definition: Indexer.h:424
Definition: SizeVector.h:88
size_t size() const
Definition: SmallVector.h:138
Definition: Tensor.h:51
Dtype GetDtype() const
Definition: Tensor.h:1169
int offset
Definition: FilePCD.cpp:64
const char const char value recording_handle imu_sample recording_handle uint8_t size_t data_size k4a_record_configuration_t config target_format k4a_capture_t capture_handle k4a_imu_sample_t imu_sample playback_handle k4a_logging_message_cb_t void min_level device_handle k4a_imu_sample_t timeout_in_ms capture_handle capture_handle capture_handle image_handle temperature_c k4a_image_t image_handle uint8_t image_handle image_handle image_handle image_handle image_handle timestamp_usec white_balance image_handle k4a_device_configuration_t config device_handle char size_t serial_number_size bool int32_t int32_t int32_t int32_t k4a_color_control_mode_t default_mode mode
Definition: K4aPlugin.cpp:697
Definition: PinholeCameraIntrinsic.cpp:35