Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions one_off_projects/2025_05_10_detr_debugging/detr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
model:
class_path: rslearn.train.lightning_module.RslearnLightningModule
init_args:
model:
class_path: rslearn.models.multitask.MultiTaskModel
init_args:
encoder:
- class_path: rslp.helios.model.Helios
init_args:
#checkpoint_path: "{CHECKPOINT_PATH}"
checkpoint_path: "/weka/dfive-default/helios/checkpoints/favyen/20250502_train_budget6000/step168500"
selector: ["encoder"]
forward_kwargs:
#patch_size: {PATCH_SIZE}
patch_size: 8
decoders:
detect:
#- class_path: rslearn.models.conv.Conv
# init_args:
# #in_channels: {ENCODER_EMBEDDING_SIZE}
# in_channels: 768
# out_channels: 192
# kernel_size: 3
#- class_path: rslearn.models.conv.Conv
# init_args:
# in_channels: 192
# out_channels: 192
# kernel_size: 3
# activation:
# class_path: torch.nn.LayerNorm
# init_args:
# #normalized_shape: [192, {256/PATCH_SIZE}, {256/PATCH_SIZE}]
# normalized_shape: [192, 32, 32]
#- class_path: rslearn.models.conv.Conv
# init_args:
# in_channels: 192
# out_channels: 192
# kernel_size: 3
#- class_path: rslearn.models.conv.Conv
# init_args:
# in_channels: 192
# out_channels: 192
# kernel_size: 3
# activation:
# class_path: torch.nn.Identity
- class_path: rslearn.models.conv.Conv
init_args:
#in_channels: {ENCODER_EMBEDDING_SIZE}
in_channels: 768
out_channels: 256
kernel_size: 3
activation:
class_path: torch.nn.Identity
- class_path: rslearn.models.detr.Detr
init_args:
predictor:
in_channels: 256
num_classes: 2
num_queries: 32
transformer:
dropout: 0.0
num_encoder_layers: 2
num_decoder_layers: 2
d_model: 256
aux_loss: true
criterion:
num_classes: 2
lr: 0.0001
plateau: true
plateau_factor: 0.2
plateau_patience: 2
plateau_min_lr: 0
plateau_cooldown: 10
data:
class_path: rslearn.train.data_module.RslearnDataModule
init_args:
path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241210/
inputs:
image:
data_type: "raster"
layers: ["sentinel2"]
bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"]
passthrough: true
dtype: FLOAT32
mask:
data_type: "raster"
layers: ["mask"]
bands: ["mask"]
passthrough: true
dtype: FLOAT32
is_target: true
targets:
data_type: "vector"
layers: ["label"]
is_target: true
task:
class_path: rslearn.train.tasks.multi_task.MultiTask
init_args:
tasks:
detect:
class_path: rslp.satlas.train.MarineInfraTask
init_args:
property_name: "category"
classes: ["platform", "turbine"]
box_size: 15
remap_values: [[0, 0.25], [0, 255]]
image_bands: [2, 1, 0]
exclude_by_center: true
enable_map_metric: true
enable_f1_metric: true
f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]]
skip_unknown_categories: true
f1_metric_kwargs:
cmp_mode: "distance"
cmp_threshold: 15
flatten_classes: true
input_mapping:
detect:
targets: "targets"
batch_size: 4
num_workers: 32
default_config:
transforms:
- class_path: rslp.transforms.mask.Mask
- class_path: rslearn.train.transforms.concatenate.Concatenate
init_args:
selections:
image: []
output_selector: sentinel2_l2a
- class_path: rslp.helios.norm.HeliosNormalize
init_args:
config_fname: "/opt/helios/data/norm_configs/computed.json"
band_names:
sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"]
train_config:
patch_size: 256
tags:
split: train
nonempty: "yes"
val_config:
patch_size: 256
tags:
split: val
nonempty: "yes"
test_config:
patch_size: 256
tags:
split: val
nonempty: "yes"
trainer:
max_epochs: 500
callbacks:
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: "epoch"
- class_path: rslearn.train.prediction_writer.RslearnWriter
init_args:
path: placeholder
output_layer: output
selector: ["detect"]
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
save_top_k: 1
save_last: true
monitor: val_detect/mAP
mode: max
- class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze
init_args:
module_selector: ["model", "encoder", 0]
unfreeze_at_epoch: 30
rslp_project: helios_finetuning
#rslp_experiment: placeholder
rslp_experiment: detrdebug_09
Loading