first commit

This commit is contained in:
Maxim
2025-12-11 18:15:56 +03:00
commit d451ca7d3a
6071 changed files with 786794 additions and 0 deletions

View File

@@ -0,0 +1,445 @@
"""
MySQL database backend for Django.
Requires mysqlclient: https://pypi.org/project/mysqlclient/
"""
from django.core.exceptions import ImproperlyConfigured
from django.db import IntegrityError
from django.db.backends import utils as backend_utils
from django.db.backends.base.base import BaseDatabaseWrapper
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
from django.utils.regex_helper import _lazy_re_compile
try:
import MySQLdb as Database
except ImportError as err:
raise ImproperlyConfigured(
"Error loading MySQLdb module.\nDid you install mysqlclient?"
) from err
from MySQLdb.constants import CLIENT, FIELD_TYPE
from MySQLdb.converters import conversions
# Some of these import MySQLdb, so import them after checking if it's
# installed.
from .client import DatabaseClient
from .creation import DatabaseCreation
from .features import DatabaseFeatures
from .introspection import DatabaseIntrospection
from .operations import DatabaseOperations
from .schema import DatabaseSchemaEditor
from .validation import DatabaseValidation
version = Database.version_info
if version < (2, 2, 1):
raise ImproperlyConfigured(
"mysqlclient 2.2.1 or newer is required; you have %s." % Database.__version__
)
# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
# terms of actual behavior as they are signed and include days -- and Django
# expects time.
django_conversions = {
**conversions,
**{FIELD_TYPE.TIME: backend_utils.typecast_time},
}
# This should match the numerical portion of the version numbers (we can treat
# versions like 5.0.24 and 5.0.24a as the same).
server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
class CursorWrapper:
"""
A thin wrapper around MySQLdb's normal cursor class that catches particular
exception instances and reraises them with the correct types.
Implemented as a wrapper, rather than a subclass, so that it isn't stuck
to the particular underlying representation returned by
Connection.cursor().
"""
codes_for_integrityerror = (
1048, # Column cannot be null
1690, # BIGINT UNSIGNED value is out of range
3819, # CHECK constraint is violated
4025, # CHECK constraint failed
)
def __init__(self, cursor):
self.cursor = cursor
def execute(self, query, args=None):
try:
# args is None means no string interpolation
return self.cursor.execute(query, args)
except Database.OperationalError as e:
# Map some error codes to IntegrityError, since they seem to be
# misclassified and Django would prefer the more logical place.
if e.args[0] in self.codes_for_integrityerror:
raise IntegrityError(*tuple(e.args))
raise
def executemany(self, query, args):
try:
return self.cursor.executemany(query, args)
except Database.OperationalError as e:
# Map some error codes to IntegrityError, since they seem to be
# misclassified and Django would prefer the more logical place.
if e.args[0] in self.codes_for_integrityerror:
raise IntegrityError(*tuple(e.args))
raise
def __getattr__(self, attr):
return getattr(self.cursor, attr)
def __iter__(self):
return iter(self.cursor)
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = "mysql"
# This dictionary maps Field objects to their associated MySQL column
# types, as strings. Column-type strings can contain format strings;
# they'll be interpolated against the values of Field.__dict__ before being
# output. If a column type is set to None, it won't be included in the
# output.
_data_types = {
"AutoField": "integer AUTO_INCREMENT",
"BigAutoField": "bigint AUTO_INCREMENT",
"BinaryField": "longblob",
"BooleanField": "bool",
"CharField": "varchar(%(max_length)s)",
"DateField": "date",
"DateTimeField": "datetime(6)",
"DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
"DurationField": "bigint",
"FileField": "varchar(%(max_length)s)",
"FilePathField": "varchar(%(max_length)s)",
"FloatField": "double precision",
"IntegerField": "integer",
"BigIntegerField": "bigint",
"IPAddressField": "char(15)",
"GenericIPAddressField": "char(39)",
"JSONField": "json",
"PositiveBigIntegerField": "bigint UNSIGNED",
"PositiveIntegerField": "integer UNSIGNED",
"PositiveSmallIntegerField": "smallint UNSIGNED",
"SlugField": "varchar(%(max_length)s)",
"SmallAutoField": "smallint AUTO_INCREMENT",
"SmallIntegerField": "smallint",
"TextField": "longtext",
"TimeField": "time(6)",
"UUIDField": "char(32)",
}
@cached_property
def data_types(self):
_data_types = self._data_types.copy()
if self.features.has_native_uuid_field:
_data_types["UUIDField"] = "uuid"
return _data_types
# For these data types:
# - MySQL < 8.0.13 doesn't accept default values and implicitly treats them
# as nullable
# - all versions of MySQL and MariaDB don't support full width database
# indexes
_limited_data_types = (
"tinyblob",
"blob",
"mediumblob",
"longblob",
"tinytext",
"text",
"mediumtext",
"longtext",
"json",
)
operators = {
"exact": "= %s",
"iexact": "LIKE %s",
"contains": "LIKE BINARY %s",
"icontains": "LIKE %s",
"gt": "> %s",
"gte": ">= %s",
"lt": "< %s",
"lte": "<= %s",
"startswith": "LIKE BINARY %s",
"endswith": "LIKE BINARY %s",
"istartswith": "LIKE %s",
"iendswith": "LIKE %s",
}
# The patterns below are used to generate SQL pattern lookup clauses when
# the right-hand side of the lookup isn't a raw string (it might be an
# expression or the result of a bilateral transformation). In those cases,
# special characters for LIKE operators (e.g. \, *, _) should be escaped on
# database side.
#
# Note: we use str.format() here for readability as '%' is used as a
# wildcard for the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
pattern_ops = {
"contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
"icontains": "LIKE CONCAT('%%', {}, '%%')",
"startswith": "LIKE BINARY CONCAT({}, '%%')",
"istartswith": "LIKE CONCAT({}, '%%')",
"endswith": "LIKE BINARY CONCAT('%%', {})",
"iendswith": "LIKE CONCAT('%%', {})",
}
isolation_levels = {
"read uncommitted",
"read committed",
"repeatable read",
"serializable",
}
Database = Database
SchemaEditorClass = DatabaseSchemaEditor
# Classes instantiated in __init__().
client_class = DatabaseClient
creation_class = DatabaseCreation
features_class = DatabaseFeatures
introspection_class = DatabaseIntrospection
ops_class = DatabaseOperations
validation_class = DatabaseValidation
def get_database_version(self):
return self.mysql_version
def get_connection_params(self):
kwargs = {
"conv": django_conversions,
"charset": "utf8mb4",
}
settings_dict = self.settings_dict
if settings_dict["USER"]:
kwargs["user"] = settings_dict["USER"]
if settings_dict["NAME"]:
kwargs["database"] = settings_dict["NAME"]
if settings_dict["PASSWORD"]:
kwargs["password"] = settings_dict["PASSWORD"]
if settings_dict["HOST"].startswith("/"):
kwargs["unix_socket"] = settings_dict["HOST"]
elif settings_dict["HOST"]:
kwargs["host"] = settings_dict["HOST"]
if settings_dict["PORT"]:
kwargs["port"] = int(settings_dict["PORT"])
# We need the number of potentially affected rows after an
# "UPDATE", not the number of changed rows.
kwargs["client_flag"] = CLIENT.FOUND_ROWS
# Validate the transaction isolation level, if specified.
options = settings_dict["OPTIONS"].copy()
isolation_level = options.pop("isolation_level", "read committed")
if isolation_level:
isolation_level = isolation_level.lower()
if isolation_level not in self.isolation_levels:
raise ImproperlyConfigured(
"Invalid transaction isolation level '%s' specified.\n"
"Use one of %s, or None."
% (
isolation_level,
", ".join("'%s'" % s for s in sorted(self.isolation_levels)),
)
)
self.isolation_level = isolation_level
kwargs.update(options)
return kwargs
@async_unsafe
def get_new_connection(self, conn_params):
connection = Database.connect(**conn_params)
return connection
def init_connection_state(self):
super().init_connection_state()
assignments = []
if self.features.is_sql_auto_is_null_enabled:
# SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
# a recently inserted row will return when the field is tested
# for NULL. Disabling this brings this aspect of MySQL in line
# with SQL standards.
assignments.append("SET SQL_AUTO_IS_NULL = 0")
if self.isolation_level:
assignments.append(
"SET SESSION TRANSACTION ISOLATION LEVEL %s"
% self.isolation_level.upper()
)
if assignments:
with self.cursor() as cursor:
cursor.execute("; ".join(assignments))
@async_unsafe
def create_cursor(self, name=None):
cursor = self.connection.cursor()
return CursorWrapper(cursor)
def _rollback(self):
try:
BaseDatabaseWrapper._rollback(self)
except Database.NotSupportedError:
pass
def _set_autocommit(self, autocommit):
with self.wrap_database_errors:
self.connection.autocommit(autocommit)
def disable_constraint_checking(self):
"""
Disable foreign key checks, primarily for use in adding rows with
forward references. Always return True to indicate constraint checks
need to be re-enabled.
"""
with self.cursor() as cursor:
cursor.execute("SET foreign_key_checks=0")
return True
def enable_constraint_checking(self):
"""
Re-enable foreign key checks after they have been disabled.
"""
# Override needs_rollback in case constraint_checks_disabled is
# nested inside transaction.atomic.
self.needs_rollback, needs_rollback = False, self.needs_rollback
try:
with self.cursor() as cursor:
cursor.execute("SET foreign_key_checks=1")
finally:
self.needs_rollback = needs_rollback
def check_constraints(self, table_names=None):
"""
Check each table name in `table_names` for rows with invalid foreign
key references. This method is intended to be used in conjunction with
`disable_constraint_checking()` and `enable_constraint_checking()`, to
determine if rows with invalid references were entered while constraint
checks were off.
"""
with self.cursor() as cursor:
if table_names is None:
table_names = self.introspection.table_names(cursor)
for table_name in table_names:
primary_key_column_name = self.introspection.get_primary_key_column(
cursor, table_name
)
if not primary_key_column_name:
continue
relations = self.introspection.get_relations(cursor, table_name)
for column_name, (
referenced_column_name,
referenced_table_name,
) in relations.items():
cursor.execute(
"""
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
LEFT JOIN `%s` as REFERRED
ON (REFERRING.`%s` = REFERRED.`%s`)
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
"""
% (
primary_key_column_name,
column_name,
table_name,
referenced_table_name,
column_name,
referenced_column_name,
column_name,
referenced_column_name,
)
)
for bad_row in cursor.fetchall():
raise IntegrityError(
"The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that "
"does not have a corresponding value in %s.%s."
% (
table_name,
bad_row[0],
table_name,
column_name,
bad_row[1],
referenced_table_name,
referenced_column_name,
)
)
def is_usable(self):
try:
self.connection.ping()
except Database.Error:
return False
else:
return True
@cached_property
def display_name(self):
return "MariaDB" if self.mysql_is_mariadb else "MySQL"
@cached_property
def data_type_check_constraints(self):
if self.features.supports_column_check_constraints:
check_constraints = {
"PositiveBigIntegerField": "`%(column)s` >= 0",
"PositiveIntegerField": "`%(column)s` >= 0",
"PositiveSmallIntegerField": "`%(column)s` >= 0",
}
return check_constraints
return {}
@cached_property
def mysql_server_data(self):
with self.temporary_connection() as cursor:
# Select some server variables and test if the time zone
# definitions are installed. CONVERT_TZ returns NULL if 'UTC'
# timezone isn't loaded into the mysql.time_zone table.
cursor.execute(
"""
SELECT VERSION(),
@@sql_mode,
@@default_storage_engine,
@@sql_auto_is_null,
@@lower_case_table_names,
CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
"""
)
row = cursor.fetchone()
return {
"version": row[0],
"sql_mode": row[1],
"default_storage_engine": row[2],
"sql_auto_is_null": bool(row[3]),
"lower_case_table_names": bool(row[4]),
"has_zoneinfo_database": bool(row[5]),
}
@cached_property
def mysql_server_info(self):
return self.mysql_server_data["version"]
@cached_property
def mysql_version(self):
match = server_version_re.match(self.mysql_server_info)
if not match:
raise Exception(
"Unable to determine MySQL version from version string %r"
% self.mysql_server_info
)
return tuple(int(x) for x in match.groups())
@cached_property
def mysql_is_mariadb(self):
return "mariadb" in self.mysql_server_info.lower()
@cached_property
def sql_mode(self):
sql_mode = self.mysql_server_data["sql_mode"]
return set(sql_mode.split(",") if sql_mode else ())

