rust December 22, 2024

What would it take to add refinement types to Rust?

A few years ago, on a whim, I wrote YAIOUOM. YAOIOUM was a static analyzer for Rust that checked that the code was using units of measures correctly, e.g. a distance in meters is not a distance in centimeters, dividing meters by seconds gave you a value in m / s (aka m * s^-1).

YAIOUOM was an example of a refinement type system, i.e. a type system that does its work after another type system has already done its work. It was purely static, users could add new units in about one line of code, and it was actually surprisingly easy to write. It also couldn’t be written within the Rust type system, in part because I wanted legible error messages, and in part because Rust doesn’t offer a very good way to specify that (m / s) * s is actually the same type as m.

Sadly, it also worked only on a specific version of Rust Nightly, and the code broke down with every new version of Rust. It’s a shame, because I believe that there’s lots we could do with refinement types. Simple things such as units of measure, as above, but also, I suspect, we could achieve much better error messages for complex type-level programming, such as what Diesel is doing.

It got me to wonder how we could extend Rust in such a way that refinement types could be easily added to the language.

Starting with Rust’s type system

As I’ve already implemented one refinement type system for Rust, I’ll use it as a benchmark of features that would be needed to let us implement this in Rust.

Let’s assume that we have defined the types Value, Unit, Meter, Second to make the following possible:

struct Value<T, U> where U: Unit { ... }; // a value of type T (e.g. f64) with a unit of type U
let a = Value::<f64, Meter>::new(9.81); // f64 Meter aka Value<f64, Meter>
let b = Value::<f64, Second>::new(1.0); // f64 Second aka Value<f64, Second>

Furthermore, we have defined operations std::ops::Add and std::ops::Sub in such a way that we can write

let c = b + b; // f64 Second aka Value<f64, Second>
let d = b - b; // f64 Second aka Value<f64, Second>

let _ = a - b; // Error: cannot subtract `Value<f64, Second>` from `Value<f64, Meter>`

Multiplication and division are a bit more complicated, because they introduce new units, so let’s further assume that we have defined std::ops::Mul, std::ops::Div and two Unit combinators Mul<Left, Right> and Div<Left, Right> in such a way that we can write

let g = a / (b * b); // f64 Meter / (Second * Second) aka Value<f64, Div<Meter, Mul<Second, Second>>>>
let h = a / b / b;   // f64 Meter / (Second * Second) aka Value<f64, Div<Div<Meter, Second>, Second>>

So far, so good. Unless you want to compare g and h:

let diff = g - h; // Error: cannot subtract `Value<f64, Div<Div<Meter, Second>, Second>>` from `Value<f64, Div<Meter, Mul<Second, Second>>>`

which is a shame, because g and h represent the same value and Div<Div<Meter, Second>, Second>> and Div<Meter, Mul<Second, Second>> represent the same unit of measure.

Perhaps we could be smarter about it?

In the above encoding, we have used a definition of - with a type signature

impl<V, U> std::ops::Sub for Value<V, U> where V: std::ops::Sub {
    type Output = Value<<V as std::ops::Sub>::Output, U>;
    // ...
}

i.e.

for any number type V and any unit U, - takes two arguments of type Value<V, U> and produces an argument of type Value<V, U>.

But what we’d like here would be to express

for any number type V and any unit types Left and Right, - takes two arguments of type Value<V, Left> and Value<V, Right> and produces an argument of type Value<V, Left> if Left and Right are equivalent.

In other words, in terms of type-level programming, what we need here is some form of type-level oracle function Equivalent<Left, Right> that is defined if and only if Left and Right are indeed equivalent. Unfortunately, as far as I can tell1, in Rust, there is no way to define Equivalent<Left, Right> to combine the following three properties:

  1. Equivalent<Left, Right> is defined iff Left and Right are equivalent;
  2. keep the system extensible in such a way that other crates can create new units;
  3. don’t need the user to manually express equivalences between units.

So… what if we could expand the Rust type system to help us here?

…then expanding it

Our life is complicated because programming with types is complicated. However, if instead of having to implement Equivalent<Left, Right>, we needed to implement equivalent(left: UnitRepr, right: UnitRepr), where UnitRepr was an algebraic data structure representing our units of measures, our life would be much easier.

