diff --git a/one_off_projects/2025_05_10_detr_debugging/detr.yaml b/one_off_projects/2025_05_10_detr_debugging/detr.yaml new file mode 100644 index 00000000..34073189 --- /dev/null +++ b/one_off_projects/2025_05_10_detr_debugging/detr.yaml @@ -0,0 +1,173 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslp.helios.model.Helios + init_args: + #checkpoint_path: "{CHECKPOINT_PATH}" + checkpoint_path: "/weka/dfive-default/helios/checkpoints/favyen/20250502_train_budget6000/step168500" + selector: ["encoder"] + forward_kwargs: + #patch_size: {PATCH_SIZE} + patch_size: 8 + decoders: + detect: + #- class_path: rslearn.models.conv.Conv + # init_args: + # #in_channels: {ENCODER_EMBEDDING_SIZE} + # in_channels: 768 + # out_channels: 192 + # kernel_size: 3 + #- class_path: rslearn.models.conv.Conv + # init_args: + # in_channels: 192 + # out_channels: 192 + # kernel_size: 3 + # activation: + # class_path: torch.nn.LayerNorm + # init_args: + # #normalized_shape: [192, {256/PATCH_SIZE}, {256/PATCH_SIZE}] + # normalized_shape: [192, 32, 32] + #- class_path: rslearn.models.conv.Conv + # init_args: + # in_channels: 192 + # out_channels: 192 + # kernel_size: 3 + #- class_path: rslearn.models.conv.Conv + # init_args: + # in_channels: 192 + # out_channels: 192 + # kernel_size: 3 + # activation: + # class_path: torch.nn.Identity + - class_path: rslearn.models.conv.Conv + init_args: + #in_channels: {ENCODER_EMBEDDING_SIZE} + in_channels: 768 + out_channels: 256 + kernel_size: 3 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.detr.Detr + init_args: + predictor: + in_channels: 256 + num_classes: 2 + num_queries: 32 + transformer: + dropout: 0.0 + num_encoder_layers: 2 + num_decoder_layers: 2 + d_model: 256 + aux_loss: true + criterion: + num_classes: 2 + lr: 0.0001 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241210/ + inputs: + image: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: FLOAT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslp.satlas.train.MarineInfraTask + init_args: + property_name: "category" + classes: ["platform", "turbine"] + box_size: 15 + remap_values: [[0, 0.25], [0, 255]] + image_bands: [2, 1, 0] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]] + skip_unknown_categories: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 4 + num_workers: 32 + default_config: + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image: [] + output_selector: sentinel2_l2a + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + config_fname: "/opt/helios/data/norm_configs/computed.json" + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + train_config: + patch_size: 256 + tags: + split: train + nonempty: "yes" + val_config: + patch_size: 256 + tags: + split: val + nonempty: "yes" + test_config: + patch_size: 256 + tags: + split: val + nonempty: "yes" +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: placeholder + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 30 +rslp_project: helios_finetuning +#rslp_experiment: placeholder +rslp_experiment: detrdebug_09