View File

@@ -0,0 +1,72 @@
import signal
from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = "mysql"
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name]
env = None
database = settings_dict["OPTIONS"].get(
"database",
settings_dict["OPTIONS"].get("db", settings_dict["NAME"]),
)
user = settings_dict["OPTIONS"].get("user", settings_dict["USER"])
password = settings_dict["OPTIONS"].get(
"password",
settings_dict["OPTIONS"].get("passwd", settings_dict["PASSWORD"]),
)
host = settings_dict["OPTIONS"].get("host", settings_dict["HOST"])
port = settings_dict["OPTIONS"].get("port", settings_dict["PORT"])
server_ca = settings_dict["OPTIONS"].get("ssl", {}).get("ca")
client_cert = settings_dict["OPTIONS"].get("ssl", {}).get("cert")
client_key = settings_dict["OPTIONS"].get("ssl", {}).get("key")
defaults_file = settings_dict["OPTIONS"].get("read_default_file")
charset = settings_dict["OPTIONS"].get("charset")
# Seems to be no good way to set sql_mode with CLI.
if defaults_file:
args += ["--defaults-file=%s" % defaults_file]
if user:
args += ["--user=%s" % user]
if password:
# The MYSQL_PWD environment variable usage is discouraged per
# MySQL's documentation due to the possibility of exposure through
# `ps` on old Unix flavors but --password suffers from the same
# flaw on even more systems. Usage of an environment variable also
# prevents password exposure if the subprocess.run(check=True) call
# raises a CalledProcessError since the string representation of
# the latter includes all of the provided `args`.
env = {"MYSQL_PWD": password}
if host:
if "/" in host:
args += ["--socket=%s" % host]
else:
args += ["--host=%s" % host]
if port:
args += ["--port=%s" % port]
if server_ca:
args += ["--ssl-ca=%s" % server_ca]
if client_cert:
args += ["--ssl-cert=%s" % client_cert]
if client_key:
args += ["--ssl-key=%s" % client_key]
if charset:
args += ["--default-character-set=%s" % charset]
if database:
args += [database]
args.extend(parameters)
return args, env
def runshell(self, parameters):
sigint_handler = signal.getsignal(signal.SIGINT)
try:
# Allow SIGINT to pass to mysql to abort queries.
signal.signal(signal.SIGINT, signal.SIG_IGN)
super().runshell(parameters)
finally:
# Restore the original SIGINT handler.
signal.signal(signal.SIGINT, sigint_handler)

View File