If you’re curious, Andrew Kennedy formalized the algorithm that does this ~15 years ago, and then implemented it as part of the F# compiler (my own work on YAIOUOM was essentially a port of his work to Rust), so let’s take it for granted that such an algorithm exists and that it’s possible to implement equivalent or something like it.

There are a few places where we could imagine plugging equivalent, or a variant thereof.

Option: Trait resolution

We could somehow plug into rustc_infer::trait, rustc’s implementation of trait resolution, whenever we attempt to resolve Equivalent<Left, Right>.

This would let us write

impl<V, Left, Right, C> std::ops::Sub<Value<V, Right>> for Value<V, Left>
where V: std::ops::Sub,
      Equivalence: Equivalent<Left, Right, Canonical=C>
{
    type Output = Value<<V as std::ops::Sub>::Output, C>;
    // ...
}

let diff = h - g; // Value<f64, C> where C is the canonicalized version of Div<Div<Meter, Second>, Second>
                  // and Div<Meter, Mul<Second, Second>>

As of rustc 1.85, this might look like an additional variant in SelectionCandidate.

I am not convinced by this approach, for several reasons:

  1. trait resolution is already quite complex, I’m afraid that this would make it even more complex, with the added risk of slowing down the compiler/never returning;
  2. it’s not entirely clear to me what should happen in case our plugged-in equivalent returns an error stating that no solution could be found;
  3. it’s not entirely clear to me what should happen in case our plugged-in equivalent can find more than one equivalence – I suspect that this would probably be caused by an error in the implementation of equivalent, so perhaps this question doesn’t need an answer;
  4. somehow, this feels too powerful.

Option: TypeVar unification

We could somehow plug into rustc_infer::infer, rustc’s implementation of type inference, wherever we attempt to unify a Unit<V, T> and a Unit<V, U>, where either T or U contains a type variable. This would let us keep the definition of subtraction that requires a Unit<V, U> on both sides.

I am also not convinced by this approach, for several reasons:

  1. type unification is already non-trivial, as the author of a refinement type system, I don’t want to have to reimplement part of unification;
  2. type unification happens all over the place, so there is a strong chance that this would slow down type checking considerably;
  3. somehow, this feels too powerful.

Option: Trait resolution (optimistic)

Instead of this, we could offer a lower-powered solution that takes place in two steps.

Step 1: During trait resolution, we entirely discard the unit information and assume that Left and Right are equivalent.

impl<V, Left, Right, C> std::ops::Sub<Value<V, Right>> for Value<V, Left>
where V: std::ops::Sub,
      record!(yaoioum, Equivalent<Left, Right>) // <!-- record for later post-processing
{
    type Output = Value<<V as std::ops::Sub>::Output, Left>;
    // ...
}

In general, this is of course false, as this would let us subtract meters from seconds.

Step 2: At some point before code generation, the compiler calls the plug-in and passes the list of equivalences that were assumed and lets the plug-in perform late rejection of types that are not actually equivalent.

Critically, this steps 2 takes place after both trait selection and unification, which means that type variables are already either resolved or truly generic. Intuitively, this means that the post-processing step only needs to perform type-checking, rather than any kind of type inference.

This variant feels faster, much less powerful than plugging an entirely new trait selection mechanism into rustc_infer, the behavior in case of error is clearer, and the plug-in has access to better information when displaying error messages. Also, it’s possible that the plug-in would only need rustdoc::clean::types::Type type information, which are easier to manipulate (and more stable) than rustc_infer::traits::select::{SelectionCandidate, EvaluationResult}.

My intuition, which would need to be confirmed, is that the post-processor would not always have access to enough information to decide whether an addition/subtraction is correct, a bit like the compiler cannot always determine whether 1.0 is a f32 or a f64. Just as is the case for number literals, I suspect that we could live with that, and reject anything that cannot be proven correct, as long as error messages are clear enough that a developer can add the missing annotations.

Option: TypeVar unification (optimistic)

Similarly, we could perform optimistic typevar unification, e.g.

  1. mark Unit unification as something that must be recorded by the compiler;
  2. assume that Unit<T, U> and Unit<T, V> are always unifiable if U or V is a type variable;
  3. perform post-processing to confirm unifications.

While this shares some of the benefits of optimistic trait resolution, I suspect that this would be harder to debug both for the author of the plug-in and for its users.

