跳到主要内容

基础切片 (slice) 与索引

从矩阵中提取子矩阵,或通过索引将张量降维为子张量以进行后续计算,是一种非常常见的操作。

RSTSR 提供了 NumPy 中称为“基础索引”(basic indexing) 的大部分功能,它返回的是张量视图 (view) 而非拥有所有权的张量 (owned)。通过这种机制,大多数张量提取操作可以在不进行内存拷贝的情况下完成。对于大型张量来说,与内存分配和张量运算相比,所有基础 slice 和索引操作的成本都很低

由于语言限制,在 Rust 中,通过方括号 [] 进行索引只能返回底层数据的引用 &T,因此技术上无法通过方括号 [] 返回张量视图。在 RSTSR 中,只有当数据以 Vec<T> 类型存储时,通过 [] 进行元素索引才会返回元素的引用 &T[] 索引可以使用的情景非常有限。

然而,通过函数进行索引和 slice 以获取子张量视图 TensorView (或 TensorMut) 是可行的。最重要的 slice 函数和宏包括:

  • slice (等同于 i) :通过传入 slice 参数返回张量视图;
  • slice_mut (等同于 i) :通过传入 slice 参数返回可变的张量视图;
  • slice!((start, ) stop (, slice)):生成 slice 配置,类似于 Python 的内置 slice 函数。
slice! 与函数 slice 不同

如果您对同时使用函数 slice 和宏 slice! 感到不适 (例如 tensor.slice(slice!(1, 5, 2))),您仍然可以使用等效的函数 i 来执行张量索引和 slice (例如 tensor.i(slice!(1, 5, 2)))。

这些函数的命名冲突可能会产生困扰,但它们实际上遵循了一些惯例:

  • 函数 slice 来自 Rust 库 ndarray;
  • 函数 i 来自 Rust 库 candle;
  • slice! 来自 Python 的内置函数。

请注意,我们尚未实现高级索引。高级索引主要是通过整数张量、布尔张量或索引列表进行索引。这些功能在 NumPy 中得到了很好的支持,但在 RSTSR 中实现起来较为困难。在大多数情况下,高级索引需要 (或更高效时) 显式的内存拷贝。我们将在未来努力实现一些高级索引功能。

RSTSR 中的 slice 总是生成动态维度

请注意,通过 slice ,RSTSR 总是生成动态维度 (IxD) 的张量,而不是生成固定维度 (例如 1-D 时为 Ix1,2-D 时为 Ix2 等)。与 ndarray 相比,这是一个退步,因为 ndarray 拥有更复杂的宏系统来处理固定维度 slice。

术语

  • slice (通过 range 或 slice):nn-D 张量到 nn-D 张量的操作,返回较小张量的视图;
  • 索引 (通过整数) :nn-D 张量到 (n1)(n-1)-D 张量,通过选择合并一个维度;
  • 元素索引 (通过整数列表) :返回元素的引用 &T 而不是张量视图。

在 RSTSR 中, slice 和索引的实现方式类似。只要 Rust 允许,用户通常可以同时执行 slice 和索引。

RSTSR 遵循 Rust、C 和 Python 的 0 基索引惯例,这与 Fortran 不同。

1. 通过数字索引

例如,一个 3-D 张量 AijkA_{ijk} 可以通过索引变为 2-D 张量 Bjk=A2jkB_{jk} = A_{2jk}

// generate 3-D tensor A_ijk
let a = rt::arange(24).into_shape([4, 3, 2]);
println!("{:}", a);

// B_jk = A_ijk where i = 2
let b = a.slice(2); // equivalently `a.i(2)`
println!("{:}", b);
// output:
// [[ 12 13]
// [ 14 15]
// [ 16 17]]

更进一步,如果您希望对 i=2,j=0i = 2, j = 0 进行索引,即 Ck=A20kC_k = A_{20k},那么您可以将 [2, 0] 传递给 slice 函数:

// C_k = A_ijk where i = 2, j = 0
// surely, `a.slice(2).slice(0)` works, but we can use `a.slice([2, 0])` instead
let c = a.slice([2, 0]);
println!("{:}", c);
// output: [ 12 13]

RSTSR 也接受负索引以从数组末尾开始索引:

// D_jk = A_ijk where i = -1 = 3 (negative index from the end)
let d = a.slice(-1);
println!("{:}", d);
// output:
// [[ 18 19]
// [ 20 21]
// [ 22 23]]

2. 基础 slice

2.1 通过 slice

例如,我们希望从张量 AijkA_{ijk} 中提取 1i<31 \leq i < 3

