跳到主要内容

算术运算与广播

作为一个张量工具包,RSTSR 提供了许多基础的算术运算。

本节我们仅讨论算术运算,下一节将介绍基于映射的计算。

1. 算术运算示例

RSTSR 可以处理 +-*/ 运算:

let a = rt::arange(5.0);
let b = rt::arange(5.0) + 1.0;

let c = &a + &b;
println!("{:}", c);
// output: [ 1 3 5 7 9]

let d = &a / &b;
println!("{:6.3}", d);
// output: [ 0.000 0.500 0.667 0.750 0.800]

let e = 2.0 * &a;
println!("{:}", e);
// output: [ 0 2 4 6 8]

RSTSR 可以通过运算符 % 处理矩阵乘法运算(矩阵-矩阵、矩阵-向量或向量-向量内积,并且在某些设备如 DeviceFaerDeviceOpenBLAS 中进行了优化):

let mat = rt::arange(12).into_shape([3, 4]);
let vec = rt::arange(4).into_shape([4]);

// matrix multiplication
let res = &mat % mat.t();
println!("{:3}", res);
// output:
// [[ 14 38 62]
// [ 38 126 214]
// [ 62 214 366]]

// matrix-vector multiplication
let res = &mat % &vec;
println!("{:}", res);
// output: [ 14 38 62]

// vector-matrix multiplication
let res = &vec % &mat.t();
println!("{:}", res);
// output: [ 14 38 62]

// vector inner dot
let res = &vec % &vec;
println!("{:}", res);
// output: 14

对于一些特殊情况,位运算和移位运算也是可用的:

let a = rt::asarray(vec![true, true, false, false]);
let b = rt::asarray(vec![true, false, true, false]);

// bitwise xor
let c = a ^ b;
println!("{:?}", c);
// output: [false true true false]

let a = rt::asarray(vec![9, 7, 5, 3]);
let b = rt::asarray(vec![5, 6, 7, 8]);

// shift left
let c = a << b;
println!("{:?}", c);
// output: [ 288 448 640 768]

上述示例应该已经涵盖了大多数张量算术运算的用法。本节接下来的文档将讨论一些高级主题。

2. 重载运算符 %

我们已经展示了 % 是矩阵乘法的运算符。这是 RSTSR 特有的用法。这可能会引起一些混淆,我们将讨论这个话题。

首先,我们遵循 NumPy 的惯例,* 始终是数乘,类似于 +,它不会进行矩阵乘法或向量内积。

let mat = rt::arange(12).into_shape([3, 4]);
let vec = rt::arange(4);

// element-wise matrix multiplication
let c = &mat * &mat;
println!("{:3}", c);
// output:
// [[ 0 1 4 9]
// [ 16 25 36 49]
// [ 64 81 100 121]]

// element-wise matrix-vector multiplication (broadcasting involved)
let d = &mat * &vec;
println!("{:2}", d);
// output:
// [[ 0 1 4 9]
// [ 0 5 12 21]
// [ 0 9 20 33]]

// element-wise vector multiplication
let e = &vec * &vec;
println!("{:}", e);
// output: [ 0 1 4 9]

NumPy 在 1.10 版本中通过 PEP 465 引入了 @ 符号用于矩阵乘法。对于 Rust 来说,使用相同的 @ 运算符作为矩阵乘法几乎是不可能的,这在 Rust 内部论坛 中已经充分讨论过(@ 已经被用作 模式绑定 的二元运算符)。从 RSTSR 开发者的角度来看,这非常不幸。

此外,其他类型的运算符(如 R 中的 %*%,Matlab 和 Julia 中的 .*,Mathematica 中的 .)在 Rust 语言中并不存在作为二元运算符。如果我们希望使用这些符号,需要编程语言层面的支持,而这些功能短期内不太可能稳定。

然而,我们认为尽管 % 通常被用作取余运算,但在向量或矩阵计算中使用较少。% 也与 */ 具有相同的运算符优先级。因此,我们决定在适当的情况下将 % 用作矩阵乘法的符号。

我们保留了函数 rem 用于取余运算,函数 matmul 用于矩阵乘法。

let a = rt::arange(6);

// remainder to scalar
let c = rt::rem(&a, 3);
println!("{:}", c);
// output: [ 0 1 2 0 1 2]

