3D-UCaps: 3D Capsules Unet for Volumetric Image Segmentation (MICCAI 2021)

VinAIResearch, updated 🕥 2022-01-21 05:54:13

3D-UCaps: 3D Capsules Unet for Volumetric Image Segmentation

3D-UCaps is a voxel-based Capsule network for medical image segmentation. Our architecture is based on the symmetry U-net with two parts: the encoder forms by Capsule layers, whereas the decoder contains traditional convolutional layers. 3D-UCaps, therefore inherits the merits from both Capsule networks to preserve the part-to-whole relationship and CNNs to learn translational invariant representation. We conducted experiments on various datasets (including iSeg-2017, LUNA16, Hippocampus, and Cardiac) to demonstrate the superior performance of 3D-UCaps, where our method outperforms the baseline method SegCaps while being more robust against rotational transformation when compared to 3D-Unet.

alt text

Details of the UCaps model architecture and experimental results can be found in our following paper: @inproceedings{nguyen20213d, title={3D-UCaps: 3D Capsules Unet for Volumetric Image Segmentation}, author={Nguyen, Tan and Hua, Binh-Son and Le, Ngan}, booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, pages={548--558}, year={2021}, organization={Springer} }

Please CITE our paper when UCaps is used to help produce published results or incorporated into other software

Usage

Installation

We provide instructions on how to install dependencies via conda. First, clone the repository locally: git clone https://github.com/VinAIResearch/3D-UCaps.git

Then, install dependencies depends on your cuda version. We provide two versions for CUDA 10 and CUDA 11 conda env create -f environment_cuda11.yml or conda env create -f environment_cuda10.yml

Data preparation

Download and extract these datasets: * iSeg-2017 challenge (infant brain MRI segmentation): https://iseg2017.web.unc.edu/download/ * Lung Nodule Analysis 2016 (LUNA16): https://luna16.grand-challenge.org/Download/ * Cardiac and Hippocampus dataset from Medical Segmentation Decathlon: http://medicaldecathlon.com/

We expect the directory structure to be the following: ``` path/to/iseg/ domainA/ domainA_val/

path/to/cardiac/ imagesTr labelsTr

path/to/hippocampus/ imagesTr labelsTr

path/to/luna/ imgs segs ```

Note: there are some files in LUNA16 dataset can lead to an error when training so we have removed it: 1.3.6.1.4.1.14519.5.2.1.6279.6001.771741891125176943862272696845.mhd 1.3.6.1.4.1.14519.5.2.1.6279.6001.927394449308471452920270961822.mhd

Training

Arguments for training can be divided into 3 groups:

  1. Trainer args to initialize Trainer class from Pytorch Lightning.

  2. Important arguments: gpus, accelerator, check_val_every_n_epoch, max_epochs.

  3. Fix arguments in train.py: benchmark, logger, callbacks, num_sanity_val_steps, terminate_on_nan
  4. Model args depend on which model you use (UCaps, SegCaps or U-net) and defined in add_model_specific_args method of that module.

  5. Important arguments: in_channels, out_channels, val_frequency, val_patch_size, sw_batch_size, overlap. The last three args are use in sliding window inference method from MONAI library.

  6. Args specific for training: root_dir, log_dir, dataset, fold, cache_rate, cache_dir, model_name, train_patch_size, num_workers, batch_size, num_samples.

  7. cache_rate and cache_dir define whether you want to use CacheDataset or PersistentDataset when loading data.

  8. num_samples is a arg in RandCropByPosNegLabel method, the effective batch size is batch_size x num_samples.

The full list of arguments can be shown through the command: python train.py -h

We provide bash script with our config to train UCaps model on all datasets and can be run as follow: bash scripts/train_ucaps_iseg.sh

Validation

Arguments for validation can be divided into 3 groups:

  1. Trainer args to initialize Trainer class. The only argument we need to use here is gpus.
  2. Args for sliding window inference method
  3. Args specific for validation root_dir, output_dir, save_image, model_name, dataset, fold, checkpoint_path

The full list of arguments can be shown through the command: python evaluate.py -h

We provide bash script with our config to validate trained UCaps models on all datasets, you just need to download our models in Model Zoo and put them in logs folder. After that, you can run the evaluation script for targeted dataset as follow: bash scripts/evaluate_ucaps_iseg.sh

Rotation experiment

Same with validation but add two more arguments rotate_angle (in degree) and axis (z/y/x or all) to create test rotated subject.

The full list of arguments can be shown through the command: python evaluate_iseg.py -h