@@ -0,0 +1,72 @@
from django.core.exceptions import FieldError, FullResultSet
from django.db.models.expressions import Col
from django.db.models.sql.compiler import SQLAggregateCompiler, SQLCompiler
from django.db.models.sql.compiler import SQLDeleteCompiler as BaseSQLDeleteCompiler
from django.db.models.sql.compiler import SQLInsertCompiler
from django.db.models.sql.compiler import SQLUpdateCompiler as BaseSQLUpdateCompiler
__all__ = [
"SQLAggregateCompiler",
"SQLCompiler",
"SQLDeleteCompiler",
"SQLInsertCompiler",
"SQLUpdateCompiler",
]
class SQLDeleteCompiler(BaseSQLDeleteCompiler):
def as_sql(self):
# Prefer the non-standard DELETE FROM syntax over the SQL generated by
# the SQLDeleteCompiler's default implementation when multiple tables
# are involved since MySQL/MariaDB will generate a more efficient query
# plan than when using a subquery.
where, having, qualify = self.query.where.split_having_qualify(
must_group_by=self.query.group_by is not None
)
if self.single_alias or having or qualify:
# DELETE FROM cannot be used when filtering against aggregates or
# window functions as it doesn't allow for GROUP BY/HAVING clauses
# and the subquery wrapping (necessary to emulate QUALIFY).
return super().as_sql()
result = [
"DELETE %s FROM"
% self.quote_name_unless_alias(self.query.get_initial_alias())
]
from_sql, params = self.get_from_clause()
result.extend(from_sql)
try:
where_sql, where_params = self.compile(where)
except FullResultSet:
pass
else:
result.append("WHERE %s" % where_sql)
params.extend(where_params)
return " ".join(result), tuple(params)
class SQLUpdateCompiler(BaseSQLUpdateCompiler):
def as_sql(self):
update_query, update_params = super().as_sql()
# MySQL and MariaDB support UPDATE ... ORDER BY syntax.
if self.query.order_by:
order_by_sql = []
order_by_params = []
db_table = self.query.get_meta().db_table
try:
for resolved, (sql, params, _) in self.get_order_by():
if (
isinstance(resolved.expression, Col)
and resolved.expression.alias != db_table
):
# Ignore ordering if it contains joined fields, because
# they cannot be used in the ORDER BY clause.
raise FieldError
order_by_sql.append(sql)
order_by_params.extend(params)
update_query += " ORDER BY " + ", ".join(order_by_sql)
update_params += tuple(order_by_params)
except FieldError:
# Ignore ordering if it contains annotations, because they're
# removed in .update() and cannot be resolved.
pass
return update_query, update_params

View File

@@ -0,0 +1,100 @@
import os
import subprocess
import sys
from django.db.backends.base.creation import BaseDatabaseCreation
from .client import DatabaseClient
class DatabaseCreation(BaseDatabaseCreation):
def sql_table_creation_suffix(self):
suffix = []
test_settings = self.connection.settings_dict["TEST"]
if test_settings["CHARSET"]:
suffix.append("CHARACTER SET %s" % test_settings["CHARSET"])
if test_settings["COLLATION"]:
suffix.append("COLLATE %s" % test_settings["COLLATION"])
return " ".join(suffix)
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
try:
super()._execute_create_test_db(cursor, parameters, keepdb)
except Exception as e:
if len(e.args) < 1 or e.args[0] != 1007:
# All errors except "database exists" (1007) cancel tests.
self.log("Got an error creating the test database: %s" % e)
sys.exit(2)
else:
raise
def _clone_test_db(self, suffix, verbosity, keepdb=False):
source_database_name = self.connection.settings_dict["NAME"]
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
test_db_params = {
"dbname": self.connection.ops.quote_name(target_database_name),
"suffix": self.sql_table_creation_suffix(),
}
with self._nodb_cursor() as cursor:
try:
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception:
if keepdb:
# If the database should be kept, skip everything else.
return
try:
if verbosity >= 1:
self.log(
"Destroying old test database for alias %s..."
% (
self._get_database_display_str(
verbosity, target_database_name
),
)
)
cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
self.log("Got an error recreating the test database: %s" % e)
sys.exit(2)
self._clone_db(source_database_name, target_database_name)
def _clone_db(self, source_database_name, target_database_name):
cmd_args, cmd_env = DatabaseClient.settings_to_cmd_args_env(
self.connection.settings_dict, []
)
dump_cmd = [
"mysqldump",
*cmd_args[1:-1],
"--routines",
"--events",
source_database_name,
]
dump_env = load_env = {**os.environ, **cmd_env} if cmd_env else None
load_cmd = cmd_args
load_cmd[-1] = target_database_name
with (
subprocess.Popen(
dump_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=dump_env
) as dump_proc,
subprocess.Popen(
load_cmd,
stdin=dump_proc.stdout,
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
env=load_env,
) as load_proc,
):
# Allow dump_proc to receive a SIGPIPE if the load process exits.
dump_proc.stdout.close()
dump_err = dump_proc.stderr.read().decode(errors="replace")
load_err = load_proc.stderr.read().decode(errors="replace")
if dump_proc.returncode != 0:
self.log(
f"Got an error on mysqldump when cloning the test database: {dump_err}"
)
sys.exit(dump_proc.returncode)
if load_proc.returncode != 0:
self.log(f"Got an error cloning the test database: {load_err}")
sys.exit(load_proc.returncode)

View File