// remainder to array
let b = rt::asarray(vec![3, 2, 3, 3, 2, 2]);
let c = rt::rem(&a, &b);
println!("{:}", c);
// output: [ 0 1 2 0 0 1]
不要将 rem 作为关联(结构体成员)函数使用

我们已经展示了 rt::rem 是可以用于计算张量的余数:

let a = rt::arange(6);
let b = rt::asarray(vec![3, 2, 3, 3, 2, 2]);

// remainder to array
let c = rt::rem(&a, &b);
println!("{:}", c);
// output: [ 0 1 2 0 0 1]

然而,函数 tensor.rem(other) 并不是 rt::rem 的定义。它被定义为 Rust 的关联函数,通过 core::ops::Rem trait 实现。由于我们通过矩阵乘法重载了这个 trait,tensor.rem(other) 也会调用矩阵乘法运算。

// inner product (due to override to `Rem`)
let c = a.view().rem(&b);
println!("{:}", c);
// output: 35
not_desired_behavior

由于这种代码会引起混淆,我们建议 API 用户不要将 rem 作为关联函数使用。

3. 广播 (broadcasting)

广播 使许多张量操作变得非常简单。RSTSR 采用了 NumPy 或 Python Array API 的大部分广播规则。我们建议感兴趣的用户参考 NumPy 和 Python Array API 文档。

RSTSR 的初始开发者是一名计算化学家。我们将使用化学编程中的一个示例,展示如何在实际情况中使用广播。

3.1 数乘的示例

RI-MP2(resolution-identity Moller-Plesset 二阶微扰)的指数和近似(也称为 LT-OS-MP2)涉及以下计算:

YPia=YPiaϵiϵa\mathcal{Y}_{Pia} = Y_{Pia} \epsilon_{i} \epsilon_{a}
// task definition
let (naux, nocc, nvir) = (8, 2, 4); // subscripts (P, i, a)
let y = rt::arange(naux * nocc * nvir).into_shape([naux, nocc, nvir]);
let ei = rt::arange(nocc);
let ea = rt::arange(nvir);

这是 3-D 张量与 1-D 张量的数乘。在通常情况下,1-D 张量 ϵi\epsilon_{i}ϵa\epsilon_{a} 应该被广播并重复为 3-D 对应 EPiaocc=ϵi(P,a)E^\mathrm{occ}_{Pia} = \epsilon_i (\forall P, a)EPiavir=ϵa(P,i)E^\mathrm{vir}_{Pia} = \epsilon_a (\forall P, i),然后执行乘法:

YPia=YPiaEPiaoccEPiavir\mathcal{Y}_{Pia} = Y_{Pia} E^\mathrm{occ}_{Pia} E^\mathrm{vir}_{Pia}

这既不方便也不高效。通过广播,我们可以在不重复值的情况下为 1-D 张量插入 axis:

YPia=YPiaϵiϵa\mathcal{Y}_{Pia} = Y_{Pia} \epsilon_{\cdot i \cdot} \epsilon_{\cdot \cdot a}
// elementwise multiplication with broadcasting
// `None` means inserting axis, equivalent to `np.newaxis` in NumPy or `NewAxis` in RSTSR
let converted_y = &y * ei.slice((None, .., None)) * ea.slice((None, None, ..));

这种乘法仍然可以简化。根据 NumPy 的广播规则定义,它总是在第一个维度添加省略号。因此,任何在第一个维度插入 axis 的操作都可以被移除:

YPia=YPiaϵiϵa\mathcal{Y}_{Pia} = Y_{Pia} \epsilon_{i \cdot} \epsilon_{a}
// elementwise multiplication with simplified broadcasting
let converted_y = &y * &ei.slice((.., None)) * &ea;

最后,出于内存和效率的考虑,建议先执行 ϵiϵa\epsilon_{i \cdot} \epsilon_{a} 的数乘:

YPia=YPia(ϵiϵa)\mathcal{Y}_{Pia} = Y_{Pia} (\epsilon_{i \cdot} \epsilon_{a})
// optimize for memory access cost
let converted_y = &y * (&ei.slice((.., None)) * &ea);

3.2 矩阵乘法的示例

许多后 HF 方法涉及积分基变换,主要是从原始基(原子基或称为 AO)到分子轨道基(称为 MO):

