Skip to main content

Common Functions

1. Common Elementwise Functions

In RSTSR, most functions required by the Python Array API have been implemented. They can mostly be called either as regular Rust functions or as associated methods.

For example, under row-major ordering, elementwise comparison between two tensors can be performed following broadcasting rules:

#[rustfmt::skip]
let a = rt::asarray(vec![
5., 2.,
3., 6.,
1., 8.,
]).into_shape([3, 2]);
let b = rt::asarray(vec![3., 4.]);

// broadcasted comparison a >= b
// called by associated method
let c = a.greater_equal(&b);

println!("{:5}", c);
// output:
// [[ true false]
// [ true true ]
// [ false true ]]

Similarly, sine calculation can be applied to a tensor:

let b = rt::asarray(vec![3., 4.]);
let d = rt::sin(&b);
println!("{:6.3}", d);
// output: [ 0.141 -0.757]
Some binary elementwise functions have shorthand names

Common binary functions include exponentiation pow, floor division floor_divide, greater-than-or-equal greater_equal, etc. Among these, comparison binary functions typically have shorthand names; for example, greater_equal can be abbreviated as ge.

Binary functions with shorthand names generally cannot be called as associated methods (to avoid conflicts with traits like PartialOrd), but can be called as regular Rust functions.

let a = rt::asarray(vec![0, 2, 5, 7]);
let b = rt::asarray(vec![1, 3, 4, 6]);

let c = a.greater_equal(&b);
println!("{:5}", c);
// output: [ false false true true ]

let c = rt::greater_equal(&a, &b);
println!("{:5}", c);
// output: [ false false true true ]

let c = rt::ge(&a, &b);
println!("{:5}", c);
// output: [ false false true true ]
Some unary functions will consume the input Tensor

In RSTSR, almost all functions allow passing &TensorAny or TensorView as input; in such cases, the original tensor remains unchanged and is not consumed.

However, for certain computations (including arithmetic operations covered in the previous section), passing owned Tensor data is also allowed. Depending on the situation, the underlying data may be modified, making the tensor unusable afterward. This applies to many unary functions in RSTSR as well, so ownership considerations are important when using them.

Take the sine function as an example:

let b = rt::asarray(vec![3., 4.]);
let c = rt::sin(b);
let d = rt::cos(b);
does_not_compile

This will trigger an error message, where the compiler's hint is valuable:

error[E0382]: use of moved value: `b`
|
| let b = rt::asarray(vec![3., 4.]);
| - move occurs because `b` has type `...`, which does not implement the `Copy` trait
| let c = rt::sin(b);
| - value moved here
| let d = rt::cos(b);
| ^ value used here after move
|
help: consider borrowing `b`
|
| let c = rt::sin(&b);
| +

2. Mapping Functions

Although RSTSR implements many elementwise functions, we cannot possibly implement them all. For tensors on CPU devices, we provide mapping functions (with names containing "map") to meet users' customized mapping needs.

2.1 Unary Mapping

Here's an example calculating the Gamma function. We use the mapv function for mapping. Note that while this function can be chained, RSTSR doesn't support lazy evaluation, so chained functional calls to mapv won't be more efficient:

let a: Tensor<f64, _> = rt::asarray(vec![0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0]);
let b = a.mapv(libm::lgamma).mapv(libm::exp);
println!("{:6.3}", a);
println!("{:6.3}", b);
// output:
// [ 0.500 1.000 1.500 2.000 2.500 3.000 3.500 4.000]
// [ 1.772 1.000 0.886 1.000 1.329 2.000 3.323 6.000]

// please note that in RSTSR, mapv is not lazy evaluated
// so below code is expected to be more efficient
let b = a.mapv(|x| libm::exp(libm::lgamma(x)));
println!("{:6.3}", b);

// also, function `libm::tgamma` is equivalent to `libm::exp(libm::lgamma(x))`
// when numerical discrepancy is not a concern
let b = a.mapv(libm::tgamma);
println!("{:6.3}", b);
Correspondence with NumPy

In NumPy, the similar function would be np.vectorize. The above code can be equivalently written in NumPy as:

import numpy as np
import scipy

a = np.linspace(1.0, 10.0, 4096 * 4096)
f = np.vectorize(scipy.special.gamma)
b = f(a)