@@ -0,0 +1,295 @@
import operator
from django.db.backends.base.features import BaseDatabaseFeatures
from django.utils.functional import cached_property
class DatabaseFeatures(BaseDatabaseFeatures):
empty_fetchmany_value = ()
related_fields_match_type = True
# MySQL doesn't support sliced subqueries with IN/ALL/ANY/SOME.
allow_sliced_subqueries_with_in = False
has_select_for_update = True
has_select_for_update_nowait = True
has_select_for_update_skip_locked = True
supports_forward_references = False
supports_regex_backreferencing = False
supports_date_lookup_using_string = False
supports_timezones = False
requires_explicit_null_ordering_when_grouping = True
atomic_transactions = False
can_clone_databases = True
supports_aggregate_order_by_clause = True
supports_comments = True
supports_comments_inline = True
supports_temporal_subtraction = True
supports_slicing_ordering_in_compound = True
supports_index_on_text_field = False
supports_over_clause = True
supports_frame_range_fixed_distance = True
supports_update_conflicts = True
can_rename_index = True
delete_can_self_reference_subquery = False
create_test_procedure_without_params_sql = """
CREATE PROCEDURE test_procedure ()
BEGIN
DECLARE V_I INTEGER;
SET V_I = 1;
END;
"""
create_test_procedure_with_int_param_sql = """
CREATE PROCEDURE test_procedure (P_I INTEGER)
BEGIN
DECLARE V_I INTEGER;
SET V_I = P_I;
END;
"""
# Neither MySQL nor MariaDB support partial indexes.
supports_partial_indexes = False
# COLLATE must be wrapped in parentheses because MySQL treats COLLATE as an
# indexed expression.
collate_as_index_expression = True
insert_test_table_with_defaults = "INSERT INTO {} () VALUES ()"
supports_order_by_nulls_modifier = False
order_by_nulls_first = True
supports_logical_xor = True
supports_stored_generated_columns = True
supports_virtual_generated_columns = True
supports_json_negative_indexing = False
@cached_property
def minimum_database_version(self):
if self.connection.mysql_is_mariadb:
return (10, 6)
else:
return (8, 0, 11)
@cached_property
def test_collations(self):
return {
"ci": "utf8mb4_general_ci",
"non_default": "utf8mb4_esperanto_ci",
"swedish_ci": "utf8mb4_swedish_ci",
"virtual": "utf8mb4_esperanto_ci",
}
test_now_utc_template = "UTC_TIMESTAMP(6)"
@cached_property
def django_test_skips(self):
skips = {
"This doesn't work on MySQL.": {
"db_functions.comparison.test_greatest.GreatestTests."
"test_coalesce_workaround",
"db_functions.comparison.test_least.LeastTests."
"test_coalesce_workaround",
},
"MySQL doesn't support functional indexes on a function that "
"returns JSON": {
"schema.tests.SchemaTests.test_func_index_json_key_transform",
},
"MySQL supports multiplying and dividing DurationFields by a "
"scalar value but it's not implemented (#25287).": {
"expressions.tests.FTimeDeltaTests.test_durationfield_multiply_divide",
},
"UPDATE ... ORDER BY syntax on MySQL/MariaDB does not support ordering by"
"related fields.": {
"update.tests.AdvancedTests."
"test_update_ordered_by_inline_m2m_annotation",
"update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation",
"update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation_desc",
},
}
if not self.supports_explain_analyze:
skips.update(
{
"MariaDB and MySQL >= 8.0.18 specific.": {
"queries.test_explain.ExplainTests.test_mysql_analyze",
},
}
)
if self.connection.mysql_version < (8, 0, 31):
skips.update(
{
"Nesting of UNIONs at the right-hand side is not supported on "
"MySQL < 8.0.31": {
"queries.test_qs_combinators.QuerySetSetOperationTests."
"test_union_nested"
},
}
)
if not self.connection.mysql_is_mariadb:
skips.update(
{
"MySQL doesn't allow renaming columns referenced by generated "
"columns": {
"migrations.test_operations.OperationTests."
"test_invalid_generated_field_changes_on_rename_stored",
"migrations.test_operations.OperationTests."
"test_invalid_generated_field_changes_on_rename_virtual",
},
}
)
return skips
@cached_property
def _mysql_storage_engine(self):
"""
Internal method used in Django tests. Don't rely on this from your code
"""
return self.connection.mysql_server_data["default_storage_engine"]
@cached_property
def allows_auto_pk_0(self):
"""
Autoincrement primary key can be set to 0 if it doesn't generate new
autoincrement values.
"""
return "NO_AUTO_VALUE_ON_ZERO" in self.connection.sql_mode
@cached_property
def update_can_self_select(self):
return self.connection.mysql_is_mariadb
@cached_property
def can_introspect_foreign_keys(self):
"Confirm support for introspected foreign keys"
return self._mysql_storage_engine != "MyISAM"
@cached_property
def introspected_field_types(self):
return {
**super().introspected_field_types,
"BinaryField": "TextField",
"BooleanField": "IntegerField",
"DurationField": "BigIntegerField",
"GenericIPAddressField": "CharField",
}
@cached_property
def can_return_columns_from_insert(self):
return self.connection.mysql_is_mariadb
can_return_rows_from_bulk_insert = property(
operator.attrgetter("can_return_columns_from_insert")
)
@cached_property
def has_zoneinfo_database(self):
return self.connection.mysql_server_data["has_zoneinfo_database"]
@cached_property
def is_sql_auto_is_null_enabled(self):
return self.connection.mysql_server_data["sql_auto_is_null"]
@cached_property
def supports_column_check_constraints(self):
if self.connection.mysql_is_mariadb:
return True
return self.connection.mysql_version >= (8, 0, 16)
supports_table_check_constraints = property(
operator.attrgetter("supports_column_check_constraints")
)
@cached_property
def can_introspect_check_constraints(self):
if self.connection.mysql_is_mariadb:
return True
return self.connection.mysql_version >= (8, 0, 16)
@cached_property
def has_select_for_update_of(self):
return not self.connection.mysql_is_mariadb
@cached_property
def supports_explain_analyze(self):
return self.connection.mysql_is_mariadb or self.connection.mysql_version >= (
8,
0,
18,
)
@cached_property
def supported_explain_formats(self):
# Alias MySQL's TRADITIONAL to TEXT for consistency with other
# backends.
formats = {"JSON", "TEXT", "TRADITIONAL"}
if not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
8,
0,
16,
):
formats.add("TREE")
return formats
@cached_property
def supports_transactions(self):
"""
All storage engines except MyISAM support transactions.
"""
return self._mysql_storage_engine != "MyISAM"
@cached_property
def ignores_table_name_case(self):
return self.connection.mysql_server_data["lower_case_table_names"]
@cached_property
def supports_default_in_lead_lag(self):
# To be added in https://jira.mariadb.org/browse/MDEV-12981.
return not self.connection.mysql_is_mariadb
@cached_property
def can_introspect_json_field(self):
if self.connection.mysql_is_mariadb:
return self.can_introspect_check_constraints
return True
@cached_property
def supports_index_column_ordering(self):
if self._mysql_storage_engine != "InnoDB":
return False
if self.connection.mysql_is_mariadb:
return self.connection.mysql_version >= (10, 8)
return True
@cached_property
def supports_expression_indexes(self):
return (
not self.connection.mysql_is_mariadb
and self._mysql_storage_engine != "MyISAM"
and self.connection.mysql_version >= (8, 0, 13)
)
@cached_property
def supports_select_intersection(self):
is_mariadb = self.connection.mysql_is_mariadb
return is_mariadb or self.connection.mysql_version >= (8, 0, 31)
supports_select_difference = property(
operator.attrgetter("supports_select_intersection")
)
@cached_property
def supports_expression_defaults(self):
if self.connection.mysql_is_mariadb:
return True
return self.connection.mysql_version >= (8, 0, 13)
@cached_property
def has_native_uuid_field(self):
is_mariadb = self.connection.mysql_is_mariadb
return is_mariadb and self.connection.mysql_version >= (10, 7)
@cached_property
def allows_group_by_selected_pks(self):
if self.connection.mysql_is_mariadb:
return "ONLY_FULL_GROUP_BY" not in self.connection.sql_mode
return True
@cached_property
def supports_any_value(self):
return not self.connection.mysql_is_mariadb

View File

