Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions netbox_custom_objects/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from rest_framework.reverse import reverse
from rest_framework.utils import model_meta

from netbox_custom_objects import field_types
from netbox_custom_objects.models import (CustomObject, CustomObjectType,
Expand Down Expand Up @@ -253,6 +254,44 @@ def get_display(self, obj):
"""Get display representation of the object"""
return str(obj)

# Stock DRF create() without raise_errors_on_nested_writes guard
def create(self, validated_data):
ModelClass = self.Meta.model

info = model_meta.get_field_info(ModelClass)
many_to_many = {}
for field_name, relation_info in info.relations.items():
if relation_info.to_many and (field_name in validated_data):
many_to_many[field_name] = validated_data.pop(field_name)

instance = ModelClass._default_manager.create(**validated_data)

if many_to_many:
for field_name, value in many_to_many.items():
field = getattr(instance, field_name)
field.set(value)

return instance

# Stock DRF update() with custom field.set() for M2M
def update(self, instance, validated_data):
info = model_meta.get_field_info(instance)

m2m_fields = []
for attr, value in validated_data.items():
if attr in info.relations and info.relations[attr].to_many:
m2m_fields.append((attr, value))
else:
setattr(instance, attr, value)

instance.save()

for attr, value in m2m_fields:
field = getattr(instance, attr)
field.set(value, clear=True)

return instance

# Create basic attributes for the serializer
attrs = {
"Meta": meta,
Expand All @@ -261,6 +300,8 @@ def get_display(self, obj):
"get_url": get_url,
"display": serializers.SerializerMethodField(),
"get_display": get_display,
"create": create,
"update": update,
}

for field in model_fields:
Expand Down
12 changes: 11 additions & 1 deletion netbox_custom_objects/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def create_simple_custom_object_type(self, **kwargs):

return custom_object_type

def create_complex_custom_object_type(self, **kwargs):
@classmethod
def create_complex_custom_object_type(cls, **kwargs):
"""Create a complex custom object type with various field types."""
custom_object_type = CustomObjectsTestCase.create_custom_object_type(**kwargs)
choice_set = CustomObjectsTestCase.create_choice_set()
Expand Down Expand Up @@ -149,6 +150,15 @@ def create_complex_custom_object_type(self, **kwargs):
related_object_type=device_object_type
)

# Multi-Object field (devices)
CustomObjectsTestCase.create_custom_object_type_field(
custom_object_type,
name="devices",
label="Devices",
type="multiobject",
related_object_type=device_object_type
)

return custom_object_type

def create_self_referential_custom_object_type(self, **kwargs):
Expand Down
116 changes: 115 additions & 1 deletion netbox_custom_objects/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from django.urls import reverse

from users.models import Token
from utilities.testing import APIViewTestCases, create_test_user
from rest_framework import status

from netbox_custom_objects.models import CustomObjectType
from .base import CustomObjectsTestCase
from core.models import ObjectType
from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Rack, Site
from users.models import ObjectPermission, Token
from virtualization.models import Cluster, ClusterType


class CustomObjectTest(CustomObjectsTestCase, APIViewTestCases.APIViewTestCase):
Expand Down Expand Up @@ -32,6 +36,72 @@ def setUp(self):

@classmethod
def setUpTestData(cls):
# Set up some devices to be used in object/multiobject fields
sites = (
Site(name='Site 1', slug='site-1'),
Site(name='Site 2', slug='site-2'),
)
Site.objects.bulk_create(sites)

racks = (
Rack(name='Rack 1', site=sites[0]),
Rack(name='Rack 2', site=sites[1]),
)
Rack.objects.bulk_create(racks)

manufacturer = Manufacturer.objects.create(name='Manufacturer 1', slug='manufacturer-1')

device_types = (
DeviceType(manufacturer=manufacturer, model='Device Type 1', slug='device-type-1'),
DeviceType(manufacturer=manufacturer, model='Device Type 2', slug='device-type-2', u_height=2),
)
DeviceType.objects.bulk_create(device_types)

roles = (
DeviceRole(name='Device Role 1', slug='device-role-1', color='ff0000'),
DeviceRole(name='Device Role 2', slug='device-role-2', color='00ff00'),
)
for role in roles:
role.save()

cluster_type = ClusterType.objects.create(name='Cluster Type 1', slug='cluster-type-1')

clusters = (
Cluster(name='Cluster 1', type=cluster_type),
Cluster(name='Cluster 2', type=cluster_type),
)
Cluster.objects.bulk_create(clusters)

devices = (
Device(
device_type=device_types[0],
role=roles[0],
name='Device 1',
site=sites[0],
rack=racks[0],
cluster=clusters[0],
local_context_data={'A': 1}
),
Device(
device_type=device_types[0],
role=roles[0],
name='Device 2',
site=sites[0],
rack=racks[0],
cluster=clusters[0],
local_context_data={'B': 2}
),
Device(
device_type=device_types[0],
role=roles[0],
name='Device 3',
site=sites[0],
rack=racks[0],
cluster=clusters[0],
local_context_data={'C': 3}
),
)
Device.objects.bulk_create(devices)

# Create test custom object types
cls.custom_object_type1 = CustomObjectType.objects.create(
Expand All @@ -48,6 +118,8 @@ def setUpTestData(cls):
slug="test-objects-2",
)

cls.custom_object_type3 = cls.create_complex_custom_object_type(name="ComplexObject")

cls.model = cls.custom_object_type1.get_model()
cls.create_custom_object_type_field(cls.custom_object_type1)

Expand Down Expand Up @@ -103,6 +175,48 @@ def test_delete_object(self):
# TODO: ObjectChange causes failure
...

def test_create_with_nested_serializers(self):
"""
POST a single object with a multiobject field's values specified via a list of PKs.
"""
model = self.custom_object_type3.get_model()

# Set the model for the test class
self.model = model

# Add object-level permission
obj_perm = ObjectPermission(
name='Test permission',
actions=['add']
)
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))

devices = Device.objects.all()

data = {
'test_field': 'Test 004',
'device': devices[0].id,
'devices': [devices[1].id, devices[2].id],
}

initial_count = self._get_queryset().count()

viewname = 'plugins-api:netbox_custom_objects-api:customobject-list'
list_url = reverse(viewname, kwargs={'custom_object_type': self.custom_object_type3.slug})

response = self.client.post(list_url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(self._get_queryset().count(), initial_count + 1)
instance = self._get_queryset().get(pk=response.data['id'])
self.assertInstanceEqual(
instance,
self.create_data[0],
exclude=self.validation_excluded_fields,
api=True
)

# TODO: GraphQL
def test_graphql_list_objects(self):
...
Expand Down
Loading