What You Should Know About DRF, Part 1: ModelViewSet attributes and methods

I gave this talk at PyCascades 2021 and decided to turn it into a series of blog posts so it's available to folks who didn't attend the conference or don't like to watch videos. Here are the slides and the video if you want to see them.


One of the things I hear people say about Django is that it's a "batteries included" framework, and Django REST Framework is no different. One of the most powerful of these "batteries" is the ModelViewSet class, which is more of a "battery pack," in that it contains several different batteries. If you have any experience with Django's class-based views, then DRF's viewsets will hopefully look familiar to you.

The ModelViewSet is what it sounds like: a set of views that lets you take a series of actions on model objects. The DRF docs define it as "a type of class-based View, that does not provide any method handlers such as .get() or .post(), and instead provides actions such as .list() and .create()."

class ModelViewSet(mixins.CreateModelMixin,
                   mixins.RetrieveModelMixin,
                   mixins.UpdateModelMixin,
                   mixins.DestroyModelMixin,
                   mixins.ListModelMixin,
                   GenericViewSet):
    """
    A viewset that provides default `create()`, 
    `retrieve()`, `update()`, `partial_update()`, 
    `destroy()` and `list()` actions.
    """
    pass

You can see how the ModelViewSet is constructed: it includes a class called GenericViewSet, and then 5 mixins with names like CreateModelMixin. Each *ModelMixin class has its own methods that perform actions related to the name of the mixin. For example, CreateModelMixin has a create() method. It does not, however, have a post() method. This is what DRF means when the docs said above that it "does not provide method handlers such as .get() or .post()." If you've used Django's CBVs, you have probably dealt with the .get() and .post() methods there. But DRF's ModelViewSet skips these methods and replaces them with more specific methods related to actions.

For the CreateModelMixin, which is a set of methods that helps you create new objects, you would expect to deal with a .post() method since creating new stuff for your database is generally dealt with in an HTTP POST request. But CreateModelMixin instead gives you a .create() method. This comes in handy later on, because handling cases where you're adding new objects versus cases updating existing objects is easier. You don't need any conditional logic in a .post() method to tell the difference -- they are already in their respective .create() and .update() methods.

Example

Let's say you're building a library app, and you want to create a set of endpoints to deal with books. Using a ModelViewSet means that creating endpoints to add, update, delete, retrieve, and list all books requires just 6 lines of code in your views.py.

# views.py 
from rest_framework.viewsets import ModelViewSet 

from .models import Book
from .serializers import BookSerializer

class BookViewSet(ModelViewSet):
    queryset = Book.objects.all()
    serializer_class = BookSerializer

This gets you these endpoints:

  • List all books: GET /books/
  • Retrieve a specific book: GET /books/{id}/
  • Add a new book: POST /books/
  • Update an existing books: PUT /books/{id}/
  • Update part of an existing book: PATCH /books/{id}/
  • Remove a book: DELETE /books/{id}/

You would also need to write the BookSerializer and hook these endpoints up in your urls.py, but you can see how to do that in the docs.

Six lines of code and you're done!

Except that most of the time, your project requirements are a little more complex than "write 6 lines of code and let DRF take it from there." That's where this talk (and set of blog posts) comes in. You can do a lot to customize DRF's functionality while still using the convenience methods that DRF includes for you. This can save you time, lines of code, testing, and headaches.

ModelViewSet Attributes

There are three attributes on your ModelViewSet that you should set.

The queryset attribute answers the question, "What objects are you working with?" It takes a (you guessed it) queryset. Below, I've set mine to Book.objects.all(), but you can set yours to a model manager or a queryset with some filtering.

The serializer_class attribute addresses the question, "How should the data you are dealing with be serialized?" I've set mine to BookSerializer. A serializer is the class that defines how the data should be formatted. If you're not super familiar with APIs at this point, the basic idea is that an API sends data back as JSON blobs. Your serializer defines how you want to transform your model objects into JSON and which fields you want to include.

The permission_classes attribute defines who is allowed to access the endpoints created by this viewset, and it takes a list or tuple of permission classes. I've set mine to [AllowAny] using a built-in permission class from DRF. If you don't set this attribute, DRF provides a default or you can define your own default in settings. I always prefer to set mine explicitly, though.

from rest_framework.permissions import AllowAny
from rest_framework.viewsets import ModelViewSet 

from .models import Book
from .serializers import BookSerializer

class BookViewSet(ModelViewSet):
    queryset = Book.objects.all()
    serializer_class = BookSerializer
    permission_classes = [AllowAny]

ModelViewSet methods that come from GenericViewSet

get_queryset()

The get_queryset() method mostly just returns whatever you set in your queryset attribute.

def get_queryset(self):
    assert self.queryset is not None, (
        "'%s' should either include a `queryset` attribute, "
        "or override the `get_queryset()` method."
        % self.__class__.__name__
    )

    queryset = self.queryset
    if isinstance(queryset, QuerySet):
        queryset = queryset.all()
    return queryset

Why it's useful: Knowing about this method is useful for when you want to make some changes to your queryset using some data you don't have until the time of the request to your API. I often override this method in my own viewsets so I can filter the queryset based on the user.

class BookViewSet(ModelViewSet):
    def get_queryset(self):
        queryset = super().get_queryset() 
        return queryset.filter(owner=self.request.user)

get_object()

The get_object() method is used in endpoints that deal with a specific object, so any endpoint that uses an identifier (PUT /books/{id}/, for example).

