Skip to content

Commit c0ef641

Browse files
authored
fix: PLT-903: Use db alias in migrations (#8748)
Co-authored-by: triklozoid <[email protected]>
1 parent 9d37a23 commit c0ef641

18 files changed

+190
-154
lines changed

label_studio/data_manager/migrations/0002_remove_annotations_ids.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55

66
def remove(apps, schema_editor):
7+
db_alias = schema_editor.connection.alias
78
View = apps.get_model('data_manager', 'View')
8-
views = View.objects.all()
9+
views = View.objects.using(db_alias).all()
910

1011
for view in views:
1112
if 'hiddenColumns' in view.data:
@@ -16,7 +17,7 @@ def remove(apps, schema_editor):
1617
view.data['hiddenColumns']['labeling'].append('tasks:annotations_ids')
1718
view.data['hiddenColumns']['labeling'] = list(set(view.data['hiddenColumns']['labeling']))
1819

19-
view.save()
20+
view.save(using=db_alias)
2021

2122

2223
class Migration(migrations.Migration):

label_studio/data_manager/migrations/0003_remove_predictions_model_versions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55

66
def remove(apps, schema_editor):
7+
db_alias = schema_editor.connection.alias
78
View = apps.get_model('data_manager', 'View')
8-
views = View.objects.all()
9+
views = View.objects.using(db_alias).all()
910

1011
for view in views:
1112
if 'hiddenColumns' in view.data:
@@ -16,7 +17,7 @@ def remove(apps, schema_editor):
1617
view.data['hiddenColumns']['labeling'].append('tasks:predictions_model_versions')
1718
view.data['hiddenColumns']['labeling'] = list(set(view.data['hiddenColumns']['labeling']))
1819

19-
view.save()
20+
view.save(using=db_alias)
2021

2122

2223
class Migration(migrations.Migration):

label_studio/data_manager/migrations/0017_update_agreement_selected_to_nested_structure.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from django.db import migrations, connection
1+
from django.db import migrations
22
from copy import deepcopy
33
from django.apps import apps as django_apps
44
from django.conf import settings
@@ -12,7 +12,7 @@
1212
logger = logging.getLogger(__name__)
1313

1414

15-
def forward_migration():
15+
def forward_migration(db_alias):
1616
"""
1717
Migrates views that have agreement_selected populated to the new structure
1818
@@ -36,7 +36,7 @@ def forward_migration():
3636
'ground_truth': bool
3737
}
3838
"""
39-
migration, created = AsyncMigrationStatus.objects.get_or_create(
39+
migration, created = AsyncMigrationStatus.objects.using(db_alias).get_or_create(
4040
name=migration_name,
4141
defaults={'status': AsyncMigrationStatus.STATUS_STARTED}
4242
)
@@ -49,7 +49,7 @@ def forward_migration():
4949
# Iterate using values() to avoid loading full model instances
5050
# Fetch only the fields we need, filtering to views that have 'agreement_selected' in data
5151
qs = (
52-
View.objects
52+
View.objects.using(db_alias)
5353
.filter(data__has_key='agreement_selected')
5454
.filter(data__agreement_selected__isnull=False)
5555
.values('id', 'data')
@@ -69,18 +69,19 @@ def forward_migration():
6969
}
7070

7171
# Update only the JSON field via update(); do not load model instance or call save()
72-
View.objects.filter(id=view_id).update(data=new_data)
72+
View.objects.using(db_alias).filter(id=view_id).update(data=new_data)
7373
logger.info(f'Updated View {view_id} agreement selected to default all annotators + all models')
7474
updated += 1
7575

7676
if updated:
7777
logger.info(f'{migration_name} Updated {updated} View rows')
7878

7979
migration.status = AsyncMigrationStatus.STATUS_FINISHED
80-
migration.save(update_fields=['status'])
80+
migration.save(update_fields=['status'], using=db_alias)
8181

8282
def forwards(apps, schema_editor):
83-
start_job_async_or_sync(forward_migration, queue_name=settings.SERVICE_QUEUE_NAME)
83+
db_alias = schema_editor.connection.alias
84+
start_job_async_or_sync(forward_migration, db_alias=db_alias, queue_name=settings.SERVICE_QUEUE_NAME)
8485

8586

8687
def backwards(apps, schema_editor):
@@ -100,4 +101,3 @@ class Migration(migrations.Migration):
100101
]
101102

102103

103-

label_studio/io_storages/migrations/0014_init_statuses.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,19 @@
66
logger = logging.getLogger(__name__)
77

88

9-
def update_storage(storage):
9+
def update_storage(storage, db_alias=None):
1010
logger.info(f'=> Migration for {storage._meta.label} statuses started')
11-
storage.objects.update(status='initialized')
12-
instances = list(storage.objects.all().only('id', 'meta', 'status', 'last_sync_count'))
11+
manager = storage.objects.using(db_alias) if db_alias else storage.objects
12+
manager.update(status='initialized')
13+
instances = list(manager.all().only('id', 'meta', 'status', 'last_sync_count', 'project_id'))
1314

1415
for instance in instances:
15-
prefix = f'Project ID={instance.project.id} {instance}'
16+
prefix = f'Project ID={instance.project_id} {instance}'
1617

1718
# import source storages
1819
if 'import' in storage._meta.label_lower:
19-
count = instance.links.count() - instance.last_sync_count if instance.last_sync_count else 0
20+
links_manager = instance.links.using(db_alias) if db_alias else instance.links
21+
count = links_manager.count() - instance.last_sync_count if instance.last_sync_count else 0
2022
instance.meta['tasks_existed'] = count if count > 0 else 0
2123
if instance.meta['tasks_existed'] and instance.meta['tasks_existed'] > 0:
2224
instance.status = 'completed'
@@ -29,11 +31,12 @@ def update_storage(storage):
2931
instance.status = 'completed'
3032
logger.info(f'{prefix} total_annotations = {instance.last_sync_count}')
3133

32-
storage.objects.bulk_update(instances, fields=['meta', 'status'], batch_size=100)
34+
manager.bulk_update(instances, fields=['meta', 'status'], batch_size=100)
3335
logger.info(f'=> Migration for {storage._meta.label} statuses finished')
3436

3537

3638
def forwards(apps, schema_editor):
39+
db_alias = schema_editor.connection.alias
3740
storages = [
3841
apps.get_model('io_storages', 'AzureBlobImportStorage'),
3942
apps.get_model('io_storages', 'AzureBlobExportStorage'),
@@ -48,7 +51,7 @@ def forwards(apps, schema_editor):
4851
]
4952

5053
for storage in storages:
51-
update_storage(storage)
54+
update_storage(storage, db_alias)
5255

5356

5457
def backwards(apps, schema_editor):

label_studio/io_storages/migrations/0017_auto_20240731_1638.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def drop_index_sql(table_name, index_name, column_name):
6969
]
7070