Although functionally similar, the motivations behind NumPy's and RSTSR's (or crate ndarray's) implementations differ slightly.

RSTSR's map functions are purely for function mapping, not for any instruction-level vectorization (SIMD). However, RSTSR still performs certain optimizations:

  • Mapping is executed along the most contiguous dimension possible;
  • Parallel processing is enabled for large tensors.

Even without RSTSR, users could achieve similar efficiency by manually implementing parallel loops on Vec<T>, though with slightly more complex code.

For NumPy, since Python's native for-loops are very slow, when mappings become moderately complex, NumPy's accelerated mapping functions using CPython techniques become necessary to maintain performance. Without using Python dialects (like Numba, JAX) or strategies like CPython/ctypes for acceleration, users have few alternatives beyond np.vectorize.

2.2 Mutable Unary Mapping

For mutable Tensor and TensorMut types, RSTSR also provides the mapvi function to perform in-place mapping without allocating new memory:

let mut vec_a = vec![0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0];
let mut a: TensorMut<f64, _> = rt::asarray(&mut vec_a);
a.mapvi(libm::tgamma);

// original vector is also modified
println!("{:6.3?}", vec_a);
// [ 1.772, 1.000, 0.886, 1.000, 1.329, 2.000, 3.323, 6.000]

2.3 Binary Mapping

For binary mapping, RSTSR provides the mapvb function:

#[rustfmt::skip]
let a = rt::asarray(vec![
5., 2.,
3., 6.,
1., 8.,
]).into_shape([3, 2]);
let b = rt::asarray(vec![3., 4.]);
let c = a.mapvb(&b, libm::fmin);
println!("{:6.3}", c);
// output:
// [[ 3.000 2.000]
// [ 3.000 4.000]
// [ 1.000 4.000]]

3. Reduction Operations

RSTSR currently supports several reduction operations, including summation, maximum value, standard deviation, etc. Adding the _axes suffix allows reduction along specific dimensions.

#[rustfmt::skip]
let a = rt::asarray(vec![
5., 2.,
3., 6.,
1., 8.,
]).into_shape([3, 2]);

let b = a.l2_norm();
println!("{:6.3}", b);
// output: 11.790

let b = a.sum_axes(-1);
println!("{:6.3}", b);
// output: [ 7.000 9.000 9.000]

let b = a.argmin_axes(0);
println!("{:6.3}", b);
// output: [ 2 0]

For higher-dimensional tensors, the _axes functions can also accept arrays specifying which dimensions to reduce:

let a = rt::linspace((-1.0, 1.0, 24)).into_shape([2, 3, 4]);
let b = a.mean_axes([0, 2]);
println!("{:6.3}", b);
// output: [ -0.348 -0.000 0.348]

As a special case, Tensor<bool, B, D> can also undergo sum or sum_axes operations, where true counts as 1 and false as 0:

let a = rt::asarray(vec![false, true, false, true, true]);
let b = a.sum();
println!("{:6}", b);
// output: 3

4. Linear Algebra (linalg)

Currently, RSTSR supports some linear algebra functionality from NumPy and SciPy. Typical linear algebra problems include Hermitian matrix eigenvalue problems, SVD decomposition, Cholesky decomposition, etc.

let device = DeviceFaer::new(4);
#[rustfmt::skip]
let a = rt::asarray((vec![
1.0, 0.5, 1.5,
0.5, 5.0, 2.0,
1.5, 2.0, 8.0,
], &device)).into_shape([3, 3]);

let c = rt::linalg::eigh(&a);
let (eigenvalues, eigenvectors) = c.into();

println!("{:8.5}", eigenvalues);
// [ 0.69007 4.01426 9.29567]

println!("{:8.5}", eigenvectors);
// [[ 0.98056 0.06364 -0.18561]
// [ -0.02335 -0.90137 -0.43242]
// [ -0.19482 0.42835 -0.88236]]
Slight differences in linear algebra capabilities across backends

Currently, RSTSR primarily develops for the DeviceOpenBLAS and DeviceFaer backends, with focus on the former. DeviceOpenBLAS typically implements more functionality, including but not limited to:

  • Generalized eigenvalue problems rt::linalg::eigh(&a, &b);
  • Triangular matrix solving rt::linalg::solve_triangular(&a, &b);
  • Solving eigenvalue problems by reusing memory through mutable references rt::linalg::eigh(a.view_mut()) (similar to SciPy's overwrite_a option).

Although DeviceFaer currently lacks some features, as a pure Rust backend, it offers greater portability compared to DeviceOpenBLAS.