Rust const 泛型 (Const Generics)的使用方法:解析const_unit_poc

对于const 泛型,笔者也了解甚少。借写该文章的机会,学习一下Rust的const 泛型(Const Generics)

const_unit_poc :利用const generics实现的物理单位库

const_unit_poc 是几天前推出的,率先使用const 泛型的库。该库的使用方法如下:

#![feature(const_generics, const_evaluatable_checked)]
use const_unit_poc::values::{m, kg, s, N};

let distance = 1.0 * m;
let mass = 18.0 * kg;
let force = distance * mass / (1.8 * s * 2.0 * s);
assert_eq!(force, 5.0 * N);

let mut mutable_distance = 3.2 * m;
mutable_distance -= 0.2 * m;
mutable_distance += 2.0 * m;

assert_eq!(mutable_distance, 5.0 * m);

代码及其简便!易懂!非常神奇!所以就根据该库的源代码看看const 泛型该如何使用。

const_unit_poc 源码分析

./src/lib.rs

#![feature(const_generics, const_evaluatable_checked, doc_cfg)]
#![allow(incomplete_features)]
#![cfg_attr(feature = "non_ascii", feature(non_ascii_idents))]

use std::ops;

pub mod units;
pub mod values;

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[allow(non_snake_case)]
pub struct SiUnit {
    m: i8,
    kg: i8,
    s: i8,
    A: i8,
    K: i8,
    mol: i8,
    cd: i8,
}

// 后面还有内容

这里引入了units, values两个mod,并且声明了SiUnit 这个结构体,里面的成员是SI的7个标准单位, 类型为i8。

看到这儿,笔者推测应该是使用 SiUnit 这个结构体来实例化各种常用的单位(如NN),其中成员的值表示该单位表示成基本单位后的指数的值(如N=ms2N=m*s^{-2}, 则N.m == 1; N.s == -2)。

再看一下引入的模块units的内容

units的实现

./src/units.rs

#![allow(non_upper_case_globals)]
use super::SiUnit;
const NONE: SiUnit = SiUnit { m: 0, kg: 0, s: 0, A: 0, K: 0, mol: 0, cd: 0 };
/// meter
pub const m: SiUnit = SiUnit { m: 1, ..NONE };
/// 中间内容省略
/// lux
pub const lx: SiUnit = SiUnit { m: -2, cd: 1, ..NONE };
/// square meter
pub const sq_m: SiUnit = SiUnit { m: 2, ..NONE };
/// cubic meter
pub const cu_m: SiUnit = SiUnit { m: 3, ..NONE };

的确和猜想一致!注意结构体实例化时候这儿使用了..NONE,这是用已经创建的实例None上更新部分参数,创建新的实例,见使用结构体更新语法从其他实例创建实例。这儿使用了const使得各个单位全局可见且不可变。

现在已经有了单位,但是我们注意到代码中assert_eq!(mutable_distance, 5.0 * m);不仅仅要检测单位的值,还要检测系数的值,结构体与数字相乘后又该如何表示呢?这个问题应该能够从另一个模块得到答案吧。

对单位系数的包装

./src/values.rs

#![allow(non_upper_case_globals)]
use crate::{units, Quantity};
// base
/// 1 meter
pub const m: Quantity<{ units::m }> = Quantity { raw_value: 1.0 };
/// 1 kilogram
pub const kg: Quantity<{ units::kg }> = Quantity { raw_value: 1.0 };
/// 1 second
pub const s: Quantity<{ units::s }> = Quantity { raw_value: 1.0 };
/// 后面还有

注意use crate::{units, Quantity}, 代表的是导入的当前crate中的unitsQuantity
const 泛型: 从这儿的可以大概可以看出,Quantity的泛型参数是一个const常量,包装每一个units.rs中的const常量值。这就是const generics吧! 后面的实例化根据类型推导应该又可以写做:Quantity<{ units::m }> { raw_value: 1.0 }。 所以是用raw_value来表示每一个单位的系数。

所以再来看看Quantity的定义吧。

###结构体 Quantity的定义

./src/lib.rs

#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
#[repr(transparent)]
pub struct Quantity<const U: SiUnit> {
    pub raw_value: f64,
}

简单干脆的使用const 类型作为泛型参数。

加减运算

./src/lib.rs

impl<const U: SiUnit> ops::Add for Quantity<U> {
    type Output = Self;

