基础切片 (slice) 与索引
从矩阵中提取子矩阵,或通过索引将张量降维为子张量以进行后续计算,是一种非常常见的操作。
RSTSR 提供了 NumPy 中称为“基础索引”(basic indexing) 的大部分功能,它返回的是张量视图 (view) 而非拥有所有权的张量 (owned)。通过这种机制,大多数张量提取操作可以在不进行内存拷贝的情况下完成。对于大型张量来说,与内存分配和张量运算相比,所有基础 slice 和索引操作的成本都很低。
由于语言限制,在 Rust 中,通过方括号 []
进行索引只能返回底层数据的引用 &T
,因此技术上无法通过方括号 []
返回张量视图。在 RSTSR 中,只有当数据以 Vec<T>
类型存储时,通过 []
进行元素索引才会返回元素的引用 &T
。[]
索引可以使用的情景非常有限。
然而,通过函数进行索引和 slice 以获取子张量视图 TensorView
(或 TensorMut
) 是可行的。最重要的 slice 函数和宏包括:
slice
(等同于i
) :通过传入 slice 参数返回张量视图;slice_mut
(等同于i
) :通过传入 slice 参数返回可变的张量视图;slice!((start, ) stop (, slice))
:生成 slice 配置,类似于 Python 的内置slice
函数。
slice!
与函数 slice
不同如果您对同时使用函数 slice
和宏 slice!
感到不适 (例如 tensor.slice(slice!(1, 5, 2))
),您仍然可以使用等效的函数 i
来执行张量索引和 slice (例如 tensor.i(slice!(1, 5, 2))
)。
这些函数的命名冲突可能会产生困扰,但它们实际上遵循了一些惯例:
- 函数
slice
来自 Rust 库 ndarray; - 函数
i
来自 Rust 库 candle; - 宏
slice!
来自 Python 的内置函数。
请注意,我们尚未实现高级索引。高级索引主要是通过整数张量、布尔张量或索引列表进行索引。这些功能在 NumPy 中得到了很好的支持,但在 RSTSR 中实现起来较为困难。在大多数情况下,高级索引需要 (或更高效时) 显式的内存拷贝。我们将在未来努力实现一些高级索引功能。
请注意,通过 slice ,RSTSR 总是生成动态维度 (IxD
) 的张量,而不是生成固定维度 (例如 1-D 时为 Ix1
,2-D 时为 Ix2
等)。与 ndarray 相比,这是一个退步,因为 ndarray 拥有更复杂的宏系统来处理固定维度 slice。
术语
- slice (通过 range 或 slice):-D 张量到 -D 张量的操作,返回较小张量的视图;
- 索引 (通过整数) :-D 张量到 -D 张量,通过选择合并一个维度;
- 元素索引 (通过整数列表) :返回元素的引用
&T
而不是张量视图。
在 RSTSR 中, slice 和索引的实现方式类似。只要 Rust 允许,用户通常可以同时执行 slice 和索引。
RSTSR 遵循 Rust、C 和 Python 的 0 基索引惯例,这与 Fortran 不同。
1. 通过数字索引
例如,一个 3-D 张量 可以通过索引变为 2-D 张量 :
// generate 3-D tensor A_ijk
let a = rt::arange(24).into_shape([4, 3, 2]);
println!("{:}", a);
// B_jk = A_ijk where i = 2
let b = a.slice(2); // equivalently `a.i(2)`
println!("{:}", b);
// output:
// [[ 12 13]
// [ 14 15]
// [ 16 17]]
更进一步,如果您希望对 进行索引,即 ,那么您可以将 [2, 0]
传递给 slice
函数:
// C_k = A_ijk where i = 2, j = 0
// surely, `a.slice(2).slice(0)` works, but we can use `a.slice([2, 0])` instead
let c = a.slice([2, 0]);
println!("{:}", c);
// output: [ 12 13]
RSTSR 也接受负索引以从数组末尾开始索引:
// D_jk = A_ijk where i = -1 = 3 (negative index from the end)
let d = a.slice(-1);
println!("{:}", d);
// output:
// [[ 18 19]
// [ 20 21]
// [ 22 23]]
2. 基础 slice
2.1 通过 slice
例如,我们希望从张量 中提取 :
// generate 3-D tensor A_ijk
let a = rt::arange(24).into_shape([4, 3, 2]);
println!("{:}", a);
// B_ijk = A_ijk where 1 <= i < 3
let b = a.slice(1..3); // equivalently `a.i(1..3)`
println!("{:}", b);
// output:
// [[[ 6 7]
// [ 8 9]
// [10 11]]
//
// [[12 13]
// [14 15]
// [16 17]]]
前两个维度的 slice 也可以通过以下方式实现:
// C_ijk = A_ijk where 1 <= i < 3, 0 <= j < 2
let c = a.slice([1..3, 0..2]);
println!("{:}", c);
// output:
// [[[ 6 7]
// [ 8 9]]
//
// [[12 13]
// [14 15]]]
负索引也适用于这种情况:
let a = rt::arange(24);
// D_i = A_i where i = -5..-2 = 19..22 (negative index from the end given 24 elements)
let d = a.slice(-5..-2);
println!("{:}", d);
// output: [ 19 20 21]
2.2 通过 range
RSTSR 不仅接受 Range 类型 (如 1..3
),还接受 RangeTo (..3
) 或 RangeFrom (1..
)。
let a = rt::arange(24);
// D_i = A_i where i = -5.. or 19..
let d = a.slice(-5..);
println!("{:}", d);
// output: [ 19 20 21 22 23]
但需要注意的是,Rust 不允许将两种不同类型合并为 Rust 数组 [T]
:
要解决这个问题,您可以传递元组 (T1, T2)
而不是 Rust 数组 [T]
:
let a = rt::arange(24).into_shape([4, 3, 2]);
let b = a.slice((.., 1..3, ..2)); // equivalently `a.slice(s![.., 1..3, ..2])`
println!("{:}", b);
// output:
// [[[ 2 3]
// [ 4 5]]
//
// [[ 8 9]
// [ 10 11]]
//
// [[ 14 15]
// [ 16 17]]
//
// [[ 20 21]
// [ 22 23]]]
我们目前只实现了最多 10 个元素的元组;如果您的张量维度非常高,您可能需要使用 s!
。
3. 特殊索引
3.1 带步长的 slice
要进行带步长的 slice ,您可以使用 slice!
宏。slice!
宏的用法类似于 Python 的内置函数 slice
1:
slice!(stop)
:类似于范围到..stop
;slice!(start, stop)
:类似于范围start..stop
;slice!(start, stop, step)
:类似于 Fortran 或 NumPy 的 slicestart:stop:step
。
let a = rt::arange(24);
// first 5 elements
let b = a.slice(slice!(5));
println!("{:}", b);
// output: [ 0 1 2 3 4]
// elements from 5 to -9 (resembles 15 for the given 24 elements)
let b = a.slice(slice!(5, -9));
println!("{:}", b);
// output: [ 5 6 7 ... 12 13 14]
// elements from 5 to -9 with step 2
let b = a.slice(slice!(5, -9, 2));
println!("{:}", b);
// output: [ 5 7 9 11 13]
// reversed step 2
let b = a.slice(slice!(-9, 5, -2));
println!("{:}", b);
// output: [ 15 13 11 9 7]
在许多情况下,None
也是 slice!
的有效输入。实际上,slice!
是通过 Option<T>
的机制实现的,因此使用 Some(val)
也是有效的。
let b = a.slice(slice!(None, 9, Some(2)));
println!("{:}", b);
// output: [ 0 2 4 6 8]
3.2 插入 axis
您可以通过 None
或 NewAxis
(定义为 Indexer::Insert
) 插入 axis。这类似于 NumPy 的 None
或 np.newaxis
。
let a = rt::arange(24).into_shape([4, 3, 2]);
// insert new axis at the beginning
let b = a.slice(NewAxis);
println!("{:?}", b.layout());
// output: shape: [1, 4, 3, 2], stride: [6, 6, 2, 1], offset: 0
// using `None` is equivalent to `NewAxis`
let b = a.slice(None);
println!("{:?}", b.layout());
// output: shape: [1, 4, 3, 2], stride: [6, 6, 2, 1], offset: 0
// insert new axis at the second position
let b = a.slice((.., None));
println!("{:?}", b.layout());
// output: shape: [4, 1, 3, 2], stride: [6, 2, 2, 1], offset: 0
使用 None
会比较方便,但我们不接受 Some(val)
进行索引。因此,尽管以下代码可以编译,但它实际上不起作用。
3.3 省略号
在 RSTSR 中,您可以使用 Ellipsis
(定义为 Indexer::Ellipsis
) 来跳过一些索引:
let a = rt::arange(24).into_shape([4, 3, 2]);
// using ellipsis to select index from last dimension
// equivallently to `a.slice((.., .., 0))` for 3-D tensor
// same to numpy's `a[..., 0]`
let b = a.slice((Ellipsis, 0));
println!("{:2}", b);
// output:
// [[ 0 2 4]
// [ 6 8 10]
// [ 12 14 16]
// [ 18 20 22]]
3.4 混合索引和 slice
如前所述,使用数组类型 [T]
不适合表示各种类型的索引和 slice 。然而,您可以使用宏 s!
或元组来执行此任务2。
let a: Tensor<f64> = rt::zeros([6, 7, 5, 9, 8]);
// mixed indexing
let b = a.slice((slice!(-2, 1, -1), None, None, Ellipsis, 1, ..-2));
println!("{:?}", b.layout());
// output: shape: [3, 1, 1, 7, 5, 6], stride: [-2520, 360, 360, 360, 72, 1], offset: 10088
4. 元素索引
我们也在 RSTSR 中提供了元素索引。但请注意,在大多数情况下,元素索引并不高效。
- 对于“未检查”的元素索引,它更有可能阻止编译器的内部向量化和 SIMD 优化;
- 对于“安全”的元素索引,额外的越界检查会进一步阻碍优化。
因此,对于计算密集型任务,建议使用 RSTSR 内部的算术函数或映射函数,或者自己编写高效率的程序,以避免直接进行元素索引。只有在效率不重要或 RSTSR 内部函数无法满足需求时,才使用元素索引。
4.1 安全的元素索引
要执行索引,您可以使用 Rust 的方括号 []
:
let a = rt::arange(24).into_shape([4, 3, 2]);
let val = a[[2, 2, 1]];
println!("{:}", val);
// output: 17
println!("{:}", std::any::type_name_of_val(&val));
// output: i32
如果您提供的索引越界,RSTSR 会崩溃:
在 RSTSR 中,slice (到张量视图) 和元素索引 (到值的引用) 是不同的。如果您希望得到一个值而不是单个元素的张量,请不要使用函数 slice
。
let view = a.slice((2, 2, 1));
println!("{:}", view);
// output: 17
// it seems to be a value, but actually it is a tensor view
println!("{:?}", view);
// output:
// === Debug Tensor Print ===
// 17
// DeviceFaer { base: DeviceCpuRayon { num_threads: 0 } }
// 0-Dim (dyn), contiguous: CcFf
// shape: [], stride: [], offset: 17
// ==========================
4.2 未检查的元素索引
未检查的元素索引会比安全的元素索引稍快一些。要执行索引,您可以使用不安全的函数 index_uncheck
:
let a = rt::arange(24).into_shape([4, 3, 2]);
let val = unsafe { a.index_uncheck([2, 2, 1]) };
println!("{:}", val);
// output: 17
如果您提供的索引越界,但索引指针位置仍然处于合理的底层内存,RSTSR 不会崩溃并返回错误的值:
此函数被标记为 unsafe
是为了避免这种越界 (但未超出内存) 的情况。在大多数情况下,它仍然是内存安全的,因为超出内存访问 Vec<T>
会正常地崩溃。