Self Referential Structs in Rust (Part 2)

Self Referential Structs in Rust (Part 2)

Before You Read

I strongly advise you to read Part 1 if you haven't read it yet. It goes into the basics of self referential structs, how a basic solution can be achieved (given below) and the flaws in the design.

This article aims to fix the flaws and create a robust system.

Click here to go to Part 1.

TL;DR

  1. Create a wrapper type over the wrapped type and store the wrapped type as a raw pointer. Here our wrapper type is MeWrapper(*mut Me) over Me.

  2. Implement comparison functions (ie, Eq, Ord, PartialOrd) manually and Deref traits for the wrapper type (here, MeWrapper).

  3. Make sure that the comparison functions compare the actual wrapped types.

  4. Store MeWrapper instead of *mut Me in the holder (here Holder).

  5. Implement Drop properly by building the MeWrapper again and using it to remove the MeWrapper from Holder.

  6. Use derivative crate to "Ignore" the comparison of my_holder in Me by Ord using #[derivative(Ord = "ignore")].

Click here to jump to solution.

Recap

Let us see what we have done till now.

use std::{
    cell::RefCell, collections::BTreeSet, marker::PhantomPinned, pin::Pin,
    rc::Rc,
};
type RcCell<T> = Rc<RefCell<T>>;

#[derive(Debug)]
struct Holder {
    set_of_me: BTreeSet<*mut Me>,
}

impl Holder {
    fn new() -> RcCell<Self> {
        Rc::new(RefCell::new(Self {
            // this is initially empty
            set_of_me: Default::default(),
        }))
    }

    /// Mutate every value of `Me`.
    /// Note how a pinned value is reconstructed.
    fn mutate_value_of_me(&self, val: i32) {
        self.set_of_me.iter().for_each(|a| {
            let a = unsafe { Pin::new_unchecked(&mut **a) };
            a.mutate_me(val);
        })
    }
}

#[derive(Debug)]
struct Me {
    name: String,
    mutate_by_holder: i32,
    my_holder: RcCell<Holder>,
    _pinned: PhantomPinned,
}

impl Me {
    /// Accept a name
    pub fn new(
        holder: RcCell<Holder>,
        name: impl Into<String>,
    ) -> Pin<Box<Self>> {
        let mut this = Box::pin(Self {
            name: name.into(),
            mutate_by_holder: 0,
            my_holder: holder,
            _pinned: PhantomPinned,
        });

        let this_ptr: *mut _ = unsafe { this.as_mut().get_unchecked_mut() };
        this.my_holder.borrow_mut().set_of_me.insert(this_ptr);

        this
    }

    /// Allows you to mutate a value within me.
    /// Run this from `Holder` to see what happens.
    fn mutate_me(self: Pin<&mut Self>, val: i32) {
        let this = unsafe { self.get_unchecked_mut() };
        this.mutate_by_holder += val;
    }
}

impl Drop for Me {
    fn drop(&mut self) {
        println!("Dropping {:#?}", self);
        let this = &(self as *mut _);
        self.my_holder.borrow_mut().set_of_me.remove(this);
    }
}

/// A test function to play with `Holder`s.
fn make_ref_of_holder(holder: RcCell<Holder>) {
    let holder = Rc::clone(&holder);
    println!("Making a ref of {:?}", holder);
    println!("No. of refs = {}", Rc::strong_count(&holder));
}

pub fn main() {
    let holder = Holder::new();

    // be explicit about `Rc`'s cloning
    let a = Me::new(Rc::clone(&holder), "a");
    let b = Me::new(Rc::clone(&holder), "b");

    holder.borrow().mutate_value_of_me(455);
    make_ref_of_holder(Rc::clone(&holder));
}

Discovery of the Flaw

As we saw in Part 1, insert the following line in the main() function above:

// ...
pub fn main() {
    let holder = Holder::new();

    // be explicit about `Rc`'s cloning
    let a = Me::new(Rc::clone(&holder), "a");
    let b = Me::new(Rc::clone(&holder), "b");
+   let c = Me::new(Rc::clone(&holder), "b");

    holder.borrow().mutate_value_of_me(455);
    make_ref_of_holder(Rc::clone(&holder));
}

We might think, "Oh since we are storing our Me objects in a set, c will not be added."

But notice that the definition of set_of_me is BTreeSet<*mut Me>. Which means, the set stores unique pointers to Me, not unique Mes itself! And that's why c gets stored in holder.

Understanding the Problem

So, we need write code in a way that it compares instances of Me instead of comparing *mut Me.

Now, notice that *mut Me is basically a pointer to an instance of Me. So, we can write a wrapper for *mut Me that allows us to dereference and access the underlying Me.

At last, we need to implement some traits so that the wrapper can be used as a placeholder to compare the underlying Mes.

Building a Solution

The Wrapper

