Django Implied Relationship

[This post was updated on 2021-07-28].

A little while ago, Alec McGavin put up a post on the Kogan blog about Custom Relationships in Django. This is a really cool way to get a relationship in Django that exists in the database, but cannot be modelled correctly in Django. For instance, this could be data in the database that does not have a Foreign Key, either because it’s legacy data, or because it’s possible either side of the relationship might be added to the database before the other, rather than the normal order in a database where the target table of the FK is always added to first.

However, I have another slighly different situation where an implied relationship exists, but should not be stored directly.

Consider the following data structures:

class Employee(models.Model):
    name = models.TextField()


class EmploymentPeriod(models.Model):
    employee = models.ForeignKey(
        Employee,
        related_name='employment_periods',
        on_delete=models.CASCADE,
    )
    valid_period = DateRangeField()

    class Meta:
        constraints = [
            ExclusionConstraint(
                name='employment_overlap',
                expressions=[
                    ('employee', RangeOperators.EQUAL),
                    ('valid_period', RangeOperators.OVERLAPS),
                ]
            )
        ]


class Shift(models.Model):
    employee = models.ForeignKey(
        Employee,
        related_name='shifts',
        on_delete=models.CASCADE,
    )
    date = models.DateField()
    start_time = models.TimeField()
    duration = models.DurationField()
    employment_period = Relationship(
        EmploymentPeriod,
        from_fields=['employee', 'date'],
        to_fields=['employee', 'valid_period'],
    )

    @property
    def start(self):
        return datetime.datetime.combine(self.date, self.start_time)

    @property
    def finish(self):
        return self.start + self.duration

Now, there is a direct relationship between Shift and Employee, and also between EmploymentPeriod and Employee, but there could be an inferred relatonship between Shift and EmploymentPeriod. Because of the exclusion constraint, we know there will only be one EmploymentPeriod for a given Shift.

It would be really nice to be able to create this relationship, so we can reference the employment period (or lack thereof) directly. The Relationship class above goes close, but tries to use an equality check between the date and date_range fields.

It turns out, we can add a bit to that class, and teach it how to handle the various ways this relationship can be accessed:

  • Shift().employment_period
  • EmploymentPeriod().shifts
  • Shift.objects.select_related(‘employment_period’)
  • EmploymentPeriod.objects.prefetch_related(‘shifts’)
  • Shift.objects.filter(employment_period=emp)
  • Shift.objects.filter(employment_period__in_=[emp1, emp2])
  • Shift.objects.filter(employment_period__isnull=True)
  • Shift.objects.filter(employment_period=None)
  • Shift.objects.filter(employment_period__isnull=False)
  • EmploymentPeriod.objects.filter(shifts__contains=shift)
  • EmploymentPeriod.objects.filter(shifts__contains=[shift1, shift2])
  • EmploymentPeriod.objects.filter(shifts__isnull=True)
  • EmploymentPeriod.objects.filter(shifts__isnull=False)

…plus there is also the inverse of a bunch of these - ie Shift.objects.exclude(employment_period=emp). In some cases these are equivalent, but that’s not always possible to determine.

Let’s have a look at the original class, and a new subclass for these non-direct relationships:

from django.db import models
from django.db.models.lookups import Lookup


class Relationship(models.ForeignObject):
    """
    Create a django link between models on a field where a foreign key isn't used.
    This class allows that link to be realised through a proper relationship,
    allowing prefetches and select_related.

    https://devblog.kogan.com/blog/custom-relationships-in-django
    """

    def __init__(self, model, from_fields, to_fields, **kwargs):
        super().__init__(
            model,
            on_delete=models.DO_NOTHING,
            from_fields=from_fields,
            to_fields=to_fields,
            null=True,
            blank=True,
            **kwargs,
        )

    def contribute_to_class(self, cls, name, private_only=False, **kwargs):
        # override the default to always make it private
        # this ensures that no additional columns are created
        super().contribute_to_class(cls, name, private_only=True, **kwargs)

