The castable model mixin

Javier Ayres
August 15, 2017

If you've ever used multi-table inheritance in Django, chances are that you wrote several queries for your parent model but at the time of processing each instance you wished you could have "casted" it to proper child model. Maybe you even wrote some code to handle it. We've been there as well and here we'll show you our approach to tackle this particular issue.

Keep in mind that this solution is intended to be used with Python 3.

The problem

If you haven't, take a look at the Django documentation for multi-table model inheritance. We will extend the example models presented there by adding another subclass of Place, called CoffeeShop.

class CoffeeShop(Place):
    serves_decaffeinated = models.BooleanField(default=False)

Now we have two models subclassing Place: Restaurant and CoffeeShop. A typical use case could be a listing of all our places in the homepage of our site, but we don't want to do two different queries (one for each subclass, or even three queries if we have plain places) because we need to mantain some specific order across all items. We will also need to access specific fields or logic that belong to the subclasses.

The solution

class CastableModelMixin:
    """
    Add support to cast an object to its final class
    """

    def cast(self):
        cls = self.__class__
        subclasses = cls.all_subclasses()

        if len(subclasses) == 0:
            return self

        for subclass in subclasses:
            try:
                obj = getattr(self, subclass._meta.model_name, None)
                if obj is not None:
                    # select_related doesn't fill child with parent relateds
                    descriptors = [getattr(cls, field.name)
                                   for field in cls._meta.get_fields()
                                   if field.is_relation and field.many_to_one]
                    for descriptor in descriptors:
                        if descriptor.is_cached(self):
                            setattr(obj,
                                    descriptor.cache_name,
                                    getattr(self, descriptor.cache_name))
                    if hasattr(self, '_prefetched_objects_cache'):
                        obj._prefetched_objects_cache = \
                            self._prefetched_objects_cache
                    return obj
            except ObjectDoesNotExist:
                pass

        return self

    @classmethod
    def all_subclasses_model_names(cls):
        model_names = []
        for subclass in cls.all_subclasses():
            if not (subclass._meta.proxy or subclass._meta.abstract):
                model_names.append(subclass._meta.model_name)
        return model_names

    @classmethod
    def all_subclasses(cls):
        return [g for s in cls.__subclasses__() for g in s.all_subclasses()] + cls.__subclasses__()

    @property
    def model(self):
        return self.cast()._meta.model_name

    @property
    def verbose_name(self):
        return self.cast()._meta.verbose_name.capitalize()

By making our Place model extend this mixin, we can call cast() on its instances and it will return a new object, which is actually an instance of the appropriate subclass or simply the same object in case there is no subclass for that particular instance.

Let's review the code:

    def cast(self):
        cls = self.__class__
        subclasses = cls.all_subclasses()

        if len(subclasses) == 0:
            return self

The cast method starts by saving the current instance class in a variable and all of its subclasses in another variable. If there are no subclasses, it will simply return itself. Let's take a look at the method that returns all the subclasses.

    @classmethod
    def all_subclasses(cls):
        return [g for s in cls.__subclasses__() for g in s.all_subclasses()] + cls.__subclasses__()

This method makes use of the built-in class.__subclasses__() method, which, according to its documentation, returns a list of references to a class immediate subclasses. However, if you have a hierarchy tree with more than two levels, you will want to keep looking further down to get all subclasses. This method recursively calls itself on each subclass to transverse the whole tree; every leaf class is pushed into the front of the list until the recursion is finished.

Now, continuing with the cast method:

        for subclass in subclasses:
            try:
                obj = getattr(self, subclass._meta.model_name, None)
                if obj is not None:
                    # select_related doesn't fill child with parent relateds
                    descriptors = [getattr(cls, field.name)
                                   for field in cls._meta.get_fields()
                                   if field.is_relation and field.many_to_one]
                    for descriptor in descriptors:
                        if descriptor.is_cached(self):
                            setattr(obj,
                                    descriptor.cache_name,
                                    getattr(self, descriptor.cache_name))
                    if hasattr(self, '_prefetched_objects_cache'):
                        obj._prefetched_objects_cache = \
                            self._prefetched_objects_cache
                    return obj
            except ObjectDoesNotExist:
                pass

        return self