// generate 3-D tensor A_ijk
let a = rt::arange(24).into_shape([4, 3, 2]);
println!("{:}", a);

// B_ijk = A_ijk where 1 <= i < 3
let b = a.slice(1..3); // equivalently `a.i(1..3)`
println!("{:}", b);
// output:
// [[[ 6 7]
// [ 8 9]
// [10 11]]
//
// [[12 13]
// [14 15]
// [16 17]]]

前两个维度的 slice 也可以通过以下方式实现:

// C_ijk = A_ijk where 1 <= i < 3, 0 <= j < 2
let c = a.slice([1..3, 0..2]);
println!("{:}", c);
// output:
// [[[ 6 7]
// [ 8 9]]
//
// [[12 13]
// [14 15]]]

负索引也适用于这种情况:

let a = rt::arange(24);
// D_i = A_i where i = -5..-2 = 19..22 (negative index from the end given 24 elements)
let d = a.slice(-5..-2);
println!("{:}", d);
// output: [ 19 20 21]

2.2 通过 range

RSTSR 不仅接受 Range 类型 (如 1..3),还接受 RangeTo (..3) 或 RangeFrom (1..)。

let a = rt::arange(24);
// D_i = A_i where i = -5.. or 19..
let d = a.slice(-5..);
println!("{:}", d);
// output: [ 19 20 21 22 23]

但需要注意的是,Rust 不允许将两种不同类型合并为 Rust 数组 [T]

// generate 3-D tensor A_ijk
let a = rt::arange(24).into_shape([4, 3, 2]).into_owned();

// different types can't be merged into rust array
// - `..` is RangeFull
// - `1..3` is Range
// - `..2` is RangeTo
let b = a.slice([.., 1..3, ..2]);
does_not_compile

要解决这个问题,您可以传递元组 (T1, T2) 而不是 Rust 数组 [T]

let a = rt::arange(24).into_shape([4, 3, 2]);
let b = a.slice((.., 1..3, ..2)); // equivalently `a.slice(s![.., 1..3, ..2])`
println!("{:}", b);
// output:
// [[[ 2 3]
// [ 4 5]]
//
// [[ 8 9]
// [ 10 11]]
//
// [[ 14 15]
// [ 16 17]]
//
// [[ 20 21]
// [ 22 23]]]

我们目前只实现了最多 10 个元素的元组;如果您的张量维度非常高,您可能需要使用 s!

3. 特殊索引

3.1 带步长的 slice

要进行带步长的 slice ,您可以使用 slice! 宏。slice! 宏的用法类似于 Python 的内置函数 slice1

  • slice!(stop):类似于范围到 ..stop
  • slice!(start, stop):类似于范围 start..stop
  • slice!(start, stop, step):类似于 Fortran 或 NumPy 的 slice start:stop:step
let a = rt::arange(24);

// first 5 elements
let b = a.slice(slice!(5));
println!("{:}", b);
// output: [ 0 1 2 3 4]

// elements from 5 to -9 (resembles 15 for the given 24 elements)
let b = a.slice(slice!(5, -9));
println!("{:}", b);
// output: [ 5 6 7 ... 12 13 14]

// elements from 5 to -9 with step 2
let b = a.slice(slice!(5, -9, 2));
println!("{:}", b);
// output: [ 5 7 9 11 13]

// reversed step 2
let b = a.slice(slice!(-9, 5, -2));
println!("{:}", b);
// output: [ 15 13 11 9 7]

在许多情况下,None 也是 slice! 的有效输入。实际上,slice! 是通过 Option<T> 的机制实现的,因此使用 Some(val) 也是有效的。

let b = a.slice(slice!(None, 9, Some(2)));
println!("{:}", b);
// output: [ 0 2 4 6 8]

3.2 插入 axis

您可以通过 NoneNewAxis (定义为 Indexer::Insert) 插入 axis。这类似于 NumPy 的 Nonenp.newaxis

let a = rt::arange(24).into_shape([4, 3, 2]);

// insert new axis at the beginning
let b = a.slice(NewAxis);
println!("{:?}", b.layout());
// output: shape: [1, 4, 3, 2], stride: [6, 6, 2, 1], offset: 0

// using `None` is equivalent to `NewAxis`
let b = a.slice(None);
println!("{:?}", b.layout());
// output: shape: [1, 4, 3, 2], stride: [6, 6, 2, 1], offset: 0

// insert new axis at the second position
let b = a.slice((.., None));
println!("{:?}", b.layout());
// output: shape: [4, 1, 3, 2], stride: [6, 2, 2, 1], offset: 0

