diff --git a/CITATION.cff b/CITATION.cff
index d45fdca12ec721cb23920a28bbacabfd7f3df242..8b640c88cb75c1fc378272a324d27b8dc7c0536a 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -33,7 +33,7 @@ preferred-citation:
   month: 10
   # start: 1 # First page number
   # end: 10 # Last page number
-  title: "A distributed Gibbs Sampler with Hypergraph Structure for High-Dimensional Inverse Problems"
+  title: "A Distributed Split-Gibbs Sampler with Hypergraph Structure for High-Dimensional Inverse Problems"
   # issue: 1
   # volume: 1
   year: 2022
diff --git a/README.md b/README.md
index 874912b22285a4357a36f7713ec14f3675bf57b3..243685b0210fe64a159dccfe88508fc91f3c47a7 100644
--- a/README.md
+++ b/README.md
@@ -37,7 +37,7 @@ ______________________________________________________________________
 
 Python codes associated with the method described in the following paper.
 
-> \[1\] P.-A. Thouvenin, A. Repetti, P. Chainais - **A distributed Gibbs Sampler with Hypergraph Structure for High-Dimensional Inverse Problems**, [arxiv preprint 2210.02341](http://arxiv.org/abs/2210.02341), October 2022, under review.
+> P.-A. Thouvenin, A. Repetti, P. Chainais - **A Distributed Block-Split Gibbs Sampler with Hypergraph Structure for High-Dimensional Inverse Problems**, [arxiv preprint 2210.02341](http://arxiv.org/abs/2210.02341), October 2023, to appear in JCGS.
 
 **Authors**: P.-A. Thouvenin, A. Repetti, P. Chainais
 
@@ -53,34 +53,34 @@ ______________________________________________________________________
 
 ```bash
 # Cloning the repo. or unzip the dsgs-main.zip code archive
-# git clone --recurse-submodules https://gitlab.com/pthouvenin/...git
+# git clone --recurse-submodules https://gitlab.cristal.univ-lille.fr/pthouven/dsgs.git
 unzip dsgs-main.zip
 cd dsgs-main
 
 # Create a conda environment using one of the lock files provided in the archive
-# (use jcgs_review_environment_osx.lock.yml for MAC OS)
-mamba env create --name jcgs-review --file jcgs_review_environment_linux.lock.yml
-# mamba env create --name jcgs-review --file jcgs_review_environment.yml
+# (use dsgs_environment_osx.lock.yml for MAC OS)
+mamba env create --name dsgs --file dsgs_environment_linux.lock.yml
+# mamba env create --name dsgs --file dsgs_environment.yml
 
 # Activate the environment
-mamba activate jcgs-review
+mamba activate dsgs
 
 # Install the library in editable mode
 mamba develop src/
 
 # Deleting the environment (if needed)
-# mamba env remove --name jcgs-review
+# mamba env remove --name dsgs
 
 # Generating lock file from existing environment (if needed)
-# mamba env export --name jcgs-review --file jcgs_review_environment_linux.lock.yml
+# mamba env export --name dsgs --file dsgs_environment_linux.lock.yml
 # or
-# mamba list --explicit --md5 > explicit_jcgs_env_linux-64.txt
-# mamba create --name jcgs-test -c conda-forge --file explicit_jcgs_env_linux-64.txt
+# mamba list --explicit --md5 > explicit_dsgs_env_linux-64.txt
+# mamba create --name jcgs-test -c conda-forge --file explicit_dsgs_env_linux-64.txt
 # pip install docstr-coverage genbadge wily sphinxcontrib-apa sphinx_copybutton
 
 # Manual install (if absolutely needed)
-# mamba create --name jcgs-review numpy numba mpi4py "h5py>=2.9=mpi*" scipy scikit-image matplotlib imageio tqdm jupyterlab pytest black flake8 isort coverage pre-commit sphinx sphinx_rtd_theme sphinxcontrib-bibtex sphinx-autoapi sphinxcontrib furo conda-lock conda-build
-# mamba activate jcgs-review
+# mamba create --name dsgs numpy numba mpi4py "h5py>=2.9=mpi*" scipy scikit-image matplotlib imageio tqdm jupyterlab pytest black flake8 isort coverage pre-commit sphinx sphinx_rtd_theme sphinxcontrib-bibtex sphinx-autoapi sphinxcontrib furo conda-lock conda-build
+# mamba activate dsgs
 # pip install sphinxcontrib-apa sphinx_copybutton docstr-coverage genbadge wily
 # mamba develop src
 ```
@@ -94,7 +94,7 @@ export HDF5_USE_FILE_LOCKING='FALSE'
 - To test the installation went well, you can run the unit-tests provided in the package using the command below.
 
 ```bash
-mamba activate jcgs-review
+mamba activate dsgs
 pytest --collect-only
 export NUMBA_DISABLE_JIT=1  # need to disable jit compilation to check test coverage
 coverage run -m pytest      # run all the unit-tests (see Documentation section for more details)
@@ -139,7 +139,7 @@ tmux kill-session -t session_name
 ```bash
 # from a terminal at the root of the archive
 
-mamba activate jcgs-review
+mamba activate dsgs
 cd examples/jcgs
 
 # generate all the synthetic datasets used in the experiments (to be run only once)
@@ -161,7 +161,7 @@ mamba deactivate
 - The content of an [`.h5`](https://docs.h5py.org/en/stable/mpi.html?highlight=h5dump#using-parallel-hdf5-from-h5py) file can be quickly checked from the terminal (see the [`h5py`](https://docs.h5py.org/en/stable/quick.html) documentation for further details). Some examples are provided below.
 
 ```bash
-mamba activate jcgs-review
+mamba activate dsgs
 
 # replace <filename> by the name of your file in the instructions below
 h5dump --header <filename>.h5 # displays the name and size of all variables contained in the file
@@ -181,7 +181,7 @@ ______________________________________________________________________
 - The documentation can be generated in `.html` format using the following commands issued from a terminal.
 
 ```bash
-mamba activate jcgs-review
+mamba activate dsgs
 cd build docs/build/html
 make html
 ```
@@ -191,7 +191,7 @@ make html
 To test the code/docstring coverage, run the following commands from a terminal.
 
 ```bash
-mamba activate jcgs-review
+mamba activate dsgs
 pytest --collect-only
 export NUMBA_DISABLE_JIT=1  # need to disable jit compilation to check test coverage
 coverage run -m pytest  # check all tests
@@ -205,20 +205,13 @@ docstr-coverage .  # check docstring coverage and generate the associated covera
 To launch a single test, run a command of the form
 
 ```bash
-mamba activate jcgs-review
+mamba activate dsgs
 python -m pytest tests/models/test_crop.py
 pytest --markers  # check full list of markers availables
 pytest -m "not mpi" --ignore-glob=**/archive_unittest/* # run all tests not marked as mpi + ignore files in any directory "archive_unittest"
 mpiexec -n 2 python -m mpi4py -m pytest -m mpi  # run all tests marked mpi with 2 cores
 ```
 
-<!-- To measure code quality with `wily`
-
-```bash
-wily build src/ -n 100
-wily report dsgs -f HTML -o docs/build/wily_report.html
-``` -->
-
 ______________________________________________________________________
 
 ## License
diff --git a/docs/source/biblio.bib b/docs/source/biblio.bib
index 34fd23fc4b374fc8aad421fd7cda56e9fc74ce91..28e03bf7ffbae54b9d4baf9d055a69b9b25741d9 100644
--- a/docs/source/biblio.bib
+++ b/docs/source/biblio.bib
@@ -126,13 +126,6 @@
     eprint  = {https://doi.org/10.1137/19M1283719},
 }
 
-@PhdThesis{Prusa2012,
-    author = {Zden\v{e}k Pru\v{s}a},
-    title  = {Segmentwise discrete wavelet transform},
-    school = {Brno university of technology},
-    year   = {2012},
-}
-
 @inproceedings{Salim2020,
     author = {Salim, Adil and Richt\`{a}rik, Peter},
     booktitle = NIPS,
@@ -142,10 +135,10 @@
     year = {2020}
 }
 
-@misc{Thouvenin2022submitted,
-    title={A distributed Gibbs Sampler with Hypergraph Structure for High-Dimensional Inverse Problems},
+@misc{Thouvenin2023,
+    title={A distributed Split-Gibbs Sampler with Hypergraph Structure for High-Dimensional Inverse Problems},
     author={Pierre-Antoine Thouvenin and Audrey Repetti and Pierre Chainais},
-    year={2022},
+    year={2023},
     eprint={2210.02341},
     archivePrefix={arXiv},
     primaryClass={stat.ME}
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 900fc51e1b34cecb5239b532545b26af5f2a7468..dfb7aaad45139533eb93d32d2247b71a1f915d98 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -23,7 +23,7 @@ copyright = "2023, P.-A. Thouvenin, A. Repetti and P. Chainais"
 author = "P.-A. Thouvenin, A. Repetti and P. Chainais"
 
 # The full version, including alpha/beta/rc tags
-release = "0.1.0"
+release = "1.0"
 
 
 # -- General configuration ---------------------------------------------------
diff --git a/docs/source/index.rst b/docs/source/index.rst
index b3538ec374e3bfc2a124267a81c5490591aae298..276a65c39800065bab2ce1b83b5878d40316c17e 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -8,8 +8,7 @@ The library currently contains codes to reproduce the experiments reported in :c
 
 .. warning::
 
-   This project is under active development, and the API my evolve significantly until version ``1.0``.
-..    - A complete code coverage report is available `here <../coverage_html_report/index.html#http://>`_.
+   The code provided in this archive.
 
 
 .. toctree::
@@ -28,18 +27,9 @@ The library currently contains codes to reproduce the experiments reported in :c
    autoapi/index
 ..    Code coverage report<../coverage_html_report/index.html#http://>
 ..    License<../../../LICENSE#http://>
-   Gitlab repository<https://gitlab.cristal.univ-lille.fr/pthouven/dspa>
+   Gitlab repository<https://gitlab.cristal.univ-lille.fr/pthouven/dsgs>
 ..    Code quality<../wily_report.html#http://>
 
-..
-   .. .. autosummary::
-   ..    :toctree: _autosummary
-   ..    :template: custom-module-template.rst
-   ..    :recursive:
-
-   ..    dsgs
-
-
 Indices and tables
 ==================
 
diff --git a/docs/source/setup.rst b/docs/source/setup.rst
index b3865c8506a5e8afb97fae05aa5559313e39888d..4f4ef7f482b758aa578aea61771e5fa533cf273a 100644
--- a/docs/source/setup.rst
+++ b/docs/source/setup.rst
@@ -21,17 +21,17 @@ To install the library, issue the following commands in a terminal.
     cd dsgs-main
 
     # Create anaconda environment (from one of the provided lock files)
-    mamba env create --name jcgs-review --file jcgs_review_environment_linux.lock.yml  # use jcgs_review_environment_osx.lock.yml for osx
-    # mamba env create --name jcgs-review --file jcgs_review_environment.yml
+    mamba env create --name dsgs --file dsgs_environment_linux.lock.yml  # use dsgs_environment_osx.lock.yml for osx
+    # mamba env create --name dsgs --file dsgs_environment.yml
 
     # Activate the environment
-    mamba activate jcgs-review
+    mamba activate dsgs
 
     # Install the library in editable mode
     mamba develop src/
 
     # Deleting the environment (if needed)
-    # mamba env remove --name jcgs-review
+    # mamba env remove --name dsgs
 
 To avoid `file lock issue in h5py <https://github.com/h5py/h5py/issues/1101>`_, add the following line to the ``~/.zshrc`` file (or ``~/.bashrc``)
 
@@ -43,7 +43,7 @@ To test the installation went well, you can run the unit-tests provided in the p
 
 .. code-block:: bash
 
-    conda activate jcgs-review
+    conda activate dsgs
     pytest --collect-only
     export NUMBA_DISABLE_JIT=1  # need to disable jit compilation to check test coverage
     coverage run -m pytest      # run all the unit-tests (see Documentation section for more details)
@@ -92,7 +92,7 @@ The folder ``./examples/jcgs/configs`` contains ``.json`` files summarizing the
 
     # from a terminal at the root of the archive
 
-    mamba activate jcgs-review
+    mamba activate dsgs
     cd examples/jcgs
 
     # generate all the synthetic datasets used in the experiments (to be run only once)
@@ -117,7 +117,7 @@ The content of an ``.h5``  file can be quickly checked from the terminal (`refer
 
     # replace <filename> by the name of your file
     h5dump --header <filename>.h5 # displays the name and size of all variables contained in the file
-    conda activate jcgs-review
+    conda activate dsgs
 
     # replace <filename> by the name of your file in the instructions below
     h5dump --header <filename>.h5 # displays the name and size of all variables contained in the file
@@ -143,7 +143,7 @@ The documentation can be generated in ``.html`` format using the following comma
 
 .. code-block:: bash
 
-    conda activate jcgs-review
+    conda activate dsgs
     cd build docs/build/html
     make html # compile documentation in html, latex or linkcheck
 
@@ -155,7 +155,7 @@ To test the code/docstring coverage, run the following commands from a terminal.
 
 .. code-block:: bash
 
-    conda activate jcgs-review
+    conda activate dsgs
     export NUMBA_DISABLE_JIT=1 # need to disable jit compilation to check test coverage
     coverage run -m unittest # check all tests
     coverage report # generate a coverage report in the terminal
@@ -168,7 +168,7 @@ To launch a single test, run a command of the form
 
 .. code-block:: bash
 
-    conda activate jcgs-review
+    conda activate dsgs
     python -m pytest tests/operators/test_crop.py
     pytest --markers  # check full list of markers availables
     pytest -m "not mpi" --ignore-glob=**/archive_unittest/* # run all tests not marked as mpi + ignore files in any directory "archive_unittest"
diff --git a/jcgs_review_environment.yml b/dsgs_environment.yml
similarity index 100%
rename from jcgs_review_environment.yml
rename to dsgs_environment.yml
diff --git a/jcgs_review_environment_linux.lock.yml b/dsgs_environment_linux.lock.yml
similarity index 100%
rename from jcgs_review_environment_linux.lock.yml
rename to dsgs_environment_linux.lock.yml
diff --git a/jcgs_review_environment_osx.lock.yml b/dsgs_environment_osx.lock.yml
similarity index 100%
rename from jcgs_review_environment_osx.lock.yml
rename to dsgs_environment_osx.lock.yml
diff --git a/explicit_jcgs_env_linux-64.txt b/explicit_dsgs_env_linux-64.txt
similarity index 100%
rename from explicit_jcgs_env_linux-64.txt
rename to explicit_dsgs_env_linux-64.txt
diff --git a/explicit_jcgs_env_osx-64.txt b/explicit_dsgs_env_osx-64.txt
similarity index 100%
rename from explicit_jcgs_env_osx-64.txt
rename to explicit_dsgs_env_osx-64.txt
diff --git a/img/5.1.13.tiff b/img/5.1.13.tiff
deleted file mode 100644
index d09cef3ad54b5152c4667e3744cdc1949cb9391d..0000000000000000000000000000000000000000
Binary files a/img/5.1.13.tiff and /dev/null differ
diff --git a/img/5.3.01.tiff b/img/5.3.01.tiff
deleted file mode 100644
index 13a756d8b7566d29438c6d03910d0ca977156d59..0000000000000000000000000000000000000000
Binary files a/img/5.3.01.tiff and /dev/null differ
diff --git a/img/bank.png b/img/bank.png
deleted file mode 100644
index 71529dc024f13f94d5ce42a86eb4659ae50f486a..0000000000000000000000000000000000000000
Binary files a/img/bank.png and /dev/null differ
diff --git a/img/bank_color.png b/img/bank_color.png
deleted file mode 100644
index b390136fce633be696aaa28991bb40f1f01ed7a8..0000000000000000000000000000000000000000
Binary files a/img/bank_color.png and /dev/null differ
diff --git a/img/barb.png b/img/barb.png
deleted file mode 100755
index e9f29e5ccc8d9296066192006074591bcd76f357..0000000000000000000000000000000000000000
Binary files a/img/barb.png and /dev/null differ
diff --git a/img/boat.png b/img/boat.png
deleted file mode 100755
index 9098a93d998332bca33e3a20a6921770d62bec86..0000000000000000000000000000000000000000
Binary files a/img/boat.png and /dev/null differ
diff --git a/img/chessboard.png b/img/chessboard.png
deleted file mode 100755
index 0531a8dad9cda0fb1b69cdcff97e5e08915389ef..0000000000000000000000000000000000000000
Binary files a/img/chessboard.png and /dev/null differ
diff --git a/img/corral.png b/img/corral.png
deleted file mode 100755
index 5a789fbe07e6d193b56b9cffc0f207927e02c10d..0000000000000000000000000000000000000000
Binary files a/img/corral.png and /dev/null differ
diff --git a/img/cortex.png b/img/cortex.png
deleted file mode 100755
index 48bd58918ed660e476681d46e97ce51d7c18252e..0000000000000000000000000000000000000000
Binary files a/img/cortex.png and /dev/null differ
diff --git a/img/grating.png b/img/grating.png
deleted file mode 100755
index 2317973d33a85a6330b78635f4a43c00b56ae80e..0000000000000000000000000000000000000000
Binary files a/img/grating.png and /dev/null differ
diff --git a/img/hair.png b/img/hair.png
deleted file mode 100755
index c953f129381224f42af3e56f919389bd8683edc6..0000000000000000000000000000000000000000
Binary files a/img/hair.png and /dev/null differ
diff --git a/img/lena.png b/img/lena.png
deleted file mode 100755
index f14918282436fcd454f2587942fbe5e2564e468a..0000000000000000000000000000000000000000
Binary files a/img/lena.png and /dev/null differ
diff --git a/img/line.png b/img/line.png
deleted file mode 100755
index b48ddad36cbdf76ce6ec58d271791b67bad106c1..0000000000000000000000000000000000000000
Binary files a/img/line.png and /dev/null differ
diff --git a/img/line_horizontal.png b/img/line_horizontal.png
deleted file mode 100755
index 06c39f4c3ed1892231a80dd892c3bfa7e780be00..0000000000000000000000000000000000000000
Binary files a/img/line_horizontal.png and /dev/null differ
diff --git a/img/line_vertical.png b/img/line_vertical.png
deleted file mode 100755
index 7bf11fb97b4d1dfac2d69e620fd7ff65ba38cbd3..0000000000000000000000000000000000000000
Binary files a/img/line_vertical.png and /dev/null differ
diff --git a/img/mandrill.png b/img/mandrill.png
deleted file mode 100755
index 58b49fbc2e57cde1124fbfaae6be90a24ca2e7c1..0000000000000000000000000000000000000000
Binary files a/img/mandrill.png and /dev/null differ
diff --git a/img/mosque.png b/img/mosque.png
deleted file mode 100644
index 32b02fc039467d5d8b89c7e65502332b3850dc48..0000000000000000000000000000000000000000
Binary files a/img/mosque.png and /dev/null differ
diff --git a/img/mosque_color.png b/img/mosque_color.png
deleted file mode 100644
index 9048f0b009617c6d7e0a65834541cf4164261616..0000000000000000000000000000000000000000
Binary files a/img/mosque_color.png and /dev/null differ
diff --git a/img/parrot-mask.png b/img/parrot-mask.png
deleted file mode 100755
index 831ab3775bddc0ab8e387f1268a8276905e0425d..0000000000000000000000000000000000000000
Binary files a/img/parrot-mask.png and /dev/null differ
diff --git a/img/parrot.png b/img/parrot.png
deleted file mode 100755
index 4cfdf35ee9ef9f172f83c0f889fc5b778b6f7498..0000000000000000000000000000000000000000
Binary files a/img/parrot.png and /dev/null differ
diff --git a/img/parrotgray.png b/img/parrotgray.png
deleted file mode 100644
index 57551b3e4b374e7cc8378140623c4cccad2e0a6d..0000000000000000000000000000000000000000
Binary files a/img/parrotgray.png and /dev/null differ
diff --git a/img/periodic_bumps.png b/img/periodic_bumps.png
deleted file mode 100755
index fb4fe1b2facdb8dd9d82a1a094644898878eabaf..0000000000000000000000000000000000000000
Binary files a/img/periodic_bumps.png and /dev/null differ
diff --git a/img/rubik1.png b/img/rubik1.png
deleted file mode 100755
index bbe824072a2feecaea4ee742967e921c8a938190..0000000000000000000000000000000000000000
Binary files a/img/rubik1.png and /dev/null differ
diff --git a/img/rubik2.png b/img/rubik2.png
deleted file mode 100755
index d56bd3e11c12cd2ff683dcb68458a7d759b5020a..0000000000000000000000000000000000000000
Binary files a/img/rubik2.png and /dev/null differ
diff --git a/img/rubik3.png b/img/rubik3.png
deleted file mode 100755
index e2a140af7b9cb89af5b15babaacf9209833b8228..0000000000000000000000000000000000000000
Binary files a/img/rubik3.png and /dev/null differ
diff --git a/img/taxi1.png b/img/taxi1.png
deleted file mode 100755
index d3859086df432710c74fd0408e61faead55e936b..0000000000000000000000000000000000000000
Binary files a/img/taxi1.png and /dev/null differ
diff --git a/img/taxi2.png b/img/taxi2.png
deleted file mode 100755
index 618840842b18f53f79af146b670684d4ce422d2e..0000000000000000000000000000000000000000
Binary files a/img/taxi2.png and /dev/null differ
diff --git a/img/taxi3.png b/img/taxi3.png
deleted file mode 100755
index 583a76f197f5e9e537033e70b54bf7b4949ceaf6..0000000000000000000000000000000000000000
Binary files a/img/taxi3.png and /dev/null differ
diff --git a/img/vessels.png b/img/vessels.png
deleted file mode 100755
index 3998b37cd71df9869cba1aece77c19e130ab05b1..0000000000000000000000000000000000000000
Binary files a/img/vessels.png and /dev/null differ
diff --git a/src/dsgs/main_metrics_serial.py b/src/dsgs/main_metrics_serial.py
index 539d242b7e48c5c5044f78d2c34107e20e5b2ec7..29fb87747dc19968496ac0d038752637bca291d0 100644
--- a/src/dsgs/main_metrics_serial.py
+++ b/src/dsgs/main_metrics_serial.py
@@ -76,14 +76,8 @@ def main_metrics(
 
     f = h5py.File(filename + ".h5", "r+", driver="mpio", comm=MPI.COMM_WORLD)
     N = f["N"][()]
-    # M = f["M"][()]
-    # h = f["h"][()]
     f.close()
 
-    # overlap_size = np.array(h.shape, dtype="i")
-    # overlap_size -= 1
-    # data_size = N + overlap_size
-
     # slice to extract image tile from full image file
     tile_pixels = local_split_range_nd(grid_size, N, ranknd)
     tile_size = tile_pixels[:, 1] - tile_pixels[:, 0] + 1
diff --git a/src/dsgs/main_metrics_spmd.py b/src/dsgs/main_metrics_spmd.py
index bc2259ac0b35a517438591ef03f28b7b5e03fcf9..32a76ca469e52d1568b894d5d937e57ef7e0e437 100644
--- a/src/dsgs/main_metrics_spmd.py
+++ b/src/dsgs/main_metrics_spmd.py
@@ -76,14 +76,8 @@ def main_metrics(
 
     f = h5py.File(filename + ".h5", "r+", driver="mpio", comm=MPI.COMM_WORLD)
     N = f["N"][()]
-    # M = f["M"][()]
-    # h = f["h"][()]
     f.close()
 
-    # overlap_size = np.array(h.shape, dtype="i")
-    # overlap_size -= 1
-    # data_size = N + overlap_size
-
     # slice to extract image tile from full image file
     tile_pixels = local_split_range_nd(grid_size, N, ranknd)
     tile_size = tile_pixels[:, 1] - tile_pixels[:, 0] + 1
@@ -246,7 +240,6 @@ def main_metrics(
     # MMSE and MAP (need all processes)
     saver.save("", [N], [global_slice_tile], [None], mode="w", x_mmse=local_x_mmse)
 
-    # ! seems to create a segfault for a large number of processes: why?
     saver.save("", [N], [global_slice_tile], [None], mode="a", x_map=local_x_map)
 
     snr_mmse = 0.0
diff --git a/src/dsgs/main_serial_poisson_deconvolution.py b/src/dsgs/main_serial_poisson_deconvolution.py
index 9defb8028a135533b378d1377856d6f4da16cbae..c807ea51d448a6d0ebfa22aff0144f3a1098a18b 100644
--- a/src/dsgs/main_serial_poisson_deconvolution.py
+++ b/src/dsgs/main_serial_poisson_deconvolution.py
@@ -165,10 +165,6 @@ def main(
         f.close()
         logger.info("End: loading data and images")
 
-        # parameters of the Gamma prior on the regularization parameter
-        # a = 1e-3
-        # b = 1e-3
-
         logger.info("Parameters defined, setup sampler")
 
         if args.prof:
@@ -245,17 +241,6 @@ def main(
         else:
             raise ValueError("Unknown sampler: {}".format(sampler))
 
-        if args.prof:
-            pr.disable()
-            # Dump results:
-            # - for binary dump
-            pr.dump_stats("debug/cpu_serial.prof")
-            # - for text dump
-            with open("debug/cpu_serial.txt", "w") as output_file:
-                sys.stdout = output_file
-                pr.print_stats(sort="time")
-                sys.stdout = sys.__stdout__
-
 
 if __name__ == "__main__":
     import argparse
@@ -291,11 +276,9 @@ if __name__ == "__main__":
             ]
 
     args = parser.parse_args()
-    # print(args)
-
-    # args = utils.args.parse_args()
 
     # # * debugging values
+    # print(args)
     # args.imfile = "img/image_micro_8.h5"
     # args.datafilename = "data_image_micro_8_ds1_M30_k8"
     # args.rpath = "debug"
@@ -379,21 +362,6 @@ if __name__ == "__main__":
     datafilename = join(args.dpath, args.datafilename + ".h5")
     checkpointname = join(args.rpath, args.checkpointname)
 
-    # tau.run("""main(
-    #     args.data,
-    #     args.rpath,
-    #     args.imfile,
-    #     datafilename,
-    #     args.checkpointname,
-    #     logger,
-    #     profiling=args.prof,
-    #     alpha=args.alpha,
-    #     beta=args.beta,
-    #     Nmc=args.Nmc,
-    #     checkpoint_frequency=args.checkpoint_frequency,
-    #     monitor_frequency=5,
-    # )""")
-
     main(
         args.data,
         args.rpath,
diff --git a/src/dsgs/main_spmd_poisson_deconvolution.py b/src/dsgs/main_spmd_poisson_deconvolution.py
index 704fa7586cc3764bf2dbc517997f9833bb098997..9ab6998697013235c674fbe455282734243b8168 100644
--- a/src/dsgs/main_spmd_poisson_deconvolution.py
+++ b/src/dsgs/main_spmd_poisson_deconvolution.py
@@ -21,9 +21,6 @@ from dsgs.samplers.parallel.spmd_psgla_poisson_deconvolution import SpmdPsglaSGS
 from dsgs.utils.checkpoint import DistributedCheckpoint, SerialCheckpoint
 from dsgs.utils.communications import local_split_range_nd
 
-# import tau
-# https://forum.hdfgroup.org/t/crash-when-writing-parallel-compressed-chunks/6186
-
 
 def main(
     comm,
@@ -282,17 +279,6 @@ def main(
 
         spmd_sampler.sample()
 
-        if args.prof:
-            pr.disable()
-            # Dump results:
-            # - for binary dump
-            pr.dump_stats("debug/cpu_%d.prof" % comm.rank)
-            # - for text dump
-            with open("debug/cpu_%d.txt" % comm.rank, "w") as output_file:
-                sys.stdout = output_file
-                pr.print_stats(sort="time")
-                sys.stdout = sys.__stdout__
-
 
 if __name__ == "__main__":
     import argparse
@@ -331,11 +317,9 @@ if __name__ == "__main__":
             ]
 
     args = parser.parse_args()
-    # print(args)
-
-    # args = utils.args.parse_args()
 
     # * debugging values
+    # print(args)
     # args.imfile = "img/cameraman.png"
     # args.datafilename = "conv_data_cameraman_ds1_M30_k8"
     # args.rpath = "results_conv_cameraman_ds1_M30_k8_h5"
diff --git a/src/dsgs/operators/convolutions.py b/src/dsgs/operators/convolutions.py
index f54aca143a44906371a575d8779d42da11bb3623..ffafaeb8f9c23ac90d52b8143d091e72f1a9b971 100755
--- a/src/dsgs/operators/convolutions.py
+++ b/src/dsgs/operators/convolutions.py
@@ -128,8 +128,6 @@ def linear_convolution(x, h, mode="constant"):
     rsize = np.array(h.shape, dtype="i") - 1
     y = pad_array_nd(x, lsize, rsize, mode="constant")
     return scipy.ndimage.convolve(y, h, mode="constant", cval=0.0)
-    # y = pad_array_nd(x, lsize, rsize, mode=mode)
-    # return scipy.ndimage.convolve(y, h, mode="constant", cval=0.0)
 
 
 def adjoint_linear_convolution(y, h, mode="constant"):
@@ -156,236 +154,3 @@ def adjoint_linear_convolution(y, h, mode="constant"):
     )
     # ! np.flip(h, axis=None) flips all axes
     return adjoint_padding(x, lsize, rsize, mode="constant")
-
-    # ! symmetric bd condition
-    # ! not functional for now
-    # lsize = (np.array(h.shape, dtype="i") - 1)
-    # rsize = lsize
-    # yp = pad_array_nd(y, np.zeros(len(h.shape), dtype="i"), rsize, mode="constant")
-    # x = scipy.ndimage.convolve(yp, np.conj(np.flip(h, axis=None)), mode='constant', cval=0.0)
-    # # ! np.flip(h, axis=None) flips all axes
-    # return adjoint_padding(x, lsize, rsize, mode=mode)
-
-
-# TODO: write a generic version
-# direct: padding, conv. in valid mode
-# Hy_ = sg.convolve2d(y_, h, boundary="fill", mode="full")
-# Hadj_y_ = adjoint_padding(Hy_, ext_size, ext_size, mode="symmetric")
-
-# same kind for symmetric when not based on fft (overlap-add for distributed version)
-
-
-# ! to be made more generic (quick test for now)
-# TODO: adjoint of a convolution operator with any boundary extension involves
-# the adjoint of the circular convolution and the adjoint of the padding
-# operator
-# def adjoint_conv(x, h, shape):
-#     """Adjoint of the linear convolution operator.
-
-#     Parameters
-#     ----------
-#     x : _type_
-#         _description_
-#     h : _type_
-#         _description_
-#     shape : _type_
-#         _description_
-
-#     Returns
-#     -------
-#     _type_
-#         _description_
-#     """
-#     Hx = sg.convolve2d(x, np.flip(h, axis=None), boundary="fill", mode="full")
-#     # ! np.flip(m, axis=None) flips all axes
-#     s = tuple([np.s_[: shape[k]] for k in range(len(shape))])
-#     return Hx[s]
-
-
-if __name__ == "__main__":
-    # # TODO: structure the example better, make sure this is included in a
-    # # TODO: unit-test
-    # import matplotlib.pyplot as plt
-    # import scipy.signal as sg
-    # from PIL import Image
-
-    # from dsgs.operators.linear_convolution import SerialConvolution
-    # from dsgs.operators.padding import adjoint_padding, pad_array, pad_array_nd
-
-    # # Generate 2D Gaussian convolution kernel
-    # vr = 1
-    # M = 7
-    # if np.mod(M, 2) > 0:  # M odd
-    #     n = np.arange(-(M - 1) // 2, (M - 1) // 2 + 1)
-    # else:
-    #     n = np.arange(-M // 2, M // 2)
-    # h = np.exp(-(n**2 + n[:, np.newaxis] ** 2) / (2 * vr))
-
-    # # plt.imshow(h, cmap=plt.cm.gray)
-    # # plt.show()
-
-    # x = np.asarray(Image.open("img/cameraman.png", "r"), dtype="d")
-    # N = x.shape
-    # M = h.shape
-
-    # # * version 1: circular convolution
-    # K = N
-    # hpad = pad_array(h, K, padmode="after")  # after, using fft convention for center
-    # yc, H = fft2_conv(x, hpad, K)
-
-    # circ_conv = SerialConvolution(np.array(K, dtype="i"), h, np.array(K, dtype="i"))
-    # yc2 = circ_conv.forward(x)
-    # print("yc2 == yc ? {0}".format(np.allclose(yc2, yc)))
-
-    # # plt.imshow(yc, cmap=plt.cm.gray)
-    # # plt.show()
-
-    # # check adjoint operator (circular convolution)
-    # rng = np.random.default_rng(1234)
-    # x_ = rng.standard_normal(N)
-    # Hx_ = circ_conv.forward(x_)
-    # y_ = rng.standard_normal(K)
-    # Hadj_y_ = circ_conv.adjoint(y_)
-    # hp1 = np.sum(Hx_ * y_)
-    # hp2 = np.sum(x_ * Hadj_y_)
-
-    # print(
-    #     "Correct adjoint operator (circular convolution)? {}".format(
-    #         np.isclose(hp1, hp2)
-    #     )
-    # )
-
-    # # * version 1.2: circular convolution w/o Fourier
-    # K = N
-    # # with fft
-    # hpad = pad_array(h, K, padmode="after")  # after, using fft convention for center
-    # yc, H = fft2_conv(x, hpad, K)
-
-    # # w/o fft (pad signal, linear convolution with x w/o additional border effect) -> ok
-    # shift = [[M[n] - 1, 0] for n in range(len(M))]
-    # # xp = np.pad(x, shift, mode="wrap")
-    # # yc2 = sg.convolve2d(xp, h, boundary="fill", mode="valid")
-    # # print("yc2 == yc? {0}".format(np.allclose(yc2, yc)))
-    # # yc3 = sg.convolve2d(xp, h, boundary="fill", mode="full")
-    # # yc3 = yc3[M[0]-1:-(M[0]-1), M[1]-1:-(M[1]-1)]
-    # # print("yc3 == yc2? {0}".format(np.allclose(yc2, yc3)))
-
-    # # check adjoint operator (circular convolution, w/o Fourier)
-    # rng = np.random.default_rng(1234)
-    # x_ = rng.standard_normal(N)
-    # Hx_ = np.pad(x_, shift, mode="wrap")
-    # Hx_ = sg.convolve2d(Hx_, h, boundary="fill", mode="valid")
-
-    # y_ = rng.standard_normal(K)
-    # Hadj_y_ = sg.convolve2d(y_, np.conj(np.flip(h)), boundary="fill", mode="full")
-    # # ! manual adjoint padding (wrap condition)
-    # # Hadj_y_[-(M[0]-1):, :] += Hadj_y_[:M[0]-1, :]
-    # # Hadj_y_[:, -(M[1]-1):] += Hadj_y_[:, :M[1]-1]
-    # # Hadj_y_ = Hadj_y_[M[0]-1:, M[1]-1:]
-    # Hadj_y_ = adjoint_padding(
-    #     Hadj_y_, [M[n] - 1 for n in range(len(M))], 2 * [None], mode="wrap"
-    # )
-
-    # hp1 = np.sum(Hx_ * y_)
-    # hp2 = np.sum(x_ * Hadj_y_)
-    # print(
-    #     "Correct adjoint operator (circular convolution)? {}".format(
-    #         np.isclose(hp1, hp2)
-    #     )
-    # )
-
-    # # * version 2: linear convolution
-    # # linear convolution
-    # K = [N[n] + M[n] - 1 for n in range(len(N))]
-    # H = np.fft.rfft2(h, K)
-    # yl = np.fft.irfft2(H * np.fft.rfft2(x, K), K)  # zeros appear around
-
-    # yl2 = fft_conv(x, H, K)
-    # print("yl2 == yl ? {0}".format(np.allclose(yl2, yl)))
-
-    # linear_conv = SerialConvolution(np.array(N, dtype="i"), h, np.array(K, dtype="i"))
-    # yl3 = linear_conv.forward(x)
-    # print("yl3 == yl ? {0}".format(np.allclose(yl3, yl)))
-
-    # # plt.imshow(yl, cmap=plt.cm.gray)
-    # # plt.show()
-
-    # # check adjoint operator (linear convolution)
-    # rng = np.random.default_rng(1234)
-    # x_ = rng.standard_normal(N)
-    # Hx_ = linear_conv.forward(x_)
-    # y_ = rng.standard_normal(K)
-    # Hadj_y_ = linear_conv.adjoint(y_)
-    # hp1 = np.sum(Hx_ * y_)
-    # hp2 = np.sum(x_ * Hadj_y_)
-
-    # print(
-    #     "Correct adjoint operator (linear convolution)? {}".format(np.isclose(hp1, hp2))
-    # )
-
-    # # adjoint operator (w/o fft)
-    # Hadj_y = adjoint_conv(x_, h, N)
-    # hp2 = np.sum(np.conj(x_) * Hadj_y_)
-
-    # print(
-    #     "Correct adjoint convolution operator (linear convolution, no fft)? {}".format(
-    #         np.isclose(hp1, hp2)
-    #     )
-    # )
-
-    # # Notes:
-    # # 5::4 -> slice(5, None, 4)
-    # # 5:-1 -> slice(5, -1) -> np.s_[5:-1]
-    # # s = alpha[5::4] == t = alpha[slice(5, None, 4)]
-
-    # * convolution with symmetric boundary conditions
-    # rng = np.random.default_rng(1234)
-    # N_ = np.array(N, dtype="i")
-    # M_ = np.array(M, dtype="i")
-    # ext_size = M_ - 1
-    # K_ = N_ + ext_size
-    # Fh_d = np.fft.rfft2(h, s=N_ + 3 * M_ - 3)
-    # Fh_a = np.fft.rfft2(h, s=K_ + M_ - 1)
-
-    # x_ = rng.standard_normal(N)
-    # y_ = rng.standard_normal(K_)
-
-    # # direct operator (with sg.convolve2d)
-    # Hx_ = sg.convolve2d(x_, h, boundary="symm", mode="full")
-
-    # # direct operator (alternative using manual padding)
-    # # xp = pad_array_nd(x_, ext_size, ext_size, mode="symmetric")
-    # # Hx2 = sg.convolve(xp, h, mode="valid")
-
-    # # direct operator (alternative based on fft)
-    # xp = pad_array_nd(x_, ext_size, ext_size, mode="symmetric")
-    # Hx2_ = fft_conv(xp, Fh_d, K_ + 2 * M_ - 2)
-    # Hx2 = adjoint_padding(
-    #     Hx2_, ext_size, ext_size, mode="constant"
-    # )  # equivalent to valid
-    # print(
-    #     "Correct direct convolution (symmetric extension)? {}".format(
-    #         np.allclose(Hx2, Hx_)
-    #     )
-    # )
-    # # -> y = CHPx, C cropping (valid, remove M-1 on each side), P boundary extension, H convolution
-
-    # # adjoint operator (using sg.convolve2d)
-    # Hy_ = sg.convolve2d(y_, h, boundary="fill", mode="full")
-    # Hadj_y_ = adjoint_padding(Hy_, ext_size, ext_size, mode="symmetric")
-
-    # # adjoint operator (using fft)
-    # # ! correct, but no conjugate of Fh_a here! why?
-    # Hy0_ = fft_conv(y_, Fh_a, K_ + M_ - 1)
-    # Hadj_y0 = adjoint_padding(Hy0_, ext_size, ext_size, mode="symmetric")
-
-    # hp1 = np.sum(np.conj(Hx_) * y_)
-    # hp2 = np.sum(np.conj(x_) * Hadj_y_)
-
-    # print(
-    #     "Correct adjoint convolution operator (symmetric extension)? {}".format(
-    #         np.isclose(hp1, hp2)
-    #     )
-    # )
-
-    pass
diff --git a/src/dsgs/operators/data.py b/src/dsgs/operators/data.py
index e0f907817e161f54f3629d8a0f415b7a8873f163..1c546a37e5e0cd3dd19ed414ccf221f3ee7243da 100644
--- a/src/dsgs/operators/data.py
+++ b/src/dsgs/operators/data.py
@@ -8,9 +8,6 @@ MPI processes.
 # Sampler with Hypergraph Structure for High-Dimensional Inverse Problems**,
 # [arxiv preprint 2210.02341](http://arxiv.org/abs/2210.02341), October 2022.
 
-# TODO: to be simplified, e.g. with checkpoint and model objects, or even
-# TODO: removed
-
 from os.path import splitext
 
 import h5py
@@ -117,7 +114,7 @@ def generate_random_mask(image_size, percent, rng):
             "Fraction of observed pixels percent should be such that: 0 <= percent <= 1."
         )
 
-    N = np.prod(image_size)  # total number of pixels
+    N = np.prod(image_size)
     masked_id = np.unravel_index(
         rng.choice(N, (percent * N).astype(int), replace=False), image_size
     )
@@ -430,16 +427,11 @@ def generate_local_data(
 
     # local convolution
     fft_h = np.fft.rfftn(h, local_conv_size)
-    # H = np.fft.rfft2(
-    #     np.fft.fftshift(padding.pad_array(h, N, padmode="around"))
-    # )  # doing as in Vono's reference code
     local_coeffs = ucomm.slice_valid_coefficients(ranknd, grid_size, overlap_size)
 
     # ! issue: need to make sure Hx >= 0, not necessarily the case numerically
     # ! with a fft-based convolution
     # https://github.com/pytorch/pytorch/issues/30934
-    # Hx = scipy.ndimage.convolve(facet, h, output=Hx, mode='constant', cval=0.0)
-    # Hx = convolve2d(facet, h, mode='full')[local_coeffs]
     Hx = fft_conv(facet, fft_h, local_conv_size)[local_coeffs]
     prox_nonegativity(Hx)
     local_data = local_rng.poisson(Hx)
@@ -483,14 +475,6 @@ def mpi_load_data_from_h5(
 ):
     ndims = data_size.size
 
-    # slice for indexing into global arrays
-    # global_slice_data = tuple(
-    #     [
-    #         np.s_[tile_pixels[d, 0] : tile_pixels[d, 0] + local_data_size[d]]
-    #         for d in range(ndims)
-    #     ]
-    # )
-
     local_slice = tuple(ndims * [np.s_[:]])
     global_slice_data = create_local_to_global_slice(
         tile_pixels, ranknd, overlap_size, local_data_size, backward=backward
@@ -508,7 +492,3 @@ def mpi_load_data_from_h5(
     f.close()
 
     return local_data
-
-
-# if __name__ == "__main__":
-#     pass
diff --git a/src/dsgs/operators/jtv.py b/src/dsgs/operators/jtv.py
index 3ef00f74430a2b1bebf609a0b7d18d604bab2a3e..77792debd47b6549ccb4de29288e3cc0dedcc21d 100755
--- a/src/dsgs/operators/jtv.py
+++ b/src/dsgs/operators/jtv.py
@@ -18,12 +18,6 @@ from numba import jit
 # ! does not support type elision
 # ! only jit costly parts (by decomposing function), keep flexibility of Python
 # ! as much as possible
-# import importlib
-# importlib.reload(...)
-
-# TODO: investigate jitted nD version for the TV (not only 2D)
-# TODO: try to simplify the 2d implementation of the chunked version of the TV
-# TODO (many conditions to be checked at the moment)
 
 # * Useful numba links
 # https://stackoverflow.com/questions/57662631/vectorizing-a-function-returning-tuple-using-numba-guvectorize
@@ -352,105 +346,3 @@ def gradient_smooth_tv(x, eps):
     u = gradient_2d(x)
     w = np.sqrt(np.abs(u[0]) ** 2 + np.abs(u[1]) ** 2 + eps)
     return gradient_2d_adjoint(u[0] / w, u[1] / w)
-
-
-# if __name__ == "__main__":
-#     import timeit
-#     from inspect import cleandoc  # handle identation in multi-line strings
-
-#     rng = np.random.default_rng(1234)
-#     x = rng.standard_normal((10, 5))
-#     eps = np.finfo(float).eps
-
-#     stmt_s = "tv.gradient_smooth_tv(x, eps)"
-#     setup_s = cleandoc(
-#         """
-#     import tv
-#     import from __main__ import x, eps
-#     """
-#     )
-#     t_tv_np = timeit.timeit(
-#         stmt_s, number=100, globals=globals()
-#     )  # save in pandas (prettier display)
-#     print("TV (numpy version): ", t_tv_np)
-
-#     _ = gradient_smooth_tv(x, eps)  # trigger numba compilation
-#     stmt_s = "gradient_smooth_tv(x, eps)"
-#     setup_s = "from __main__ import gradient_smooth_tv, x, eps"
-#     t_tv_numba = timeit.timeit(stmt_s, setup=setup_s, number=100)
-#     print("TV (numba version): ", t_tv_numba)
-
-#     # ! multiple facets (serial test)
-#     # uh0, uv0 = gradient_2d(x)
-
-#     # yh0 = (1 + 1j) * rng.standard_normal(uh0.shape)
-#     # yv0 = (1 + 1j) * rng.standard_normal(uv0.shape)
-#     # z0 = gradient_2d_adjoint(yh0, yv0)
-
-#     # ndims = 2
-#     # grid_size = np.array([3, 2], dtype="i")
-#     # nchunks = np.prod(grid_size)
-#     # N = np.array(x.shape, dtype="i")
-
-#     # overlap = (grid_size > 1).astype(int)
-#     # isfirst = np.empty((nchunks, ndims), dtype=bool)
-#     # islast = np.empty((nchunks, ndims), dtype=bool)
-
-#     # range0 = []
-#     # range_direct = []
-#     # range_adjoint = []
-#     # uh = np.zeros(x.shape)
-#     # uv = np.zeros(x.shape)
-#     # z = np.zeros(x.shape, dtype=complex)
-
-#     # for k in range(nchunks):
-
-#     #     ranknd = np.array(np.unravel_index(k, grid_size), dtype="i")
-#     #     islast[k] = ranknd == grid_size - 1
-#     #     isfirst[k] = ranknd == 0
-#     #     range0.append(ucomm.local_split_range_nd(grid_size, N, ranknd))
-#     #     Nk = range0[k][:,1] - range0[k][:,0] + 1
-
-#     #     # * direct operator
-#     #     # version with backward overlap
-#     #     # range_direct.append(ucomm.local_split_range_nd(grid_size, N, ranknd, overlap=overlap, backward=True))
-#     #     # facet = x[range_direct[k][0,0]:range_direct[k][0,1]+1, \
-#     #     # range_direct[k][1,0]:range_direct[k][1,1]+1]
-#     #     # uh_k, uv_k = chunk_gradient_2d(facet, islast[k])
-
-#     #     # start = range_direct[k][:,0]
-#     #     # stop = start + np.array(uv_k.shape, dtype="i")
-#     #     # uv[start[0]:stop[0], start[1]:stop[1]] = uv_k
-#     #     # stop = start + np.array(uh_k.shape, dtype="i")
-#     #     # uh[start[0]:stop[0], start[1]:stop[1]] = uh_k
-
-#     #     # version with forward overlap
-#     #     range_direct.append(ucomm.local_split_range_nd(grid_size, N, ranknd, overlap=overlap, backward=False))
-#     #     facet = x[range_direct[k][0,0]:range_direct[k][0,1]+1, \
-#     #     range_direct[k][1,0]:range_direct[k][1,1]+1]
-#     #     uh_k, uv_k = chunk_gradient_2d(facet, islast[k])
-
-#     #     start = range0[k][:,0]
-#     #     stop = start + np.array(uv_k.shape, dtype="i")
-#     #     uv[start[0]:stop[0], start[1]:stop[1]] = uv_k
-#     #     stop = start + np.array(uh_k.shape, dtype="i")
-#     #     uh[start[0]:stop[0], start[1]:stop[1]] = uh_k
-
-#     #     # * adjoint (backward overlap only, forward more difficult to encode)
-#     #     range_adjoint.append(ucomm.local_split_range_nd(grid_size, N, ranknd, overlap=overlap, backward=True))
-#     #     facet_h = yh0[range_adjoint[k][0,0]:range_adjoint[k][0,1]+1, \
-#     #     range_adjoint[k][1,0]:range_adjoint[k][1,1]+1]
-#     #     facet_v = yv0[range_adjoint[k][0,0]:range_adjoint[k][0,1]+1, \
-#     #     range_adjoint[k][1,0]:range_adjoint[k][1,1]+1]
-#     #     x_k = np.zeros(Nk, dtype=complex)
-#     #     chunk_gradient_2d_adjoint(facet_h, facet_v, x_k, isfirst[k], islast[k])
-
-#     #     start = range0[k][:,0]
-#     #     stop = start + Nk
-#     #     z[start[0]:stop[0], start[1]:stop[1]] = x_k
-
-#     # print(np.allclose(uh, uh0))
-#     # print(np.allclose(uv, uv0))
-#     # print(np.allclose(z, z0))
-
-#     pass
diff --git a/src/dsgs/operators/linear_convolution.py b/src/dsgs/operators/linear_convolution.py
index 384411d6c07c3e3aab63263626e89f4eb254f767..dbbe687118288dbb5f41aee5866c62db24da3682 100644
--- a/src/dsgs/operators/linear_convolution.py
+++ b/src/dsgs/operators/linear_convolution.py
@@ -16,13 +16,8 @@ from dsgs.operators.convolutions import fft_conv
 from dsgs.operators.distributed_convolutions import calculate_local_data_size
 from dsgs.operators.linear_operator import LinearOperator
 
-# ? = Keep kernel value out of the associated model object?
-# TODO: - add circular communications for distributed convolutions
-# TODO: - trim-down number of attributes in the classes?
-
 
 # * Convolution model class (serial and distributed)
-# ! keep kernel out of the class? (e.g., for blind deconvolution)
 class SerialConvolution(LinearOperator):
     r"""Serial (FFT-based) convolution operator.
 
@@ -39,8 +34,6 @@ class SerialConvolution(LinearOperator):
         - If ``data_size == image_size + kernel_size - 1``: linear convolution.
     """
 
-    # TODO: make sure the interface works for both linear and circular
-    # TODO- convolutions
     def __init__(
         self,
         image_size,
diff --git a/src/dsgs/operators/padding.py b/src/dsgs/operators/padding.py
index 900166df00b7219a28642a019635cc0bc6cd9c25..d6a6683f3b0b5b41929401937a71c98a732cca76 100755
--- a/src/dsgs/operators/padding.py
+++ b/src/dsgs/operators/padding.py
@@ -322,7 +322,6 @@ def adjoint_padding(y, lsize, rsize, mode="constant"):
                 )
             )
 
-        # sel_core = tuple([np.s_[lsize[d] : rsize[d]] for d in range(ndims)])
         sel_core = []
         x_ = np.copy(y)
 
@@ -373,7 +372,6 @@ def adjoint_padding(y, lsize, rsize, mode="constant"):
                 )
             )
 
-        # sel_core = tuple([np.s_[lsize[d] : rsize[d]] for d in range(ndims)])
         sel_core = []
         x_ = np.copy(y)
 
@@ -413,7 +411,6 @@ def adjoint_padding(y, lsize, rsize, mode="constant"):
         x = x_[tuple(sel_core)]
 
     elif mode == "wrap":
-        # sel_core = tuple([np.s_[lsize[d] : rsize[d]] for d in range(ndims)])
         sel_core = []
         x_ = np.copy(y)
 
@@ -481,8 +478,6 @@ class Padding(LinearOperator):
         "constant".
     """
 
-    # TODO: include tests at the level of the object (not the application of
-    # TODO: - the functions / method)
     def __init__(self, lsize, rsize, mode="constant"):
         """Implementation of padding as a linear operator.
 
@@ -575,56 +570,3 @@ class Padding(LinearOperator):
             Output of the adjoint padding operator.
         """
         return adjoint_padding(y, self.lsize, self.rsize, self.mode)
-
-
-# if __name__ == "__main__":
-#     # import matplotlib.pyplot as plt
-#     # from imageio import imread
-
-#     # Generate 2D Gaussian convolution kernel
-#     vr = 1
-#     M = 7
-#     # if np.mod(M, 2) > 0:  # M odd
-#     #     n = np.arange(-(M - 1) // 2, (M - 1) // 2 + 1)
-#     # else:
-#     #     n = np.arange(-M // 2, M // 2)
-#     # h = np.exp(-(n ** 2 + n[:, np.newaxis] ** 2) / (2 * vr))
-
-#     # plt.imshow(h, cmap=plt.cm.gray)
-#     # plt.show()
-
-#     # x = imread("img/cameraman.png")
-#     # N = x.shape
-#     # M = h.shape
-
-#     # # version 1: circular convolution
-#     # # circular convolution: pad around and fftshift kernel for nicer results
-#     # input_size = N
-#     # hpad = pad_array(h, input_size, padmode="after")  # around
-
-#     # plt.imshow(hpad, cmap=plt.cm.gray)
-#     # plt.show()
-
-#     # testing symmetric padding (with adjoint) (constant is fine)
-#     rng = np.random.default_rng(1234)
-#     x = rng.standard_normal(size=(M,))
-#     y = rng.standard_normal(size=(M + 4,))
-
-#     Hx = pad_array(
-#         x, [M + 4], padmode="around", mode="symmetric"
-#     )  # around, 2 on both sides
-#     Hadj_y = adjoint_padding((y, np.array([2], dtype="i"), np.array([2], dtype="i"), mode="symmetric")
-
-#     hp1 = np.sum(Hx * y)
-#     hp2 = np.sum(x * Hadj_y)
-
-#     print("Correct adjoint operator? {}".format(np.isclose(hp1, hp2)))
-
-#     # 2D: ok
-#     # hpad = pad_array(
-#     #     h, [M[0] + 3, M[1] + 1], padmode="around", mode="symmetric"
-#     # )  # around
-#     # plt.imshow(hpad, cmap=plt.cm.gray)
-#     # plt.show()
-
-#     pass
diff --git a/src/dsgs/operators/tv.py b/src/dsgs/operators/tv.py
index 0695912339e88f19460028af63e3c029362713ba..b4c7eb4a0b911d169b9d0bbf07a506f94b24291c 100755
--- a/src/dsgs/operators/tv.py
+++ b/src/dsgs/operators/tv.py
@@ -95,8 +95,6 @@ def tv(x):
     return np.sum(np.sqrt(np.sum(np.abs(u) ** 2, axis=0)))
 
 
-# TODO: see of the n-D version of the TV can be possibly simplified (i.e., to
-# TODO: enable jit support)
 def gradient_nd(x):
     r"""Nd discrete gradient operator.
 
@@ -253,23 +251,3 @@ def gradient_smooth_tv(x, eps):
     u = gradient_2d(x)
     v = gradient_2d_adjoint(u / np.sqrt(np.sum(np.abs(u) ** 2, axis=0) + eps))
     return v
-
-
-# if __name__ == "__main__":
-#     rng = np.random.default_rng(1234)
-#     x = rng.standard_normal((5, 5))
-#     eps = np.finfo(float).eps
-
-#     u2 = gradient_2d(x)
-#     y2 = gradient_2d_adjoint(u2)
-
-#     u = gradient_nd(x)
-#     y = gradient_nd_adjoint(u)
-
-#     err_ = np.linalg.norm(y - y2)
-#     print("Error: {0:1.5e}".format(err_))
-
-#     tv_x = tv_nd(x)
-#     tv_x_2d = tv(x)
-#     err = np.linalg.norm(tv_x - tv_x_2d)
-#     print("Error: {0:1.5e}".format(err))
diff --git a/src/dsgs/samplers/base_sampler.py b/src/dsgs/samplers/base_sampler.py
index 4d873429304e9b2ff86a133c36535f94f3891e38..853daa8c2fa62c57525d1563b7d311e3eeaefae2 100644
--- a/src/dsgs/samplers/base_sampler.py
+++ b/src/dsgs/samplers/base_sampler.py
@@ -160,10 +160,6 @@ class BaseSampler(ABC):
             k for k in (self.hyperparameters.keys() - self.updated_hyperparameters)
         ]
 
-        # ! this is specific to a serial sampler (only needed on the root
-        # ! worker)
-        # ! create a mock "rank parameter" to reuse the base sampler for the distributed version?
-
         # monitoring variables (iteration counter + timing)
         self.atime = np.zeros((1,), dtype="d")
         self.asqtime = np.zeros((1,), dtype="d")
@@ -257,15 +253,12 @@ class BaseSampler(ABC):
         potential_ : float
             Current value of the potential function.
         """
-        # ! could be created somewhere else (just need to return the value of
-        # ! the cost, and the iteration self._it to set up value in the right
-        # ! position on the buffer)
         potential = np.zeros((self.checkpointfreq,), dtype="d")
         self._setup_rng()
         if self.warmstart:
             self._load()
             self.start_iter = self.warmstart_it
-            self._it = -1  # iteration from which aux. params are initialized
+            self._it = -1  # iteration from which auxiliary parameters are initialized
             aux = self._initialize_auxiliary_parameters(self._it)
             potential_ = self._compute_potential(aux, self._it)
             potential[-1] = potential_
diff --git a/src/dsgs/samplers/parallel/spmd_psgla_poisson_deconvolution.py b/src/dsgs/samplers/parallel/spmd_psgla_poisson_deconvolution.py
index 5b5d9f87cbd8ae2d77b5f1ee4c374de2d396f958..2fa6e6e5973e51e7441f5f82ef5770baa058a58f 100644
--- a/src/dsgs/samplers/parallel/spmd_psgla_poisson_deconvolution.py
+++ b/src/dsgs/samplers/parallel/spmd_psgla_poisson_deconvolution.py
@@ -4,15 +4,12 @@
 # Sampler with Hypergraph Structure for High-Dimensional Inverse Problems**,
 # [arxiv preprint 2210.02341](http://arxiv.org/abs/2210.02341), October 2022.
 
-# based on main_s.py, spa_psgla_sync_s and spa_psgla_sync_mml.py
-
 from logging import Logger
 
 import numpy as np
 from mpi4py import MPI
 from numba import jit
 
-# import dsgs.utils.checkpoint_parallel as chkpt
 import dsgs.utils.communications as ucomm
 from dsgs.functionals.prox import (
     kullback_leibler,
@@ -57,16 +54,6 @@ def loading(
     numpy.ndarray, int, float
         Variables required to restart the sampler.
     """
-    # ! version when saving all iterations for all variables
-    # [(np.s_[-1], *global_slice_tile)]
-    #     + 2 * [(np.s_[-1], *global_slice_data)]
-    #     + 2 * [(np.s_[-1], np.s_[:], *global_slice_tile)]
-    #     + 3 * [np.s_[-1]],
-    # ! saving all iterations only for x
-    # [(np.s_[-1], *global_slice_tile)]
-    # + 2 * [global_slice_data]
-    # + 2 * [(np.s_[:], *global_slice_tile)]
-    # + 3 * [np.s_[-1]],
     # ! saving only estimator and last state of each variable
     dic = checkpointer.load(
         warmstart_iter,
@@ -170,111 +157,6 @@ def loading_per_process(
     )
 
 
-# ! ISSUE: code hanging forever in MPI, to be revised..
-def saving(
-    checkpointer: DistributedCheckpoint,
-    iter_mc: int,
-    rng: np.random.Generator,
-    local_x,
-    local_u1,
-    local_u2,
-    local_z1,
-    local_z2,
-    beta,
-    rho1,
-    rho2,
-    alpha1,
-    alpha2,
-    potential,
-    counter,
-    atime,
-    asqtime,
-    nsamples,
-    global_slice_tile,
-    global_slice_data,
-    image_size,
-    data_size,
-):
-    root_id = 0
-    it = iter_mc + 1
-
-    # find index of the candidate MAP within the current checkpoint
-    id_map = np.empty(1, dtype="i")
-    if checkpointer.rank == root_id:
-        id_map[0] = np.argmin(potential[:nsamples])
-
-    checkpointer.comm.Bcast([id_map, 1, MPI.INT], root=0)
-
-    # save MMSE, MAP (only for x)
-    select_ = (
-        [(np.s_[:], *global_slice_tile)]
-        + 2 * [(*global_slice_tile,)]
-        + 2 * [(*global_slice_data,)]
-        + 2 * [(np.s_[:], *global_slice_tile)]
-    )
-    shape_ = (
-        [(nsamples, *image_size)]
-        + 2 * [(*image_size,)]
-        + 2 * [(*data_size,)]
-        + 2 * [(2, *image_size)]
-    )
-    chunk_sizes_ = (
-        [(1, *local_x.shape[1:])]  # ! see if saving only the last point or more...
-        + 2 * [(*local_x.shape[1:],)]
-        + 2 * [(*local_z1.shape,)]
-        + 2 * [(1, *local_x.shape[1:])]
-    )
-
-    x_mmse = np.mean(local_x[:nsamples, ...], axis=0)
-    x_map = local_x[id_map[0]]
-    checkpointer.save(
-        it,
-        shape_,
-        select_,
-        chunk_sizes_,
-        rng=rng,
-        mode="w",
-        rdcc_nbytes=1024**2 * 200,  # 200 MB cache
-        x=local_x[:nsamples],
-        x_map=x_map,
-        x_mmse=x_mmse,
-        z1=local_z1,
-        u1=local_u1,  # u1 and/or z1 create an issue: check the data selection + size fpr these 2 variables
-        z2=local_z2,
-        u2=local_u2,
-    )
-
-    # ! z1 and u1 create a problem when code run from MPI (code hanging
-    # ! forever...): why?
-    # z1=local_z1,
-    # u1=local_u1,
-
-    # ! saving from process 0 only (beta, potential, iter, time)
-    if checkpointer.rank == root_id:
-        select_ = 10 * [np.s_[:]]
-        chunk_sizes_ = 10 * [None]
-        checkpointer.save_from_process(
-            root_id,
-            it,
-            select_,
-            chunk_sizes_,
-            mode="a",
-            rdcc_nbytes=1024**2 * 200,  # 200 MB cache
-            asqtime=np.array(asqtime),
-            atime=np.array(atime),
-            iter=it,
-            potential=potential[:nsamples],
-            beta=beta[:nsamples],
-            rho1=rho1[:nsamples],
-            rho2=rho2[:nsamples],
-            alpha1=alpha1[:nsamples],
-            alpha2=alpha2[:nsamples],
-            counter=counter,
-        )
-
-    pass
-
-
 def saving_per_process(
     rank,
     comm,
@@ -348,7 +230,7 @@ def saving_per_process(
     # checkpointer configuration
     chunk_sizes_ = (
         # [(1, *local_x.shape[1:])]  # ! see if saving only the last point or more...
-        [local_x.shape[1:]]  # ! see if saving only the last point or more...
+        [local_x.shape[1:]]
         + 3 * [(*local_x.shape[1:],)]
         + 2 * [local_z1.shape]
         + 2 * [(1, *local_x.shape[1:])]
@@ -360,7 +242,7 @@ def saving_per_process(
         rng=rng,
         mode="w",
         rdcc_nbytes=1024**2 * 200,  # 200 MB cache
-        x=local_x[-1],  # [:nsamples]
+        x=local_x[-1],
         x_map=x_map,
         x_m=x_mmse,
         x_sq_m=x_sq_m,
@@ -672,32 +554,21 @@ def potentials_function(y, Hx, Gx, z1, u1, z2, u2):
     return np.array([data_fidelity, lrho1, lrho2, lalpha1, lalpha2, prior], dtype="d")
 
 
-# ! TO investigate:
-# - where to define local parameter sizes? global sizes?
-# - sampling step / initialization step
-# - checkpointing modes
-
-# - observations / buffers are all chuncked, except for the hyperparameters (known only on master process)
-
-
 class SpmdPsglaSGS(BaseSPMDSampler):
     def __init__(
         self,
         xmax,
         observations,
-        model,  #: SyncLinearConvolution,
+        model,
         hyperparameters,
         Nmc: int,
         seed: int,
-        checkpointer,  #: DistributedCheckpoint,
+        checkpointer,
         checkpointfreq: int,
         logger: Logger,
         warmstart_it: int = -1,
         warmstart: bool = False,
     ):
-        # TODO: encapsulate all this into a "partition" object (to be defined
-        # in relation to the communicator...)
-
         # * model parameters (AXDA regularization parameter)
         self.max_sq_kernel = np.max(np.abs(model.fft_kernel)) ** 2
 
@@ -718,7 +589,6 @@ class SpmdPsglaSGS(BaseSPMDSampler):
 
         # data
         # ! larger number of data points on the border (forward overlap)
-        # local_data_size = tile_size + (ranknd == 0) * overlap_size
         (
             self.local_data_size,
             self.facet_size,
@@ -756,30 +626,6 @@ class SpmdPsglaSGS(BaseSPMDSampler):
             model.ranknd, grid_size, self.offset_tv_adj, backward=True
         )
 
-        # indexing into global arrays
-        # if not save_mode == "process":
-        #     global_slice_tile = tuple(
-        #         [np.s_[tile_pixels[d, 0] : tile_pixels[d, 1] + 1] for d in range(ndims)]
-        #     )
-        #     global_slice_data = create_local_to_global_slice(
-        #         tile_pixels,
-        #         ranknd,
-        #         sync_conv_model.overlap_size,
-        #         local_data_size,
-        #         backward=False,
-        #     )
-
-        # * slice selection for checkpoints
-        # single file checkpoint
-        # "x", "z1", "u1", "z2", "u2", "beta", "rho1", "rho2", "alpha1",
-        # "alpha2", "potential", "iter",
-        # checkpoint_select = [(*self.global_slice_tile)] + 2 * [(*self.global_slice_data,)] + 2 * [(np.s_[:], *self.global_slice_tile)]
-        # chunk_sizes = (
-        #     [(*self.tile_size)]  # ! see if saving only the last point or more...
-        #     + 2 * [(*self.local_data_size,)]
-        #     + 2 * [(1, *self.tile_size)]
-        # )
-
         # per-process checkpoints
         checkpoint_select = [np.s_[:]] + 4 * [np.s_[:]]
         chunk_sizes = (
@@ -797,17 +643,6 @@ class SpmdPsglaSGS(BaseSPMDSampler):
             "z2": (2, *self.tile_size),
         }
 
-        # TODO: to be double checked
-        # TODO: check is properly loaded into auxiliary buffer for x (main
-        # parameter)
-        # chunk_sizes = 5 * [None]  # chunk sizes for x, z1, u1, z2, u2
-        # checkpoint_select = len(parameter_sizes) * [
-        #     np.s_[:]  # ! saving only last sample to disk for all these variables
-        # ]
-
-        # create dict local_parameter_size
-        # with, in order, data + all variables/parameters involved?
-
         self.xmax = xmax
 
         super(SpmdPsglaSGS, self).__init__(
@@ -827,9 +662,6 @@ class SpmdPsglaSGS(BaseSPMDSampler):
             updated_hyperparameters=[],
         )
 
-        # TODO: move these buffers to the initialization? (auxiliary vars, ...)
-        # TODO: check if using hyperparameter_batch from the beginning or not
-        # TODO - (even when the hyperparameters are fixed...)
         # ! auxiliary buffer for in-place updates / communications of the main
         # ! - variable x
         self.local_x = np.empty(self.facet_size, dtype="d")
@@ -872,7 +704,6 @@ class SpmdPsglaSGS(BaseSPMDSampler):
     def _initialize_auxiliary_parameters(self, it: int):
         # called after initialize_parameters
         # it: position of the element initialized in the larger x buffer
-        # non need for the splitting and augmentation variables
 
         self.local_x[self.local_slice_tile] = self.parameter_batch["x"][it]
 
@@ -897,7 +728,7 @@ class SpmdPsglaSGS(BaseSPMDSampler):
 
         # * PSGLA step-sizes
         # TODO: replace by self.hyperparameter_batch["..."][it] when
-        # TODO - hyperparameters are updates
+        # TODO - hyperparameters are updated
         stepsize_x = 0.99 / (
             self.max_sq_kernel / self.hyperparameters["rho1"]
             + 8 / self.hyperparameters["rho2"]
@@ -912,18 +743,9 @@ class SpmdPsglaSGS(BaseSPMDSampler):
             "stepsize_z1": stepsize_z1,
             "stepsize_z2": stepsize_z2,
         }
-
-        # TODO: see if computation of ojective is needed from here (only
-        # TODO - for MML version)
-
-        # compute parts of the potential for later on
-        # self._compute_potentials(aux, local_potentials, global_potentials)
-
         return aux
 
-    # TODO: add _compute_potentials? (only for mml version, debug initial one first)
 
-    # TODO: version based on local_potentials / global_potentials?
     def _compute_potential(self, aux, local_potential, global_potential):
         local_potential[0] = potential_function(
             self.observations,
@@ -947,37 +769,8 @@ class SpmdPsglaSGS(BaseSPMDSampler):
             root=0,
         )
 
-        # potential = (
-        #     aux["potentials"][0]
-        #     + aux["potentials"][1] / self.hyperparameter_batch["rho1"][it]
-        #     + aux["potentials"][2] / self.hyperparameter_batch["rho2"][it]
-        #     + aux["potentials"][3] / self.hyperparameter_batch["alpha1"][it]
-        #     + aux["potentials"][4] / self.hyperparameter_batch["alpha2"][it]
-        #     + aux["potentials"][5] * self.hyperparameter_batch["beta"][it]
-        # )
-
         pass
 
-    # def _compute_potentials(self, aux, local_potentials, global_potentials):
-    #     # data_fidelity, lrho1, lrho2, lalpha1, lalpha2, prior
-    #     local_potentials = potentials_function(
-    #         self.observations,
-    #         aux["Hx"][self.local_slice_conv_adj],
-    #         aux["Gx"][tuple((np.s_[:], *self.local_slice_tv_adj))],
-    #         self.parameter_batch["z1"],
-    #         self.parameter_batch["u1"],
-    #         self.parameter_batch["z2"],
-    #         self.parameter_batch["u2"],
-    #     )
-
-    #     self.model.comm.Reduce(
-    #         [local_potentials, MPI.DOUBLE],
-    #         [global_potentials, MPI.DOUBLE],
-    #         op=MPI.SUM,
-    #         root=0,
-    #     )
-
-    #     pass
 
     def _sample_step(self, iter_mc: int, current_iter: int, past_iter: int, aux):
         # notational shortcuts (for in-place assignments)
@@ -1065,402 +858,4 @@ class SpmdPsglaSGS(BaseSPMDSampler):
             self.local_rng,
         )
 
-        # ! sample TV regularization parameter beta (to be debugged)
-        # local_l21_z2 = np.sum(np.sqrt(np.sum(local_z2_mc ** 2, axis=0)))
-        # ...
-
-        pass
-
-
-class SpmdPsglaSGSmml(BaseSPMDSampler):
-    def __init__(
-        self,
-        xmax,
-        observations,
-        model,  #: SyncLinearConvolution,
-        hyperparameters,
-        mml_dict,
-        Nmc: int,
-        seed: int,
-        checkpointer,  #: DistributedCheckpoint,
-        checkpointfreq: int,
-        logger: Logger,
-        warmstart_it: int = -1,
-        warmstart: bool = False,
-    ):
-        # TODO: encapsulate all this into a "partition" object (to be defined
-        # in relation to the communicator...)
-
-        # ! only need to be defined on worker 0 (broadcast updates afterwards)
-        self.mml_dict = mml_dict
-
-        # * model parameters (AXDA regularization parameter)
-        self.max_sq_kernel = np.max(np.abs(model.fft_kernel)) ** 2
-
-        # * auxiliary quantities for stochastic gradient updates
-        # dimensions
-        self.d_N = np.prod(model.image_size)
-        # dimension of the proper space (removing 0s)
-        # d_tv = (N[0] - 1) * N[1] + (N[1] - 1) * N[0]
-        self.d_M = np.prod(model.data_size)
-
-        # tile
-        grid_size = MPI.Compute_dims(model.comm.Get_size(), model.ndims)
-        grid_size = np.array(grid_size, dtype="i")
-        self.tile_pixels = ucomm.local_split_range_nd(
-            grid_size, model.image_size, model.ranknd
-        )
-        self.tile_size = self.tile_pixels[:, 1] - self.tile_pixels[:, 0] + 1
-
-        # data
-        # ! larger number of data points on the border (forward overlap)
-        # local_data_size = tile_size + (ranknd == 0) * overlap_size
-        (
-            self.local_data_size,
-            self.facet_size,
-            self.facet_size_adj,
-        ) = calculate_local_data_size(
-            self.tile_size, model.ranknd, model.overlap_size, grid_size, backward=False
-        )
-
-        # facet (convolution)
-        self.offset = self.facet_size - self.tile_size
-        self.offset_adj = self.facet_size_adj - self.tile_size
-
-        # facet (tv)
-        self.offset_tv = self.offset - (self.offset > 0).astype("i")
-        self.offset_tv_adj = np.logical_and(model.ranknd > 0, grid_size > 1).astype("i")
-        self.tv_facet_size_adj = self.tile_size + self.offset_tv_adj
-
-        # * Useful slices (direct operators)
-        # extract tile from local facet (direct conv. operator)
-        self.local_slice_tile = ucomm.get_local_slice(
-            model.ranknd, grid_size, self.offset, backward=False
-        )
-        # extract values from local conv facet to apply local gradient operator
-        self.local_slice_tv = ucomm.get_local_slice(
-            model.ranknd, grid_size, self.offset_tv, backward=False
-        )
-
-        # * Useful slices (adjoint operators)
-        # set value of local convolution in the adjoint buffer
-        self.local_slice_conv_adj = ucomm.get_local_slice(
-            model.ranknd, grid_size, self.offset_adj, backward=True
-        )
-        # set value of local discrete gradient into the adjoint gradient buffer
-        self.local_slice_tv_adj = ucomm.get_local_slice(
-            model.ranknd, grid_size, self.offset_tv_adj, backward=True
-        )
-
-        # * slice selection for checkpoints
-        # single file checkpoint
-        # "x", "z1", "u1", "z2", "u2", "beta", "rho1", "rho2", "alpha1",
-        # "alpha2", "potential", "iter",
-        # checkpoint_select = [(*self.global_slice_tile)] + 2 * [(*self.global_slice_data,)] + 2 * [(np.s_[:], *self.global_slice_tile)]
-        # chunk_sizes = (
-        #     [(*self.tile_size)]  # ! see if saving only the last point or more...
-        #     + 2 * [(*self.local_data_size,)]
-        #     + 2 * [(1, *self.tile_size)]
-        # )
-
-        # per-process checkpoints
-        checkpoint_select = [np.s_[:]] + 4 * [np.s_[:]]
-        chunk_sizes = (
-            [tuple(self.tile_size)]
-            + 2 * [tuple(self.local_data_size)]
-            + 2 * [(1, *self.tile_size)]
-        )
-
-        # ! local parameter sizes
-        parameter_sizes = {
-            "x": self.tile_size,
-            "u1": self.local_data_size,
-            "z1": self.local_data_size,
-            "u2": (2, *self.tile_size),
-            "z2": (2, *self.tile_size),
-        }
-
-        self.xmax = xmax
-
-        super(SpmdPsglaSGSmml, self).__init__(
-            observations,
-            model,
-            parameter_sizes,
-            hyperparameters,
-            Nmc,
-            seed,
-            checkpointer,
-            checkpointfreq,
-            checkpoint_select,
-            chunk_sizes,
-            logger,
-            warmstart_it=warmstart_it,
-            warmstart=warmstart,
-            updated_hyperparameters=["rho1", "rho2", "alpha1", "alpha2", "beta"],
-        )
-
-        # TODO: move these buffers to the initialization? (auxiliary vars, ...)
-        # TODO: check if using hyperparameter_batch from the beginning or not
-        # TODO - (even when the hyperparameters are fixed...)
-        # ! auxiliary buffer for in-place updates / communications of the main
-        # ! - variable x
-        self.local_x = np.empty(self.facet_size, dtype="d")
-
-        # * setup communication scheme
-        # ! convolution (direct + adjoint) + direct TV covered by self.model
-        # adjoint TV communicator
-        # ! need a different object for the moment (because of the size required...)
-        self.adjoint_tv_communicator = SyncCartesianCommunicatorTV(
-            self.model.comm,
-            self.model.cartcomm,
-            self.model.grid_size,
-            self.local_x.itemsize,
-            self.tv_facet_size_adj,
-            direction=True,
-        )
-
-    def _initialize_parameters(self):
-        for key in self.parameter_batch.keys():
-            if key == "x":  # ! batch only kept for the main variable x
-                self.parameter_batch[key][0] = self.local_rng.integers(
-                    0, high=self.xmax, size=self.parameter_sizes[key], endpoint=True
-                ).astype(float)
-            else:
-                self.parameter_batch[key] = self.local_rng.integers(
-                    0, high=self.xmax, size=self.parameter_sizes[key], endpoint=True
-                ).astype(float)
-
-        # ! need to broadcast latest value of all the hyperparameters later on
-        # ! in _initialize_auxiliary_parameters
-        if self.rank == 0:
-            for key in ["rho1", "rho2", "alpha1", "alpha2", "beta"]:
-                # in self.hyperparameters.keys():
-                self.hyperparameter_batch[key][0] = self.hyperparameters[key]
-
-        pass
-
-    def _initialize_auxiliary_parameters(self, it: int):
-        # called after initialize_parameters
-        # it: position of the element initialized in the larger x buffer
-        # non need for the splitting and augmentation variables
-
-        self.local_x[self.local_slice_tile] = self.parameter_batch["x"][it]
-
-        # * setup auxiliary buffers
-        # ! communicating facet borders to neighbours already done in-place
-        # ! with the direction convolution operator
-        # ! Hx, Gx updated whenever buffer_Hx and buffer_Gx are
-        buffer_Hx = np.empty(self.facet_size_adj)
-        buffer_Hx[self.local_slice_conv_adj] = self.model.forward(self.local_x)
-
-        buffer_Gx = np.empty((2, *self.tv_facet_size_adj))
-        (
-            buffer_Gx[tuple((0, *self.local_slice_tv_adj))],
-            buffer_Gx[tuple((1, *self.local_slice_tv_adj))],
-        ) = chunk_gradient_2d(self.local_x[self.local_slice_tv], self.islast)
-
-        # communicate facet borders to neighbours (Hx, Gx)
-        # ! Hx updated in place
-        self.model.adjoint_communicator.update_borders(buffer_Hx)
-        # ! Gx updated in place
-        self.adjoint_tv_communicator.update_borders(buffer_Gx)
-
-        # * create hyperparameter vector (for comms between workers)
-        self.hyperparameter_vector = np.empty((len(self.hyperparameters)), dtype="d")
-        if self.rank == 0:
-            c_ = 0
-            for key in ["rho1", "rho2", "alpha1", "alpha2", "beta"]:
-                self.hyperparameter_vector[c_] = self.hyperparameter_batch[key][it]
-                c_ += 1
-        self.model.comm.Bcast(
-            [self.hyperparameter_vector, len(self.hyperparameters), MPI.DOUBLE], root=0
-        )
-
-        # * PSGLA step-sizes
-        stepsize_x = 0.99 / (
-            self.max_sq_kernel / self.hyperparameter_vector[0]
-            + 8 / self.hyperparameter_vector[1]
-        )
-        stepsize_z1 = 0.99 * self.hyperparameter_vector[0]
-        stepsize_z2 = 0.99 * self.hyperparameter_vector[1]
-
-        aux = {
-            "Hx": buffer_Hx,
-            "Gx": buffer_Gx,
-            "stepsize_x": stepsize_x,
-            "stepsize_z1": stepsize_z1,
-            "stepsize_z2": stepsize_z2,
-        }
-
-        # compute parts of the potential for later on (MML updates)
-        if self.rank == 0:
-            aux.update({"potentials": np.empty((len(self.hyperparameters) + 1), "d")})
-        else:
-            aux.update({"potentials": None})
-
-        self._compute_potentials(aux)
-
-        return aux
-
-    # TODO: version based on local_potentials / global_potentials?
-    # ! revise local_potential / global potential (not needed if modifying the algo to
-    # ! - accommodate hyperparameter updates)
-    def _compute_potential(self, aux, local_potential, global_potential):
-        if self.rank == 0:
-            global_potential[0] = (
-                aux["potentials"][0]
-                + aux["potentials"][1] / self.hyperparameter_vector[0]  # rho1
-                + aux["potentials"][2] / self.hyperparameter_vector[1]  # rho2
-                + aux["potentials"][3] / self.hyperparameter_vector[2]  # alpha1
-                + aux["potentials"][4] / self.hyperparameter_vector[3]  # alpha2
-                + aux["potentials"][5] * self.hyperparameter_vector[4]  # beta
-            )
-
-        pass
-
-    def _compute_potentials(self, aux):
-        # data_fidelity, lrho1, lrho2, lalpha1, lalpha2, prior
-        local_potentials = potentials_function(
-            self.observations,
-            aux["Hx"][self.local_slice_conv_adj],
-            aux["Gx"][tuple((np.s_[:], *self.local_slice_tv_adj))],
-            self.parameter_batch["z1"],
-            self.parameter_batch["u1"],
-            self.parameter_batch["z2"],
-            self.parameter_batch["u2"],
-        )
-
-        self.model.comm.Reduce(
-            [local_potentials, MPI.DOUBLE],
-            [aux["potentials"], MPI.DOUBLE],
-            op=MPI.SUM,
-            root=0,
-        )
-
-        pass
-
-    def _sample_step(self, iter_mc: int, current_iter: int, past_iter: int, aux):
-        # * PSGLA step-sizes
-        aux["stepsize_x"] = 0.99 / (
-            self.max_sq_kernel / self.hyperparameter_vector[0]
-            + 8 / self.hyperparameter_vector[1]
-        )
-        aux["stepsize_z1"] = 0.99 * self.hyperparameter_vector[0]
-        aux["stepsize_z2"] = 0.99 * self.hyperparameter_vector[1]
-
-        # notational shortcuts (for in-place assignments)
-        # ? can be defined just once? (to be double checked)
-        # ! Hx, Gx updated whenever buffer_Hx and buffer_Gx are
-        Hx = aux["Hx"][self.local_slice_conv_adj]
-        Gx = aux["Gx"][tuple((np.s_[:], *self.local_slice_tv_adj))]
-
-        # sample image x (update local tile)
-        grad_x = gradient_x(
-            self.model,
-            self.adjoint_tv_communicator,
-            self.isfirst,
-            self.islast,
-            aux["Hx"],
-            aux["Gx"],
-            self.parameter_batch["z1"],
-            self.parameter_batch["u1"],
-            self.parameter_batch["z2"],
-            self.parameter_batch["u2"],
-            self.hyperparameter_vector[0],  # rho1
-            self.hyperparameter_vector[1],  # rho2
-            self.local_slice_tv,
-            self.local_slice_tv_adj,
-        )
-
-        sample_x(
-            self.local_x[self.local_slice_tile],
-            aux["stepsize_x"],
-            grad_x,
-            self.local_rng,
-        )
-        self.parameter_batch["x"][current_iter] = self.local_x[self.local_slice_tile]
-
-        # communicate borders of each facet to appropriate neighbours
-        # ! synchronous case: need to update Hx and Gx for the next step
-        # ! local_x updated in-place here (border communication involved in
-        # ! direct operator)
-        aux["Hx"][self.local_slice_conv_adj] = self.model.forward(self.local_x)
-        # ! Hx and Gx updated in-place whenever buffer_Hx, buffer_Gx are
-        (
-            aux["Gx"][tuple((0, *self.local_slice_tv_adj))],
-            aux["Gx"][tuple((1, *self.local_slice_tv_adj))],
-        ) = chunk_gradient_2d(self.local_x[self.local_slice_tv], self.islast)
-
-        # communicate borders of buffer_Hx, buffer_Gx (adjoint op) for the
-        # next iteration
-        self.model.adjoint_communicator.update_borders(aux["Hx"])
-        self.adjoint_tv_communicator.update_borders(aux["Gx"])
-
-        # * sample auxiliary variables (z1, u1)
-        self.parameter_batch["z1"] = sample_z1(
-            self.parameter_batch["z1"],
-            self.observations,
-            Hx,
-            self.parameter_batch["u1"],
-            self.hyperparameter_vector[0],  # rho1
-            aux["stepsize_z1"],
-            self.local_rng,
-        )
-        self.parameter_batch["u1"] = sample_u(
-            self.parameter_batch["z1"],
-            Hx,
-            self.hyperparameter_vector[0],  # rho1
-            self.hyperparameter_vector[2],  # alpha1
-            self.local_rng,
-        )
-
-        # * sample auxiliary variables (z2, u2)
-        self.parameter_batch["z2"] = sample_z2(
-            self.parameter_batch["z2"],
-            Gx,
-            self.parameter_batch["u2"],
-            self.hyperparameter_vector[1],  # rho2
-            aux["stepsize_z2"],
-            self.hyperparameter_vector[4] * aux["stepsize_z2"],  # beta
-            self.local_rng,
-        )
-
-        self.parameter_batch["u2"] = sample_u(
-            self.parameter_batch["z2"],
-            Gx,
-            self.hyperparameter_vector[1],  # rho2
-            self.hyperparameter_vector[3],  # alpha2
-            self.local_rng,
-        )
-
-        # * update hyperparameters (using MML)
-        # compute auxiliary potentials for MML updates + computation of the
-        # overall potential
-        # ! need to define local_potentials, global_potentials
-        self._compute_potentials(aux)
-
-        # TODO: revise and simplify this part
-        # order: [data_fidelity, lrho1, lrho2, lalpha1, lalpha2, prior]
-        # check if this is fine in practice
-        if self.rank == 0:
-            c_ = 0
-
-            # if iter_mc == 20:
-            #     breakpoint()
-
-            for key in ["rho1", "rho2", "alpha1", "alpha2", "beta"]:
-                self.hyperparameter_batch[key][current_iter] = self.mml_dict[
-                    key
-                ].update(
-                    self.hyperparameter_batch[key][past_iter],
-                    iter_mc,
-                    aux["potentials"][c_ + 1],  # ! data-fidelity term in position 0
-                )
-                self.hyperparameter_vector[c_] = self.hyperparameter_batch[key][
-                    current_iter
-                ]
-                c_ += 1
-
-        self.model.comm.Bcast([self.hyperparameter_vector, MPI.DOUBLE], root=0)
-
         pass
diff --git a/src/dsgs/samplers/serial/serial_pmyula_poisson_deconvolution.py b/src/dsgs/samplers/serial/serial_pmyula_poisson_deconvolution.py
index c9624ee39a05ddceaff91777f6611d1ccf4be5e9..a15292f581b8bc0d7edec9fbdf096f23d1599479 100644
--- a/src/dsgs/samplers/serial/serial_pmyula_poisson_deconvolution.py
+++ b/src/dsgs/samplers/serial/serial_pmyula_poisson_deconvolution.py
@@ -505,226 +505,3 @@ class MyulaSGS(SerialSampler):
         )
 
         pass
-
-
-class MyulaSGSmml(SerialSampler):
-    def __init__(
-        self,
-        xmax,
-        observations,
-        model: SerialConvolution,
-        hyperparameters,
-        mml_dict,
-        Nmc: int,
-        seed: int,
-        checkpointer: SerialCheckpoint,
-        checkpointfreq: int,
-        logger: Logger,
-        warmstart_it=-1,
-        warmstart=False,
-        save_batch=False,
-    ):
-        data_size = model.data_size
-        N = model.image_size
-        parameter_sizes = {
-            "x": N,
-            "u1": data_size,
-            "z1": data_size,
-            "u2": N,
-            "z2": N,
-            "u3": N,
-            "z3": N,
-        }
-        chunk_sizes = 7 * [None]  # chunk sizes for x, z1, u1, z2, u2, z3, u3
-        checkpoint_select = len(parameter_sizes) * [
-            np.s_[:]  # ! saving only last sample to disk for all these variables
-        ]
-
-        self.xmax = xmax
-        self.mml_dict = mml_dict
-
-        super(MyulaSGSmml, self).__init__(
-            observations,
-            model,
-            parameter_sizes,
-            hyperparameters,
-            Nmc,
-            seed,
-            checkpointer,
-            checkpointfreq,
-            checkpoint_select,
-            chunk_sizes,
-            logger,
-            warmstart_it=warmstart_it,
-            warmstart=warmstart,
-            save_batch=save_batch,
-            updated_hyperparameters=[
-                "rho1",
-                "rho2",
-                "rho3",
-                "alpha1",
-                "alpha2",
-                "alpha3",
-                "beta",
-            ],
-        )
-
-    def _initialize_parameters(self):
-        for key in self.parameter_batch.keys():
-            self.parameter_batch[key][0] = self.rng.integers(
-                0, high=self.xmax, size=self.parameter_sizes[key], endpoint=True
-            ).astype(float)
-
-        for key in self.hyperparameter_batch.keys():
-            self.hyperparameter_batch[key][0] = self.hyperparameters[key]
-
-        pass
-
-    def _initialize_auxiliary_parameters(self, it: int):
-        Hx = self.model.forward(self.parameter_batch["x"][it])
-
-        aux = {
-            "Hx": Hx,
-            "myula1": MYULA(1 / self.hyperparameter_batch["rho1"][it]),
-            "myula2": MYULA(1 / self.hyperparameter_batch["rho2"][it]),
-            "myula3": MYULA(1 / self.hyperparameter_batch["rho3"][it]),
-        }
-
-        # compute parts of the potential for later on
-        self._compute_potentials(aux, it)
-
-        return aux
-
-    def _compute_potentials(self, aux, it):
-        aux["potentials"] = potentials_function(
-            self.observations,
-            aux["Hx"],
-            self.parameter_batch["x"][it],
-            self.parameter_batch["z1"][it],
-            self.parameter_batch["u1"][it],
-            self.parameter_batch["z2"][it],
-            self.parameter_batch["u2"][it],
-            self.parameter_batch["z3"][it],
-            self.parameter_batch["u3"][it],
-        )
-
-        pass
-
-    def _compute_potential(self, aux, it):
-        # np.array([data_fidelity, lrho1, lrho2, lrho3, lalpha1, lalpha2, lalpha3, prior],
-        #     dtype="d",
-        # )
-
-        potential = (
-            aux["potentials"][0]
-            + aux["potentials"][1] / self.hyperparameter_batch["rho1"][it]
-            + aux["potentials"][2] / self.hyperparameter_batch["rho2"][it]
-            + aux["potentials"][3] / self.hyperparameter_batch["rho3"][it]
-            + aux["potentials"][4] / self.hyperparameter_batch["alpha1"][it]
-            + aux["potentials"][5] / self.hyperparameter_batch["alpha2"][it]
-            + aux["potentials"][6] / self.hyperparameter_batch["alpha3"][it]
-            + aux["potentials"][7] * self.hyperparameter_batch["beta"][it]
-        )
-
-        return potential
-
-    def _sample_step(self, iter_mc, current_iter, past_iter, aux):
-        # * update MYULA step-sizes
-        aux["myula1"].reset_lipschitz_constant(
-            1 / self.hyperparameter_batch["rho1"][past_iter]
-        )
-        aux["myula2"].reset_lipschitz_constant(
-            1 / self.hyperparameter_batch["rho2"][past_iter]
-        )
-        aux["myula3"].reset_lipschitz_constant(
-            1 / self.hyperparameter_batch["rho3"][past_iter]
-        )
-
-        # * sample image x
-        self.parameter_batch["x"][current_iter], aux["Hx"] = sample_x(
-            self.model,
-            self.parameter_batch["u1"][past_iter],
-            self.parameter_batch["z1"][past_iter],
-            self.parameter_batch["u2"][past_iter],
-            self.parameter_batch["z2"][past_iter],
-            self.parameter_batch["u3"][past_iter],
-            self.parameter_batch["z3"][past_iter],
-            self.hyperparameter_batch["rho1"][past_iter],
-            self.hyperparameter_batch["rho2"][past_iter],
-            self.hyperparameter_batch["rho3"][past_iter],
-            self.rng,
-        )
-
-        # * sample auxiliary variables (z1, u1)
-        self.parameter_batch["z1"][current_iter] = sample_z1(
-            self.parameter_batch["z1"][past_iter],
-            self.observations,
-            aux["Hx"],
-            self.parameter_batch["u1"][past_iter],
-            self.hyperparameter_batch["rho1"][past_iter],
-            aux["myula1"],
-            self.rng,
-        )
-        self.parameter_batch["u1"][current_iter] = sample_u(
-            self.parameter_batch["z1"][current_iter],
-            aux["Hx"],
-            self.hyperparameter_batch["rho1"][past_iter],
-            self.hyperparameter_batch["alpha1"][past_iter],
-            self.rng,
-        )
-
-        # * sample auxiliary variables (z2, u2)
-        self.parameter_batch["z2"][current_iter] = sample_z2(
-            self.parameter_batch["z2"][past_iter],
-            self.parameter_batch["x"][current_iter],
-            self.parameter_batch["u2"][past_iter],
-            self.hyperparameter_batch["rho2"][past_iter],
-            self.hyperparameter_batch["beta"][past_iter],
-            aux["myula2"],
-            self.rng,
-        )
-
-        self.parameter_batch["u2"][current_iter] = sample_u(
-            self.parameter_batch["z2"][current_iter],
-            self.parameter_batch["x"][current_iter],
-            self.hyperparameter_batch["rho2"][past_iter],
-            self.hyperparameter_batch["alpha2"][past_iter],
-            self.rng,
-        )
-
-        # * sample auxiliary variables (z3, u3)
-        self.parameter_batch["z3"][current_iter] = sample_z3(
-            self.parameter_batch["z3"][past_iter],
-            self.parameter_batch["x"][current_iter],
-            self.parameter_batch["u3"][past_iter],
-            self.hyperparameter_batch["rho3"][past_iter],
-            aux["myula3"],
-            self.rng,
-        )
-
-        self.parameter_batch["u3"][current_iter] = sample_u(
-            self.parameter_batch["z3"][current_iter],
-            self.parameter_batch["x"][current_iter],
-            self.hyperparameter_batch["rho3"][past_iter],
-            self.hyperparameter_batch["alpha3"][past_iter],
-            self.rng,
-        )
-
-        # * update hyperparameters (using MML)
-        # compute auxiliary potentials for MML updates + computation of the
-        # overall potential
-        self._compute_potentials(aux, current_iter)
-
-        # TODO: revise and simplify this part
-        # order: [data_fidelity, lrho1, lrho2, lrho3, lalpha1, lalpha2, lalpha3, prior]
-        # check if this is fine in practice
-        c_ = 1  # assuming entries are on the rights order in the vector
-        for key in ["rho1", "rho2", "rho3", "alpha1", "alpha2", "alpha3", "beta"]:
-            self.hyperparameter_batch[key][current_iter] = self.mml_dict[key].update(
-                self.hyperparameter_batch[key][past_iter],
-                iter_mc,
-                aux["potentials"][c_],
-            )
-            c_ += 1
-
-        pass
diff --git a/src/dsgs/samplers/serial/serial_psgla_poisson_deconvolution.py b/src/dsgs/samplers/serial/serial_psgla_poisson_deconvolution.py
index 723895b13794810bd2b5dba46d654c16453ede5c..e37fe6b5c03f1402d5b9eb700a203f0337a5e78a 100644
--- a/src/dsgs/samplers/serial/serial_psgla_poisson_deconvolution.py
+++ b/src/dsgs/samplers/serial/serial_psgla_poisson_deconvolution.py
@@ -292,15 +292,6 @@ class PsglaSGS(SerialSampler):
 
         self.xmax = xmax
 
-        # gamma_x = 0.99 / (
-        #     np.max(np.abs(model.fft_kernel)) ** 2 / hyperparameters["rho1"]
-        #     + 8 / hyperparameters["rho2"]
-        # )
-        # gamma1 = 0.99 * hyperparameters["rho1"]
-        # gamma2 = 0.99 * hyperparameters["rho2"]
-
-        # kernels = {"x": PSGLA(gamma_x), "z1": PSGLA(gamma1), "z2": PSGLA(gamma2)}
-
         super(PsglaSGS, self).__init__(
             observations,
             model,
@@ -341,7 +332,6 @@ class PsglaSGS(SerialSampler):
 
         # PSGLA step-sizes
         # ! only if the hyperparameters are kept fixed
-        # ! see what to do if hyperparameters have been updated in the iterations
         stepsize_x = 0.99 / (
             np.max(np.abs(self.model.fft_kernel)) ** 2
             / self.hyperparameter_batch["rho1"][0]
@@ -377,13 +367,6 @@ class PsglaSGS(SerialSampler):
         return potential
 
     def _sample_step(self, iter_mc, current_iter, past_iter, aux):
-        # PSGLA step-sizes (-> now defined earlier)
-        # gamma_x = 0.99 / (
-        #     np.max(np.abs(conv_model.fft_kernel)) ** 2 / rho1_mc[past_iter]
-        #     + 8 / rho2_mc[past_iter]
-        # )
-        # gamma1 = 0.99 * rho1_mc[past_iter]
-        # gamma2 = 0.99 * rho2_mc[past_iter]
 
         # * sample image x
         grad_x = gradient_x(
diff --git a/src/dsgs/samplers/serial_sampler.py b/src/dsgs/samplers/serial_sampler.py
index c21533ffb7eada88328d59d314d057383f196c44..6b6f45456feba1e9599ebdb13cc58ca79228dcd1 100644
--- a/src/dsgs/samplers/serial_sampler.py
+++ b/src/dsgs/samplers/serial_sampler.py
@@ -196,13 +196,6 @@ class SerialSampler(BaseSampler):
         potential : numpy.ndarray
             Evolution of the log-posterior over the checkpoint window.
         """
-        #
-        # f2(**d)  # turning a dict into keyword parameters
-        # f(**d1, **d2)  # need to make sure no duplicate keys, otherwise use collections.ChainMap (https://subscription.packtpub.com/book/application-development/9781788830829/1/ch01lvl1sec14/unpacking-multiple-keyword-arguments)
-        #
-        # TODO: see if copies can be avoided (not sure this occurs with the
-        # dictionaries)
-
         # ! if the full batch for the parameter "x" needs to be saved to disk
         chunk_sizes = []
         last_sample = {}
diff --git a/src/dsgs/samplers/spmd_sampler.py b/src/dsgs/samplers/spmd_sampler.py
index 3d34b4302ab083eaa297dd44946851c6093b840b..d36a2c7530e24dbdf07d8df4a0c4b813a783c43d 100644
--- a/src/dsgs/samplers/spmd_sampler.py
+++ b/src/dsgs/samplers/spmd_sampler.py
@@ -443,10 +443,6 @@ class BaseSPMDSampler(ABC):
         potential : numpy.ndarray
             Evolution of the log-posterior over the checkpoint window.
         """
-        # TODO: give the option to switch from one config to another: single
-        # TODO: save file or one per process.
-
-        # TODO: save MMSE, MAP and last sample only (not mean of squares)
         # ! save last sample for all variables, compute MMSE and MAP for x only
         chunk_sizes = self.chunk_sizes + 2 * [self.chunk_sizes[0]]
 
diff --git a/src/dsgs/samplers/transition_kernels/metropolis_hastings.py b/src/dsgs/samplers/transition_kernels/metropolis_hastings.py
index 892be0c3800936d41b10d32a1a9f1cfd95044bbd..ed87eb1daf1f21c6f7a7a0df734c05c31376cf6d 100644
--- a/src/dsgs/samplers/transition_kernels/metropolis_hastings.py
+++ b/src/dsgs/samplers/transition_kernels/metropolis_hastings.py
@@ -146,10 +146,6 @@ class MetropolisHastings(BaseMCMCKernel):
             return current_state
 
 
-# TODO: 1. take a lambda function to compute the relevant potential?
-# TODO: 2. check if a BaseMCMCKernel can be plugged in there or not!
-# TODO: 3. pass a lambda function to pass the expression for the potential of
-# TODO     the target distribution
 class MetropolisedKernel(MetropolisHastings):
     """Generic implementation of a Metropolis-Hastings transition kernel, where
     the proposal distribution is generated by another kernel."""
diff --git a/src/dsgs/samplers/transition_kernels/myula.py b/src/dsgs/samplers/transition_kernels/myula.py
index d713b93730454b22eb46048b36e85bf087be37d3..2f382fc2603c944bc8534f2e68f426b3cadc5b04 100644
--- a/src/dsgs/samplers/transition_kernels/myula.py
+++ b/src/dsgs/samplers/transition_kernels/myula.py
@@ -59,7 +59,6 @@ def default_myula_parameters(
     return np.array([stepsize, regularization]), np.array([c1, c2])
 
 
-# TODO: implement prox as a lambda function?
 class MYULA(BaseMCMCKernel):
     r"""MYULA transition kernel :cite:p:`Durmus2018` associated with a target
     density of the form