diff --git a/netbox_custom_objects/api/serializers.py b/netbox_custom_objects/api/serializers.py index 0df23e4..a18f90e 100644 --- a/netbox_custom_objects/api/serializers.py +++ b/netbox_custom_objects/api/serializers.py @@ -8,6 +8,7 @@ from rest_framework import serializers from rest_framework.exceptions import ValidationError from rest_framework.reverse import reverse +from rest_framework.utils import model_meta from netbox_custom_objects import field_types from netbox_custom_objects.models import (CustomObject, CustomObjectType, @@ -253,6 +254,44 @@ def get_display(self, obj): """Get display representation of the object""" return str(obj) + # Stock DRF create() without raise_errors_on_nested_writes guard + def create(self, validated_data): + ModelClass = self.Meta.model + + info = model_meta.get_field_info(ModelClass) + many_to_many = {} + for field_name, relation_info in info.relations.items(): + if relation_info.to_many and (field_name in validated_data): + many_to_many[field_name] = validated_data.pop(field_name) + + instance = ModelClass._default_manager.create(**validated_data) + + if many_to_many: + for field_name, value in many_to_many.items(): + field = getattr(instance, field_name) + field.set(value) + + return instance + + # Stock DRF update() with custom field.set() for M2M + def update(self, instance, validated_data): + info = model_meta.get_field_info(instance) + + m2m_fields = [] + for attr, value in validated_data.items(): + if attr in info.relations and info.relations[attr].to_many: + m2m_fields.append((attr, value)) + else: + setattr(instance, attr, value) + + instance.save() + + for attr, value in m2m_fields: + field = getattr(instance, attr) + field.set(value, clear=True) + + return instance + # Create basic attributes for the serializer attrs = { "Meta": meta, @@ -261,6 +300,8 @@ def get_display(self, obj): "get_url": get_url, "display": serializers.SerializerMethodField(), "get_display": get_display, + "create": create, + "update": update, } for field in model_fields: diff --git a/netbox_custom_objects/tests/base.py b/netbox_custom_objects/tests/base.py index 6633797..20eac84 100644 --- a/netbox_custom_objects/tests/base.py +++ b/netbox_custom_objects/tests/base.py @@ -96,7 +96,8 @@ def create_simple_custom_object_type(self, **kwargs): return custom_object_type - def create_complex_custom_object_type(self, **kwargs): + @classmethod + def create_complex_custom_object_type(cls, **kwargs): """Create a complex custom object type with various field types.""" custom_object_type = CustomObjectsTestCase.create_custom_object_type(**kwargs) choice_set = CustomObjectsTestCase.create_choice_set() @@ -149,6 +150,15 @@ def create_complex_custom_object_type(self, **kwargs): related_object_type=device_object_type ) + # Multi-Object field (devices) + CustomObjectsTestCase.create_custom_object_type_field( + custom_object_type, + name="devices", + label="Devices", + type="multiobject", + related_object_type=device_object_type + ) + return custom_object_type def create_self_referential_custom_object_type(self, **kwargs): diff --git a/netbox_custom_objects/tests/test_api.py b/netbox_custom_objects/tests/test_api.py index db4bea2..0397b7a 100644 --- a/netbox_custom_objects/tests/test_api.py +++ b/netbox_custom_objects/tests/test_api.py @@ -1,10 +1,14 @@ from django.urls import reverse -from users.models import Token from utilities.testing import APIViewTestCases, create_test_user +from rest_framework import status from netbox_custom_objects.models import CustomObjectType from .base import CustomObjectsTestCase +from core.models import ObjectType +from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Rack, Site +from users.models import ObjectPermission, Token +from virtualization.models import Cluster, ClusterType class CustomObjectTest(CustomObjectsTestCase, APIViewTestCases.APIViewTestCase): @@ -32,6 +36,72 @@ def setUp(self): @classmethod def setUpTestData(cls): + # Set up some devices to be used in object/multiobject fields + sites = ( + Site(name='Site 1', slug='site-1'), + Site(name='Site 2', slug='site-2'), + ) + Site.objects.bulk_create(sites) + + racks = ( + Rack(name='Rack 1', site=sites[0]), + Rack(name='Rack 2', site=sites[1]), + ) + Rack.objects.bulk_create(racks) + + manufacturer = Manufacturer.objects.create(name='Manufacturer 1', slug='manufacturer-1') + + device_types = ( + DeviceType(manufacturer=manufacturer, model='Device Type 1', slug='device-type-1'), + DeviceType(manufacturer=manufacturer, model='Device Type 2', slug='device-type-2', u_height=2), + ) + DeviceType.objects.bulk_create(device_types) + + roles = ( + DeviceRole(name='Device Role 1', slug='device-role-1', color='ff0000'), + DeviceRole(name='Device Role 2', slug='device-role-2', color='00ff00'), + ) + for role in roles: + role.save() + + cluster_type = ClusterType.objects.create(name='Cluster Type 1', slug='cluster-type-1') + + clusters = ( + Cluster(name='Cluster 1', type=cluster_type), + Cluster(name='Cluster 2', type=cluster_type), + ) + Cluster.objects.bulk_create(clusters) + + devices = ( + Device( + device_type=device_types[0], + role=roles[0], + name='Device 1', + site=sites[0], + rack=racks[0], + cluster=clusters[0], + local_context_data={'A': 1} + ), + Device( + device_type=device_types[0], + role=roles[0], + name='Device 2', + site=sites[0], + rack=racks[0], + cluster=clusters[0], + local_context_data={'B': 2} + ), + Device( + device_type=device_types[0], + role=roles[0], + name='Device 3', + site=sites[0], + rack=racks[0], + cluster=clusters[0], + local_context_data={'C': 3} + ), + ) + Device.objects.bulk_create(devices) # Create test custom object types cls.custom_object_type1 = CustomObjectType.objects.create( @@ -48,6 +118,8 @@ def setUpTestData(cls): slug="test-objects-2", ) + cls.custom_object_type3 = cls.create_complex_custom_object_type(name="ComplexObject") + cls.model = cls.custom_object_type1.get_model() cls.create_custom_object_type_field(cls.custom_object_type1) @@ -103,6 +175,48 @@ def test_delete_object(self): # TODO: ObjectChange causes failure ... + def test_create_with_nested_serializers(self): + """ + POST a single object with a multiobject field's values specified via a list of PKs. + """ + model = self.custom_object_type3.get_model() + + # Set the model for the test class + self.model = model + + # Add object-level permission + obj_perm = ObjectPermission( + name='Test permission', + actions=['add'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model)) + + devices = Device.objects.all() + + data = { + 'test_field': 'Test 004', + 'device': devices[0].id, + 'devices': [devices[1].id, devices[2].id], + } + + initial_count = self._get_queryset().count() + + viewname = 'plugins-api:netbox_custom_objects-api:customobject-list' + list_url = reverse(viewname, kwargs={'custom_object_type': self.custom_object_type3.slug}) + + response = self.client.post(list_url, data, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_201_CREATED) + self.assertEqual(self._get_queryset().count(), initial_count + 1) + instance = self._get_queryset().get(pk=response.data['id']) + self.assertInstanceEqual( + instance, + self.create_data[0], + exclude=self.validation_excluded_fields, + api=True + ) + # TODO: GraphQL def test_graphql_list_objects(self): ...