7171

72-
def forward_migration(migration_name):
73-
migration = AsyncMigrationStatus.objects.create(
72+
def forward_migration(migration_name, db_alias):
73+
migration = AsyncMigrationStatus.objects.using(db_alias).create(
7474
name=migration_name,
7575
status=AsyncMigrationStatus.STATUS_STARTED,
7676
)
@@ -79,7 +79,8 @@ def forward_migration(migration_name):
7979
)
8080

8181
# Get db cursor
82-
cursor = connection.cursor()
82+
from django.db import connections
83+
cursor = connections[db_alias].cursor()
8384
for table in tables:
8485
index_sql = create_index_sql(table['table_name'], table['index_name'], table['column_name'])
8586
fk_sql = create_fk_sql(table['table_name'], table['fk_constraint'], table['column_name'], "task_completion",
@@ -90,13 +91,13 @@ def forward_migration(migration_name):
9091
cursor.execute(fk_sql)
9192

9293
migration.status = AsyncMigrationStatus.STATUS_FINISHED
93-
migration.save()
94+
migration.save(using=db_alias)
9495
logger.debug(
9596
f'Async migration {migration_name} complete'
9697
)
9798

98-
def reverse_migration(migration_name):
99-
migration = AsyncMigrationStatus.objects.create(
99+
def reverse_migration(migration_name, db_alias):
100+
migration = AsyncMigrationStatus.objects.using(db_alias).create(
100101
name=migration_name,
101102
status=AsyncMigrationStatus.STATUS_STARTED,
102103
)
@@ -105,26 +106,29 @@ def reverse_migration(migration_name):
105106
)
106107

107108
# Get db cursor
108-
cursor = connection.cursor()
109+
from django.db import connections
110+
cursor = connections[db_alias].cursor()
109111
for table in tables:
110112
reverse_sql = drop_index_sql(table['table_name'], table['index_name'], table['column_name'])
111113
# Run reverse_sql
112114
cursor.execute(reverse_sql)
113115

114116
migration.status = AsyncMigrationStatus.STATUS_FINISHED
115-
migration.save()
117+
migration.save(using=db_alias)
116118
logger.debug(
117119
f'Async migration {migration_name} complete'
118120
)
119121

120122

121123
def forwards(apps, schema_editor):
122124
# Dispatch migrations to rqworkers
123-
start_job_async_or_sync(forward_migration, migration_name=migration_name)
125+
db_alias = schema_editor.connection.alias
126+
start_job_async_or_sync(forward_migration, migration_name=migration_name, db_alias=db_alias)
124127

125128

126129
def backwards(apps, schema_editor):
127-
start_job_async_or_sync(reverse_migration, migration_name=migration_name)
130+
db_alias = schema_editor.connection.alias
131+
start_job_async_or_sync(reverse_migration, migration_name=migration_name, db_alias=db_alias)
128132

129133

130134
def get_operations():
@@ -158,4 +162,3 @@ class Migration(migrations.Migration):
158162
]
159163

