diff --git a/CHANGELOG.md b/CHANGELOG.md index d74a299a5b..5b14515173 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to ### Changed +- 🛂(backend) stop throttling collaboration servers #1730 - 🚸(backend) use unaccented full name for user search #1637 - 🌐(backend) internationalize demo #1644 diff --git a/src/backend/core/api/throttling.py b/src/backend/core/api/throttling.py index adfb6d57d1..260a572937 100644 --- a/src/backend/core/api/throttling.py +++ b/src/backend/core/api/throttling.py @@ -1,5 +1,8 @@ """Throttling modules for the API.""" +from django.conf import settings + +from lasuite.drf.throttling import MonitoredScopedRateThrottle from rest_framework.throttling import UserRateThrottle from sentry_sdk import capture_message @@ -19,3 +22,30 @@ class UserListThrottleSustained(UserRateThrottle): """Throttle for the user list endpoint.""" scope = "user_list_sustained" + + +class DocumentThrottle(MonitoredScopedRateThrottle): + """ + Throttle for document-related endpoints, with an exception for requests from the + collaboration server. + """ + + scope = "document" + + def allow_request(self, request, view): + """ + Override to skip throttling for requests from the collaboration server. + + Verifies the X-Y-Provider-Key header contains a valid Y_PROVIDER_API_KEY. + Using a custom header instead of Authorization to avoid triggering + authentication middleware. + """ + + y_provider_header = request.headers.get("X-Y-Provider-Key", "") + + # Check if this is a valid y-provider request and exempt from throttling + y_provider_key = getattr(settings, "Y_PROVIDER_API_KEY", None) + if y_provider_key and y_provider_header == y_provider_key: + return True + + return super().allow_request(request, view) diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index 7594770bdd..8b2ed7a675 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -55,7 +55,11 @@ from . import permissions, serializers, utils from .filters import DocumentFilter, ListDocumentFilter, UserSearchFilter -from .throttling import UserListThrottleBurst, UserListThrottleSustained +from .throttling import ( + DocumentThrottle, + UserListThrottleBurst, + UserListThrottleSustained, +) logger = logging.getLogger(__name__) @@ -373,6 +377,7 @@ class DocumentViewSet( permission_classes = [ permissions.DocumentPermission, ] + throttle_classes = [DocumentThrottle] throttle_scope = "document" queryset = models.Document.objects.select_related("creator").all() serializer_class = serializers.DocumentSerializer diff --git a/src/backend/core/tests/test_api_throttling_document_throttle.py b/src/backend/core/tests/test_api_throttling_document_throttle.py new file mode 100644 index 0000000000..c29f01a44b --- /dev/null +++ b/src/backend/core/tests/test_api_throttling_document_throttle.py @@ -0,0 +1,107 @@ +""" +Test DocumentThrottle for regular throttling and y-provider bypass. +""" + +import pytest +from rest_framework.test import APIClient + +from core import factories + +pytestmark = pytest.mark.django_db + + +def test_api_throttling_document_throttle_regular_requests(settings): + """Test that regular requests are throttled normally.""" + + current_rate = settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] = "3/minute" + settings.Y_PROVIDER_API_KEY = "test-y-provider-key" + + user = factories.UserFactory() + client = APIClient() + client.force_login(user) + + document = factories.DocumentFactory() + factories.UserDocumentAccessFactory(document=document, user=user) + + # Make 3 requests without the y-provider key + for _i in range(3): + response = client.get( + f"/api/v1.0/documents/{document.id!s}/", + ) + assert response.status_code == 200 + + # 4th request should be throttled + response = client.get( + f"/api/v1.0/documents/{document.id!s}/", + ) + assert response.status_code == 429 + + # A request with the y-provider key should NOT be throttled + response = client.get( + f"/api/v1.0/documents/{document.id!s}/", + HTTP_X_Y_PROVIDER_KEY="test-y-provider-key", + ) + assert response.status_code == 200 + + # Restore original rate + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] = current_rate + + +def test_api_throttling_document_throttle_y_provider_exempted(settings): + """Test that y-provider requests are exempted from throttling.""" + + current_rate = settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] = "3/minute" + settings.Y_PROVIDER_API_KEY = "test-y-provider-key" + + user = factories.UserFactory() + client = APIClient() + client.force_login(user) + + document = factories.DocumentFactory() + factories.UserDocumentAccessFactory(document=document, user=user) + + # Make many requests with the y-provider API key + for _i in range(10): + response = client.get( + f"/api/v1.0/documents/{document.id!s}/", + HTTP_X_Y_PROVIDER_KEY="test-y-provider-key", + ) + assert response.status_code == 200 + + # Restore original rate + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] = current_rate + + +def test_api_throttling_document_throttle_invalid_token(settings): + """Test that requests with invalid tokens are throttled.""" + + current_rate = settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] = "3/minute" + settings.Y_PROVIDER_API_KEY = "test-y-provider-key" + + user = factories.UserFactory() + client = APIClient() + client.force_login(user) + + document = factories.DocumentFactory() + factories.UserDocumentAccessFactory(document=document, user=user) + + # Make 3 requests with an invalid token + for _i in range(3): + response = client.get( + f"/api/v1.0/documents/{document.id!s}/", + HTTP_X_Y_PROVIDER_KEY="invalid-token", + ) + assert response.status_code == 200 + + # 4th request should be throttled + response = client.get( + f"/api/v1.0/documents/{document.id!s}/", + HTTP_X_Y_PROVIDER_KEY="invalid-token", + ) + assert response.status_code == 429 + + # Restore original rate + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] = current_rate diff --git a/src/frontend/servers/y-provider/__tests__/collaborationBackend.test.ts b/src/frontend/servers/y-provider/__tests__/collaborationBackend.test.ts new file mode 100644 index 0000000000..d5b7678f37 --- /dev/null +++ b/src/frontend/servers/y-provider/__tests__/collaborationBackend.test.ts @@ -0,0 +1,66 @@ +import axios from 'axios'; +import { describe, expect, test, vi } from 'vitest'; + +vi.mock('../src/env', () => ({ + COLLABORATION_BACKEND_BASE_URL: 'http://app-dev:8000', + Y_PROVIDER_API_KEY: 'test-yprovider-key', +})); + +describe('CollaborationBackend', () => { + test('fetchDocument sends X-Y-Provider-Key header', async () => { + const axiosGetSpy = vi.spyOn(axios, 'get').mockResolvedValue({ + status: 200, + data: { + id: 'test-doc-id', + abilities: { retrieve: true, update: true }, + }, + }); + + const { fetchDocument } = await import('@/api/collaborationBackend'); + const documentId = 'test-document-123'; + + await fetchDocument(documentId, { cookie: 'test-cookie' }); + + expect(axiosGetSpy).toHaveBeenCalledWith( + `http://app-dev:8000/api/v1.0/documents/${documentId}/`, + expect.objectContaining({ + headers: expect.objectContaining({ + 'X-Y-Provider-Key': 'test-yprovider-key', + cookie: 'test-cookie', + }), + }), + ); + + axiosGetSpy.mockRestore(); + }); + + test('fetchCurrentUser sends X-Y-Provider-Key header', async () => { + const axiosGetSpy = vi.spyOn(axios, 'get').mockResolvedValue({ + status: 200, + data: { + id: 'test-user-id', + email: 'test@example.com', + }, + }); + + const { fetchCurrentUser } = await import('@/api/collaborationBackend'); + + await fetchCurrentUser({ + cookie: 'test-cookie', + origin: 'http://localhost:3000', + }); + + expect(axiosGetSpy).toHaveBeenCalledWith( + 'http://app-dev:8000/api/v1.0/users/me/', + expect.objectContaining({ + headers: expect.objectContaining({ + 'X-Y-Provider-Key': 'test-yprovider-key', + cookie: 'test-cookie', + origin: 'http://localhost:3000', + }), + }), + ); + + axiosGetSpy.mockRestore(); + }); +}); diff --git a/src/frontend/servers/y-provider/src/api/collaborationBackend.ts b/src/frontend/servers/y-provider/src/api/collaborationBackend.ts index 04892c2bd4..6fca2b84f3 100644 --- a/src/frontend/servers/y-provider/src/api/collaborationBackend.ts +++ b/src/frontend/servers/y-provider/src/api/collaborationBackend.ts @@ -2,7 +2,7 @@ import { IncomingHttpHeaders } from 'http'; import axios from 'axios'; -import { COLLABORATION_BACKEND_BASE_URL } from '@/env'; +import { COLLABORATION_BACKEND_BASE_URL, Y_PROVIDER_API_KEY } from '@/env'; export interface User { id: string; @@ -61,6 +61,7 @@ async function fetch( headers: { cookie: requestHeaders['cookie'], origin: requestHeaders['origin'], + 'X-Y-Provider-Key': Y_PROVIDER_API_KEY, }, }, );