We provide bash script with our config to compare between trained UCaps (download) and U-net (download) on subject 9th of iSeg-2017 dataset, the first arugment is rotate_angle and the second argument is axis: bash scripts/evaluate_rotation.sh 0 z

Rotation experiment on SkipDenseSeg model

  1. Cloning SkipDenseSeg project
  2. Replacing their val.py with our val.py
  3. Running val.py with args, for example:

python val.py --gpu 1 --sw_batch_size 32 --overlap 0.75 --output_dir=/home/ubuntu/

Model Zoo

About the code This repository has been refactored to use Pytorch Lightning framework and MONAI library for data preprocessing, data loading, inferencing to ensure the reproducibility and extendability of our work as well as improve efficiency when training. Hence, the results here have been improved a little bit when compared to their counterparts in the paper.

Dice Coefficient on subject 9th of iSeg-2017 dataset:

| Model | CSF | GM | WM | Average | Pretrained model | |-------|:---:|:---:|:---:|:-----:|------------------| | 3D-UCaps | 95.01 | 91.51 | 90.59 | 92.37 | download | | Paper | 94.21 | 91.34 | 90.95 | 92.17 | |

Dice Coefficient of 3D-UCaps on hippocampus dataset in 4-folds cross-validation:

| | Anterior | Posterior | Average | Pretrained model | |-------|:--------:|:---------:|:-------:|------------------| | Fold 0 | 86.33 | 83.79 | 85.06 | download | | Fold 1 | 86.57 | 84.51 | 85.54 | download | | Fold 2 | 84.29 | 83.23 | 83.76 | download | | Fold 3 | 85.71 | 83.53 | 84.62 | download | | Mean | 85.73 | 83.77 | 84.75 | | | Paper | 85.07 | 82.49 | 83.78 | |

Result of 3D-UCaps on the cardiac dataset in 4-folds cross-validation:

| | Recall | Precision | Dice | Pretrained model | |-------|:------:|:---------:|:----:|------------------| | Fold 0 | 91.38 | 89.66 | 90.51 | download | | Fold 1 | 89.68 | 95.10 | 91.76 | download | | Fold 2 | 93.12 | 93.00 | 92.53 | download | | Fold 3 | 91.55 | 94.84 | 90.89 | download | | Mean | 91.43 | 93.15 | 91.42 | | | Paper | 92.69 | 89.45 | 90.82 | |

Acknowledgement

The implementation of dynamic routing algorithm and capsule layers were based on the Tensorflow build of CapsNet by its authors in this link

Issues

RuntimeError: Given groups=1, weight of size [16, 2, 5, 5, 5], expected input[1, 1, 32, 32, 32] to have 2 channels, but got 1 channels instead

opened on 2022-10-12 07:47:42 by 217dalao

File "/home/mtc206/anaconda3/envs/lcj/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 219, in validation_step return self.model.validation_step(args, kwargs) File "/home/mtc206/0qsw/SSL4MIS-master/code/3D-UCaps/3D-UCaps-main/module/segcaps.py", line 203, in validation_step overlap=self.overlap, File "/home/mtc206/anaconda3/envs/lcj/lib/python3.6/site-packages/monai/inferers/utils.py", line 130, in sliding_window_inference seg_prob = predictor(window_data, args, kwargs).to(device) # batched patch segmentation File "/home/mtc206/0qsw/SSL4MIS-master/code/3D-UCaps/3D-UCaps-main/module/segcaps.py", line 124, in forward x = self.feature_extractor(x) File "/home/mtc206/anaconda3/envs/lcj/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, kwargs) File "/home/mtc206/anaconda3/envs/lcj/lib/python3.6/site-packages/torch/nn/modules/container.py", line 117, in forward input = module(input) File "/home/mtc206/anaconda3/envs/lcj/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(input, kwargs) File "/home/mtc206/anaconda3/envs/lcj/lib/python3.6/site-packages/torch/nn/modules/container.py", line 117, in forward input = module(input) File "/home/mtc206/anaconda3/envs/lcj/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(input, **kwargs) File "/home/mtc206/anaconda3/envs/lcj/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 573, in forward self.padding, self.dilation, self.groups) RuntimeError: Given groups=1, weight of size [16, 2, 5, 5, 5], expected input[1, 1, 32, 32, 32] to have 2 channels, but got 1 channels instead

when running the code,how can i input the data with 2 channels?

Complexity of the model

opened on 2022-08-28 12:10:06 by cugwu