使用 None 会比较方便,但我们不接受 Some(val) 进行索引。因此,尽管以下代码可以编译,但它实际上不起作用。

let a = rt::arange(24).into_shape([4, 3, 2]);

// insert new axis at the beginning
let b = a.slice(Some(2));
println!("{:?}", b.layout());
// panic: Option<T> should not be used in Indexer.
panics

3.3 省略号

在 RSTSR 中,您可以使用 Ellipsis (定义为 Indexer::Ellipsis) 来跳过一些索引:

let a = rt::arange(24).into_shape([4, 3, 2]);

// using ellipsis to select index from last dimension
// equivallently to `a.slice((.., .., 0))` for 3-D tensor
// same to numpy's `a[..., 0]`
let b = a.slice((Ellipsis, 0));
println!("{:2}", b);
// output:
// [[ 0 2 4]
// [ 6 8 10]
// [ 12 14 16]
// [ 18 20 22]]

3.4 混合索引和 slice

如前所述,使用数组类型 [T] 不适合表示各种类型的索引和 slice 。然而,您可以使用宏 s! 或元组来执行此任务2

let a: Tensor<f64> = rt::zeros([6, 7, 5, 9, 8]);

// mixed indexing
let b = a.slice((slice!(-2, 1, -1), None, None, Ellipsis, 1, ..-2));
println!("{:?}", b.layout());
// output: shape: [3, 1, 1, 7, 5, 6], stride: [-2520, 360, 360, 360, 72, 1], offset: 10088

4. 元素索引

元素索引效率不高

我们也在 RSTSR 中提供了元素索引。但请注意,在大多数情况下,元素索引并不高效。

  • 对于“未检查”的元素索引,它更有可能阻止编译器的内部向量化和 SIMD 优化;
  • 对于“安全”的元素索引,额外的越界检查会进一步阻碍优化。

因此,对于计算密集型任务,建议使用 RSTSR 内部的算术函数或映射函数,或者自己编写高效率的程序,以避免直接进行元素索引。只有在效率不重要或 RSTSR 内部函数无法满足需求时,才使用元素索引。

4.1 安全的元素索引

要执行索引,您可以使用 Rust 的方括号 []

let a = rt::arange(24).into_shape([4, 3, 2]);

let val = a[[2, 2, 1]];
println!("{:}", val);
// output: 17

println!("{:}", std::any::type_name_of_val(&val));
// output: i32

如果您提供的索引越界,RSTSR 会崩溃:

let a = rt::arange(24).into_shape([4, 3, 2]);

let val = a[[2, 2, 3]];
println!("{:}", val);
// panic: Error::ValueOutOfRange : "idx" = 3 not match to pattern 0..(shp as isize) = 0..2
panics

在 RSTSR 中,slice (到张量视图) 和元素索引 (到值的引用) 是不同的。如果您希望得到一个值而不是单个元素的张量,请不要使用函数 slice

let view = a.slice((2, 2, 1));
println!("{:}", view);
// output: 17

// it seems to be a value, but actually it is a tensor view
println!("{:?}", view);
// output:
// === Debug Tensor Print ===
// 17
// DeviceFaer { base: DeviceCpuRayon { num_threads: 0 } }
// 0-Dim (dyn), contiguous: CcFf
// shape: [], stride: [], offset: 17
// ==========================

4.2 未检查的元素索引

未检查的元素索引会比安全的元素索引稍快一些。要执行索引,您可以使用不安全的函数 index_uncheck

let a = rt::arange(24).into_shape([4, 3, 2]);

let val = unsafe { a.index_uncheck([2, 2, 1]) };
println!("{:}", val);
// output: 17

如果您提供的索引越界,但索引指针位置仍然处于合理的底层内存,RSTSR 不会崩溃并返回错误的值:

let a = rt::arange(24).into_shape([4, 3, 2]);

let val = unsafe { a.index_uncheck([2, 2, 3]) };
println!("{:}", val);
// output: 19
// not desired: last dimension index 3 is out of bound
not_desired_behavior

此函数被标记为 unsafe 是为了避免这种越界 (但未超出内存) 的情况。在大多数情况下,它仍然是内存安全的,因为超出内存访问 Vec<T> 会正常地崩溃。

Footnotes

  1. 在 ndarray 中,这是通过 s![start..stop;step] 完成的。ndarray 的解决方案更为简洁。然而,我们坚持使用看似冗长的 slice! 宏来生成带步长的 slice 。

  2. 在大多数情况下,宏 s! 和元组的工作方式相同;然而,它们在程序中的定义不同。s! 应该在更多场景中工作。