We iterate over our subclasses and check for the presence of a field in our instance with the same name as the subclass (Django will add this OneToOneField automatically when our model is extended by another one). Two things can go wrong here: either the field doesn't exist, which means that the subclass isn't a model, or accessing the field raises an ObjectDoesNotExist exception, this meaning that while the subclass is in fact a model, no row exists in the subclass table for our particular instance (for example: our Place isn't a Restaurant). In both cases nothing happens and we move on to the next iteration.

When we do find a match, we take care of populating our new instance (referenced by the obj variable) with the previous one related objects (related models pre-populated by potential calls to select_related or prefetch_related) and finally return it. If no match is found and the list is exhausted, the instance will again return itself.

The mixin in action

Create a new Django project and set up the following models. You will also need to put our new mixin in a mixins.py file in the same app.

from django.db import models
from .mixins import CastableModelMixin

class Place(models.Model, CastableModelMixin):
    name = models.CharField(max_length=50)
    address = models.CharField(max_length=80)

class Restaurant(Place):
    serves_hot_dogs = models.BooleanField(default=False)
    serves_pizza = models.BooleanField(default=False)

class CoffeeShop(Place):
    serves_decaffeinated = models.BooleanField(default=False)

Now fire up the django shell and create some instances:

(django) jayres:places jayres$ ./manage.py shell
Python 3.6.1 (v3.6.1:69c0db5050, Mar 21 2017, 01:21:04)
[GCC 4.2.1 (Apple Inc. build 5666) (dot 3)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
(InteractiveConsole)
>>> from app.models import *
>>> CoffeeShop.objects.create(name='The coffee shop', address='Street 1')
<CoffeeShop: CoffeeShop object>
>>> Restaurant.objects.create(name='The hot dog restaurant', address='Street 1')
<Restaurant: Restaurant object>
>>> Restaurant.objects.create(name='The pizza restaurant', address='Street 2')
<Restaurant: Restaurant object>
>>> Place.objects.create(name='A special place', address='Street 3')
<Place: Place object>

Now we have a few different places, some are restaurants, some are coffee shops and some are simple places. We can now query them all together and cast each instance to its appropriate subclass:

>>> [place.cast() for place in Place.objects.all()]
[<CoffeeShop: CoffeeShop object>, <Restaurant: Restaurant object>, <Restaurant: Restaurant object>, <Place: Place object>]

Here's what happens if we don't cast them.

>>> Place.objects.all()
<QuerySet [<Place: Place object>, <Place: Place object>, <Place: Place object>, <Place: Place object>]>

Bonus track

Each time we cast an instance and access the child model through the OneToOneField, Django needs to perform a new query to retrieve the subclass fields from the database. In our previous example, fetching our 4 places from the database in what seems to be a single query is actually costing us at least 5 queries (could be more, since each failed access to the field also implies a query that returned no results). Luckily, we can use select_related as we usually do for every normal ForeignKey field. Since manually doing this every time would be far from optimal, we wrote a custom QuerySet that will make sure to call select_related for all possible subclasses.

First of all, let's test our current situation:

>>> from django.db import connection
>>> from django.db import reset_queries
>>> reset_queries()
>>> [place.cast() for place in Place.objects.all()]
[<CoffeeShop: CoffeeShop object>, <Restaurant: Restaurant object>, <Restaurant: Restaurant object>, <Place: Place object>]
>>> len(connection.queries)
7

Not so good. Write the custom QuerySet mixin in mixins.py using the following code:

class CastableQuerySetMixin:

    def select_related_subclasses(self):
        return self.select_related(*[subclass._meta.model_name for subclass in self.model.all_subclasses()
                                     if not subclass._meta.proxy])

...and add it to our Place model as a manager.

from django.db import models
from .mixins import CastableModelMixin, CastableQuerySetMixin

class CustomQuerySet(CastableQuerySetMixin, models.QuerySet):
    pass

class Place(models.Model, CastableModelMixin):
    name = models.CharField(max_length=50)
    address = models.CharField(max_length=80)
    objects = CustomQuerySet.as_manager()

Let's see what happens now:

(django) jayres:places jayres$ ./manage.py shell
Python 3.6.1 (v3.6.1:69c0db5050, Mar 21 2017, 01:21:04)
[GCC 4.2.1 (Apple Inc. build 5666) (dot 3)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
(InteractiveConsole)
>>> from app.models import *
>>> from django.db import connection
>>> [place.cast() for place in Place.objects.all().select_related_subclasses()]
[<CoffeeShop: CoffeeShop object>, <Restaurant: Restaurant object>, <Restaurant: Restaurant object>, <Place: Place object>]
>>> len(connection.queries)
1
>>>

That's more like it!

"The castable model mixin" by Javier Ayres is licensed under CC BY SA. Source code examples are licensed under MIT.

Photo by Didssph.

Categorized under inheritance / models / django / orm.

We are Sophilabs

A software design and development agency that helps companies build and grow products by delivering high-quality software through agile practices and perfectionist teams.