Skip to content

Commit

Permalink
Fixed #470 -- Added support for database defaults on fields.
Browse files Browse the repository at this point in the history
Special thanks to Hannes Ljungberg for finding multiple implementation
gaps.

Thanks also to Simon Charette, Adam Johnson, and Mariusz Felisiak for
reviews.
  • Loading branch information
LilyFoote authored and felixxm committed May 12, 2023
1 parent 599f3e2 commit 7414704
Show file tree
Hide file tree
Showing 32 changed files with 1,089 additions and 34 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Expand Up @@ -587,6 +587,7 @@ answer newbie questions, and generally made Django that much better:
lerouxb@gmail.com
Lex Berezhny <lex@damoti.com>
Liang Feng <hutuworm@gmail.com>
Lily Foote
limodou
Lincoln Smith <lincoln.smith@anu.edu.au>
Liu Yijie <007gzs@gmail.com>
Expand Down
12 changes: 12 additions & 0 deletions django/db/backends/base/features.py
Expand Up @@ -201,6 +201,15 @@ class BaseDatabaseFeatures:
# Does the backend require literal defaults, rather than parameterized ones?
requires_literal_defaults = False

# Does the backend support functions in defaults?
supports_expression_defaults = True

# Does the backend support the DEFAULT keyword in insert queries?
supports_default_keyword_in_insert = True

# Does the backend support the DEFAULT keyword in bulk insert queries?
supports_default_keyword_in_bulk_insert = True

# Does the backend require a connection reset after each material schema change?
connection_persists_old_columns = False

Expand Down Expand Up @@ -361,6 +370,9 @@ class BaseDatabaseFeatures:
# SQL template override for tests.aggregation.tests.NowUTC
test_now_utc_template = None

# SQL to create a model instance using the database defaults.
insert_test_table_with_defaults = None