def get_object(self):
    queryset = self.filter_queryset(self.get_queryset())

    lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field

    assert lookup_url_kwarg in self.kwargs, (
        'Expected view %s to be called with a URL keyword argument '
        'named "%s". Fix your URL conf, or set the `.lookup_field` '
        'attribute on the view correctly.' %
        (self.__class__.__name__, lookup_url_kwarg)
    )
    # Uses the lookup_field attribute, which defaults to `pk`
    filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
    obj = get_object_or_404(queryset, **filter_kwargs)

    # May raise a permission denied
    self.check_object_permissions(self.request, obj)
    return obj

Why it's useful: I don't need to override this method very often, but it's really useful to know about because of all the steps it takes for you.

  • First, it filters the queryset for you.
  • Then, it makes sure it's able to look up your object with the lookup_url_kwarg. (This will default to id or pk but you can set it to something else if you need to.)
  • Then, it tries to retrieve your object for you and will raise a 404 error on your behalf if it can't find it in your queryset using get_object_or_404().
  • Finally, before it returns the object, it checks to make sure that the user who made this request has adequate permissions for this object.

This is a lot of tedious work. If you write custom endpoints for your viewset, or you're in the update() method doing some special work for your project's requirements, you will probably need the object itself at some point. If you grab the id from the request and try to get the object from there, you then have to worry about permissions, what to do if the object doesn't exist, etc.

Instead, you can run

obj = self.get_object()

from inside your method and let DRF take care of those important steps for you!

The serializer methods

There are three methods that deal with the serializer:

  • get_serializer_class()
  • get_serializer_context()
  • get_serializer()

These three methods work together to return a serializer that's ready for you to work with.

get_serializer_class returns whatever you set in your serializer_class attribute.

def get_serializer_class(self):
    assert self.serializer_class is not None, (
        "'%s' should either include a `serializer_class` attribute, "
        "or override the `get_serializer_class()` method."
        % self.__class__.__name__
    )

    return self.serializer_class

Why it's useful: If you want to use a different serializer in different situations, you can override get_serializer_class() to add that logic. You might want to use different serializers for list requests and detail requests, for example. We'll go over that in the next post.

get_serializer() calls get_serializer_class() and returns it.

def get_serializer(self, *args, **kwargs):
    serializer_class = self.get_serializer_class()

    # The context is where the request is added 
    # to the serializer
    kwargs['context'] = self.get_serializer_context()

    return serializer_class(*args, **kwargs)

But first, it calls get_serializer_context() and adds what that returns to the serializer, before getting it back to you.

def get_serializer_context(self):
    return {
        'request': self.request,
        'format': self.format_kwarg,
        'view': self
    }

Why it's useful: You can override get_serializer_context() to add more information to your serializer if you need to. If you've ever been in one of your serializer methods and used self.context["request"].user, the reason you're able to access the user from the request in your serializer is because of get_serializer_context().

I recently had a situation where I needed to do a lot of math calculations in my serializer for each object. It was more effecient to get some of the values I needed up front and pass them into the serializer context by overriding this method, rather than getting those values new for each object I was dealing with.

I don't often need to override get_serializer(), but knowing what it does (get your serializer class and pass your serializer context into it, before giving your serializer to you) means that you can run

serializer = self.get_serializer()

in your viewset methods as a shortcut. Like with get_object(), this ensures that you're getting the serializer you want, with the data you want in it, without having to do any extra or duplicate work. If you construct your serializer manually in your methods, like serializer = BookSerializer(instance=obj), then you skip that context and lose the chance to have access to the request (and therefore the user) in your serializer.

ModelViewSet methods that come from the action mixins

I'm not going to go into the methods that come with all five of the mixins that are included with ModelViewSet, but I'll go through the ones that come with CreateModelMixin as an example, and hopefully you can extrapolate from there.

CreateModelMixin comes with three methods: create(), perform_create(), and get_success_headers(). I won't go over get_success_headers() because I never need to mess with it, but you can explore what it does on your own.

The create() method does several things:

  • Gets the serializer from get_serializer() and passes the data from the request into it
  • Checks that the serializer is valid, and raises an exception for you if it isn't
  • Calls perform_create()
  • Gets the success headers from get_success_headers()
  • Returns the serializer data in the response with those headers and an HTTP status code
class CreateModelMixin:
    def create(self, request, *args, **kwargs):
      serializer = self.get_serializer(data=request.data)
      serializer.is_valid(raise_exception=True)
      self.perform_create(serializer)
      headers = self.get_success_headers(serializer.data)
      return Response(
        serializer.data, status=status.HTTP_201_CREATED, headers=headers
      )

The perform_create() method calls the save() method from the serializer, but doesn't return anything.

class CreateModelMixin:
    def perform_create(self, serializer):
        serializer.save()

Why it's useful: It's useful to know about these because sometimes, you need to do some custom processing either before or after you have performed the action in the request. For creating a new object, maybe you need to call a task that does some other processing, or you need to send a message to a message or event bus so another system can take some action.

Knowing where the action is happening, so to speak, lets you override the method you want to inject your custom behavior. For example, to fire off a special task after you've created a new object, you could override the perform_create() method:

from .tasks import special_new_book_task 

class BookViewSet(ModelViewSet):
    def perform_create(self, serializer):
        super().perform_create(serializer)
        special_new_book_task.delay(serializer.instance.id)

This lets you fire off the task after the new obejct has been saved without having to manually retool the whole create() method to make it happen.


In Part 2: Customizing built-in methods, I'll go through some real-world examples for when you might want to override some of these built-in methods.

In Part 3: Adding custom endpoints, I tell you how to add your own custom endpoints to your viewset without having to write a whole new view or add anything new to your urls.py.