算术运算与广播
作为一个张量工具包,RSTSR 提供了许多基础的算术运算。
本节我们仅讨论算术运算,下一节将介绍基于映射的计算。
1. 算术运算示例
RSTSR 可以处理 +
、-
、*
、/
运算:
let a = rt::arange(5.0);
let b = rt::arange(5.0) + 1.0;
let c = &a + &b;
println!("{:}", c);
// output: [ 1 3 5 7 9]
let d = &a / &b;
println!("{:6.3}", d);
// output: [ 0.000 0.500 0.667 0.750 0.800]
let e = 2.0 * &a;
println!("{:}", e);
// output: [ 0 2 4 6 8]
RSTSR 可以通过运算符 %
处理矩阵乘法运算(矩阵-矩阵、矩阵-向量或向量-向量内积,并且在某些设备如 DeviceFaer
与 DeviceOpenBLAS
中进行了优化):
let mat = rt::arange(12).into_shape([3, 4]);
let vec = rt::arange(4).into_shape([4]);
// matrix multiplication
let res = &mat % mat.t();
println!("{:3}", res);
// output:
// [[ 14 38 62]
// [ 38 126 214]
// [ 62 214 366]]
// matrix-vector multiplication
let res = &mat % &vec;
println!("{:}", res);
// output: [ 14 38 62]
// vector-matrix multiplication
let res = &vec % &mat.t();
println!("{:}", res);
// output: [ 14 38 62]
// vector inner dot
let res = &vec % &vec;
println!("{:}", res);
// output: 14
对于一些特殊情况,位运算和移位运算也是可用的:
let a = rt::asarray(vec![true, true, false, false]);
let b = rt::asarray(vec![true, false, true, false]);
// bitwise xor
let c = a ^ b;
println!("{:?}", c);
// output: [false true true false]
let a = rt::asarray(vec![9, 7, 5, 3]);
let b = rt::asarray(vec![5, 6, 7, 8]);
// shift left
let c = a << b;
println!("{:?}", c);
// output: [ 288 448 640 768]
上述示例应该已经涵盖了大多数张量算术运算的用法。本节接下来的文档将讨论一些高级主题。
2. 重载运算符 %
我们已经展示了 %
是矩阵乘法的运算符。这是 RSTSR 特有的用法。这可能会引起一些混淆,我们将讨论这个话题。
首先,我们遵循 NumPy 的惯例,*
始终是数乘,类似于 +
,它不会进行矩阵乘法或向量内积。
let mat = rt::arange(12).into_shape([3, 4]);
let vec = rt::arange(4);
// element-wise matrix multiplication
let c = &mat * &mat;
println!("{:3}", c);
// output:
// [[ 0 1 4 9]
// [ 16 25 36 49]
// [ 64 81 100 121]]
// element-wise matrix-vector multiplication (broadcasting involved)
let d = &mat * &vec;
println!("{:2}", d);
// output:
// [[ 0 1 4 9]
// [ 0 5 12 21]
// [ 0 9 20 33]]
// element-wise vector multiplication
let e = &vec * &vec;
println!("{:}", e);
// output: [ 0 1 4 9]
NumPy 在 1.10 版本中通过 PEP 465 引入了 @
符号用于矩阵乘法。对于 Rust 来说,使用相同的 @
运算符作为矩阵乘法几乎是不可能的,这在 Rust 内部论坛 中已经充分讨论过(@
已经被用作 模式绑定 的二元运算符)。从 RSTSR 开发者的角度来看,这非常不幸。
此外,其他类型的运算符(如 R 中的 %*%
,Matlab 和 Julia 中的 .*
,Mathematica 中的 .
)在 Rust 语言中并不存在作为二元运算符。如果我们希望使用这些符号,需要编程语言层面的支持,而这些功能短期内不太可能稳定。
然而,我们认为尽管 %
通常被用作取余运算,但在向量或矩阵计算中使用较少。%
也与 *
和 /
具有相同的运算符优先级。因此,我们决定在适当的情况下将 %
用作矩阵乘法的符号。
我们保留了函数 rem
用于取余运算,函数 matmul
用于矩阵乘法。
let a = rt::arange(6);
// remainder to scalar
let c = rt::rem(&a, 3);
println!("{:}", c);
// output: [ 0 1 2 0 1 2]
// remainder to array
let b = rt::asarray(vec![3, 2, 3, 3, 2, 2]);
let c = rt::rem(&a, &b);
println!("{:}", c);
// output: [ 0 1 2 0 0 1]
rem
作为关联(结构体成员)函数使用我们已经展示了 rt::rem
是可以用于计算张量的余数:
let a = rt::arange(6);
let b = rt::asarray(vec![3, 2, 3, 3, 2, 2]);
// remainder to array
let c = rt::rem(&a, &b);
println!("{:}", c);
// output: [ 0 1 2 0 0 1]
然而,函数 tensor.rem(other)
并不是 rt::rem
的定义。它被定义为 Rust 的关联函数,通过 core::ops::Rem
trait 实现。由于我们通过矩阵乘法重载了这个 trait,tensor.rem(other)
也会调用矩阵乘法运算。
由于这种代码会引起混淆,我们建议 API 用户不要将 rem
作为关联函数使用。
3. 广播 (broadcasting)
广播 使许多张量操作变得非常简单。RSTSR 采用了 NumPy 或 Python Array API 的大部分广播规则。我们建议感兴趣的用户参考 NumPy 和 Python Array API 文档。
RSTSR 的初始开发者是一名计算化学家。我们将使用化学编程中的一个示例,展示如何在实际情况中使用广播。
3.1 数乘的示例
RI-MP2(resolution-identity Moller-Plesset 二阶微扰)的指数和近似(也称为 LT-OS-MP2)涉及以下计算:
// task definition
let (naux, nocc, nvir) = (8, 2, 4); // subscripts (P, i, a)
let y = rt::arange(naux * nocc * nvir).into_shape([naux, nocc, nvir]);
let ei = rt::arange(nocc);
let ea = rt::arange(nvir);
这是 3-D 张量与 1-D 张量的数乘。在通常情况下,1-D 张量 和 应该被广播并重复为 3-D 对应 和 ,然后执行乘法:
这既不方便也不高效。通过广播,我们可以在不重复值的情况下为 1-D 张量插入 axis:
// elementwise multiplication with broadcasting
// `None` means inserting axis, equivalent to `np.newaxis` in NumPy or `NewAxis` in RSTSR
let converted_y = &y * ei.slice((None, .., None)) * ea.slice((None, None, ..));
这种乘法仍然可以简化。根据 NumPy 的广播规则定义,它总是在第一个维度添加省略号。因此,任何在第一个维度插入 axis 的操作都可以被移除:
// elementwise multiplication with simplified broadcasting
let converted_y = &y * &ei.slice((.., None)) * &ea;
最后,出于内存和效率的考虑,建议先执行 的数乘:
// optimize for memory access cost
let converted_y = &y * (&ei.slice((.., None)) * &ea);
3.2 矩阵乘法的示例
许多后 HF 方法涉及积分基变换,主要是从原始基(原子基或称为 AO)到分子轨道基(称为 MO):
此操作涉及五个索引 ,其中索引 的数量小于 。
// task definition
let (naux, nocc, nvir, nao, _) = (8, 2, 4, 6, 6); // subscripts (P, i, a, μ, ν)
let y_ao = rt::arange(naux * nao * nao).into_shape([naux, nao, nao]);
let c_occ = rt::arange(nao * nocc).into_shape([nao, nocc]);
let c_vir = rt::arange(nao * nvir).into_shape([nao, nvir]);
矩阵乘法的 广播规则 稍微复杂一些。然而,如果您熟悉广播规则,这个任务可以用非常简单的代码实现:
let y_mo = &c_occ.t() % &y_ao % &c_vir;
println!("{:?}", y_mo.layout());
这段代码简单而优雅。它会在支持 rayon 的设备上正确处理多线程。
然而,它需要多次访问 3-D 张量,并且会生成一个临时的 3-D 张量。这在内存访问和内存成本上都不高效。
为了解决内存效率问题,可以使用并行 axis 迭代器执行此计算。但这种方法的代码的编写有一定难度。
另一种解决方案是通过有限度的 unsafe 代码,并行地对指标 进行迭代:
use rayon::prelude::*;
let y_mo = unsafe { rt::empty([naux, nocc, nvir]) };
(0..naux).into_par_iter().for_each(|p| {
let mut y_mo = unsafe { y_mo.force_mut() };
y_mo.i_mut(p).assign(&c_occ.t() % &y_ao.i(p) % &c_vir);
});
4. 内存问题
这与值如何传递给算术运算有关。
4.1 通过算术运算符 (operator) 进行计算
在 Rust 中,变量的所有权和生命周期规则非常严格。以下代码将导致编译错误:
| let c = a + b;
| - value moved here
| let d = a * b;
| ^ value used here after move
|
help: consider cloning the value if the performance cost is acceptable
|
| let c = a + b.clone();
| ++++++++
然而,在许多情况下,克隆张量的性能和内存成本是不可接受的。因此,更推荐通过以下方式执行计算,以避免内存拷贝和生命周期限制:
- 使用张量的引用,
- 使用张量的视图,
// arithmetic by reference
let c = &a + &b;
// arithmetic by view
let d = a.view() * b.view();
// generating a view is cheap, given tensor is large
let a_view = a.view();
let b_view = b.view();
let e = a_view * b_view;
需要注意的是,除了生命周期限制外,拥有所有权的张量仍然可以传递给算术运算。此外,在可能的情况下会应用就地算术运算(类型约束和广播能力)。例如,对于 1-D 张量加法,变量 c
的内存不会被分配,而是从变量 a
中重用。因此,如果您确定 a
不会再被使用,可以通过值传递 a
,这样会更高效。
let a = rt::arange(5.0);
let b = rt::arange(5.0) + 1.0;
let ptr_a = a.as_ptr();
// if sure that `a` is not used anymore, pass `a` by value instead of reference
let c = a + &b;
let ptr_c = c.as_ptr();
// raw data of `a` is reused in `c`
// similar to `a += &b; let c = a;`
assert_eq!(ptr_a, ptr_c);
4.2 通过关联函数 (associated method) 进行计算
在 RSTSR 中,有三种方式执行算术运算:
- 通过运算符:
&a + &b
; - 通过函数:
rt::add(&a, &b)
; - 通过关联函数:
(&a).add(&b)
或a.view().add(&b)
。
您可能会发现关联函数的使用代码有些奇怪。实际上,a.add(&b)
在 RSTSR 中也是有效的,但这会消耗变量 a
。以下代码由于这个问题将无法编译: