Awesome
QMedSAM
Quantization of MedSAM / LiteMedSAM
Environment
For inference only:
pip install opencv-python-headless openvino tqdm
Further install dependencies for training: install PyTorch, then
pip install brevitas timm monai pyyaml scipy
Training
We use ddp_dist_tune.py
to achieve the quantization-aware training. DDP was implemented in the style of torch.multiprocessing.spawn
, so this script can be run as a normal python script.
The control panel of the training script is at the very beginning of the file:
# control panel
CONFIG_YAML = os.path.join('config', 'finaltune_quantize_iemd-2.yaml')
PORT = 10311
GPU_RANK = [2, 7, 8, 9]
- Write a configure file, in our pipeline, the three stages training protocols corresponds to
tune_quantize_ie-2.yaml
,tune_quantize_iemd-2.yaml
, andfinaltune_quantize_iemd-2.yaml
. - Specify the DDP port.
- Specify the local GPU ids you want to use.
Before training, one should prepare the dataset. We propose sdk.ModalityDataset
and sdk.GatherModalityDataset
for general purposes. Organization of the dataset should be:
Root Dir
├── Modality Name (such as CT, MR, ...)
├── ...
├── Modality Dataset Name (such as AbdomenCT1K in CT)
├── ...
├── npz files
├── ...
Basically this is the same as the original dataset. We can pass the root directory to sdk.GatherModalityDataset
and specify some of its parameters. The only parameter that worths a word is mod_sample
, which controls the number of samples from each modality. Passing an integer or None to mod_sample
simply spreads the strategy to all the modalities, while you can also pass a list or tuple to specify values for each modality. We provide sdk.GatherModalityDataset
with some helpful functions, you can locate modalities in the list using mod_names()
, fetch the real-time samples from all the modalities using mod_samples()
, and modify mod_sample
using modify_mod_sample()
.
Export
Specify the checkpoint to load and the directory to save in export.py
. By default we will export the quantized model to docker_python
folder, where we will build our docker.
Inference
The general inference script is at /docker_python/validate.py
.
Docker
We provide two docker, one inferencing with cpp and the other running in python. You will see the python version runs faster. Thanks team seno for the reference in cpp code.
cd docker_python
docker build -f Dockerfile -t uestcsd .
docker save uestcsd | gzip -c > uestcsd.tar.gz