Awesome
The Pretrained Remote Sensing Transformer (Presto)
This code accompanies our paper, Lightweight, Pre-trained Transformers for Remote Sensing Timeseries.
Environment Setup
python -m venv venv
source venv/bin/activate
pip install -e .
wandb
can additionally be installed for full functionality of the train.py
script.
Entrypoints
Three entrypoints to the code are available: train.py
, eval.py
and mosaiks.py
.
In addition, a jupyter notebook is available demonstrating how Presto can be finetuned on different downstream tasks.
Finally, Presto can also be loaded directly from the python package.
We also have included Presto contained in a single file (i.e. with no imports from elsewhere in the package) at single_file_presto.py
, if you want to easily integrate it into a different application.
We test that these models are equivalent:
# either import works. The single_file_presto has no load_pretrained function, since this
# requires knowing where the pretrained file is. The state dict can be loaded directly
# from data/default_models.pt
from single_file_presto import Presto
from presto import Presto
# to make a randomly initialized encoder-decoder model
encoder_decoder = Presto.construct()
# alternatively, the pre-trained model can also be loaded
encoder_decoder = Presto.load_pretrained()
# to isolate the encoder
encoder_only = encoder_decoder.encoder
# to add a linear transformation to the encoder's output for finetuning
finetuning_model = encoder_decoder.construct_finetuning_model(num_outputs=1, regression=True)
The default arguments to construct
are the same as the default parameters described in default.json
.
Presto expects the following values as input, and returns the following outputs:
reconstructed_x, reconstructed_dynamic_world = encoder_decoder(x, dynamic_world, latlons, mask, month)
globally_pooled_tokens = encoder(x, dynamic_world, latlons, mask, month, eval_task=True)
predictions = finetuning_model(x, dynamic_world, latlons, mask, month)
x
: torch.Tensor of shape[batch_size, num_timesteps, bands]
wherebands
is described byNORMED_BANDS
.dynamic_world
: torch.Tensor of shape[batch_size, num_timesteps]
. If no Dynamic World classes are available, this tensor should be filled with the valueDynamicWorld2020_2021.class_amount
(i.e.9
), in which case it is ignored.latlons
: torch.Tensor of shape[batch_size, 2]
describing the latitude and longitude of each input instance.mask
: An optional torch.Tensor of shape[batch_size, num_timesteps, bands]
.mask[i, j, k] == 1
meansx[i, j, k]
is considered masked. If the mask isNone
, no values inx
are ignored.month
: An int or torch.Tensor describing the first month of the instances being passed. If an int, all instances in the batch are assumed to have the same starting month.
The number of timesteps passed is optional, and can be any value between 1 and 24 (2 years of data).
3 of the input tensors (x
, dynamic_world
, mask
) can be generated using presto.construct_single_presto_input
.
An example of this is in the downstream task jupyter notebook.
For example, if I have access to some RGB imagery, it can be turned into Presto-compatible inputs:
import presto
x, mask, dynamic_world = presto.construct_single_presto_input(
s2=rgb_imagery, # of shape [num_timesteps, 3]
s2_bands=["B2", "B3", "B4"]
)
Here, x
will contain only the (normalized) RGB values in the correct indices, and mask
will communicate to Presto to ignore every other input.
Similarly, dynamic_world
will contain only DynamicWorld2020_2021.class_amount
, so Presto will ignore it.
Training
The train.py
script contains code for self-supervised training. This can be run locally on a small subset of the data with:
# Barebones local run
python train.py \
--train_url "data/dw_144_mini_shard_44.tar" \
--val_url "data/dw_144_mini_shard_44.tar" \
--val_per_n_steps 1 \
--cropharvest_per_n_validations 0 \
--skip_finetuning
Evaluation
A trained model (or a randomly initialized model) can be run against the evaluation tasks using eval.py
. If an --id
and --epoch
is passed to the script, a model will be loaded from models/{id}/{epoch}.pt
- otherwise, a randomly initialized model will be evaluated.
Mosaiks
The MOSAIKS1D benchmark can be run against evaluation tasks using the mosaiks.py
script.
Generating new data
Diagram: url
Prerequisites:
- Account with Google Cloud access and Earth Engine access
curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-387.0.0-linux-x86_64.tar.gz tar -xf google-cloud-cli-387.0.0-linux-x86_64.tar.gz exec bash ./google-cloud-sdk/install.sh gcloud init earthengine authenticate
- Create buckets for processing
gcloud storage mb -l us-central1 $(python -c "from dataops import EE_BUCKET; print(EE_BUCKET)") gcloud storage mb -l us-central1 $(python -c "from dataops import NPY_BUCKET; print(NPY_BUCKET)") gcloud storage mb -l us-central1 $(python -c "from dataops import TAR_BUCKET; print(TAR_BUCKET)")
- Deploy tif-to-np Cloud Function
sh scripts/deploy_tif_to_np.sh
Once prerequisites are satisfied, data can be generated by running:
python scripts/generate_data.py
⚠️ This script assumes you have a Google Cloud project named presto
- you need to change this in the script if the name of the project is different. ⚠️
The script will generate:
data/tile_processing.txt
A summary of tiles being processeddata/tile_stats.yaml
Stats for all tiles available for training
Behind the scenes for each tile the script will:
- Begin Earth Engine exports to get data for tile from specific data pipeline. Examples:
gs://<EE_BUCKET>/<SHARD_1>/<PIPELINE_1>/*.tif
gs://<EE_BUCKET>/<SHARD_1>/<PIPELINE_2>/*.tif
gs://<EE_BUCKET>/<SHARD_1>/<PIPELINE_3>/*.tif
- Process each tif file to npy. Examples:
gs://<NPY_BUCKET>/<SHARD_1>/<PIPELINE_1>/*.npy
gs://<NPY_BUCKET>/<SHARD_1>/<PIPELINE_2>/*.npy
gs://<NPY_BUCKET>/<SHARD_1>/<PIPELINE_3>/*.npy
- Combine all npy files into a tar file accessible through webdataset. Example:
gs://<TAR_BUCKET>/<DATASET_NAME>/<SHARD_1>.tar
Accessing new data
In [0]:
import webdataset as wds
import pandas as pd
uri = "gs://lem-assets2/S1_S2_ERA5_SRTM_2020_2021_DynamicWorld2020_2021_tars/dw_144_shard_0.tar"
dataset = wds.WebDataset(f"pipe:gcloud storage cat {uri}").decode()
for sample in dataset:
break
In [1]: list(sample.keys())
Out[1]: ['__key__', '__url__', 'dynamicworld2020_2021.npy', 's1_s2_era5_srtm_2020_2021.npy', 'worldcover2020.npy']
In [2]: sample["s1_s2_era5_srtm_2020_2021.npy"].shape
Out[2]: (625, 24, 18)
In [3]: sample["latlon.npy"].shape
Out[3]: (625, 2)
In [4]: sample["worldcover2020.npy"].shape
Out[4]: (625, 1)
In [5]: sample["dynamicworld2020_2021.npy"].shape
Out[5]: (625, 24)
In [6]: pd.Series(sample["dynamicworld2020_2021.npy"].flatten()).value_counts()
Out[6]:
0 14978
7 22
dtype: int64
Reference
If you find this code useful, please cite the following paper:
@misc{tseng2023lightweight,
title={Lightweight, Pre-trained Transformers for Remote Sensing Timeseries},
author={Gabriel Tseng and Ruben Cartuyvels and Ivan Zvonkov and Mirali Purohit and David Rolnick and Hannah Kerner},
year={2023},
eprint={2304.14065},
archivePrefix={arXiv},
primaryClass={cs.CV}
}