YPai=μνYPμνCμiCνaY_{P ai} = \sum_{\mu \nu} Y_{P \mu \nu} C_{\mu i} C_{\nu a}

此操作涉及五个索引 P,μ,ν,a,iP, \mu, \nu, a, i,其中索引 a,ia, i 的数量小于 μ,ν\mu, \nu

// task definition
let (naux, nocc, nvir, nao, _) = (8, 2, 4, 6, 6); // subscripts (P, i, a, μ, ν)
let y_ao = rt::arange(naux * nao * nao).into_shape([naux, nao, nao]);
let c_occ = rt::arange(nao * nocc).into_shape([nao, nocc]);
let c_vir = rt::arange(nao * nvir).into_shape([nao, nvir]);

矩阵乘法的 广播规则 稍微复杂一些。然而,如果您熟悉广播规则,这个任务可以用非常简单的代码实现:

let y_mo = &c_occ.t() % &y_ao % &c_vir;
println!("{:?}", y_mo.layout());
此操作在效率上可以进一步优化

这段代码简单而优雅。它会在支持 rayon 的设备上正确处理多线程。

然而,它需要多次访问 3-D 张量,并且会生成一个临时的 3-D 张量。这在内存访问和内存成本上都不高效。

为了解决内存效率问题,可以使用并行 axis 迭代器执行此计算。但这种方法的代码的编写有一定难度。

另一种解决方案是通过有限度的 unsafe 代码,并行地对指标 PP 进行迭代:

use rayon::prelude::*;
let y_mo = unsafe { rt::empty([naux, nocc, nvir]) };
(0..naux).into_par_iter().for_each(|p| {
let mut y_mo = unsafe { y_mo.force_mut() };
y_mo.i_mut(p).assign(&c_occ.t() % &y_ao.i(p) % &c_vir);
});

4. 内存问题

这与值如何传递给算术运算有关。

4.1 通过算术运算符 (operator) 进行计算

在 Rust 中,变量的所有权和生命周期规则非常严格。以下代码将导致编译错误:

let a = rt::arange(5.0);
let b = rt::arange(5.0) + 1.0;

let c = a + b;
let d = a * b;
does_not_compile
    |     let c = a + b;
| - value moved here
| let d = a * b;
| ^ value used here after move
|
help: consider cloning the value if the performance cost is acceptable
|
| let c = a + b.clone();
| ++++++++

然而,在许多情况下,克隆张量的性能和内存成本是不可接受的。因此,更推荐通过以下方式执行计算,以避免内存拷贝和生命周期限制:

  • 使用张量的引用,
  • 使用张量的视图,
// arithmetic by reference
let c = &a + &b;

// arithmetic by view
let d = a.view() * b.view();

// generating a view is cheap, given tensor is large
let a_view = a.view();
let b_view = b.view();
let e = a_view * b_view;

需要注意的是,除了生命周期限制外,拥有所有权的张量仍然可以传递给算术运算。此外,在可能的情况下会应用就地算术运算(类型约束和广播能力)。例如,对于 1-D 张量加法,变量 c 的内存不会被分配,而是从变量 a 中重用。因此,如果您确定 a 不会再被使用,可以通过值传递 a,这样会更高效。

let a = rt::arange(5.0);
let b = rt::arange(5.0) + 1.0;
let ptr_a = a.as_ptr();
// if sure that `a` is not used anymore, pass `a` by value instead of reference
let c = a + &b;
let ptr_c = c.as_ptr();
// raw data of `a` is reused in `c`
// similar to `a += &b; let c = a;`
assert_eq!(ptr_a, ptr_c);

4.2 通过关联函数 (associated method) 进行计算

在 RSTSR 中,有三种方式执行算术运算:

  • 通过运算符:&a + &b
  • 通过函数:rt::add(&a, &b)
  • 通过关联函数:(&a).add(&b)a.view().add(&b)

您可能会发现关联函数的使用代码有些奇怪。实际上,a.add(&b) 在 RSTSR 中也是有效的,但这会消耗变量 a。以下代码由于这个问题将无法编译:

let a = rt::arange(5.0);
let b = rt::arange(5.0) + 1.0;

// below is valid, however `a` is moved
let c = a.add(&b);

// below is invalid
let d = a.div(&b);
// ^ value used here after move
// note: `std::ops::Add::add` takes ownership of the receiver `self`, which moves `a`
does_not_compile