first commit
This commit is contained in:
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
Classes to represent the definitions of aggregate functions.
|
||||
"""
|
||||
|
||||
from django.core.exceptions import FieldError, FullResultSet
|
||||
from django.db import NotSupportedError
|
||||
from django.db.models.expressions import (
|
||||
Case,
|
||||
ColPairs,
|
||||
Func,
|
||||
OrderByList,
|
||||
Star,
|
||||
Value,
|
||||
When,
|
||||
)
|
||||
from django.db.models.fields import IntegerField, TextField
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.db.models.functions.mixins import (
|
||||
FixDurationInputMixin,
|
||||
NumericOutputFieldMixin,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Aggregate",
|
||||
"AnyValue",
|
||||
"Avg",
|
||||
"Count",
|
||||
"Max",
|
||||
"Min",
|
||||
"StdDev",
|
||||
"StringAgg",
|
||||
"Sum",
|
||||
"Variance",
|
||||
]
|
||||
|
||||
|
||||
class AggregateFilter(Func):
|
||||
arity = 1
|
||||
template = " FILTER (WHERE %(expressions)s)"
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
if not connection.features.supports_aggregate_filter_clause:
|
||||
raise NotSupportedError(
|
||||
"Aggregate filter clauses are not supported on this database backend."
|
||||
)
|
||||
try:
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
except FullResultSet:
|
||||
return "", ()
|
||||
|
||||
@property
|
||||
def condition(self):
|
||||
return self.source_expressions[0]
|
||||
|
||||
def __str__(self):
|
||||
return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
|
||||
|
||||
|
||||
class AggregateOrderBy(OrderByList):
|
||||
template = " ORDER BY %(expressions)s"
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
if not connection.features.supports_aggregate_order_by_clause:
|
||||
raise NotSupportedError(
|
||||
"This database backend does not support specifying an order on "
|
||||
"aggregates."
|
||||
)
|
||||
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Aggregate(Func):
|
||||
template = "%(function)s(%(distinct)s%(expressions)s%(order_by)s)%(filter)s"
|
||||
contains_aggregate = True
|
||||
name = None
|
||||
window_compatible = True
|
||||
allow_distinct = False
|
||||
allow_order_by = False
|
||||
empty_result_set_value = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*expressions,
|
||||
distinct=False,
|
||||
filter=None,
|
||||
default=None,
|
||||
order_by=None,
|
||||
**extra,
|
||||
):
|
||||
if distinct and not self.allow_distinct:
|
||||
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
|
||||
if order_by and not self.allow_order_by:
|
||||
raise TypeError("%s does not allow order_by." % self.__class__.__name__)
|
||||
if default is not None and self.empty_result_set_value is not None:
|
||||
raise TypeError(f"{self.__class__.__name__} does not allow default.")
|
||||
|
||||
self.distinct = distinct
|
||||
self.filter = None if filter is None else AggregateFilter(filter)
|
||||
self.default = default
|
||||
self.order_by = AggregateOrderBy.from_param(
|
||||
f"{self.__class__.__name__}.order_by", order_by
|
||||
)
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def get_source_fields(self):
|
||||
# Don't consider filter and order by expression as they have nothing
|
||||
# to do with the output field resolution.
|
||||
return [e._output_field_or_none for e in super().get_source_expressions()]
|
||||
|
||||
def get_source_expressions(self):
|
||||
source_expressions = super().get_source_expressions()
|
||||
return [*source_expressions, self.filter, self.order_by]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
*exprs, self.filter, self.order_by = exprs
|
||||
return super().set_source_expressions(exprs)
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
# Aggregates are not allowed in UPDATE queries, so ignore for_save
|
||||
c = super().resolve_expression(query, allow_joins, reuse, summarize)
|
||||
if summarize:
|
||||
# Summarized aggregates cannot refer to summarized aggregates.
|
||||
for ref in c.get_refs():
|
||||
if query.annotations[ref].is_summary:
|
||||
raise FieldError(
|
||||
f"Cannot compute {c.name}('{ref}'): '{ref}' is an aggregate"
|
||||
)
|
||||
elif not self.is_summary:
|
||||
# Call Aggregate.get_source_expressions() to avoid
|
||||
# returning self.filter and including that in this loop.
|
||||
expressions = super(Aggregate, c).get_source_expressions()
|
||||
for index, expr in enumerate(expressions):
|
||||
if expr.contains_aggregate:
|
||||
before_resolved = self.get_source_expressions()[index]
|
||||
name = (
|
||||
before_resolved.name
|
||||
if hasattr(before_resolved, "name")
|
||||
else repr(before_resolved)
|
||||
)
|
||||
raise FieldError(
|
||||
"Cannot compute %s('%s'): '%s' is an aggregate"
|
||||
% (c.name, name, name)
|
||||
)
|
||||
if (default := c.default) is None:
|
||||
return c
|
||||
if hasattr(default, "resolve_expression"):
|
||||
default = default.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
if default._output_field_or_none is None:
|
||||
default.output_field = c._output_field_or_none
|
||||
else:
|
||||
default = Value(default, c._output_field_or_none)
|
||||
c.default = None # Reset the default argument before wrapping.
|
||||
coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
|
||||
coalesce.is_summary = c.is_summary
|
||||
return coalesce
|
||||
|
||||
@property
|
||||
def default_alias(self):
|
||||
expressions = [
|
||||
expr for expr in self.get_source_expressions() if expr is not None
|
||||
]
|
||||
if len(expressions) == 1 and hasattr(expressions[0], "name"):
|
||||
return "%s__%s" % (expressions[0].name, self.name.lower())
|
||||
raise TypeError("Complex expressions require an alias")
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return []
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
if (
|
||||
self.distinct
|
||||
and not connection.features.supports_aggregate_distinct_multiple_argument
|
||||
and len(super().get_source_expressions()) > 1
|
||||
):
|
||||
raise NotSupportedError(
|
||||
f"{self.name} does not support distinct with multiple expressions on "
|
||||
f"this database backend."
|
||||
)
|
||||
|
||||
distinct_sql = "DISTINCT " if self.distinct else ""
|
||||
order_by_sql = ""
|
||||
order_by_params = []
|
||||
filter_sql = ""
|
||||
filter_params = []
|
||||
|
||||
if (order_by := self.order_by) is not None:
|
||||
order_by_sql, order_by_params = compiler.compile(order_by)
|
||||
|
||||
if self.filter is not None:
|
||||
try:
|
||||
filter_sql, filter_params = compiler.compile(self.filter)
|
||||
except NotSupportedError:
|
||||
# Fallback to a CASE statement on backends that don't support
|
||||
# the FILTER clause.
|
||||
copy = self.copy()
|
||||
copy.filter = None
|
||||
source_expressions = copy.get_source_expressions()
|
||||
condition = When(self.filter.condition, then=source_expressions[0])
|
||||
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
|
||||
return copy.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
extra_context.update(
|
||||
distinct=distinct_sql,
|
||||
filter=filter_sql,
|
||||
order_by=order_by_sql,
|
||||
)
|
||||
sql, params = super().as_sql(compiler, connection, **extra_context)
|
||||
return sql, (*params, *order_by_params, *filter_params)
|
||||
|
||||
def _get_repr_options(self):
|
||||
options = super()._get_repr_options()
|
||||
if self.distinct:
|
||||
options["distinct"] = self.distinct
|
||||
if self.filter:
|
||||
options["filter"] = self.filter
|
||||
if self.order_by:
|
||||
options["order_by"] = self.order_by
|
||||
return options
|
||||
|
||||
|
||||
class AnyValue(Aggregate):
|
||||
function = "ANY_VALUE"
|
||||
name = "AnyValue"
|
||||
arity = 1
|
||||
window_compatible = False
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
if not connection.features.supports_any_value:
|
||||
raise NotSupportedError(
|
||||
"ANY_VALUE is not supported on this database backend."
|
||||
)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
|
||||
function = "AVG"
|
||||
name = "Avg"
|
||||
allow_distinct = True
|
||||
arity = 1
|
||||
|
||||
|
||||
class Count(Aggregate):
|
||||
function = "COUNT"
|
||||
name = "Count"
|
||||
output_field = IntegerField()
|
||||
allow_distinct = True
|
||||
empty_result_set_value = 0
|
||||
arity = 1
|
||||
allows_composite_expressions = True
|
||||
|
||||
def __init__(self, expression, filter=None, **extra):
|
||||
if expression == "*":
|
||||
expression = Star()
|
||||
if isinstance(expression, Star) and filter is not None:
|
||||
raise ValueError("Star cannot be used with filter. Please specify a field.")
|
||||
super().__init__(expression, filter=filter, **extra)
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
result = super().resolve_expression(*args, **kwargs)
|
||||
source_expressions = result.get_source_expressions()
|
||||
|
||||
# In case of composite primary keys, count the first column.
|
||||
if isinstance(expr := source_expressions[0], ColPairs):
|
||||
if self.distinct:
|
||||
raise ValueError(
|
||||
"COUNT(DISTINCT) doesn't support composite primary keys"
|
||||
)
|
||||
|
||||
source_expressions[0] = expr.get_cols()[0]
|
||||
result.set_source_expressions(source_expressions)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Max(Aggregate):
|
||||
function = "MAX"
|
||||
name = "Max"
|
||||
arity = 1
|
||||
|
||||
|
||||
class Min(Aggregate):
|
||||
function = "MIN"
|
||||
name = "Min"
|
||||
arity = 1
|
||||
|
||||
|
||||
class StdDev(NumericOutputFieldMixin, Aggregate):
|
||||
name = "StdDev"
|
||||
arity = 1
|
||||
|
||||
def __init__(self, expression, sample=False, **extra):
|
||||
self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
def _get_repr_options(self):
|
||||
return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
|
||||
|
||||
|
||||
class StringAggDelimiter(Func):
|
||||
arity = 1
|
||||
template = "%(expressions)s"
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
super().__init__(value)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
template = " SEPARATOR %(expressions)s"
|
||||
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=template,
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class StringAgg(Aggregate):
|
||||
template = "%(function)s(%(distinct)s%(expressions)s%(order_by)s)%(filter)s"
|
||||
function = "STRING_AGG"
|
||||
name = "StringAgg"
|
||||
allow_distinct = True
|
||||
allow_order_by = True
|
||||
output_field = TextField()
|
||||
|
||||
def __init__(self, expression, delimiter, **extra):
|
||||
self.delimiter = StringAggDelimiter(delimiter)
|
||||
super().__init__(expression, self.delimiter, **extra)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
if self.order_by:
|
||||
template = (
|
||||
"%(function)s(%(distinct)s%(expressions)s) WITHIN GROUP (%(order_by)s)"
|
||||
"%(filter)s"
|
||||
)
|
||||
else:
|
||||
template = "%(function)s(%(distinct)s%(expressions)s)%(filter)s"
|
||||
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="LISTAGG",
|
||||
template=template,
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
extra_context["function"] = "GROUP_CONCAT"
|
||||
|
||||
template = "%(function)s(%(distinct)s%(expressions)s%(order_by)s%(delimiter)s)"
|
||||
extra_context["template"] = template
|
||||
|
||||
c = self.copy()
|
||||
# The creation of the delimiter SQL and the ordering of the parameters
|
||||
# must be handled explicitly, as MySQL puts the delimiter at the end of
|
||||
# the aggregate using the `SEPARATOR` declaration (rather than treating
|
||||
# as an expression like other database backends).
|
||||
delimiter_params = []
|
||||
if c.delimiter:
|
||||
delimiter_sql, delimiter_params = compiler.compile(c.delimiter)
|
||||
# Drop the delimiter from the source expressions.
|
||||
c.source_expressions = c.source_expressions[:-1]
|
||||
extra_context["delimiter"] = delimiter_sql
|
||||
|
||||
sql, params = c.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
return sql, (*params, *delimiter_params)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
if connection.get_database_version() < (3, 44):
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="GROUP_CONCAT",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Sum(FixDurationInputMixin, Aggregate):
|
||||
function = "SUM"
|
||||
name = "Sum"
|
||||
allow_distinct = True
|
||||
arity = 1
|
||||
|
||||
|
||||
class Variance(NumericOutputFieldMixin, Aggregate):
|
||||
name = "Variance"
|
||||
arity = 1
|
||||
|
||||
def __init__(self, expression, sample=False, **extra):
|
||||
self.function = "VAR_SAMP" if sample else "VAR_POP"
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
def _get_repr_options(self):
|
||||
return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}
|
||||
Reference in New Issue
Block a user