@@ -0,0 +1,365 @@
from collections import namedtuple
import sqlparse
from MySQLdb.constants import FIELD_TYPE
from django.db.backends.base.introspection import BaseDatabaseIntrospection
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
from django.db.backends.base.introspection import TableInfo as BaseTableInfo
from django.db.models import Index
from django.utils.datastructures import OrderedSet
FieldInfo = namedtuple(
"FieldInfo",
[
*BaseFieldInfo._fields,
"extra",
"is_unsigned",
"has_json_constraint",
"comment",
"data_type",
],
)
InfoLine = namedtuple(
"InfoLine",
"col_name data_type max_len num_prec num_scale extra column_default "
"collation is_unsigned comment",
)
TableInfo = namedtuple("TableInfo", [*BaseTableInfo._fields, "comment"])
class DatabaseIntrospection(BaseDatabaseIntrospection):
data_types_reverse = {
FIELD_TYPE.BLOB: "TextField",
FIELD_TYPE.CHAR: "CharField",
FIELD_TYPE.DECIMAL: "DecimalField",
FIELD_TYPE.NEWDECIMAL: "DecimalField",
FIELD_TYPE.DATE: "DateField",
FIELD_TYPE.DATETIME: "DateTimeField",
FIELD_TYPE.DOUBLE: "FloatField",
FIELD_TYPE.FLOAT: "FloatField",
FIELD_TYPE.INT24: "IntegerField",
FIELD_TYPE.JSON: "JSONField",
FIELD_TYPE.LONG: "IntegerField",
FIELD_TYPE.LONGLONG: "BigIntegerField",
FIELD_TYPE.SHORT: "SmallIntegerField",
FIELD_TYPE.STRING: "CharField",
FIELD_TYPE.TIME: "TimeField",
FIELD_TYPE.TIMESTAMP: "DateTimeField",
FIELD_TYPE.TINY: "IntegerField",
FIELD_TYPE.TINY_BLOB: "TextField",
FIELD_TYPE.MEDIUM_BLOB: "TextField",
FIELD_TYPE.LONG_BLOB: "TextField",
FIELD_TYPE.VAR_STRING: "CharField",
}
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if "auto_increment" in description.extra:
if field_type == "IntegerField":
return "AutoField"
elif field_type == "BigIntegerField":
return "BigAutoField"
elif field_type == "SmallIntegerField":
return "SmallAutoField"
if description.is_unsigned:
if field_type == "BigIntegerField":
return "PositiveBigIntegerField"
elif field_type == "IntegerField":
return "PositiveIntegerField"
elif field_type == "SmallIntegerField":
return "PositiveSmallIntegerField"
if description.data_type.upper() == "UUID":
return "UUIDField"
# JSON data type is an alias for LONGTEXT in MariaDB, use check
# constraints clauses to introspect JSONField.
if description.has_json_constraint:
return "JSONField"
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
cursor.execute(
"""
SELECT
table_name,
table_type,
table_comment
FROM information_schema.tables
WHERE table_schema = DATABASE()
"""
)
return [
TableInfo(row[0], {"BASE TABLE": "t", "VIEW": "v"}.get(row[1]), row[2])
for row in cursor.fetchall()
]
def get_table_description(self, cursor, table_name):
"""
Return a description of the table with the DB-API cursor.description
interface."
"""
json_constraints = {}
if (
self.connection.mysql_is_mariadb
and self.connection.features.can_introspect_json_field
):
# JSON data type is an alias for LONGTEXT in MariaDB, select
# JSON_VALID() constraints to introspect JSONField.
cursor.execute(
"""
SELECT c.constraint_name AS column_name
FROM information_schema.check_constraints AS c
WHERE
c.table_name = %s AND
LOWER(c.check_clause) =
'json_valid(`' + LOWER(c.constraint_name) + '`)' AND
c.constraint_schema = DATABASE()
""",
[table_name],
)
json_constraints = {row[0] for row in cursor.fetchall()}
# A default collation for the given table.
cursor.execute(
"""
SELECT table_collation
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = %s
""",
[table_name],
)
row = cursor.fetchone()
default_column_collation = row[0] if row else ""
# information_schema database gives more accurate results for some
# figures:
# - varchar length returned by cursor.description is an internal
# length, not visible length (#5725)
# - precision and scale (for decimal fields) (#5014)
# - auto_increment is not available in cursor.description
cursor.execute(
"""
SELECT
column_name, data_type, character_maximum_length,
numeric_precision, numeric_scale, extra, column_default,
CASE
WHEN collation_name = %s THEN NULL
ELSE collation_name
END AS collation_name,
CASE
WHEN column_type LIKE '%% unsigned' THEN 1
ELSE 0
END AS is_unsigned,
column_comment
FROM information_schema.columns
WHERE table_name = %s AND table_schema = DATABASE()
""",
[default_column_collation, table_name],
)
field_info = {line[0]: InfoLine(*line) for line in cursor.fetchall()}
cursor.execute(
"SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)
)
def to_int(i):
return int(i) if i is not None else i
fields = []
for line in cursor.description:
info = field_info[line[0]]
fields.append(
FieldInfo(
*line[:2],
to_int(info.max_len) or line[2],
to_int(info.max_len) or line[3],
to_int(info.num_prec) or line[4],
to_int(info.num_scale) or line[5],
line[6],
info.column_default,
info.collation,
info.extra,
info.is_unsigned,
line[0] in json_constraints,
info.comment,
info.data_type,
)
)
return fields
def get_sequences(self, cursor, table_name, table_fields=()):
for field_info in self.get_table_description(cursor, table_name):
if "auto_increment" in field_info.extra:
# MySQL allows only one auto-increment column per table.
return [{"table": table_name, "column": field_info.name}]
return []
def get_relations(self, cursor, table_name):
"""
Return a dictionary of {field_name: (field_name_other_table,
other_table)} representing all foreign keys in the given table.
"""
cursor.execute(
"""
SELECT column_name, referenced_column_name, referenced_table_name
FROM information_schema.key_column_usage
WHERE table_name = %s
AND table_schema = DATABASE()
AND referenced_table_schema = DATABASE()
AND referenced_table_name IS NOT NULL
AND referenced_column_name IS NOT NULL
""",
[table_name],
)
return {
field_name: (other_field, other_table)
for field_name, other_field, other_table in cursor.fetchall()
}
def get_storage_engine(self, cursor, table_name):
"""
Retrieve the storage engine for a given table. Return the default
storage engine if the table doesn't exist.
"""
cursor.execute(
"""
SELECT engine
FROM information_schema.tables
WHERE
table_name = %s AND
table_schema = DATABASE()
""",
[table_name],
)
result = cursor.fetchone()
if not result:
return self.connection.features._mysql_storage_engine
return result[0]
def _parse_constraint_columns(self, check_clause, columns):
check_columns = OrderedSet()
statement = sqlparse.parse(check_clause)[0]
tokens = (token for token in statement.flatten() if not token.is_whitespace)
for token in tokens:
if (
token.ttype == sqlparse.tokens.Name
and self.connection.ops.quote_name(token.value) == token.value
and token.value[1:-1] in columns
):
check_columns.add(token.value[1:-1])
return check_columns
def get_constraints(self, cursor, table_name):
"""
Retrieve any constraints or keys (unique, pk, fk, check, index) across
one or more columns.
"""
constraints = {}
# Get the actual constraint names and columns
name_query = """
SELECT kc.`constraint_name`, kc.`column_name`,
kc.`referenced_table_name`, kc.`referenced_column_name`,
c.`constraint_type`
FROM
information_schema.key_column_usage AS kc,
information_schema.table_constraints AS c
WHERE
kc.table_schema = DATABASE() AND
(
kc.referenced_table_schema = DATABASE() OR
kc.referenced_table_schema IS NULL
) AND
c.table_schema = kc.table_schema AND
c.constraint_name = kc.constraint_name AND
c.constraint_type != 'CHECK' AND
kc.table_name = %s
ORDER BY kc.`ordinal_position`
"""
cursor.execute(name_query, [table_name])
for constraint, column, ref_table, ref_column, kind in cursor.fetchall():
if constraint not in constraints:
constraints[constraint] = {
"columns": OrderedSet(),
"primary_key": kind == "PRIMARY KEY",
"unique": kind in {"PRIMARY KEY", "UNIQUE"},
"index": False,
"check": False,
"foreign_key": (ref_table, ref_column) if ref_column else None,
}
if self.connection.features.supports_index_column_ordering:
constraints[constraint]["orders"] = []
constraints[constraint]["columns"].add(column)
# Add check constraints.
if self.connection.features.can_introspect_check_constraints:
unnamed_constraints_index = 0
columns = {
info.name for info in self.get_table_description(cursor, table_name)
}
if self.connection.mysql_is_mariadb:
type_query = """
SELECT c.constraint_name, c.check_clause
FROM information_schema.check_constraints AS c
WHERE
c.constraint_schema = DATABASE() AND
c.table_name = %s
"""
else:
type_query = """
SELECT cc.constraint_name, cc.check_clause
FROM
information_schema.check_constraints AS cc,
information_schema.table_constraints AS tc
WHERE
cc.constraint_schema = DATABASE() AND
tc.table_schema = cc.constraint_schema AND
cc.constraint_name = tc.constraint_name AND
tc.constraint_type = 'CHECK' AND
tc.table_name = %s
"""
cursor.execute(type_query, [table_name])
for constraint, check_clause in cursor.fetchall():
constraint_columns = self._parse_constraint_columns(
check_clause, columns
)
# Ensure uniqueness of unnamed constraints. Unnamed unique
# and check columns constraints have the same name as
# a column.
if set(constraint_columns) == {constraint}:
unnamed_constraints_index += 1
constraint = "__unnamed_constraint_%s__" % unnamed_constraints_index
constraints[constraint] = {
"columns": constraint_columns,
"primary_key": False,
"unique": False,
"index": False,
"check": True,
"foreign_key": None,
}
# Now add in the indexes
cursor.execute(
"SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name)
)
for table, non_unique, index, colseq, column, order, type_ in [
x[:6] + (x[10],) for x in cursor.fetchall()
]:
if index not in constraints:
constraints[index] = {
"columns": OrderedSet(),
"primary_key": False,
"unique": not non_unique,
"check": False,
"foreign_key": None,
}
if self.connection.features.supports_index_column_ordering:
constraints[index]["orders"] = []
constraints[index]["index"] = True
constraints[index]["type"] = (
Index.suffix if type_ == "BTREE" else type_.lower()
)
constraints[index]["columns"].add(column)
if self.connection.features.supports_index_column_ordering:
constraints[index]["orders"].append("DESC" if order == "D" else "ASC")
# Convert the sorted sets to lists
for constraint in constraints.values():
constraint["columns"] = list(constraint["columns"])
return constraints