# A set of dotted paths to tests in Django's test suite that are expected
# to fail on this database.
django_test_expected_failures = set()
Expand Down
88 changes: 81 additions & 7 deletions django/db/backends/base/schema.py
Expand Up @@ -12,7 +12,7 @@
Table,
)
from django.db.backends.utils import names_digest, split_identifier, truncate_name
from django.db.models import Deferrable, Index
from django.db.models import NOT_PROVIDED, Deferrable, Index
from django.db.models.sql import Query
from django.db.transaction import TransactionManagementError, atomic
from django.utils import timezone
Expand Down Expand Up @@ -296,6 +296,12 @@ def _iter_column_sql(
yield self._comment_sql(field.db_comment)
# Work out nullability.
null = field.null
# Add database default.
if field.db_default is not NOT_PROVIDED:
default_sql, default_params = self.db_default_sql(field)
yield f"DEFAULT {default_sql}"
params.extend(default_params)
include_default = False
# Include a default value, if requested.
include_default = (
include_default
Expand Down Expand Up @@ -400,6 +406,22 @@ def _column_default_sql(self, field):
"""
return "%s"

def db_default_sql(self, field):
"""Return the sql and params for the field's database default."""
from django.db.models.expressions import Value

sql = "%s" if isinstance(field.db_default, Value) else "(%s)"
query = Query(model=field.model)
compiler = query.get_compiler(connection=self.connection)
default_sql, params = compiler.compile(field.db_default)
if self.connection.features.requires_literal_defaults:
# Some databases doesn't support parameterized defaults (Oracle,
# SQLite). If this is the case, the individual schema backend
# should implement prepare_default().
default_sql %= tuple(self.prepare_default(p) for p in params)
params = []
return sql % default_sql, params

@staticmethod
def _effective_default(field):
# This method allows testing its logic without a connection.
Expand Down Expand Up @@ -1025,6 +1047,21 @@ def _alter_field(
)
actions.append(fragment)
post_actions.extend(other_actions)

if new_field.db_default is not NOT_PROVIDED:
if (
old_field.db_default is NOT_PROVIDED
or new_field.db_default != old_field.db_default
):
actions.append(
self._alter_column_database_default_sql(model, old_field, new_field)
)
elif old_field.db_default is not NOT_PROVIDED:
actions.append(
self._alter_column_database_default_sql(
model, old_field, new_field, drop=True
)
)
# When changing a column NULL constraint to NOT NULL with a given
# default value, we need to perform 4 steps:
# 1. Add a default for new incoming writes
Expand All @@ -1033,7 +1070,11 @@ def _alter_field(
# 4. Drop the default again.
# Default change?
needs_database_default = False
if old_field.null and not new_field.null:
if (
old_field.null
and not new_field.null
and new_field.db_default is NOT_PROVIDED
):
old_default = self.effective_default(old_field)
new_default = self.effective_default(new_field)
if (
Expand All @@ -1051,9 +1092,9 @@ def _alter_field(
if fragment:
null_actions.append(fragment)
# Only if we have a default and there is a change from NULL to NOT NULL
four_way_default_alteration = new_field.has_default() and (
old_field.null and not new_field.null
)
four_way_default_alteration = (
new_field.has_default() or new_field.db_default is not NOT_PROVIDED
) and (old_field.null and not new_field.null)
if actions or null_actions:
if not four_way_default_alteration:
# If we don't have to do a 4-way default alteration we can
Expand All @@ -1074,15 +1115,20 @@ def _alter_field(
params,
)
if four_way_default_alteration:
if new_field.db_default is NOT_PROVIDED:
default_sql = "%s"
params = [new_default]
else:
default_sql, params = self.db_default_sql(new_field)
# Update existing rows with default value
self.execute(
self.sql_update_with_default
% {
"table": self.quote_name(model._meta.db_table),
"column": self.quote_name(new_field.column),
"default": "%s",
"default": default_sql,
},
[new_default],
params,
)
# Since we didn't run a NOT NULL change before we need to do it
# now
Expand Down Expand Up @@ -1264,6 +1310,34 @@ def _alter_column_default_sql(self, model, old_field, new_field, drop=False):
params,
)

def _alter_column_database_default_sql(
self, model, old_field, new_field, drop=False
):
"""
Hook to specialize column database default alteration.
Return a (sql, params) fragment to add or drop (depending on the drop
argument) a default to new_field's column.
"""
if drop:
sql = self.sql_alter_column_no_default
default_sql = ""
params = []
else:
sql = self.sql_alter_column_default
default_sql, params = self.db_default_sql(new_field)

new_db_params = new_field.db_parameters(connection=self.connection)
return (
sql
% {
"column": self.quote_name(new_field.column),
"type": new_db_params["type"],
"default": default_sql,
},
params,
)

def _alter_column_type_sql(
self, model, old_field, new_field, new_type, old_collation, new_collation
):
Expand Down
7 changes: 7 additions & 0 deletions django/db/backends/mysql/features.py
Expand Up @@ -51,6 +51,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
# COLLATE must be wrapped in parentheses because MySQL treats COLLATE as an
# indexed expression.
collate_as_index_expression = True
insert_test_table_with_defaults = "INSERT INTO {} () VALUES ()"

supports_order_by_nulls_modifier = False
order_by_nulls_first = True
Expand Down Expand Up @@ -342,3 +343,9 @@ def can_rename_index(self):
if self.connection.mysql_is_mariadb:
return self.connection.mysql_version >= (10, 5, 2)
return True

@cached_property
def supports_expression_defaults(self):
if self.connection.mysql_is_mariadb:
return True
return self.connection.mysql_version >= (8, 0, 13)
29 changes: 24 additions & 5 deletions django/db/backends/mysql/schema.py
Expand Up @@ -209,11 +209,15 @@ def _delete_composed_index(self, model, fields, *args):
self._create_missing_fk_index(model, fields=fields)
return super()._delete_composed_index(model, fields, *args)

def _set_field_new_type_null_status(self, field, new_type):
def _set_field_new_type(self, field, new_type):
"""
Keep the null property of the old field. If it has changed, it will be
handled separately.
Keep the NULL and DEFAULT properties of the old field. If it has
changed, it will be handled separately.
"""
if field.db_default is not NOT_PROVIDED:
default_sql, params = self.db_default_sql(field)
default_sql %= tuple(self.quote_value(p) for p in params)
new_type += f" DEFAULT {default_sql}"
if field.null:
new_type += " NULL"
else:
Expand All @@ -223,7 +227,7 @@ def _set_field_new_type_null_status(self, field, new_type):
def _alter_column_type_sql(
self, model, old_field, new_field, new_type, old_collation, new_collation
):
new_type = self._set_field_new_type_null_status(old_field, new_type)
new_type = self._set_field_new_type(old_field, new_type)
return super()._alter_column_type_sql(
model, old_field, new_field, new_type, old_collation, new_collation
)
Expand All @@ -242,7 +246,7 @@ def _field_db_check(self, field, field_db_params):
return field_db_params["check"]

def _rename_field_sql(self, table, old_field, new_field, new_type):
new_type = self._set_field_new_type_null_status(old_field, new_type)
new_type = self._set_field_new_type(old_field, new_type)
return super()._rename_field_sql(table, old_field, new_field, new_type)

def _alter_column_comment_sql(self, model, new_field, new_type, new_db_comment):
Expand All @@ -252,3 +256,18 @@ def _alter_column_comment_sql(self, model, new_field, new_type, new_db_comment):
def _comment_sql(self, comment):
comment_sql = super()._comment_sql(comment)
return f" COMMENT {comment_sql}"

def _alter_column_null_sql(self, model, old_field, new_field):
if new_field.db_default is NOT_PROVIDED:
return super()._alter_column_null_sql(model, old_field, new_field)

new_db_params = new_field.db_parameters(connection=self.connection)
type_sql = self._set_field_new_type(new_field, new_db_params["type"])
return (
"MODIFY %(column)s %(type)s"
% {
"column": self.quote_name(new_field.column),
"type": type_sql,
},
[],
)
4 changes: 4 additions & 0 deletions django/db/backends/oracle/features.py
Expand Up @@ -32,6 +32,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
atomic_transactions = False
nulls_order_largest = True
requires_literal_defaults = True
supports_default_keyword_in_bulk_insert = False
closed_cursor_error_class = InterfaceError
bare_select_suffix = " FROM DUAL"
# Select for update with limit can be achieved on Oracle, but not with the
Expand Down Expand Up @@ -130,6 +131,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"annotations.tests.NonAggregateAnnotationTestCase."
"test_custom_functions_can_ref_other_functions",
}
insert_test_table_with_defaults = (
"INSERT INTO {} VALUES (DEFAULT, DEFAULT, DEFAULT)"
)

@cached_property
def introspected_field_types(self):
Expand Down
2 changes: 1 addition & 1 deletion django/db/backends/oracle/introspection.py
Expand Up @@ -156,7 +156,7 @@ def get_table_description(self, cursor, table_name):
field_map = {
column: (
display_size,
default if default != "NULL" else None,
default.rstrip() if default and default != "NULL" else None,
collation,
is_autofield,
is_json,
Expand Down
4 changes: 3 additions & 1 deletion django/db/backends/oracle/schema.py
Expand Up @@ -198,7 +198,9 @@ def _generate_temp_name(self, for_name):
return self.normalize_name(for_name + "_" + suffix)

def prepare_default(self, value):
return self.quote_value(value)
# Replace % with %% as %-formatting is applied in
# FormatStylePlaceholderCursor._fix_for_params().
return self.quote_value(value).replace("%", "%%")

def _field_should_be_indexed(self, model, field):
create_index = super()._field_should_be_indexed(model, field)
Expand Down
1 change: 1 addition & 0 deletions django/db/backends/postgresql/features.py
Expand Up @@ -76,6 +76,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"swedish_ci": "sv-x-icu",
}
test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'"
insert_test_table_with_defaults = "INSERT INTO {} DEFAULT VALUES"

django_test_skips = {
"opclasses are PostgreSQL only.": {
Expand Down
2 changes: 2 additions & 0 deletions django/db/backends/sqlite3/features.py
Expand Up @@ -59,6 +59,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
PRIMARY KEY(column_1, column_2)
)
"""
insert_test_table_with_defaults = 'INSERT INTO {} ("null") VALUES (1)'
supports_default_keyword_in_insert = False

@cached_property
def django_test_skips(self):
Expand Down
24 changes: 20 additions & 4 deletions django/db/backends/sqlite3/schema.py
Expand Up @@ -6,7 +6,7 @@
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.backends.ddl_references import Statement
from django.db.backends.utils import strip_quotes
from django.db.models import UniqueConstraint
from django.db.models import NOT_PROVIDED, UniqueConstraint
from django.db.transaction import atomic


Expand Down Expand Up @@ -233,9 +233,13 @@ def is_self_referential(f):
if create_field:
body[create_field.name] = create_field
# Choose a default and insert it into the copy map
if not create_field.many_to_many and create_field.concrete:
if (
create_field.db_default is NOT_PROVIDED
and not create_field.many_to_many
and create_field.concrete
):
mapping[create_field.column] = self.prepare_default(
self.effective_default(create_field),
self.effective_default(create_field)
)
# Add in any altered fields
for alter_field in alter_fields:
Expand All @@ -244,9 +248,13 @@ def is_self_referential(f):
mapping.pop(old_field.column, None)
body[new_field.name] = new_field
if old_field.null and not new_field.null:
if new_field.db_default is NOT_PROVIDED:
default = self.prepare_default(self.effective_default(new_field))
else:
default, _ = self.db_default_sql(new_field)
case_sql = "coalesce(%(col)s, %(default)s)" % {
"col": self.quote_name(old_field.column),
"default": self.prepare_default(self.effective_default(new_field)),
"default": default,
}
mapping[new_field.column] = case_sql
else:
Expand Down Expand Up @@ -381,6 +389,8 @@ def delete_model(self, model, handle_autom2m=True):

def add_field(self, model, field):
"""Create a field on a model."""
from django.db.models.expressions import Value

# Special-case implicit M2M tables.
if field.many_to_many and field.remote_field.through._meta.auto_created:
self.create_model(field.remote_field.through)
Expand All @@ -394,6 +404,12 @@ def add_field(self, model, field):
# COLUMN statement because DROP DEFAULT is not supported in
# ALTER TABLE.
or self.effective_default(field) is not None
# Fields with non-constant defaults cannot by handled by ALTER
# TABLE ADD COLUMN statement.
or (
field.db_default is not NOT_PROVIDED
and not isinstance(field.db_default, Value)
)
):
self._remake_table(model, create_field=field)
else:
Expand Down
2 changes: 2 additions & 0 deletions django/db/migrations/autodetector.py
Expand Up @@ -1040,6 +1040,7 @@ def _generate_added_field(self, app_label, model_name, field_name):
preserve_default = (
field.null
or field.has_default()
or field.db_default is not models.NOT_PROVIDED
or field.many_to_many
or (field.blank and field.empty_strings_allowed)
or (isinstance(field, time_fields) and field.auto_now)
Expand Down Expand Up @@ -1187,6 +1188,7 @@ def generate_altered_fields(self):
old_field.null
and not new_field.null
and not new_field.has_default()
and new_field.db_default is models.NOT_PROVIDED
and not new_field.many_to_many
):
field = new_field.clone()
Expand Down

0 comments on commit 7414704

Please sign in to comment.