Hi, just for reproducibility purposes. In the paper's final model, what are the number of parameters, FLOPs, and inference time? I want to check if I'm recreating a similar situation with your code. Thanks in advance.

RuntimeError: CUDA error: device-side assert triggered

opened on 2022-06-24 06:36:04 by jsong0041

First of all, congratulations for your recent paper '3D-UCaps: 3D Capsules Unet for Volumetric Image Segmentation' accepted by MICCAI'21, it's really a great job, and thank you very much for your open source code in github.

As for codes, I used a new dataset as inputs with .tif format, but following errors are thrown:

Validation sanity check: 0%| | 0/1 [00:00<?, ?it/s]C :/w/b/windows/pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:312: block: [1926,0,0], thread: [32,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed. ... === Transform input info -- AsDiscrete === Traceback (most recent call last): File "C:\Python36\lib\site-packages\monai\transforms\transform.py", line 84, in apply_transform return apply_transform(transform, data, unpack_items) File "C:\Python36\lib\site-packages\monai\transforms\transform.py", line 52, in _apply_transform return transform(parameters) File "C:\Python36\lib\site-packages\monai\transforms\post\array.py", line 174, in __call__ img = one_hot(img, num_classes=_nclasses, dim=0) File "C:\Python36\lib\site-packages\monai\networks\utils.py", line 86, in one_hot labels = o.scatter(dim=dim, index=labels.long(), value=1) RuntimeError: CUDA error: device-side assert triggered

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "train.py", line 129, in trainer.fit(net, datamodule=data_module) File "C:\Python36\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 741, in fit self.fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path File "C:\Python36\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 685, in _call_and_handle_interrupt return trainer_fn(args, kwargs) File "C:\Python36\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 777, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "C:\Python36\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1199, in _run self._dispatch() File "C:\Python36\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1279, in _dispatch self.training_type_plugin.start_training(self) File "C:\Python36\lib\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py", line 202, in start_training self._results = trainer.run_stage() File "C:\Python36\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1289, in run_stage return self._run_train() File "C:\Python36\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1311, in _run_train self._run_sanity_check(self.lightning_module) File "C:\Python36\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1375, in _run_sanity_check self._evaluation_loop.run() File "C:\Python36\lib\site-packages\pytorch_lightning\loops\base.py", line 145, in run self.advance(args, kwargs) File "C:\Python36\lib\site-packages\pytorch_lightning\loops\dataloader\evaluation_loop.py", line 110, in advance dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders) File "C:\Python36\lib\site-packages\pytorch_lightning\loops\base.py", line 145, in run self.advance(*args, kwargs) File "C:\Python36\lib\site-packages\pytorch_lightning\loops\epoch\evaluation_epoch_loop.py", line 122, in advance output = self._evaluation_step(batch, batch_idx, dataloader_idx) File "C:\Python36\lib\site-packages\pytorch_lightning\loops\epoch\evaluation_epoch_loop.py", line 217, in _evaluation_step output = self.trainer.accelerator.validation_step(step_kwargs) File "C:\Python36\lib\site-packages\pytorch_lightning\accelerators\accelerator.py", line 239, in validation_step return self.training_type_plugin.validation_step(step_kwargs.values()) File "C:\Python36\lib\site-packages\pytorch_lightning\plugins\training_type\dp.py", line 104, in validation_step return self.model(args, kwargs) File "C:\Python36\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, kwargs) File "C:\Python36\lib\site-packages\torch\nn\parallel\data_parallel.py", line 159, in forward return self.module(inputs[0], kwargs[0]) File "C:\Python36\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(input, kwargs) File "C:\Python36\lib\site-packages\pytorch_lightning\overrides\data_parallel.py", line 63, in forward output = super().forward(*inputs, kwargs) File "C:\Python36\lib\site-packages\pytorch_lightning\overrides\base.py", line 92, in forward output = self.module.validation_step(inputs, *kwargs) File "E:#project_b\3d-ucaps-master\module\ucaps.py", line 265, in validation_step labels = [self.post_label(label) for label in decollate_batch(labels)] File "E:#project_b\3d-ucaps-master\module\ucaps.py", line 265, in labels = [self.post_label(label) for label in decollate_batch(labels)] File "C:\Python36\lib\site-packages\monai\transforms\compose.py", line 159, in call input = apply_transform(transform, input, self.map_items, self.unpack_items) File "C:\Python36\lib\site-packages\monai\transforms\transform.py", line 107, in apply_transform _log_stats(data=data) File "C:\Python36\lib\site-packages\monai\transforms\transform.py", line 98, in _log_stats datastats(img=data, data_shape=True, value_range=True, prefix=prefix) # type: ignore File "C:\Python36\lib\site-packages\monai\transforms\utility\array.py", line 524, in call lines.append(f"Value range: ({torch.min(img)}, {torch.max(img)})") RuntimeError: CUDA error: device-side assert triggered


