djangoでの集計は辛いという話 -- ORMは用法・用量を守って正しく使いましょう
djangoのORMの機能の不足にぶち当たり辛いという話。別の言い方をすると、ORMは用法・容量守って正しく使いましょうという感じになるかもしれない。
はじめに
以下のような情報を年齢で丸めた値で集計してヒストグラムのようなものを作りたい。
| 名前 | 年齢 |
|---|---|
| foo | 10 |
| bar | 15 |
| boo | 20 |
結果
| rank | c |
|---|---|
| 1 | 2 |
| 2 | 1 |
SQLでは頑張ればどうにかなる
集計をしたい時など何らかの演算の結果で GROUP BY したい時など結構ある。おそらくきっとある。
例えばヒストグラム的なものを作成したい時など。SQLであれば CASEとWHENを書き連ねることを気にしなければどうにかなる。
sqlite> create table person(name string primary key, age int);
sqlite> insert into person values ('foo', 10);
sqlite> insert into person values ('bar', 15);
sqlite> insert into person values ('boo', 20);
sqlite> select case when age < 10 then 0 when 10 <= age and age < 20 then 1 when 20 <= age and age < 30 then 2 else -1 end as rank, count(*) from person group by case when age < 10 then 0 when 10 <= age and age < 20 then 1 when 20 <= age and age < 30 then 2 else -1 end;
rank = 1
count(*) = 2
rank = 2
count(*) = 1
djangoでどうするのという話
結論から言うとdjangoのORMで書くのは辛い。 以下のようなmodelがあったとして。
from django.db import models class Person(models.Model): name = models.CharField(max_length=32, default="a", blank=False) age = models.PositiveIntegerField(null=False)
COUNT(*) の部分が無ければある程度機能としては揃っていると思い、はじめは、楽観視していた。
CaseもWhenも使うことができるし- 複数条件に関しても
Q objectを渡せばどうにかなる group byの方法 は以前から調べて知っていた。
というわけで頑張ればそれなりにすぐにできるだろうと思っていた。が、意外と大変だった。以下のところまではそれなりにすぐにたどり着ける。
from django.db.models import Count, Case, When, Value, Q # case/when case = Case( When(age__lt=10, then=Value(0)), When(Q(age__gte=10, age__lt=20), then=Value(1)), When(Q(age__gte=20, age__lt=30), then=Value(2)), default=Value(-1), output_field=models.IntegerField() ) qs = ( Person.objects.all() .annotate(rank=case) .values("rank") ) # group by rank qs.query.group_by = ["rank"] # 実は qs.query.group_by = True でも qs.query.set_group_by() でも良い
すると結果として以下のような結果が返るところまではくる。しかしここから先が辛かった。
[{"rank": 1}, {"rank": 2}]
GROUP BY も辛いという話
ところで GROUP BY に関してわざわざ query objectのqueryを触っているのは理由があり、通常は values() のあとに annotate() を書いてあげれば values() で指定したフィールドで GROUP BY されるのだけれど、この values() で設定されるものに関しては modelで定義されたフィールドであることを暗黙の前提としてコードが書かれている。
なので以下の様には書けない。"rank"というフィールドが存在しないと言われてしまう。
qs.values("rank").annotate(rank=case)
COUNT(*) を含めるのが辛いという話 (これが辛い)
そして、そもそも集計結果の値が存在しなければ、つまり COUNT(*) が付加されていなければ何の意味も無いのだけれど、ここから先は結構辛くて、原因は、djangoのORMが暗黙に SELECT句 に来るフィールドと GROUP BY句 に来るフィールドが同じという仮定を要求してくるため(詳しいことが知りたかったら、 django.db.models.query, django.db.models.sql.query, django.db.models.sql.compiler のあたりを行ったり来たりしながら読んでみて下さい)。
右往左往の結果、一応、期待した通りに COUNT(*) を追加するコードを書くことはできた。
バッドノウハウっぽいのでどこかで共有しようと思いこの記事を書いている。
qs.query.values_select.append("c") qs.query.add_select(Count("*"))
これは、django.db.models.ValuesIterable 辺りを見ると良い (djangoのORMのqueryは呼び出すメソッドによって、queryが抱える _iterable_class が代わりSQLの結果はこのクラスに転写される)。
class ValuesIterable(BaseIterable): """ Iterable returned by QuerySet.values() that yields a dict for each row. """ def __iter__(self): queryset = self.queryset query = queryset.query compiler = query.get_compiler(queryset.db) field_names = list(query.values_select) extra_names = list(query.extra_select) annotation_names = list(query.annotation_select) # extra(select=...) cols are always at the start of the row. names = extra_names + field_names + annotation_names for row in compiler.results_iter(): yield dict(zip(names, row))
見ての通り SELECT句 に値を追加しようと思ったら、以下のどれかに値を追加できれば良い。
- field_names
- extra_names
- annotation_names
通常のQueryオブジェクトに用意されているメソッドを利用しての追加を考えると、 annotation_names か extra_names に値を追加しようということになるのだけれど、ここに追加しようとした場合にはGROUP BY句 にも付加されるようなSQLが生成されてしまう。
結果として FROM person GROUP BY <caseを使った式>, COUNT(*) というような謎のGROUP BY を作ろうとして失敗する。(また、djangoのORMは定義の指定に失敗すると、GROUP BY に id を含めたがるような問題もあり注意が必要)
そんなわけで、生成するSQLのSELECT句に追加する処理 と 転写されるIterableクラスの名前に追加する処理 を無理矢理追加してあげると言うことが必要になる。
全体を繋げたコードは以下の様になる。
from django.db.models import Count, Case, When, Value, Q def extra_select(qs, **kwargs): qs = qs.all() for name, col in kwargs.items(): qs.query.values_select.append(name) qs.query.add_select(col) return qs case = Case( When(age__lt=10, then=Value(0)), When(Q(age__gte=10, age__lt=20), then=Value(1)), When(Q(age__gte=20, age__lt=30), then=Value(2)), default=Value(-1), output_field=models.IntegerField() ) qs = ( Person.objects.all() .annotate(rank=case) .values("rank") ) qs = extra_select(qs, c=Count("*")) qs.query.group_by = ["rank"] print(qs) # => [{"c": 2, "rank": 1}, {"c": 1, "rank": 2}]
これは以下のような期待通りSQLを生成してくれる。
SELECT
COUNT(*),
CASE
WHEN "person"."age" < 10 THEN 0
WHEN ("person"."age" < 20 AND "person"."age" >= 10) THEN 1
WHEN ("person"."age" < 30 AND "person"."age" >= 20) THEN 2
ELSE -1
END AS "rank"
FROM "person"
GROUP BY
CASE
WHEN "person"."age" < 10 THEN 0
WHEN ("person"."age" < 20 AND "person"."age" >= 10) THEN 1
WHEN ("person"."age" < 30 AND "person"."age" >= 20) THEN 2
ELSE -1
END
djangoのORMは難しいという印象は消えたことが無いですね。
ところで sqlalchemy であれば...
以下の様に書けます。
import sqlalchemy as sa # Baseとsessionは各自で作成 class Person(Base): __tablename__ = "person" name = sa.Column(sa.String(255), default="", nullable=False, primary_key=True) age = sa.Column(sa.Integer) # query case = sa.case( [ (Person.age < 10, 1), ((10 <= Person.age) & (Person.age < 20), 2), ((20 <= Person.age) & (Person.age < 30), 3) ], else_=-1) qs = session.query(sa.func.count("*"), case).group_by(case) print(qs.all()) # => [(2, 2), (1, 3)]
dictを返したければ
qs = session.query(sa.func.count("*").label("c"), case.label("rank")).group_by(case) print([row._asdict() for row in qs.all()]) # => [{'c': 2, 'rank': 2}, {'c': 1, 'rank': 3}]
ちなみに、djangoのORMとsqlalchemyとを交互に使っているときには qs.all() の意味が両者の間でほとんど真逆なあたりが一番つらい。