Now, we need a wrapper. Let's call it MeWrapper, and since it has only one non-zero-sized field, we will use #[repr(transparent)].

#[repr(transparent)]
struct MeWrapper(*mut Me);

Read more about #[repr(transparent)] here, as it is beyond the scope of this article.

The Dereferencing Abilities

Since we want to be able to access the underlying Me without any hassle of using unsafe, we are going to implement Deref.

impl Deref for MeWrapper {
    type Target = Me;
    fn deref(&self) -> &Self::Target {
        unsafe { return &*self.0 }
    }
}

The Comparison Functions

Now, it's time to write the comparison functions. They are fairly simple to understand. We simply take advantage of Deref trait and access the underlying Me. Then we compare the underlying objects using their implementations of the same traits.

impl Eq for MeWrapper {}

impl PartialEq for MeWrapper {
    fn eq(&self, other: &Self) -> bool {
        let this = &**self;
        let other = &**other;
        this == other
    }
}

impl Ord for MeWrapper {
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
        let this = &**self;
        let other = &**other;
        this.cmp(&other)
    }
}

impl PartialOrd for MeWrapper {
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        let this = &**self;
        let other = &**other;
        this.partial_cmp(&other)
    }
}

Note that we still haven't implemented the traits on Holder and Me. To make this wrapper work, we need to implement them. Don't worry though, we will discuss about that when the necessity comes.

Holder holds our Wrappers

This part is simple. We simply replace *mut Me with MeWrapper.

#[derive(Debug)]
struct Holder {
-    set_of_me: BTreeSet<*mut Me>,
+    set_of_me: BTreeSet<MeWrapper>,
}

Hence, it'd look like:

#[derive(Debug)]
struct Holder {
     set_of_me: BTreeSet<MeWrapper>,
}

Easy Conversion of *mut Me to MeWrapper

We simply take advantage of From trait and implement for *mut Me.

impl From<*mut Me> for MeWrapper {
    fn from(value: *mut Me) -> Self {
        Self(value)
    }
}

This automatically gives us the ability to convert *mut Me to MeWrapper using .into().

Dropping Wrappers Properly

Now we just need to convert this into a MeWrapper and remove it from the set. Not that we take advantage of From trait by calling the .into() function.

impl Drop for Me {
    fn drop(&mut self) {
        println!("Dropping {:#?}", self);
-       let this = &(self as *mut _);
+       let this = (self as *mut Self).into();
-       self.my_holder.borrow_mut().set_of_me.remove(this);
+       self.my_holder.borrow_mut().set_of_me.remove(&this);
    }
}

Hence it'd look like:

impl Drop for Me {
    fn drop(&mut self) {
        println!("Dropping {:#?}", self);
        let this = (self as *mut Self).into();
        self.my_holder.borrow_mut().set_of_me.remove(&this);
    }
}

Inserting Wrappers into Holder Properly

Just like Drop trait, we will update Me::new(...) to insert the wrappers properly.

    pub fn new(
        holder: RcCell<Holder>,
        name: impl Into<String>,
    ) -> Pin<Box<Self>> {
        let mut this = Box::pin(Self {
            name: name.into(),
            mutate_by_holder: 0,
            my_holder: holder,
            _pinned: PhantomPinned,
        });

        let this_ptr: *mut _ = unsafe { this.as_mut().get_unchecked_mut() };
-       this.my_holder.borrow_mut().set_of_me.insert(this_ptr);
+       this.my_holder.borrow_mut().set_of_me.insert(this_ptr.into());

        this
    }

Hence it’d look like:

pub fn new(
    holder: RcCell<Holder>,
    name: impl Into<String>,
) -> Pin<Box<Self>> {
    let mut this = Box::pin(Self {
        name: name.into(),
        mutate_by_holder: 0,
        my_holder: holder,
        _pinned: PhantomPinned,
    });

    let this_ptr: *mut _ = unsafe { this.as_mut().get_unchecked_mut() };
    this.my_holder.borrow_mut().set_of_me.insert(this_ptr.into());

    this
}

More on Comparing Objects

This is where the fun and complication begins. So buckle up and grab a cup of coffee.

Let's say we did add comparison traits to the Me structs via #[derive(...)].

-#[derive(Debug)]
+#[derive(Debug, PartialOrd, Ord, PartialEq, Eq)]
struct Me {
    name: String,
    mutate_by_holder: i32,
    my_holder: RcCell<Holder>,
    _pinned: PhantomPinned,
}

And thus, we also add the same to the Holder struct.

-#[derive(Debug)]
+#[derive(Debug, PartialOrd, Ord, PartialEq, Eq)]
struct Holder {
    set_of_me: BTreeSet<MeWrapper>,
}

Now go on and execute the program and see what happens. You'll face something like:

thread 'main' panicked at 'already mutably borrowed: BorrowError'