    fn add(self, rhs: Self) -> Self {
        Self { raw_value: self.raw_value + rhs.raw_value }
    }
}

impl<const U: SiUnit> ops::AddAssign for Quantity<U> {
    fn add_assign(&mut self, rhs: Self) {
        self.raw_value += rhs.raw_value;
    }
}

impl<const U: SiUnit> ops::Sub for Quantity<U> {
    type Output = Self;

    fn sub(self, rhs: Self) -> Self {
        Self { raw_value: self.raw_value - rhs.raw_value }
    }
}

impl<const U: SiUnit> ops::SubAssign for Quantity<U> {
    fn sub_assign(&mut self, rhs: Self) {
        self.raw_value -= rhs.raw_value;
    }
}

由于加减运算并不会变单位,所以就是对系数的加减,非常直接。

乘除运算

  • 一个数字与一个Quantity相乘,单位不变:

./src/lib.rs

impl<const U: SiUnit> ops::Mul<f64> for Quantity<U> {
    type Output = Quantity<U>;

    fn mul(self, rhs: f64) -> Self::Output {
        Quantity { raw_value: self.raw_value * rhs }
    }
}
  • 两个带单位的量相乘,单位发生变化,也就是Quantity<U>的泛型参数发生了变化,此时需要重新实例化一个Quantity<U>结构体:

./src/lib.rs

// Quantity相乘转换为unit相乘,UL.unit_mul(UR)
impl<const UL: SiUnit, const UR: SiUnit> ops::Mul<Quantity<UR>> for Quantity<UL>
where
    Quantity<{ UL.unit_mul(UR) }>: ,
{
    type Output = Quantity<{ UL.unit_mul(UR) }>;

    fn mul(self, rhs: Quantity<UR>) -> Self::Output {
        Quantity { raw_value: self.raw_value * rhs.raw_value }
    }
}
// unit相乘
impl SiUnit {
    const fn unit_mul(self, rhs: Self) -> Self {
        Self {
            m: self.m + rhs.m,
            kg: self.kg + rhs.kg,
            s: self.s + rhs.s,
            A: self.A + rhs.A,
            K: self.K + rhs.K,
            mol: self.mol + rhs.mol,
            cd: self.cd + rhs.cd,
        }
    }
}
  • 这儿对泛型的约束很神奇,竟然是这样写Quantity<{ UL.unit_mul(UR) }>: ,如果去掉这个where,则会报错unconstrained generic constant。虽然现在还不懂是什么意思....
error: unconstrained generic constant
   --> src\lib.rs:156:5
    |
156 |     type Output = Quantity<{ UL.unit_mul(UR) }>;     
    |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^     
    |
help: consider adding a `where` bound for this expression  
   --> src\lib.rs:156:28
    |
156 |     type Output = Quantity<{ UL.unit_mul(UR) }>;     
    |                            ^^^^^^^^^^^^^^^^^^^     
  • 另外一点,泛型中的const是调用一个const fn得到的时候,这个时候需要打大括号,如:Quantity<{ UL.unit_mul(UR) }> ,否则也会出错。这在最近的Stabilization report里面也有提到。

Const arguments

Const parameters are instantiated using const arguments. Any concrete const expression or const parameter as a standalone argument can be used. When applying an expression as const parameter, most expressions must be contained within a block, with two exceptions:

  1. literals and single-segment path expressions
  2. array lengths

This syntactic restriction is necessary to avoid ambiguity, or requiring infinite lookahead when parsing an expression as a generic argument.

In the cases where a generic argument could be resolved as either a type or const argument, we always interpret it as a type. This causes the following test to fail:

type N = u32;
struct Foo<const N: usize>;
fn foo<const N: usize>() -> Foo<N> { todo!() } // ERR

To circumvent this, the user may wrap the const parameter with braces, at which point it is unambiguously accepted.

type N = u32;
struct Foo<const N: usize>;
fn bar<const N: usize>() -> Foo<{ N }> { todo!() } // ok

简单来说,const 泛型一般都需要打大括号(两种情况除外),打大括号是为了防止歧义。

总结

这篇文章主要是了解了const 泛型在做啥,以及如何做的。

通过这个简单的例子,我感觉const 泛型是用在当泛型需要有取值的时候(比如这里的unit就需要取值,而不能每一个单位一个类型...)。