Option: Optimistic, pluggable keyword

In fact, in YAOIOUM, I implemented a variant on optimistic unification. When you write a method in Rust, you can mark it with an attribute, e.g.:

impl Value<T, U> where U: Unit {
    #[allow(unused_attributes)]
    #[yaoioum(method_unify)]
    pub fn unify<V: Unit>(self) -> Value<T, V> {
        Self {
            value: self.value,
            unit: PhantomData,
        }
    }
}

when writing a compiler plug-in, you can then walk the tree, looking for all instances of this attribute, and apply some post-processing to the method. In effect, you are defining something akin to a new keyword.

With this mechanism, the following definition still fails:

let diff = h - g; // Error: cannot subtract `Value<f64, Div<Div<Meter, Second>, Second>>` from `Value<f64, Div<Meter, Mul<Second, Second>>>`

But the following two definitions pass type-checking

let diff_1 = h.unify() - g; // f64 Meter / (Second * Second) aka Value<f64, Div<Meter, Mul<Second, Second>>>>
let diff_2 = h - g.unify(); // f64 Meter / (Second * Second) aka Value<f64, Div<Div<Meter, Second>, Second>>>

Since calls to unify can be detected by the plug-in, they can be accepted or rejected during the compilation phase.

While this variant is even more limited than the previous ones, and clearly less transparent for users, a few dozen minutes of experimenting with a compiler variant using it showed no really annoying case.

What about error messages?

As mentioned above, I expect that refinement types can have two benefits:

  • when they complement type-level programming, they make the type system more poweful;
  • when they replace type-level programming, it becomes possible to make error messages much clearer.

In the case of YAOIOUM, an error message looks like

let distance = Measure::<f64, Km>::new(300_000.);
let duration = Measure::<f64, S>::new(1.);
let _ = distance.unify() - duration;
// Error:
//         ----------------- in this unification
//    = note: expected unit of measure: `S`
//               found unit of measure: `Km`

which feels pretty neat.

What about an API?

Let’s look a bit further into the future, at a possible stable implementation. I believe that we’d like:

  • as many types as possible defined in regular, userland crates (in this case, Unit, Value, Mul, Div, etc.);
  • any checks that cannot be implemented within Rust’s regular type system implemented in a post-processing crate, executed during compilation.

This feels dual to the current design of procedural macros, which I think is a good sign. So, this could look like one of:

use refinement::Type;  // Functionally equivalent to rustdoc::clean::types::Type, but doesn't need `rustc_private`.
use refinement::Error; // Somewhat similar to rustc_errors, to be designed.

// If we go for post-processing trait selection.
//
// The directive means that we're only ever called when `selection` is an instance
// of `crate::yaoioum::Equivalent<Left, Right>`.
#[postprocess::trait(crate::yaoioum::Equivalent)]
fn confirm_trait_selection(span: Span, selection: &Type) -> Result<(), Error> {
    // ...
}

// If we go for post-processing type variable unification.
//
// The directive means that we're only ever called when both `left` and `right`
// are instances of `crate::yaoioum::Unit<T, _>`.
#[postprocess::unify(crate::yaoioum::Unit)]
fn confirm_typevar_unification(span: Span, left: &Type, right: &Type) -> Result<(), Error> {
    /// ...
}

// If we go for post-processing a pseudo-keyword.
//
// The directive means that we're only ever called upon a call to `Value::unify`, which
// entails that `args[0]` is an instance of `Value<T, U>` and `args[1]` is an instance of
// `Value<T, V>`.
#[postprocess::keyword(yaoioum(unify))]
fn confirm_keyword_unification(span: Span, args: &[&Type]) -> Result<(), Error> {
    /// ...
}

What’s next?

From here, the following steps would involve:

  • gathering feedback;
  • writing a rustc driver that supports plug-ins for the three optimistic interfaces mentioned above;
  • implementing a few refinement types using these plug-ins, including a new version of YAOIOUM and ideally some subset of Flux or Liquid Haskell or SQL.

I’m a bit busy working on analog quantum programming these days, so I’m not sure when and if I find time to do that. We’ll see :)


  1. If you find a way, drop me a line :) However, I suspect that any error messages would be fairly hard to read. ↩︎

Copyright: Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)

Author: David Teller

Posted on: December 22, 2024