160164
operations = get_operations()
161-

label_studio/ml_models/migrations/0011_thirdpartymodelversion_model_provider_connection.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,25 @@
88
from ml_model_providers.models import ModelProviderConnection, ModelProviders
99

1010

11-
def _fill_model_version_model_provider_connection():
11+
def _fill_model_version_model_provider_connection(db_alias: str):
1212
for provider in [ModelProviders.OPENAI, ModelProviders.AZURE_OPENAI]:
13-
this_provider_model_versions = ThirdPartyModelVersion.objects.filter(provider=provider).values('id', 'organization_id', 'provider_model_id')
13+
this_provider_model_versions = (
14+
ThirdPartyModelVersion.objects.using(db_alias)
15+
.filter(provider=provider)
16+
.values('id', 'organization_id', 'provider_model_id')
17+
)
1418
for provider_model_version in this_provider_model_versions:
15-
connection_ids = ModelProviderConnection.objects.filter(
19+
connection_ids = ModelProviderConnection.objects.using(db_alias).filter(
1620
organization_id=provider_model_version['organization_id'],
1721
provider=provider,
1822
**({'deployment_name': provider_model_version['provider_model_id']} if provider == ModelProviders.AZURE_OPENAI else {}),
1923
).values_list('id', flat=True)[:1]
2024
connection_id = connection_ids[0] if connection_ids else None
21-
ThirdPartyModelVersion.objects.filter(id=provider_model_version['id']).update(model_provider_connection_id=connection_id)
25+
ThirdPartyModelVersion.objects.using(db_alias).filter(id=provider_model_version['id']).update(model_provider_connection_id=connection_id)
2226

2327
def forwards(apps, schema_editor):
24-
start_job_async_or_sync(_fill_model_version_model_provider_connection)
28+
db_alias = schema_editor.connection.alias
29+
start_job_async_or_sync(_fill_model_version_model_provider_connection, db_alias=db_alias)
2530

2631

2732
def backwards(apps, schema_editor):

label_studio/organizations/migrations/0001_squashed_0008_auto_20201005_1552.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010

1111
def rename_disabled_to_off0006(apps, schema_editor):
12+
db_alias = schema_editor.connection.alias
1213
OrganizationMember = apps.get_model('organizations', 'OrganizationMember')
13-
OrganizationMember.objects.filter(role="Disabled").update(role="Off")
14+
OrganizationMember.objects.using(db_alias).filter(role="Disabled").update(role="Off")
1415

1516
migrations.AlterField(
1617
model_name='organizationmember',
@@ -23,8 +24,9 @@ def rename_disabled_to_off0006(apps, schema_editor):
2324

2425

2526
def rename_disabled_to_off0007(apps, schema_editor):
27+
db_alias = schema_editor.connection.alias
2628
OrganizationMember = apps.get_model('organizations', 'OrganizationMember')
27-
OrganizationMember.objects.filter(role="Off").update(role="Deactivated")
29+
OrganizationMember.objects.using(db_alias).filter(role="Off").update(role="Deactivated")
2830

2931
migrations.AlterField(
3032
model_name='organizationmember',

label_studio/projects/migrations/0026_auto_20231103_0020.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,31 @@
1010
logger = logging.getLogger(__name__)
1111

1212

13-
def _fill_label_config_hash(migration_name):
14-
project_tuples = Project.objects.all().values_list('id', 'parsed_label_config')
13+
def _fill_label_config_hash(migration_name, db_alias):
14+
project_tuples = Project.objects.using(db_alias).all().values_list('id', 'parsed_label_config')
1515
for project_id, parsed_label_config in project_tuples:
16-
migration = AsyncMigrationStatus.objects.create(
16+
migration = AsyncMigrationStatus.objects.using(db_alias).create(
1717
project_id=project_id,
1818
name=migration_name,
1919
status=AsyncMigrationStatus.STATUS_STARTED,
2020
)
2121

2222
hashed_label_config = hash(str(parsed_label_config))
23-
Project.objects.filter(id=project_id).update(label_config_hash=hashed_label_config)
23+
Project.objects.using(db_alias).filter(id=project_id).update(label_config_hash=hashed_label_config)
2424

2525
migration.status = AsyncMigrationStatus.STATUS_FINISHED
26-
migration.save()
26+
migration.save(using=db_alias)
2727

2828

29-
def fill_label_config_hash(migration_name):
29+
def fill_label_config_hash(migration_name, db_alias):
3030
logger.info('Start filling label config hash')
31-
start_job_async_or_sync(_fill_label_config_hash, migration_name=migration_name)
31+
start_job_async_or_sync(_fill_label_config_hash, migration_name=migration_name, db_alias=db_alias)
3232
logger.info('Finished filling label config hash')
3333

3434

3535
def forward(apps, schema_editor):
36-
fill_label_config_hash('0026_auto_20231103_0020')
36+
db_alias = schema_editor.connection.alias
37+
fill_label_config_hash('0026_auto_20231103_0020', db_alias)
3738

3839

3940
def backwards(apps, schema_editor):
Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from django.db import migrations, connection
1+
from django.db import migrations, connections
22
from django.conf import settings
33
from core.redis import start_job_async_or_sync
44
from core.models import AsyncMigrationStatus
@@ -8,8 +8,8 @@
88
migration_name = '0030_project_search_vector_index'
99

1010
# Actual DDL to run
11-
def forward_migration(migration_name):
12-
migration = AsyncMigrationStatus.objects.create(
11+
def forward_migration(migration_name, db_alias):
12+
migration = AsyncMigrationStatus.objects.using(db_alias).create(
1313
name=migration_name,
1414
status=AsyncMigrationStatus.STATUS_STARTED,
1515
)
@@ -21,16 +21,16 @@ def forward_migration(migration_name):
2121
CREATE INDEX CONCURRENTLY IF NOT EXISTS project_search_vector_idx ON project USING GIN (search_vector);
2222
'''
2323

24-
with connection.cursor() as cursor:
24+
with connections[db_alias].cursor() as cursor:
2525
cursor.execute(sql)
2626

2727
migration.status = AsyncMigrationStatus.STATUS_FINISHED
28-
migration.save()
28+
migration.save(using=db_alias)
2929
logger.debug(f'Async migration {migration_name} complete')
3030

3131
# Reverse DDL
32-
def reverse_migration(migration_name):
33-
migration = AsyncMigrationStatus.objects.create(
32+
def reverse_migration(migration_name, db_alias):
33+
migration = AsyncMigrationStatus.objects.using(db_alias).create(
3434
name=migration_name,
3535
status=AsyncMigrationStatus.STATUS_STARTED,
3636
)
@@ -39,23 +39,27 @@ def reverse_migration(migration_name):
3939
# Drop index (handle database differences)
4040
sql = 'DROP INDEX CONCURRENTLY IF EXISTS "project_search_vector_idx";'
4141

42-
with connection.cursor() as cursor:
42+
with connections[db_alias].cursor() as cursor:
4343
cursor.execute(sql)
4444

4545
migration.status = AsyncMigrationStatus.STATUS_FINISHED
46-
migration.save()
46+
migration.save(using=db_alias)
4747
logger.debug(f'Async migration rollback {migration_name} complete')
4848

4949
# Hook into Django migration
5050
def forwards(apps, schema_editor):
51-
if connection.vendor == 'postgresql':
52-
start_job_async_or_sync(forward_migration, migration_name=migration_name)
51+
db_alias = schema_editor.connection.alias
52+
conn = connections[db_alias]
53+
if conn.vendor == 'postgresql':
54+
start_job_async_or_sync(forward_migration, migration_name=migration_name, db_alias=db_alias)
5355
else:
5456
logger.debug(f'No index to create if is sqllite')
5557

5658
def backwards(apps, schema_editor):
57-
if connection.vendor == 'postgresql':
58-
start_job_async_or_sync(reverse_migration, migration_name=migration_name)
59+
db_alias = schema_editor.connection.alias
60+
conn = connections[db_alias]
61+
if conn.vendor == 'postgresql':
62+
start_job_async_or_sync(reverse_migration, migration_name=migration_name, db_alias=db_alias)
5963
else:
6064
logger.debug(f'No index to drop if is sqllite')
6165

@@ -66,4 +70,4 @@ class Migration(migrations.Migration):
6670
]
6771
operations = [
6872
migrations.RunPython(forwards, backwards),
69-
]
73+
]

0 commit comments

Comments
 (0)