View File

@@ -0,0 +1,435 @@
import uuid
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta
from django.db.models import Exists, ExpressionWrapper, Lookup
from django.db.models.constants import OnConflict
from django.utils import timezone
from django.utils.encoding import force_str
from django.utils.regex_helper import _lazy_re_compile
class DatabaseOperations(BaseDatabaseOperations):
compiler_module = "django.db.backends.mysql.compiler"
# MySQL stores positive fields as UNSIGNED ints.
integer_field_ranges = {
**BaseDatabaseOperations.integer_field_ranges,
"PositiveSmallIntegerField": (0, 65535),
"PositiveIntegerField": (0, 4294967295),
"PositiveBigIntegerField": (0, 18446744073709551615),
}
cast_data_types = {
"AutoField": "signed integer",
"BigAutoField": "signed integer",
"SmallAutoField": "signed integer",
"CharField": "char(%(max_length)s)",
"DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
"TextField": "char",
"IntegerField": "signed integer",
"BigIntegerField": "signed integer",
"SmallIntegerField": "signed integer",
"PositiveBigIntegerField": "unsigned integer",
"PositiveIntegerField": "unsigned integer",
"PositiveSmallIntegerField": "unsigned integer",
"DurationField": "signed integer",
}
cast_char_field_without_max_length = "char"
explain_prefix = "EXPLAIN"
# EXTRACT format cannot be passed in parameters.
_extract_format_re = _lazy_re_compile(r"[A-Z_]+")
def date_extract_sql(self, lookup_type, sql, params):
# https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
if lookup_type == "week_day":
# DAYOFWEEK() returns an integer, 1-7, Sunday=1.
return f"DAYOFWEEK({sql})", params
elif lookup_type == "iso_week_day":
# WEEKDAY() returns an integer, 0-6, Monday=0.
return f"WEEKDAY({sql}) + 1", params
elif lookup_type == "week":
# Override the value of default_week_format for consistency with
# other database backends.
# Mode 3: Monday, 1-53, with 4 or more days this year.
return f"WEEK({sql}, 3)", params
elif lookup_type == "iso_year":
# Get the year part from the YEARWEEK function, which returns a
# number as year * 100 + week.
return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
else:
# EXTRACT returns 1-53 based on ISO-8601 for the week number.
lookup_type = lookup_type.upper()
if not self._extract_format_re.fullmatch(lookup_type):
raise ValueError(f"Invalid loookup type: {lookup_type!r}")
return f"EXTRACT({lookup_type} FROM {sql})", params
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
fields = {
"year": "%Y-01-01",
"month": "%Y-%m-01",
}
if lookup_type in fields:
format_str = fields[lookup_type]
return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
elif lookup_type == "quarter":
return (
f"MAKEDATE(YEAR({sql}), 1) + "
f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
(*params, *params),
)
elif lookup_type == "week":
return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
else:
return f"DATE({sql})", params
def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname)
return f"{sign}{offset}" if offset else tzname
def _convert_sql_to_tz(self, sql, params, tzname):
if tzname and settings.USE_TZ and self.connection.timezone_name != tzname:
return f"CONVERT_TZ({sql}, %s, %s)", (
*params,
self.connection.timezone_name,
self._prepare_tzname_delta(tzname),
)
return sql, params
def datetime_cast_date_sql(self, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
return f"DATE({sql})", params
def datetime_cast_time_sql(self, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
return f"TIME({sql})", params
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
return self.date_extract_sql(lookup_type, sql, params)
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
fields = ["year", "month", "day", "hour", "minute", "second"]
format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
if lookup_type == "quarter":
return (
f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
f"INTERVAL QUARTER({sql}) QUARTER - "
f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
), (*params, *params, "%Y-%m-01 00:00:00")
if lookup_type == "week":
return (
f"CAST(DATE_FORMAT("
f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
), (*params, *params, "%Y-%m-%d 00:00:00")
try:
i = fields.index(lookup_type) + 1
except ValueError:
pass
else:
format_str = "".join(format[:i] + format_def[i:])
return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
return sql, params
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
fields = {
"hour": "%H:00:00",
"minute": "%H:%i:00",
"second": "%H:%i:%s",
}
if lookup_type in fields:
format_str = fields[lookup_type]
return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
else:
return f"TIME({sql})", params
def format_for_duration_arithmetic(self, sql):
return "INTERVAL %s MICROSECOND" % sql
def force_no_ordering(self):
"""
"ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
columns. If no ordering would otherwise be applied, we don't want any
implicit sorting going on.
"""
return [(None, ("NULL", [], False))]
def last_executed_query(self, cursor, sql, params):
# With MySQLdb, cursor objects have an (undocumented) "_executed"
# attribute where the exact query sent to the database is saved.
# See MySQLdb/cursors.py in the source distribution.
# MySQLdb returns string, PyMySQL bytes.
return force_str(getattr(cursor, "_executed", None), errors="replace")
def no_limit_value(self):
# 2**64 - 1, as recommended by the MySQL documentation
return 18446744073709551615
def quote_name(self, name):
if name.startswith("`") and name.endswith("`"):
return name # Quoting once is enough.
return "`%s`" % name
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
if not tables:
return []
sql = ["SET FOREIGN_KEY_CHECKS = 0;"]
if reset_sequences:
# It's faster to TRUNCATE tables that require a sequence reset
# since ALTER TABLE AUTO_INCREMENT is slower than TRUNCATE.
sql.extend(
"%s %s;"
% (
style.SQL_KEYWORD("TRUNCATE"),
style.SQL_FIELD(self.quote_name(table_name)),
)
for table_name in tables
)
else:
# Otherwise issue a simple DELETE since it's faster than TRUNCATE
# and preserves sequences.
sql.extend(
"%s %s %s;"
% (
style.SQL_KEYWORD("DELETE"),
style.SQL_KEYWORD("FROM"),
style.SQL_FIELD(self.quote_name(table_name)),
)
for table_name in tables
)
sql.append("SET FOREIGN_KEY_CHECKS = 1;")
return sql
def sequence_reset_by_name_sql(self, style, sequences):
return [
"%s %s %s %s = 1;"
% (
style.SQL_KEYWORD("ALTER"),
style.SQL_KEYWORD("TABLE"),
style.SQL_FIELD(self.quote_name(sequence_info["table"])),
style.SQL_FIELD("AUTO_INCREMENT"),
)
for sequence_info in sequences
]
def validate_autopk_value(self, value):
# Zero in AUTO_INCREMENT field does not work without the
# NO_AUTO_VALUE_ON_ZERO SQL mode.
if value == 0 and not self.connection.features.allows_auto_pk_0:
raise ValueError(
"The database backend does not accept 0 as a value for AutoField."
)
return value
def adapt_datetimefield_value(self, value):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, "resolve_expression"):
return value
# MySQL doesn't support tz-aware datetimes
if timezone.is_aware(value):
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
raise ValueError(
"MySQL backend does not support timezone-aware datetimes when "
"USE_TZ is False."
)
return str(value)
def adapt_timefield_value(self, value):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, "resolve_expression"):
return value
# MySQL doesn't support tz-aware times
if timezone.is_aware(value):
raise ValueError("MySQL backend does not support timezone-aware times.")
return value.isoformat(timespec="microseconds")
def max_name_length(self):
return 64
def pk_default_value(self):
return "NULL"
def combine_expression(self, connector, sub_expressions):
if connector == "^":
return "POW(%s)" % ",".join(sub_expressions)
# Convert the result to a signed integer since MySQL's binary operators
# return an unsigned integer.
elif connector in ("&", "|", "<<", "#"):
connector = "^" if connector == "#" else connector
return "CONVERT(%s, SIGNED)" % connector.join(sub_expressions)
elif connector == ">>":
lhs, rhs = sub_expressions
return "FLOOR(%(lhs)s / POW(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
return super().combine_expression(connector, sub_expressions)
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
if internal_type == "BooleanField":
converters.append(self.convert_booleanfield_value)
elif internal_type == "DateTimeField":
if settings.USE_TZ:
converters.append(self.convert_datetimefield_value)
elif internal_type == "UUIDField":
converters.append(self.convert_uuidfield_value)
return converters
def convert_booleanfield_value(self, value, expression, connection):
if value in (0, 1):
value = bool(value)
return value
def convert_datetimefield_value(self, value, expression, connection):
if value is not None:
value = timezone.make_aware(value, self.connection.timezone)
return value
def convert_uuidfield_value(self, value, expression, connection):
if value is not None:
value = uuid.UUID(value)
return value
def binary_placeholder_sql(self, value):
return (
"_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
)
def subtract_temporals(self, internal_type, lhs, rhs):
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
if internal_type == "TimeField":
if self.connection.mysql_is_mariadb:
# MariaDB includes the microsecond component in TIME_TO_SEC as
# a decimal. MySQL returns an integer without microseconds.
return (
"CAST((TIME_TO_SEC(%(lhs)s) - TIME_TO_SEC(%(rhs)s)) "
"* 1000000 AS SIGNED)"
) % {
"lhs": lhs_sql,
"rhs": rhs_sql,
}, (
*lhs_params,
*rhs_params,
)
return (
"((TIME_TO_SEC(%(lhs)s) * 1000000 + MICROSECOND(%(lhs)s)) -"
" (TIME_TO_SEC(%(rhs)s) * 1000000 + MICROSECOND(%(rhs)s)))"
) % {"lhs": lhs_sql, "rhs": rhs_sql}, tuple(lhs_params) * 2 + tuple(
rhs_params
) * 2
params = (*rhs_params, *lhs_params)
return "TIMESTAMPDIFF(MICROSECOND, %s, %s)" % (rhs_sql, lhs_sql), params
def explain_query_prefix(self, format=None, **options):
# Alias MySQL's TRADITIONAL to TEXT for consistency with other
# backends.
if format and format.upper() == "TEXT":
format = "TRADITIONAL"
elif (
not format and "TREE" in self.connection.features.supported_explain_formats
):
# Use TREE by default (if supported) as it's more informative.
format = "TREE"
analyze = options.pop("analyze", False)
prefix = super().explain_query_prefix(format, **options)
if analyze and self.connection.features.supports_explain_analyze:
# MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
prefix = (
"ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
)
if format and not (analyze and not self.connection.mysql_is_mariadb):
# Only MariaDB supports the analyze option with formats.
prefix += " FORMAT=%s" % format
return prefix
def regex_lookup(self, lookup_type):
# REGEXP_LIKE doesn't exist in MariaDB.
if self.connection.mysql_is_mariadb:
if lookup_type == "regex":
return "%s REGEXP BINARY %s"
return "%s REGEXP %s"
match_option = "c" if lookup_type == "regex" else "i"
return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option
def insert_statement(self, on_conflict=None):
if on_conflict == OnConflict.IGNORE:
return "INSERT IGNORE INTO"
return super().insert_statement(on_conflict=on_conflict)
def lookup_cast(self, lookup_type, internal_type=None):
lookup = "%s"
if internal_type == "JSONField":
if self.connection.mysql_is_mariadb or lookup_type in (
"iexact",
"contains",
"icontains",
"startswith",
"istartswith",
"endswith",
"iendswith",
"regex",
"iregex",
):
lookup = "JSON_UNQUOTE(%s)"
return lookup
def conditional_expression_supported_in_where_clause(self, expression):
# MySQL ignores indexes with boolean fields unless they're compared
# directly to a boolean value.
if isinstance(expression, (Exists, Lookup)):
return True
if isinstance(expression, ExpressionWrapper) and expression.conditional:
return self.conditional_expression_supported_in_where_clause(
expression.expression
)
if getattr(expression, "conditional", False):
return False
return super().conditional_expression_supported_in_where_clause(expression)
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
if on_conflict == OnConflict.UPDATE:
conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
# The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
# aliases for the new row and its columns available in MySQL
# 8.0.19+.
if not self.connection.mysql_is_mariadb:
if self.connection.mysql_version >= (8, 0, 19):
conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
field_sql = "%(field)s = new.%(field)s"
else:
field_sql = "%(field)s = VALUES(%(field)s)"
# Use VALUE() on MariaDB.
else:
field_sql = "%(field)s = VALUE(%(field)s)"
fields = ", ".join(
[
field_sql % {"field": field}
for field in map(self.quote_name, update_fields)
]
)
return conflict_suffix_sql % {"fields": fields}
return super().on_conflict_suffix_sql(
fields,
on_conflict,
update_fields,
unique_fields,
)