There’s not much to this, which is part of the beauty of it. Django pretty-much handles composite-key relationships, it just won’t create actual ForeignKeys based on them. There have been noises about implementing this for years, and maybe eventually it will happen.

But what about a subclass that allows the implicit relationship described above?

class ImplicitRelationship(Relationship):
    """
    Create a django link between two models where at least one of the fields
    uses a containment (or other type of non-direct) relationship.

    Relationship should be used if this is just a composite key (or a single
    key that is not a real ForeignKey in the database).

    """

    def get_path_info(self, filtered_relation=None):
        """Get path from this field to the related model."""
        opts = self.remote_field.model._meta
        from_opts = self.model._meta
        self.related_fields
        return [
            PathInfo(
                from_opts=from_opts,
                to_opts=opts,
                target_fields=[rhs for lhs, rhs in self.other_related_fields],
                join_field=self,
                m2m=False,
                direct=False,
                filtered_relation=filtered_relation,
            )
        ]

    def get_reverse_path_info(self, filtered_relation=None):
        """Get path from the related model to this field's model."""
        opts = self.model._meta
        from_opts = self.remote_field.model._meta
        self.related_fields
        return [
            PathInfo(
                from_opts=from_opts,
                to_opts=opts,
                target_fields=[lhs for lhs, rhs in self.other_related_fields],
                join_field=self.remote_field,
                m2m=False,
                direct=False,
                filtered_relation=filtered_relation,
            )
        ]

    @cached_property
    def other_related_fields(self):
        return self.resolve_related_fields()

    @cached_property
    def related_fields(self):
        return []

    def get_local_related_value(self, instance):
        """
        Given an instance, determine the value that will be used as
        the key for this value in a dict of related items.

        This is where it starts to get tricky. Django only really expects
        keys to match exactly, but we may have a value that contains a
        date, that needs to be checked for inclusion in a DateRange.

        Whilst psycopg2 does not normalise Range values, it will handle
        <date> in <DateRange> correctly, so we can use that as the
        comparison.
        """
        parts = self.get_instance_value_for_fields(
            instance,
            [lhs for lhs, rhs in self.other_related_fields],
        )

        if not hasattr(self, '_known_instance_keys'):
            return parts

        if parts in self._known_instance_keys:
            return parts

        for keys in self._known_instance_keys:
            for part, key in zip(parts, keys):
                if part == key or getattr(key, '__contains__', None) and part in key:
                    return keys

    @property
    def get_foreign_related_value(self):
        """
        Because we need to use non-exact matching, we need to set up to store
        instances based on known keys. The Django code that uses this builds up
        a dict of keys/values, but since we need to do containment testing in
        get_local_related_value(instance), we have to store a local set of
        key values, which will be used for that containment checking.

        This is implemented as a property that returns a function, after clearing
        out the cache of known instances so that each queryset will have it's own
        cache. Otherwise, instances from the last run through would be matched
        in the next run.

        This uses knowledge of the Django internals, where this method is called
        before get_local_related_value, which really is not ideal, but there does
        not seem to be a better way to handle this.
        """
        self._known_instance_keys = set()

        def get_foreign_related_value(instance):
            values = self.get_instance_value_for_fields(
                instance, 
                [rhs for lhs, rhs in self.other_related_fields]
            )

            self._known_instance_keys.add(values)
            return values

        return get_foreign_related_value

    def get_extra_restriction(self, where_class, alias, remote_alias):
        """
        This method is used to get extra JOIN conditions.

        We don't need to include the exact conditions, only those
        that we filtered out from the regular related_fields.
        The exact conditinos will be already applied to the JOIN
        by the regular handling.
        """
        if not alias or not remote_alias:
            return

        if self.other_related_fields:
            cond = where_class()

            for local, remote in self.other_related_fields:
                local, remote = local.get_col(remote_alias), remote.get_col(alias)
                lookup_name = JOIN_LOOKUPS.get(get_key(local, remote), 'exact')
                lookup = local.get_lookup(lookup_name)(local, remote)
                cond.add(lookup, 'AND')

            return cond

    def get_extra_descriptor_filter(self, instance):
        """
        The sibling to get_extra_restriction, this one is used to get the extra
        filters that are required to limit to the correct objects.
        """
        extra_filter = {}
        for lhs, rhs in self.other_related_fields:
            lookup = JOIN_LOOKUPS.get(get_key(rhs.cached_col, lhs.cached_col), 'exact')
            extra_filter[f'{rhs.name}__{lookup}'] = getattr(instance, lhs.attname)
        return extra_filter

    def get_where(self, value, alias=None):
        constraint = WhereNode(connector=AND)

        values = self.get_instance_value_for_fields(
            value, 
            [remote for local, remote in self.other_related_fields]
        )

        for (source, target), value in zip(self.other_related_fields, values):
            key = (source.get_internal_type(), target.get_internal_type())
            lookup_type = JOIN_LOOKUPS.get(key, 'exact')
            lookup_class = source.get_lookup(lookup_type)
            lookup = lookup_class(target.get_col(alias or self.model._meta.db_table, source), value)
            constraint.add(lookup, AND)

        return constraint

    def get_exists_subquery_filters(self, inverted=False):
        filters = {}
        for source, target in self.other_related_fields:
            if inverted:
                source, target = target, source
            key = (target.get_internal_type(), source.get_internal_type())
            lookup = JOIN_LOOKUPS.get(key, 'exact')
            filters[f'{target.attname}__{lookup}'] = OuterRef(source.attname)
        return filters

    def exists_subquery(self, negated=False, inverted=False, **filters):
        if inverted:
            exists = Exists(
                self.model.objects.filter(
                    **self.get_exists_subquery_filters(inverted=True), 
                    **filters
                ).values('pk')
            )
        else:
            exists = Exists(
                self.related_model.objects.filter(
                    **self.get_exists_subquery_filters(), 
                    **filters
                ).values('pk')
            )
        return ~exists if negated else exists

    @property
    def target_field(self):
        raise FieldError()