So, why did that happen?

RefCell and Borrows

Now, to understand that, we have to understand what BorrowError is. The BorrowError is basically the feature of RefCell, where Rust's borrow checking is performed dynamically.

Now normally, Rust complains about borrow issues when we either:

  1. When multiple mutable references are created.

  2. Create an immutable reference when a mutable reference of the same object is being used.

In RefCell, we create mutable references with .borrow_mut() and immutable references with .borrow().

Keeping this previous statement in mind, we have to find where one of the two rules is being violated.

Search for BorrowError's source

Look at the source code, especially Holder's initialization. We create a Rc<RefCell<Holder>>. That means, the issue must be with or related to Holder.

Now look for the part where we create a mutable reference of Holder (via .borrow_mut()).

let this_ptr: *mut Self = unsafe { this.as_mut().get_unchecked_mut() };
this.my_holder.borrow_mut().set_of_me.insert(this_ptr.into());

Here we are creating a mutable reference to Holder and inserting a MeWrapper. Now let's check for any violations of the two rules stated above.

  1. Nope. We are creating only one mutable reference via .borrow_mut() and using to insert values in set_of_me.

  2. This is where things get tricky. Read on, and make sure you have your seatbelts on!

Relation Between Comparing Objects and Borrowing

Now let us focus on this part of the code:

set_of_me.insert(this_ptr.into());
//        ^^^^^^^^^^^^^^^^^^^^^^

What are we doing here? We are inserting a MeWrapper in a BTreeSet. And how does a BTreeSet check for uniqueness of a value? By comparing with other values.

So now we have atleast some idea of where the error originates from. It is definitely related to some sort of comparison function(s).

Now, look for the comparison functions we've implemented for MeWrapper. They are eq(), cmp() and partial_cmp().

PartialEq

PartialEq is a simple one:

this == other

It just checks for equality of two values. Here, they are instances of references to Me. So, this is not the culprit.

Ord and PartialOrd

Try an experiment. Put dbg!(""); in both cmp() and partial_cmp(), just before the functions are executed and result is returned.

impl Ord for MeWrapper {
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
        let this = &**self;
        let other = &**other;
+       dbg!("did it fail in cmp?");
        this.cmp(&other)
    }
}

impl PartialOrd for MeWrapper {
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        let this = &**self;
        let other = &**other;
+       dbg!("did it fail in partial_cmp?");
        this.partial_cmp(&other)
    }
}

Execute it and you'll notice that cmp() is executed just before the error happens! So, what is going on?

Ord, Borrowing and BorrowError

First, let us look at the signature of insert() function of BTreeSet:

pub fn insert(&mut self, value: T) -> bool
where
    T: Ord,
{
    // ...
}

Here, we see that .insert() expects a value that implements Ord. Thus, we were correct about Ord's involvement in the error.

Now the question arises, why does the error happen?

During comparison via .cmp(), every value in a given struct is compared with the corresponding values of the other struct. And RefCell<T> is no exception. Now what's interesting is what actually happens during a .cmp() call for RefCell<T>.

If we look at the source of RefCell::cmp() , we will find that it borrows the holder inside it immutably.

fn cmp(&self, other: &RefCell<T>) -> Ordering {
    self.borrow().cmp(&*other.borrow())
}

But simultaneously we are also borrowing Holder mutably in this line:

this.my_holder.borrow_mut().set_of_me.insert(this_ptr.into());

And the official docs says this about Ord, PartialOrd and PartialEq traits:

Panics if the value in either RefCell is currently borrowed.