View File

@@ -0,0 +1,263 @@
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.models import NOT_PROVIDED, F, UniqueConstraint
from django.db.models.constants import LOOKUP_SEP
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_rename_table = "RENAME TABLE %(old_table)s TO %(new_table)s"
sql_alter_column_null = "MODIFY %(column)s %(type)s NULL"
sql_alter_column_not_null = "MODIFY %(column)s %(type)s NOT NULL"
sql_alter_column_type = "MODIFY %(column)s %(type)s%(collation)s%(comment)s"
sql_alter_column_no_default_null = "ALTER COLUMN %(column)s SET DEFAULT NULL"
sql_delete_unique = "ALTER TABLE %(table)s DROP INDEX %(name)s"
sql_create_column_inline_fk = (
", ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) "
"REFERENCES %(to_table)s(%(to_column)s)"
)
sql_delete_fk = "ALTER TABLE %(table)s DROP FOREIGN KEY %(name)s"
sql_delete_index = "DROP INDEX %(name)s ON %(table)s"
sql_rename_index = "ALTER TABLE %(table)s RENAME INDEX %(old_name)s TO %(new_name)s"
sql_create_pk = (
"ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
)
sql_delete_pk = "ALTER TABLE %(table)s DROP PRIMARY KEY"
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s"
sql_alter_table_comment = "ALTER TABLE %(table)s COMMENT = %(comment)s"
sql_alter_column_comment = None
@property
def sql_delete_check(self):
if self.connection.mysql_is_mariadb:
# The name of the column check constraint is the same as the field
# name on MariaDB. Adding IF EXISTS clause prevents migrations
# crash. Constraint is removed during a "MODIFY" column statement.
return "ALTER TABLE %(table)s DROP CONSTRAINT IF EXISTS %(name)s"
return "ALTER TABLE %(table)s DROP CHECK %(name)s"
def quote_value(self, value):
self.connection.ensure_connection()
# MySQLdb escapes to string, PyMySQL to bytes.
quoted = self.connection.connection.escape(
value, self.connection.connection.encoders
)
if isinstance(value, str) and isinstance(quoted, bytes):
quoted = quoted.decode()
return quoted
def _is_limited_data_type(self, field):
db_type = field.db_type(self.connection)
return (
db_type is not None
and db_type.lower() in self.connection._limited_data_types
)
def _is_text_or_blob(self, field):
db_type = field.db_type(self.connection)
return db_type and db_type.lower().endswith(("blob", "text"))
def skip_default(self, field):
default_is_empty = self.effective_default(field) in ("", b"")
if default_is_empty and self._is_text_or_blob(field):
return True
if not self._supports_limited_data_type_defaults:
return self._is_limited_data_type(field)
return False
def skip_default_on_alter(self, field):
default_is_empty = self.effective_default(field) in ("", b"")
if default_is_empty and self._is_text_or_blob(field):
return True
if self._is_limited_data_type(field) and not self.connection.mysql_is_mariadb:
# MySQL doesn't support defaults for BLOB and TEXT in the
# ALTER COLUMN statement.
return True
return False
@property
def _supports_limited_data_type_defaults(self):
# MariaDB and MySQL >= 8.0.13 support defaults for BLOB and TEXT.
if self.connection.mysql_is_mariadb:
return True
return self.connection.mysql_version >= (8, 0, 13)
def _column_default_sql(self, field):
if (
not self.connection.mysql_is_mariadb
and self._supports_limited_data_type_defaults
and self._is_limited_data_type(field)
):
# MySQL supports defaults for BLOB and TEXT columns only if the
# default value is written as an expression i.e. in parentheses.
return "(%s)"
return super()._column_default_sql(field)
def add_field(self, model, field):
super().add_field(model, field)
# Simulate the effect of a one-off default.
# field.default may be unhashable, so a set isn't used for "in" check.
if self.skip_default(field) and field.default not in (None, NOT_PROVIDED):
effective_default = self.effective_default(field)
self.execute(
"UPDATE %(table)s SET %(column)s = %%s"
% {
"table": self.quote_name(model._meta.db_table),
"column": self.quote_name(field.column),
},
[effective_default],
)
def remove_constraint(self, model, constraint):
if (
isinstance(constraint, UniqueConstraint)
and constraint.create_sql(model, self) is not None
):
self._create_missing_fk_index(
model,
fields=constraint.fields,
expressions=constraint.expressions,
)
super().remove_constraint(model, constraint)
def remove_index(self, model, index):
self._create_missing_fk_index(
model,
fields=[field_name for field_name, _ in index.fields_orders],
expressions=index.expressions,
)
super().remove_index(model, index)
def _field_should_be_indexed(self, model, field):
if not super()._field_should_be_indexed(model, field):
return False
storage = self.connection.introspection.get_storage_engine(
self.connection.cursor(), model._meta.db_table
)
# No need to create an index for ForeignKey fields except if
# db_constraint=False because the index from that constraint won't be
# created.
if (
storage == "InnoDB"
and field.get_internal_type() == "ForeignKey"
and field.db_constraint
):
return False
return not self._is_limited_data_type(field)
def _create_missing_fk_index(
self,
model,
*,
fields,
expressions=None,
):
"""
MySQL can remove an implicit FK index on a field when that field is
covered by another index like a unique_together. "covered" here means
that the more complex index has the FK field as its first field (see
https://bugs.mysql.com/bug.php?id=37910).
Manually create an implicit FK index to make it possible to remove the
composed index.
"""
first_field_name = None
if fields:
first_field_name = fields[0]
elif (
expressions
and self.connection.features.supports_expression_indexes
and isinstance(expressions[0], F)
and LOOKUP_SEP not in expressions[0].name
):
first_field_name = expressions[0].name
if not first_field_name:
return
first_field = model._meta.get_field(first_field_name)
if first_field.get_internal_type() == "ForeignKey":
column = self.connection.introspection.identifier_converter(
first_field.column
)
with self.connection.cursor() as cursor:
constraint_names = [
name
for name, infodict in self.connection.introspection.get_constraints(
cursor, model._meta.db_table
).items()
if infodict["index"] and infodict["columns"][0] == column
]
# There are no other indexes that starts with the FK field, only
# the index that is expected to be deleted.
if len(constraint_names) == 1:
self.execute(
self._create_index_sql(model, fields=[first_field], suffix="")
)
def _delete_composed_index(self, model, fields, *args):
self._create_missing_fk_index(model, fields=fields)
return super()._delete_composed_index(model, fields, *args)
def _set_field_new_type(self, field, new_type):
"""
Keep the NULL and DEFAULT properties of the old field. If it has
changed, it will be handled separately.
"""
if field.has_db_default():
default_sql, params = self.db_default_sql(field)
default_sql %= tuple(self.quote_value(p) for p in params)
new_type += f" DEFAULT {default_sql}"
if field.null:
new_type += " NULL"
else:
new_type += " NOT NULL"
return new_type
def _alter_column_type_sql(
self, model, old_field, new_field, new_type, old_collation, new_collation
):
new_type = self._set_field_new_type(old_field, new_type)
return super()._alter_column_type_sql(
model, old_field, new_field, new_type, old_collation, new_collation
)
def _field_db_check(self, field, field_db_params):
if self.connection.mysql_is_mariadb:
return super()._field_db_check(field, field_db_params)
# On MySQL, check constraints with the column name as it requires
# explicit recreation when the column is renamed.
return field_db_params["check"]
def _rename_field_sql(self, table, old_field, new_field, new_type):
new_type = self._set_field_new_type(old_field, new_type)
return super()._rename_field_sql(table, old_field, new_field, new_type)
def _alter_column_comment_sql(self, model, new_field, new_type, new_db_comment):
# Comment is alter when altering the column type.
return "", []
def _comment_sql(self, comment):
comment_sql = super()._comment_sql(comment)
return f" COMMENT {comment_sql}"
def _alter_column_null_sql(self, model, old_field, new_field):
if not new_field.has_db_default():
return super()._alter_column_null_sql(model, old_field, new_field)
new_db_params = new_field.db_parameters(connection=self.connection)
type_sql = self._set_field_new_type(new_field, new_db_params["type"])
return (
"MODIFY %(column)s %(type)s"
% {
"column": self.quote_name(new_field.column),
"type": type_sql,
},
[],
)