# We should be able to add more pairs here as we need to handle them.
JOIN_LOOKUPS = {
    ('DateField', 'DateRangeField'): 'contained_by',
    ('DateRangeField', 'DateField'): 'contains',
}

def get_key(lhs, rhs):
    return (lhs.output_field.get_internal_type(), rhs.output_field.get_internal_type())

There’s actually a lot more code there than I really wanted, however it seems mostly to be necessary.

But wait, there’s more. We also need to teach Django how to handle the various lookups that can be performed on these relationships:

@ImplicitRelationship.register_lookup
class RelatedMultipleExact(Lookup):
    lookup_name = 'exact'
    """
    Apply each lookup type from each of the fields in an ImplicitRelationship.

    This is for querysets of the form:

    >>> Foo.objects.filter(relationship=instance)

    This gets the relevant operator for each of the lookups, based on
    the field type of the pair of (to/from) fields.
    """

    def as_sql(self, compiler, connection):
        field = self.lhs.field

        # If we only have a primary key here, and not an instance, then we
        # will need to push the querying back into the database - normally
        # a diroct lookup just uses the value as a PK, but here we need to
        # get the database to do a subquery to get the other values.
        if self.rhs_is_direct_value() and not isinstance(self.rhs, models.Model):
            # We can't unref the alias here, because Django will have also put in a IS NOT NULL
            # on the thing, which is referencing the wrong table.
            compiler.query.alias_map[self.lhs.alias] = compiler.query.alias_map[self.lhs.alias].promote()
            return field.exists_subquery(pk=self.rhs).resolve_expression(compiler.query).as_sql(compiler, connection)

        return field.get_where(self.rhs, alias=self.lhs.alias).as_sql(compiler, connection)


