Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ and this project adheres to

- ✨(backend) allow to create a new user in a marketing system

### Changed

- 🛂(backend) stop throttling collaboration servers #1730

## [4.1.0] - 2025-12-09

### Added
Expand Down
30 changes: 30 additions & 0 deletions src/backend/core/api/throttling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Throttling modules for the API."""

from django.conf import settings

from lasuite.drf.throttling import MonitoredScopedRateThrottle
from rest_framework.throttling import UserRateThrottle
from sentry_sdk import capture_message

Expand All @@ -19,3 +22,30 @@ class UserListThrottleSustained(UserRateThrottle):
"""Throttle for the user list endpoint."""

scope = "user_list_sustained"


class DocumentThrottle(MonitoredScopedRateThrottle):
"""
Throttle for document-related endpoints, with an exception for requests from the
collaboration server.
"""

scope = "document"

def allow_request(self, request, view):
"""
Override to skip throttling for requests from the collaboration server.

Verifies the X-Y-Provider-Key header contains a valid Y_PROVIDER_API_KEY.
Using a custom header instead of Authorization to avoid triggering
authentication middleware.
"""

y_provider_header = request.headers.get("X-Y-Provider-Key", "")

# Check if this is a valid y-provider request and exempt from throttling
y_provider_key = getattr(settings, "Y_PROVIDER_API_KEY", None)
if y_provider_key and y_provider_header == y_provider_key:
return True

return super().allow_request(request, view)
7 changes: 6 additions & 1 deletion src/backend/core/api/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@

from . import permissions, serializers, utils
from .filters import DocumentFilter, ListDocumentFilter, UserSearchFilter
from .throttling import UserListThrottleBurst, UserListThrottleSustained
from .throttling import (
DocumentThrottle,
UserListThrottleBurst,
UserListThrottleSustained,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -365,6 +369,7 @@ class DocumentViewSet(
permission_classes = [
permissions.DocumentPermission,
]
throttle_classes = [DocumentThrottle]
throttle_scope = "document"
queryset = models.Document.objects.select_related("creator").all()
serializer_class = serializers.DocumentSerializer
Expand Down
100 changes: 100 additions & 0 deletions src/backend/core/tests/test_api_throttling_document_throttle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
Test DocumentThrottle for regular throttling and y-provider bypass.
"""

import pytest
from rest_framework.test import APIClient

from core import factories

pytestmark = pytest.mark.django_db


def test_api_throttling_document_throttle_regular_requests(settings):
"""Test that regular requests are throttled normally."""

current_rate = settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"]
settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] = "3/minute"
settings.Y_PROVIDER_API_KEY = "test-y-provider-key"

user = factories.UserFactory()
client = APIClient()
client.force_login(user)

document = factories.DocumentFactory()
factories.UserDocumentAccessFactory(document=document, user=user)

# Make 3 requests without the y-provider key
for _i in range(3):
response = client.get(
f"/api/v1.0/documents/{document.id!s}/",
)
assert response.status_code == 200

# 4th request should be throttled
response = client.get(
f"/api/v1.0/documents/{document.id!s}/",
)
assert response.status_code == 429

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can maybe here make a new request using the special header to test that this one will not be throttled ?

# Restore original rate
settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] = current_rate


def test_api_throttling_document_throttle_y_provider_exempted(settings):
"""Test that y-provider requests are exempted from throttling."""

current_rate = settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"]
settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] = "3/minute"
settings.Y_PROVIDER_API_KEY = "test-y-provider-key"

user = factories.UserFactory()
client = APIClient()
client.force_login(user)

document = factories.DocumentFactory()
factories.UserDocumentAccessFactory(document=document, user=user)

# Make many requests with the y-provider API key
for _i in range(100):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can test with 10 ? To save time in the test suite

response = client.get(
f"/api/v1.0/documents/{document.id!s}/",
HTTP_X_Y_PROVIDER_KEY="test-y-provider-key",
)
assert response.status_code == 200

# Restore original rate
settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] = current_rate


def test_api_throttling_document_throttle_invalid_token(settings):
"""Test that requests with invalid tokens are throttled."""

current_rate = settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"]
settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] = "3/minute"
settings.Y_PROVIDER_API_KEY = "test-y-provider-key"

user = factories.UserFactory()
client = APIClient()
client.force_login(user)

document = factories.DocumentFactory()
factories.UserDocumentAccessFactory(document=document, user=user)

# Make 3 requests with an invalid token
for _i in range(3):
response = client.get(
f"/api/v1.0/documents/{document.id!s}/",
HTTP_X_Y_PROVIDER_KEY="invalid-token",
)
assert response.status_code == 200

# 4th request should be throttled
response = client.get(
f"/api/v1.0/documents/{document.id!s}/",
HTTP_X_Y_PROVIDER_KEY="invalid-token",
)
assert response.status_code == 429

# Restore original rate
settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] = current_rate
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import axios from 'axios';
import { describe, expect, test, vi } from 'vitest';

vi.mock('../src/env', () => ({
COLLABORATION_BACKEND_BASE_URL: 'http://app-dev:8000',
Y_PROVIDER_API_KEY: 'test-yprovider-key',
}));

describe('CollaborationBackend', () => {
test('fetchDocument sends X-Y-Provider-Key header', async () => {
const axiosGetSpy = vi.spyOn(axios, 'get').mockResolvedValue({
status: 200,
data: {
id: 'test-doc-id',
abilities: { retrieve: true, update: true },
},
});

const { fetchDocument } = await import('@/api/collaborationBackend');
const documentId = 'test-document-123';

await fetchDocument(documentId, { cookie: 'test-cookie' });

expect(axiosGetSpy).toHaveBeenCalledWith(
`http://app-dev:8000/api/v1.0/documents/${documentId}/`,
expect.objectContaining({
headers: expect.objectContaining({
'X-Y-Provider-Key': 'test-yprovider-key',
cookie: 'test-cookie',
}),
}),
);

axiosGetSpy.mockRestore();
});

test('fetchCurrentUser sends X-Y-Provider-Key header', async () => {
const axiosGetSpy = vi.spyOn(axios, 'get').mockResolvedValue({
status: 200,
data: {
id: 'test-user-id',
email: '[email protected]',
},
});

const { fetchCurrentUser } = await import('@/api/collaborationBackend');

await fetchCurrentUser({
cookie: 'test-cookie',
origin: 'http://localhost:3000',
});

expect(axiosGetSpy).toHaveBeenCalledWith(
'http://app-dev:8000/api/v1.0/users/me/',
expect.objectContaining({
headers: expect.objectContaining({
'X-Y-Provider-Key': 'test-yprovider-key',
cookie: 'test-cookie',
origin: 'http://localhost:3000',
}),
}),
);

axiosGetSpy.mockRestore();
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { IncomingHttpHeaders } from 'http';

import axios from 'axios';

import { COLLABORATION_BACKEND_BASE_URL } from '@/env';
import { COLLABORATION_BACKEND_BASE_URL, Y_PROVIDER_API_KEY } from '@/env';

export interface User {
id: string;
Expand Down Expand Up @@ -61,6 +61,7 @@ async function fetch<T>(
headers: {
cookie: requestHeaders['cookie'],
origin: requestHeaders['origin'],
'X-Y-Provider-Key': Y_PROVIDER_API_KEY,
},
},
);
Expand Down
Loading