View File

@@ -0,0 +1,77 @@
from django.core import checks
from django.db.backends.base.validation import BaseDatabaseValidation
from django.utils.version import get_docs_version
class DatabaseValidation(BaseDatabaseValidation):
def check(self, **kwargs):
issues = super().check(**kwargs)
issues.extend(self._check_sql_mode(**kwargs))
return issues
def _check_sql_mode(self, **kwargs):
if not (
self.connection.sql_mode & {"STRICT_TRANS_TABLES", "STRICT_ALL_TABLES"}
):
return [
checks.Warning(
"%s Strict Mode is not set for database connection '%s'"
% (self.connection.display_name, self.connection.alias),
hint=(
"%s's Strict Mode fixes many data integrity problems in "
"%s, such as data truncation upon insertion, by "
"escalating warnings into errors. It is strongly "
"recommended you activate it. See: "
"https://docs.djangoproject.com/en/%s/ref/databases/"
"#mysql-sql-mode"
% (
self.connection.display_name,
self.connection.display_name,
get_docs_version(),
),
),
id="mysql.W002",
)
]
return []
def check_field_type(self, field, field_type):
"""
MySQL has the following field length restriction:
No character (varchar) fields can have a length exceeding 255
characters if they have a unique index on them.
MySQL doesn't support a database index on some data types.
"""
errors = []
if (
field_type.startswith("varchar")
and field.unique
and (field.max_length is None or int(field.max_length) > 255)
):
errors.append(
checks.Warning(
"%s may not allow unique CharFields to have a max_length "
"> 255." % self.connection.display_name,
obj=field,
hint=(
"See: https://docs.djangoproject.com/en/%s/ref/"
"databases/#mysql-character-fields" % get_docs_version()
),
id="mysql.W003",
)
)
if field.db_index and field_type.lower() in self.connection._limited_data_types:
errors.append(
checks.Warning(
"%s does not support a database index on %s columns."
% (self.connection.display_name, field_type),
hint=(
"An index won't be created. Silence this warning if "
"you don't care about it."
),
obj=field,
id="fields.W162",
)
)
return errors