@ImplicitRelationship.register_lookup
class RelatedMultipleIn(Lookup):
    """
    Apply each lookup from each of the fields in an ImplicitRelationship.

    This is for querysets of the form:

    >>> Foo.objects.filter(relationship__in=[instance1, instance2])
    >>> Foo.objects.filter(relationship__not_in=[instance1, instance2])

    This builds an EXISTS() clause that uses a subquery to find
    if each instance matches - this is usually better than a bunch
    of clauses that would use OR, because that would preclude the
    use of indexes.
    """

    lookup_name = 'in'
    negated = False

    def as_sql(self, compiler, connection):
        field = self.lhs.field
        if self.negated:
            # We remove one reference to the joined table, so that if we only
            # have this reference, ie no columns, then we don't even join the
            # table in (as we'll be using an EXISTS in WHERE)
            compiler.query.unref_alias(self.lhs.alias)
        return (
            field.exists_subquery(
                negated=self.negated,
                pk__in=[getattr(x, 'pk', x) for x in self.rhs],
            )
            .resolve_expression(compiler.query)
            .as_sql(compiler, connection)
        )


@ImplicitRelationship.register_lookup
class RelatedMultipleNotIn(RelatedMultipleIn):
    negated = True
    lookup_name = 'not_in'


@ImplicitRelationship.register_lookup
class RelatedMultipleNull(Lookup):
    """
    Apply each lookup from each of the fields in an ImplicitRelationship.

    This is for querysets of the form:

    >>> Foo.objects.filter(relationship=None)
    >>> Foo.objects.exclude(relationship=None)
    >>> Foo.objects.filter(relationship__isnull=True)
    >>> Foo.objects.filter(relationship__isnull=False)

    """

    lookup_name = 'isnull'

    def as_sql(self, compiler, connection):
        field = self.lhs.field

        if not isinstance(field, Relationship):
            pk = field.related_model._meta.pk
            lookup = pk.get_lookup('isnull')(pk.get_col(self.lhs.alias), self.rhs)
            return lookup.as_sql(compiler, connection)

        # We remove one reference to the joined table, so that if we only
        # have this reference, ie no columns, then we don't even join the
        # table in (as we'll be using an EXISTS in WHERE)
        compiler.query.unref_alias(self.lhs.alias)
        return field.exists_subquery(negated=self.rhs).resolve_expression(compiler.query).as_sql(compiler, connection)


@ImplicitRelationship.register_lookup
class RelatedMultipleContains(Lookup):
    """
    Apply each lookup from each of the fields in an ImplicitRelationship.

    This is for querysets of the form:

    >>> Foo.objects.filter(reverse_relationship__contains=[x, y])
    >>> Foo.objects.filter(reverse_relationship__contains=x)

    """

    lookup_name = 'contains'
    negated = False

    def as_sql(self, compiler, connection):
        if isinstance(self.lhs.field, Relationship):
            raise TypeError(f'Unable to perform __{self.lookup_name} queries on Relationship, only on reversed')

        field = self.lhs.field.remote_field

        try:
            iter(self.rhs)
        except TypeError:
            value = getattr(self.rhs, 'pk', self.rhs)
            lookup = 'pk'
        else:
            if self.negated:
                raise ValueError('Unable to perform not_contains=list')
            value = [getattr(x, 'pk', x) for x in self.rhs]
            lookup = 'pk__in'

        # We remove one reference to the joined table, so that if we only
        # have this reference, ie no columns, then we don't even join the
        # table in (as we'll be using an EXISTS in WHERE)
        compiler.query.unref_alias(self.lhs.alias)

        return (
            field.exists_subquery(negated=self.negated, inverted=True, **{lookup: value})
            .resolve_expression(compiler.query)
            .as_sql(compiler, connection)
        )


@ImplicitRelationship.register_lookup
class RelatedMultipleNotContains(RelatedMultipleContains):
    lookup_name = 'not_contains'
    negated = True
    

There is one caveat to this - these lookups (under certain conditions) will change the list of FROM tables that are required to be joined into the query. In some cases this adds a required JOIN, in others it actually removes the need for a JOIN because the WHERE clause is all within a subquery, and should not contain a join to the original table.

However, until https://github.com/django/django/pull/14683 is merged, these lookups will not always work - if the list of tables is mutated by the lookup, the SQL query that is generated will not contain this required mutation. Hopefully I can get that PR merged, but it is possible to use patchy to patch your local installation until that is done.