diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3c0d31bf..62f9ab24 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,7 +9,7 @@ jobs: strategy: matrix: python-version: ["3.7", "3.8"] - django-version: ["2.1.1", "3.1.4"] + django-version: ["3.1.4"] database-engine: ["postgres", "mysql"] services: diff --git a/binder/plugins/loaded_values.py b/binder/plugins/loaded_values.py index 12d03cf5..a65e4184 100644 --- a/binder/plugins/loaded_values.py +++ b/binder/plugins/loaded_values.py @@ -48,6 +48,7 @@ def field_changed(self, *fields): def get_old_value(self, field): + field = type(self)._meta.get_field(field).name try: return self.__loaded_values[field] except KeyError: diff --git a/binder/stored/__init__.py b/binder/stored/__init__.py new file mode 100644 index 00000000..30209837 --- /dev/null +++ b/binder/stored/__init__.py @@ -0,0 +1,4 @@ +from .base import Stored # noqa + + +default_app_config = 'binder.stored.apps.StoredAppConfig' diff --git a/binder/stored/apps.py b/binder/stored/apps.py new file mode 100644 index 00000000..6520340c --- /dev/null +++ b/binder/stored/apps.py @@ -0,0 +1,11 @@ +from django.apps import AppConfig + +from .signal import apps_ready + + +class StoredAppConfig(AppConfig): + + name = 'binder.stored' + + def ready(self): + apps_ready.send(sender=None) diff --git a/binder/stored/base.py b/binder/stored/base.py new file mode 100644 index 00000000..cb18e25f --- /dev/null +++ b/binder/stored/base.py @@ -0,0 +1,150 @@ +from collections import namedtuple + +from django.db.models import F, Aggregate +from django.db.models.signals import post_save, class_prepared +from django.db.models.expressions import BaseExpression +from django.conf import settings + +from .signal import apps_ready + + +Dep = namedtuple('Dep', ['model', 'fields', 'rev_path', 'rev_field']) + + +def get_deps_base(model, expr): + """ + Given a model and an expr yield all changes that could affect the result + of this expr. + + A change is defined as a 4-tuple of (model, changed, rev_path, rev_field). + """ + from ..plugins.loaded_values import LoadedValuesMixin + + if not issubclass(model, LoadedValuesMixin): + raise ValueError(f'{model} should inherit from LoadedValuesMixin if you want to use it in a stored field') + + if isinstance(expr, Aggregate): + expr, = expr.source_expressions + + if isinstance(expr, F): + head, sep, tail = expr.name.partition('__') + + field = model._meta.get_field(head) + if not sep and field.is_relation: + sep = '__' + tail = 'id' + + if not sep: + if head != 'id': + yield Dep(model, {head}, 'id', 'id') + return + + if not field.is_relation: + raise ValueError(f'expected {model.__name__}.{field} to be a relation') + + if field.one_to_many: + yield Dep(field.related_model, {field.remote_field.name}, 'id', field.remote_field.column) + elif field.many_to_one: + yield Dep(model, {head}, 'id', 'id') + else: + raise ValueError('unsupported type of relation') + + for dep in get_deps(field.related_model, F(tail)): + if dep.rev_path != 'id': + yield dep._replace(rev_path=f'{head}__{dep.rev_path}') + elif field.one_to_many: + yield dep._replace(rev_field=field.remote_field.column) + else: + yield dep._replace(rev_path=head) + + else: + raise ValueError(f'cannot infer deps for {expr!r}') + + +def get_deps(*args, **kwargs): + deps = {} + for dep in get_deps_base(*args, **kwargs): + key = dep._replace(fields=None) + try: + base_dep = deps[key] + except KeyError: + deps[key] = dep + else: + deps[key] = dep._replace(fields=base_dep.fields | dep.fields) + return deps.values() + + +class Stored: + + def __init__(self, expr): + self.expr = expr + + def __set_name__(self, model, name): + from ..views import fix_output_field + + if 'binder.stored' not in settings.INSTALLED_APPS: + raise ValueError('cannot use Stored if \'binder.stored\' is not in INSTALLED_APPS') + + # We dont actually want this to be the attribute + delattr(model, name) + + # Add the field + def add_field(**kwargs): + class_prepared.disconnect(add_field, sender=model) + + # Get field + fix_output_field(self.expr, model) + if isinstance(self.expr, F): + field = self.expr._output_field_or_none + elif isinstance(self.expr, BaseExpression): + field = self.expr.field + else: + raise ValueError( + '{}.{} is not a valid django query expression' + .format(model.__name__, name) + ) + + # Make blank & nullable copy of field + _, _, args, kwargs = field.deconstruct() + kwargs['blank'] = True + kwargs['null'] = True + field = type(field)(*args, **kwargs) + field._binder_stored_expr = self.expr + + model.add_to_class(name, field) + + class_prepared.connect(add_field, sender=model, weak=False) + + # Add triggers for deps + def add_triggers(**kwargs): + apps_ready.disconnect(add_triggers) + + register_init(model, name, self.expr) + for dep in get_deps(model, self.expr): + register_dep(model, name, self.expr, dep) + + apps_ready.connect(add_triggers, weak=False) + + +def update_queryset(queryset, name, expr): + for pk, value in queryset.annotate(value=expr).values_list('pk', 'value'): + queryset.model.objects.filter(pk=pk).update(**{name: value}) + + +def register_init(model, name, expr): + def update_values(instance, **kwargs): + if instance.field_changed('id'): + update_queryset(model.objects.filter(id=instance.id), name, expr) + + post_save.connect(update_values, sender=model, weak=False) + + +def register_dep(model, name, expr, dep): + def update_values(instance, **kwargs): + if instance.field_changed('id', *dep.fields): + ids = [getattr(instance, dep.rev_field)] + if instance.field_changed(dep.rev_field): + ids.append(instance.get_old_value(dep.rev_field)) + update_queryset(model.objects.filter(id__in=ids), name, expr) + + post_save.connect(update_values, sender=dep.model, weak=False) diff --git a/binder/stored/management/commands/autofillstoredexpr.py b/binder/stored/management/commands/autofillstoredexpr.py new file mode 100644 index 00000000..8ad9d00e --- /dev/null +++ b/binder/stored/management/commands/autofillstoredexpr.py @@ -0,0 +1,132 @@ +from argparse import ArgumentTypeError +from datetime import datetime +from importlib import import_module +import re +import os.path + +from django.apps import apps +from django.core.management.base import BaseCommand +from django.db.migrations.loader import MigrationLoader +from django.utils.module_loading import module_dir + + +EXPR_RE = re.compile(r'(\w+)\.(\w+)\.(\w+)') + + +def stored_expr(value): + match = EXPR_RE.fullmatch(value) + + if match is None: + raise ArgumentTypeError('invalid format') + + app, model, field = match.groups() + + try: + field = apps.get_app_config(app).get_model(model)._meta.get_field(field) + except Exception as e: + raise ArgumentTypeError(str(e)) + + if not hasattr(field, '_binder_stored_expr'): + raise ArgumentTypeError(f'{field.model.__name__}.{field.name} is not a stored expr') + + return field + + +def value_to_string(expr): + if not hasattr(expr, 'deconstruct'): + return repr(expr), set() + + parts = [] + modules = set() + + name, args, kwargs = expr.deconstruct() + module = name.rpartition('.')[0] + + parts.append(f'{name}(') + modules.add(module) + + first = True + + for value in args: + if first: + first = False + else: + parts.append(', ') + + substring, submodules = value_to_string(value) + parts.append(substring) + modules.update(submodules) + + for key, value in kwargs.items(): + if first: + first = False + else: + parts.append(', ') + + substring, submodules = value_to_string(value) + parts.append(f'{key}={substring}') + modules.update(submodules) + + parts.append(')') + + return ''.join(parts), modules + + +class Command(BaseCommand): + + def add_arguments(self, parser): + parser.add_argument( + 'exprs', type=stored_expr, nargs='+', + help='stored exprs to autofill', + ) + + def handle(self, exprs, **kwargs): + for expr in exprs: + app = expr.model._meta.app_config + + loader = MigrationLoader(None, ignore_no_migrations=True) + conflicts = loader.detect_conflicts() + + assert not conflicts + leaves = loader.graph.leaf_nodes(app.label) + assert len(leaves) <= 1 + + migrations_module, _ = MigrationLoader.migrations_module(app.label) + migrations_dir = module_dir(import_module(migrations_module)) + + if leaves: + number = int(re.match(r'\d+', leaves[0][1]).group()) + 1 + else: + number = 1 + + migration_path = os.path.join(migrations_dir, f'{number:>04}_autofill_{expr.model.__name__}_{expr.name}'.lower()) + expr_string, modules = value_to_string(expr._binder_stored_expr) + + with open(migration_path, 'w') as f: + f.write(f'# Generated by Django Binder on {datetime.now():%Y-%m-%d %H:%M}\n') + f.write('\n') + f.write('import django.db.migrations\n') + for module in modules: + f.write(f'import {module}\n') + f.write('\n') + f.write('\n') + f.write('def autofill(apps, schema_editor):\n') + f.write(f' {expr.model.__name__} = apps.get_model({app.label!r}, {expr.model.__name__!r})\n') + f.write(f' {expr.model.__name__}.objects.update({expr.name}={expr_string})\n') + f.write('\n') + f.write('\n') + f.write('class Migration(django.migrations.Migration):\n') + f.write('\n') + if leaves: + f.write(' dependencies = [\n') + for dep in leaves: + f.write(' {dep!r},\n') + f.write(' ]\n') + else: + f.write(' initial = True\n') + f.write('\n') + f.write(' dependencies = []\n') + f.write('\n') + f.write(' operations = [\n') + f.write(' django.migrations.RunPython(autofill, migrations.RunPython.noop),\n') + f.write(' ]\n') diff --git a/binder/stored/signal.py b/binder/stored/signal.py new file mode 100644 index 00000000..7a0136d3 --- /dev/null +++ b/binder/stored/signal.py @@ -0,0 +1,4 @@ +from django.dispatch import Signal + + +apps_ready = Signal() diff --git a/binder/views.py b/binder/views.py index 2c419db4..a26bc0f4 100644 --- a/binder/views.py +++ b/binder/views.py @@ -1892,6 +1892,9 @@ def _store_field(self, obj, field, value, request, pk=None): # Regular fields and FKs for f in self.model._meta.fields: if f.name == field: + if hasattr(f, '_binder_stored_expr'): + raise BinderReadOnlyFieldError(self.model.__name__, field) + if isinstance(f, models.ForeignKey): if not (value is None or isinstance(value, int)): raise BinderFieldTypeError(self.model.__name__, field) diff --git a/project/project/settings.py b/project/project/settings.py index c9e56f70..d51bad24 100644 --- a/project/project/settings.py +++ b/project/project/settings.py @@ -41,6 +41,7 @@ 'binder.plugins.token_auth', 'binder.plugins.my_filters', 'testapp', + 'binder.stored', ] MIDDLEWARE = [ diff --git a/tests/__init__.py b/tests/__init__.py index 4dcaf8bd..1ee30811 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,4 @@ -from django import setup +import django from django.conf import settings from django.core.management import call_command import os @@ -55,6 +55,7 @@ 'binder.plugins.token_auth', 'tests', 'tests.testapp', + 'binder.stored', ], 'MIGRATION_MODULES': { 'testapp': None, @@ -111,7 +112,7 @@ } }) -setup() +django.setup() # Do the dance to ensure the models are synched to the DB. # This saves us from having to include migrations diff --git a/tests/test_stored.py b/tests/test_stored.py new file mode 100644 index 00000000..479ed28c --- /dev/null +++ b/tests/test_stored.py @@ -0,0 +1,96 @@ +from contextlib import contextmanager +import json + +from django.db import connection +from django.test import TestCase +from django.contrib.auth.models import User + +from .testapp.models import Zoo, Animal + + +@contextmanager +def collect_queries(): + init_count = len(connection.queries) + queries = [] + try: + yield queries + finally: + queries[:] = connection.queries[init_count:] + + +class StoredTest(TestCase): + + def test_base(self): + zoo = Zoo.objects.create(name='Zoo') + + zoo.refresh_from_db() + self.assertEqual(zoo.stored_animal_count, 0) + + for n in range(1, 11): + Animal.objects.create(zoo=zoo, name=f'Animal {n}') + zoo.refresh_from_db() + self.assertEqual(zoo.stored_animal_count, n) + + def test_id_switch(self): + zoo1 = Zoo.objects.create(name='Zoo 1') + zoo2 = Zoo.objects.create(name='Zoo 2') + + animals = [ + Animal.objects.create( + zoo=zoo1 if n <= 4 else zoo2, + name=f'Animal {n}', + ) + for n in range(1, 7) + ] + + zoo1.refresh_from_db() + self.assertEqual(zoo1.stored_animal_count, 4) + zoo2.refresh_from_db() + self.assertEqual(zoo2.stored_animal_count, 2) + + animals[3].zoo = zoo2 + animals[3].save() + + zoo1.refresh_from_db() + self.assertEqual(zoo1.stored_animal_count, 3) + zoo2.refresh_from_db() + self.assertEqual(zoo2.stored_animal_count, 3) + + def test_only_update_when_needed(self): + zoo = Zoo.objects.create(name='Zoo') + animal = Animal.objects.create(zoo=zoo, name='Animal') + + animal.name = 'Other' + with collect_queries() as queries: + animal.save() + self.assertEqual(len(queries), 1) + + zoo2 = Zoo.objects.create(name='Zoo 2') + animal.zoo = zoo2 + with collect_queries() as queries: + animal.save() + self.assertGreater(len(queries), 1) + + def test_cannot_update_through_api(self): + user = User(username='test', is_superuser=True) + user.set_password('test') + user.save() + + zoo = Zoo.objects.create(name='Zoo') + + zoo.refresh_from_db() + self.assertEqual(zoo.stored_animal_count, 0) + + self.assertTrue(self.client.login(username='test', password='test')) + res = self.client.put( + f'/zoo/{zoo.pk}/', + data={'stored_animal_count': 1337}, + content_type='application/json', + ) + self.assertEqual(res.status_code, 200) + res = json.loads(res.content) + self.assertEqual(res['stored_animal_count'], 0) + self.assertIn('stored_animal_count', res['_meta']['ignored_fields']) + + zoo.refresh_from_db() + self.assertEqual(zoo.stored_animal_count, 0) diff --git a/tests/testapp/models/zoo.py b/tests/testapp/models/zoo.py index a6b1737c..9c58ba1b 100644 --- a/tests/testapp/models/zoo.py +++ b/tests/testapp/models/zoo.py @@ -1,9 +1,14 @@ import os import datetime + from django.core.exceptions import ValidationError from django.db import models +from django.db.models import Count from django.db.models.signals import post_delete + from binder.models import BinderModel, BinderImageField +from binder.stored import Stored +from binder.plugins.loaded_values import LoadedValuesMixin def delete_files(sender, instance=None, **kwargs): for field in sender._meta.fields: @@ -16,7 +21,7 @@ def delete_files(sender, instance=None, **kwargs): # From the api docs: a zoo with a name. It also has a founding date, # which is nullable (representing "unknown"). -class Zoo(BinderModel): +class Zoo(LoadedValuesMixin, BinderModel): name = models.TextField() founding_date = models.DateField(null=True, blank=True) floor_plan = models.ImageField(upload_to='floor-plans', null=True, blank=True) @@ -35,6 +40,8 @@ class Zoo(BinderModel): binder_picture_custom_extensions = BinderImageField(allowed_extensions=['png'], blank=True, null=True) + stored_animal_count = Stored(Count('animals')) + def __str__(self): return 'zoo %d: %s' % (self.pk, self.name)