Any help is much appreciated.

RuntimeError: CUDA error: an illegal memory access was encountered

opened on 2021-12-08 07:23:06 by wsonia

I deploy the same environment and use the public cardiac data to run the code. But got this problem while training: Validation sanity check: 0%| | 0/1 [00:00<?, ?it/s]Traceback (most recent call last): File "train.py", line 137, in trainer.fit(net, datamodule=data_module) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit self._call_and_handle_interrupt( File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt return trainer_fn(args, kwargs) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run self._dispatch() File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _dispatch self.training_type_plugin.start_training(self) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training self._results = trainer.run_stage() File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1284, in run_stage return self._run_train() File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1306, in _run_train self._run_sanity_check(self.lightning_module) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1370, in _run_sanity_check self._evaluation_loop.run() File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(args, kwargs) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 109, in advance dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(*args, kwargs) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance output = self._evaluation_step(batch, batch_idx, dataloader_idx) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step output = self.trainer.accelerator.validation_step(step_kwargs) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step return self.training_type_plugin.validation_step(step_kwargs.values()) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 444, in validation_step return self.model(args, kwargs) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, kwargs) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 619, in forward output = self.module(inputs[0], kwargs[0]) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(input, kwargs) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 92, in forward output = self.module.validation_step(*inputs, kwargs) File "/3D-UCaps-main/module/ucaps.py", line 265, in validation_step val_outputs = sliding_window_inference( File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/monai/inferers/utils.py", line 130, in sliding_window_inference seg_prob = predictor(window_data, args, kwargs).to(device) # batched patch segmentation File "/3D-UCaps-main/module/ucaps.py", line 171, in forward x = self.feature_extractor(x) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(input, kwargs) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/torch/nn/modules/container.py", line 117, in forward input = module(input) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, kwargs) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/torch/nn/modules/container.py", line 117, in forward input = module(input) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(input, *kwargs) File "/anaconda3/envs/UCaps/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 572, in forward return F.conv3d(input, self.weight, self.bias, self.stride, RuntimeError: CUDA error: an illegal memory access was encountered terminate called after throwing an instance of 'std::runtime_error' what(): NCCL error in: /opt/conda/conda-bld/pytorch_1607370172916/work/torch/lib/c10d/../c10d/NCCLUtils.hpp:136, unhandled cuda error, NCCL version 2.7.8 ./train_ucaps_cardiac.sh: line 25: 171684 Aborted (core dumped) python train.py --log_dir ./3D-UCaps-main/logs_heart --gpus 1 --accelerator ddp --check_val_every_n_epoch 5 --max_epochs 100 --dataset task02_heart --model_name ucaps --root_dir ./3D-UCaps-main/Task02_Heart --fold 0 --cache_rate 1.0 --train_patch_size 128 128 128 --num_workers 64 --batch_size 1 --share_weight 0 --num_samples 1 --in_channels 1 --out_channels 2 --val_patch_size(UCaps

Getting stuck at validation step

opened on 2021-11-18 18:27:07 by cndu234

Epoch 0: 0%| | 0/10 [00:00<?, ?it/s]Trying to infer the batch_size from an ambiguous collection. The batch size we found is 1. To avoid any miscalculations, use self.log(..., batch_size=batch_size). The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details.

Epoch 0: 70%|█████████ | 7/10 [00:14<00:06, 2.12s/it, loss=0.724, v_num=11] Validating: 0it [00:00, ?it/s] Validating: 0%| | 0/3 [00:00<?, ?it/s]

I have been trying for larger dataset unlike the error here. But im always getting stuck at validation stage. I tried with hippocampus dataset, the results are fine. But with my custom data, Im facing this problem. What could be the reason?

How to compute the metrics between testset predictions and true labels?

opened on 2021-11-18 07:53:55 by cndu234

I am using my custom data...After training, how can I compute the metrics between test set predictions and true labels? I am using Hippocampus data loader provided by you . But i have imagesTr, labelsTr for training , and imagesTs, labelsTs for testing. I want to compute metrics for the test set

VinAI Research
GitHub Repository

medical-image-segmentation capsule-network luna16 iseg-challenge hippocampus cardiac-segmentation capsnet segcaps unet