But this error is not unique to Ord. We get this error from Ord only because of its involvement in .insert() (look at .insert()'s signature again).

Now that we know where the error is coming from and its reason, we need to somehow disable the comparison of Holder values. You might ask, "Isn't it daunting, to write custom trait implementations just to prevent comparison of a particular field? Only if we had something to automatically disable the comparison of unreqired fields." Yes, it is daunting, and a total waste of time. And thankfully, we have a pretty nice solution.

To #[derive] or not to #[derive]

Meet derivative! It is an amazing crate that automates derivation of several traits and performs the heavy-lifting of conditional derivation for us!

Let us focus on our issue. We need to disable the comparison of holder in Me by Ord trait.

Start by bringing derivative crate into scope using extern. And by the way, don't forget to include it in Cargo.toml's dependencies! Then, put a #[macro_use] above that statement to actually use the macro.

#[macro_use]
extern crate derivative;

Now, we need to disable Ord for my_holder field in Me struct. Or in derivative's words, we need to ignore the field my_holder. As we shall see, it is extremely easy!

-#[derive(Debug)]
+#[derive(Debug, Derivative)]
+#[derivative(PartialEq, Eq, PartialOrd, Ord)]
struct Me {
    name: String,
    mutate_by_holder: i32,
+   #[derivative(Ord = "ignore")]
    my_holder: RcCell<Holder>,
    _pinned: PhantomPinned,
}

We follow these simple steps:

  1. Derive Derivative via #[derive(...)].

  2. Use #[derivative(...)] on struct to derive the required traits.

  3. "Ignore" a particular field by specifying #[derivative(<trait name> = "ignore")].

And boom! That's all it takes to solve the issue!

Speaking of derivative, we can also take advantage of its custom Debug trait derivation for MeWrapper!

#[repr(transparent)]
#[derive(Derivative)]
#[derivative(Debug = "transparent")]
struct MeWrapper(*mut Me);

The Final Solution

And this is the solution! Run it, play with it, break it!

#[macro_use]
extern crate derivative;

use std::{
    cell::RefCell, collections::BTreeSet, marker::PhantomPinned, pin::Pin,
    rc::Rc,
};
type RcCell<T> = Rc<RefCell<T>>;

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
struct Holder {
    set_of_me: BTreeSet<MeWrapper>,
}

impl Holder {
    fn new() -> RcCell<Self> {
        Rc::new(RefCell::new(Self {
            // this is initially empty
            set_of_me: Default::default(),
        }))
    }

    /// Mutate every value of `Me`.
    /// Note how a pinned value is reconstructed.
    fn mutate_value_of_me(&self, val: i32) {
        self.set_of_me.iter().for_each(|a| {
            let a = unsafe { Pin::new_unchecked(&mut *a.0) };
            a.mutate_me(val);
        })
    }
}

#[derive(Debug, Derivative)]
#[derivative(PartialEq, Eq, PartialOrd, Ord)]
struct Me {
    name: String,
    mutate_by_holder: i32,
    #[derivative(Ord = "ignore")]
    my_holder: RcCell<Holder>,
    _pinned: PhantomPinned,
}

impl Me {
    /// Accept a name
    pub fn new(
        holder: RcCell<Holder>,
        name: impl Into<String>,
    ) -> Pin<Box<Self>> {
        let mut this = Box::pin(Self {
            name: name.into(),
            mutate_by_holder: 0,
            my_holder: holder,
            _pinned: PhantomPinned,
        });

        let this_ptr: *mut Self = unsafe { this.as_mut().get_unchecked_mut() };
        this.my_holder
            .borrow_mut()
            .set_of_me
            .insert(this_ptr.into());

        this
    }

    /// Allows you to mutate a value within me.
    /// Run this from `Holder` to see what happens.
    fn mutate_me(self: Pin<&mut Self>, val: i32) {
        let this = unsafe { self.get_unchecked_mut() };
        this.mutate_by_holder += val;
    }
}

#[repr(transparent)]
#[derive(Derivative)]
#[derivative(Debug = "transparent")]
struct MeWrapper(*mut Me);

impl From<*mut Me> for MeWrapper {
    fn from(value: *mut Me) -> Self {
        Self(value)
    }
}

impl std::ops::Deref for MeWrapper {
    type Target = Me;
    fn deref(&self) -> &Self::Target {
        unsafe { return &*self.0 }
    }
}

impl Eq for MeWrapper {}

impl PartialEq for MeWrapper {
    fn eq(&self, other: &Self) -> bool {
        let this = &**self;
        let other = &**other;
        this == other
    }
}

impl Ord for MeWrapper {
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
        let this = &**self;
        let other = &**other;
        this.cmp(&other)
    }
}

impl PartialOrd for MeWrapper {
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        let this = &**self;
        let other = &**other;
        this.partial_cmp(&other)
    }
}

impl Drop for Me {
    fn drop(&mut self) {
        println!("Dropping {:#?}", self);
        let this = (self as *mut Self).into();
        self.my_holder.borrow_mut().set_of_me.remove(&this);
    }
}

/// A test function to play with `Holder`s.
fn make_ref_of_holder(holder: RcCell<Holder>) {
    let holder = Rc::clone(&holder);
    println!("Making a ref of {:?}", holder);
    println!("No. of refs = {}", Rc::strong_count(&holder));
}

pub fn main() {
    let holder = Holder::new();

    // be explicit about `Rc`'s cloning
    let a = Me::new(Rc::clone(&holder), "a");
    let b = Me::new(Rc::clone(&holder), "b");
    let c = Me::new(Rc::clone(&holder), "b");

    holder.borrow().mutate_value_of_me(455);
    make_ref_of_holder(Rc::clone(&holder));
}

Conclusion

I hope this article helps you in understanding Rust's Self Referential Structs a bit better.

Questions? Comments? Concerns? Please put them down below and I'd be happy to help you.

Cover Image Source: Manjaro's /usr/share/backgrounds folder 😃

Did you find this article valuable?

Support Arunanshu's Ramblings by becoming a sponsor. Any amount is appreciated!