diff --git a/CITATION.cff b/CITATION.cff index d45fdca12ec721cb23920a28bbacabfd7f3df242..8b640c88cb75c1fc378272a324d27b8dc7c0536a 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -33,7 +33,7 @@ preferred-citation: month: 10 # start: 1 # First page number # end: 10 # Last page number - title: "A distributed Gibbs Sampler with Hypergraph Structure for High-Dimensional Inverse Problems" + title: "A Distributed Split-Gibbs Sampler with Hypergraph Structure for High-Dimensional Inverse Problems" # issue: 1 # volume: 1 year: 2022 diff --git a/README.md b/README.md index 874912b22285a4357a36f7713ec14f3675bf57b3..243685b0210fe64a159dccfe88508fc91f3c47a7 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ ______________________________________________________________________ Python codes associated with the method described in the following paper. -> \[1\] P.-A. Thouvenin, A. Repetti, P. Chainais - **A distributed Gibbs Sampler with Hypergraph Structure for High-Dimensional Inverse Problems**, [arxiv preprint 2210.02341](http://arxiv.org/abs/2210.02341), October 2022, under review. +> P.-A. Thouvenin, A. Repetti, P. Chainais - **A Distributed Block-Split Gibbs Sampler with Hypergraph Structure for High-Dimensional Inverse Problems**, [arxiv preprint 2210.02341](http://arxiv.org/abs/2210.02341), October 2023, to appear in JCGS. **Authors**: P.-A. Thouvenin, A. Repetti, P. Chainais @@ -53,34 +53,34 @@ ______________________________________________________________________ ```bash # Cloning the repo. or unzip the dsgs-main.zip code archive -# git clone --recurse-submodules https://gitlab.com/pthouvenin/...git +# git clone --recurse-submodules https://gitlab.cristal.univ-lille.fr/pthouven/dsgs.git unzip dsgs-main.zip cd dsgs-main # Create a conda environment using one of the lock files provided in the archive -# (use jcgs_review_environment_osx.lock.yml for MAC OS) -mamba env create --name jcgs-review --file jcgs_review_environment_linux.lock.yml -# mamba env create --name jcgs-review --file jcgs_review_environment.yml +# (use dsgs_environment_osx.lock.yml for MAC OS) +mamba env create --name dsgs --file dsgs_environment_linux.lock.yml +# mamba env create --name dsgs --file dsgs_environment.yml # Activate the environment -mamba activate jcgs-review +mamba activate dsgs # Install the library in editable mode mamba develop src/ # Deleting the environment (if needed) -# mamba env remove --name jcgs-review +# mamba env remove --name dsgs # Generating lock file from existing environment (if needed) -# mamba env export --name jcgs-review --file jcgs_review_environment_linux.lock.yml +# mamba env export --name dsgs --file dsgs_environment_linux.lock.yml # or -# mamba list --explicit --md5 > explicit_jcgs_env_linux-64.txt -# mamba create --name jcgs-test -c conda-forge --file explicit_jcgs_env_linux-64.txt +# mamba list --explicit --md5 > explicit_dsgs_env_linux-64.txt +# mamba create --name jcgs-test -c conda-forge --file explicit_dsgs_env_linux-64.txt # pip install docstr-coverage genbadge wily sphinxcontrib-apa sphinx_copybutton # Manual install (if absolutely needed) -# mamba create --name jcgs-review numpy numba mpi4py "h5py>=2.9=mpi*" scipy scikit-image matplotlib imageio tqdm jupyterlab pytest black flake8 isort coverage pre-commit sphinx sphinx_rtd_theme sphinxcontrib-bibtex sphinx-autoapi sphinxcontrib furo conda-lock conda-build -# mamba activate jcgs-review +# mamba create --name dsgs numpy numba mpi4py "h5py>=2.9=mpi*" scipy scikit-image matplotlib imageio tqdm jupyterlab pytest black flake8 isort coverage pre-commit sphinx sphinx_rtd_theme sphinxcontrib-bibtex sphinx-autoapi sphinxcontrib furo conda-lock conda-build +# mamba activate dsgs # pip install sphinxcontrib-apa sphinx_copybutton docstr-coverage genbadge wily # mamba develop src ``` @@ -94,7 +94,7 @@ export HDF5_USE_FILE_LOCKING='FALSE' - To test the installation went well, you can run the unit-tests provided in the package using the command below. ```bash -mamba activate jcgs-review +mamba activate dsgs pytest --collect-only export NUMBA_DISABLE_JIT=1 # need to disable jit compilation to check test coverage coverage run -m pytest # run all the unit-tests (see Documentation section for more details) @@ -139,7 +139,7 @@ tmux kill-session -t session_name ```bash # from a terminal at the root of the archive -mamba activate jcgs-review +mamba activate dsgs cd examples/jcgs # generate all the synthetic datasets used in the experiments (to be run only once) @@ -161,7 +161,7 @@ mamba deactivate - The content of an [`.h5`](https://docs.h5py.org/en/stable/mpi.html?highlight=h5dump#using-parallel-hdf5-from-h5py) file can be quickly checked from the terminal (see the [`h5py`](https://docs.h5py.org/en/stable/quick.html) documentation for further details). Some examples are provided below. ```bash -mamba activate jcgs-review +mamba activate dsgs # replace <filename> by the name of your file in the instructions below h5dump --header <filename>.h5 # displays the name and size of all variables contained in the file @@ -181,7 +181,7 @@ ______________________________________________________________________ - The documentation can be generated in `.html` format using the following commands issued from a terminal. ```bash -mamba activate jcgs-review +mamba activate dsgs cd build docs/build/html make html ``` @@ -191,7 +191,7 @@ make html To test the code/docstring coverage, run the following commands from a terminal. ```bash -mamba activate jcgs-review +mamba activate dsgs pytest --collect-only export NUMBA_DISABLE_JIT=1 # need to disable jit compilation to check test coverage coverage run -m pytest # check all tests @@ -205,20 +205,13 @@ docstr-coverage . # check docstring coverage and generate the associated covera To launch a single test, run a command of the form ```bash -mamba activate jcgs-review +mamba activate dsgs python -m pytest tests/models/test_crop.py pytest --markers # check full list of markers availables pytest -m "not mpi" --ignore-glob=**/archive_unittest/* # run all tests not marked as mpi + ignore files in any directory "archive_unittest" mpiexec -n 2 python -m mpi4py -m pytest -m mpi # run all tests marked mpi with 2 cores ``` -<!-- To measure code quality with `wily` - -```bash -wily build src/ -n 100 -wily report dsgs -f HTML -o docs/build/wily_report.html -``` --> - ______________________________________________________________________ ## License diff --git a/docs/source/biblio.bib b/docs/source/biblio.bib index 34fd23fc4b374fc8aad421fd7cda56e9fc74ce91..28e03bf7ffbae54b9d4baf9d055a69b9b25741d9 100644 --- a/docs/source/biblio.bib +++ b/docs/source/biblio.bib @@ -126,13 +126,6 @@ eprint = {https://doi.org/10.1137/19M1283719}, } -@PhdThesis{Prusa2012, - author = {Zden\v{e}k Pru\v{s}a}, - title = {Segmentwise discrete wavelet transform}, - school = {Brno university of technology}, - year = {2012}, -} - @inproceedings{Salim2020, author = {Salim, Adil and Richt\`{a}rik, Peter}, booktitle = NIPS, @@ -142,10 +135,10 @@ year = {2020} } -@misc{Thouvenin2022submitted, - title={A distributed Gibbs Sampler with Hypergraph Structure for High-Dimensional Inverse Problems}, +@misc{Thouvenin2023, + title={A distributed Split-Gibbs Sampler with Hypergraph Structure for High-Dimensional Inverse Problems}, author={Pierre-Antoine Thouvenin and Audrey Repetti and Pierre Chainais}, - year={2022}, + year={2023}, eprint={2210.02341}, archivePrefix={arXiv}, primaryClass={stat.ME} diff --git a/docs/source/conf.py b/docs/source/conf.py index 900fc51e1b34cecb5239b532545b26af5f2a7468..dfb7aaad45139533eb93d32d2247b71a1f915d98 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -23,7 +23,7 @@ copyright = "2023, P.-A. Thouvenin, A. Repetti and P. Chainais" author = "P.-A. Thouvenin, A. Repetti and P. Chainais" # The full version, including alpha/beta/rc tags -release = "0.1.0" +release = "1.0" # -- General configuration --------------------------------------------------- diff --git a/docs/source/index.rst b/docs/source/index.rst index b3538ec374e3bfc2a124267a81c5490591aae298..276a65c39800065bab2ce1b83b5878d40316c17e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -8,8 +8,7 @@ The library currently contains codes to reproduce the experiments reported in :c .. warning:: - This project is under active development, and the API my evolve significantly until version ``1.0``. -.. - A complete code coverage report is available `here <../coverage_html_report/index.html#http://>`_. + The code provided in this archive. .. toctree:: @@ -28,18 +27,9 @@ The library currently contains codes to reproduce the experiments reported in :c autoapi/index .. Code coverage report<../coverage_html_report/index.html#http://> .. License<../../../LICENSE#http://> - Gitlab repository<https://gitlab.cristal.univ-lille.fr/pthouven/dspa> + Gitlab repository<https://gitlab.cristal.univ-lille.fr/pthouven/dsgs> .. Code quality<../wily_report.html#http://> -.. - .. .. autosummary:: - .. :toctree: _autosummary - .. :template: custom-module-template.rst - .. :recursive: - - .. dsgs - - Indices and tables ================== diff --git a/docs/source/setup.rst b/docs/source/setup.rst index b3865c8506a5e8afb97fae05aa5559313e39888d..4f4ef7f482b758aa578aea61771e5fa533cf273a 100644 --- a/docs/source/setup.rst +++ b/docs/source/setup.rst @@ -21,17 +21,17 @@ To install the library, issue the following commands in a terminal. cd dsgs-main # Create anaconda environment (from one of the provided lock files) - mamba env create --name jcgs-review --file jcgs_review_environment_linux.lock.yml # use jcgs_review_environment_osx.lock.yml for osx - # mamba env create --name jcgs-review --file jcgs_review_environment.yml + mamba env create --name dsgs --file dsgs_environment_linux.lock.yml # use dsgs_environment_osx.lock.yml for osx + # mamba env create --name dsgs --file dsgs_environment.yml # Activate the environment - mamba activate jcgs-review + mamba activate dsgs # Install the library in editable mode mamba develop src/ # Deleting the environment (if needed) - # mamba env remove --name jcgs-review + # mamba env remove --name dsgs To avoid `file lock issue in h5py <https://github.com/h5py/h5py/issues/1101>`_, add the following line to the ``~/.zshrc`` file (or ``~/.bashrc``) @@ -43,7 +43,7 @@ To test the installation went well, you can run the unit-tests provided in the p .. code-block:: bash - conda activate jcgs-review + conda activate dsgs pytest --collect-only export NUMBA_DISABLE_JIT=1 # need to disable jit compilation to check test coverage coverage run -m pytest # run all the unit-tests (see Documentation section for more details) @@ -92,7 +92,7 @@ The folder ``./examples/jcgs/configs`` contains ``.json`` files summarizing the # from a terminal at the root of the archive - mamba activate jcgs-review + mamba activate dsgs cd examples/jcgs # generate all the synthetic datasets used in the experiments (to be run only once) @@ -117,7 +117,7 @@ The content of an ``.h5`` file can be quickly checked from the terminal (`refer # replace <filename> by the name of your file h5dump --header <filename>.h5 # displays the name and size of all variables contained in the file - conda activate jcgs-review + conda activate dsgs # replace <filename> by the name of your file in the instructions below h5dump --header <filename>.h5 # displays the name and size of all variables contained in the file @@ -143,7 +143,7 @@ The documentation can be generated in ``.html`` format using the following comma .. code-block:: bash - conda activate jcgs-review + conda activate dsgs cd build docs/build/html make html # compile documentation in html, latex or linkcheck @@ -155,7 +155,7 @@ To test the code/docstring coverage, run the following commands from a terminal. .. code-block:: bash - conda activate jcgs-review + conda activate dsgs export NUMBA_DISABLE_JIT=1 # need to disable jit compilation to check test coverage coverage run -m unittest # check all tests coverage report # generate a coverage report in the terminal @@ -168,7 +168,7 @@ To launch a single test, run a command of the form .. code-block:: bash - conda activate jcgs-review + conda activate dsgs python -m pytest tests/operators/test_crop.py pytest --markers # check full list of markers availables pytest -m "not mpi" --ignore-glob=**/archive_unittest/* # run all tests not marked as mpi + ignore files in any directory "archive_unittest" diff --git a/jcgs_review_environment.yml b/dsgs_environment.yml similarity index 100% rename from jcgs_review_environment.yml rename to dsgs_environment.yml diff --git a/jcgs_review_environment_linux.lock.yml b/dsgs_environment_linux.lock.yml similarity index 100% rename from jcgs_review_environment_linux.lock.yml rename to dsgs_environment_linux.lock.yml diff --git a/jcgs_review_environment_osx.lock.yml b/dsgs_environment_osx.lock.yml similarity index 100% rename from jcgs_review_environment_osx.lock.yml rename to dsgs_environment_osx.lock.yml diff --git a/explicit_jcgs_env_linux-64.txt b/explicit_dsgs_env_linux-64.txt similarity index 100% rename from explicit_jcgs_env_linux-64.txt rename to explicit_dsgs_env_linux-64.txt diff --git a/explicit_jcgs_env_osx-64.txt b/explicit_dsgs_env_osx-64.txt similarity index 100% rename from explicit_jcgs_env_osx-64.txt rename to explicit_dsgs_env_osx-64.txt diff --git a/img/5.1.13.tiff b/img/5.1.13.tiff deleted file mode 100644 index d09cef3ad54b5152c4667e3744cdc1949cb9391d..0000000000000000000000000000000000000000 Binary files a/img/5.1.13.tiff and /dev/null differ diff --git a/img/5.3.01.tiff b/img/5.3.01.tiff deleted file mode 100644 index 13a756d8b7566d29438c6d03910d0ca977156d59..0000000000000000000000000000000000000000 Binary files a/img/5.3.01.tiff and /dev/null differ diff --git a/img/bank.png b/img/bank.png deleted file mode 100644 index 71529dc024f13f94d5ce42a86eb4659ae50f486a..0000000000000000000000000000000000000000 Binary files a/img/bank.png and /dev/null differ diff --git a/img/bank_color.png b/img/bank_color.png deleted file mode 100644 index b390136fce633be696aaa28991bb40f1f01ed7a8..0000000000000000000000000000000000000000 Binary files a/img/bank_color.png and /dev/null differ diff --git a/img/barb.png b/img/barb.png deleted file mode 100755 index e9f29e5ccc8d9296066192006074591bcd76f357..0000000000000000000000000000000000000000 Binary files a/img/barb.png and /dev/null differ diff --git a/img/boat.png b/img/boat.png deleted file mode 100755 index 9098a93d998332bca33e3a20a6921770d62bec86..0000000000000000000000000000000000000000 Binary files a/img/boat.png and /dev/null differ diff --git a/img/chessboard.png b/img/chessboard.png deleted file mode 100755 index 0531a8dad9cda0fb1b69cdcff97e5e08915389ef..0000000000000000000000000000000000000000 Binary files a/img/chessboard.png and /dev/null differ diff --git a/img/corral.png b/img/corral.png deleted file mode 100755 index 5a789fbe07e6d193b56b9cffc0f207927e02c10d..0000000000000000000000000000000000000000 Binary files a/img/corral.png and /dev/null differ diff --git a/img/cortex.png b/img/cortex.png deleted file mode 100755 index 48bd58918ed660e476681d46e97ce51d7c18252e..0000000000000000000000000000000000000000 Binary files a/img/cortex.png and /dev/null differ diff --git a/img/grating.png b/img/grating.png deleted file mode 100755 index 2317973d33a85a6330b78635f4a43c00b56ae80e..0000000000000000000000000000000000000000 Binary files a/img/grating.png and /dev/null differ diff --git a/img/hair.png b/img/hair.png deleted file mode 100755 index c953f129381224f42af3e56f919389bd8683edc6..0000000000000000000000000000000000000000 Binary files a/img/hair.png and /dev/null differ diff --git a/img/lena.png b/img/lena.png deleted file mode 100755 index f14918282436fcd454f2587942fbe5e2564e468a..0000000000000000000000000000000000000000 Binary files a/img/lena.png and /dev/null differ diff --git a/img/line.png b/img/line.png deleted file mode 100755 index b48ddad36cbdf76ce6ec58d271791b67bad106c1..0000000000000000000000000000000000000000 Binary files a/img/line.png and /dev/null differ diff --git a/img/line_horizontal.png b/img/line_horizontal.png deleted file mode 100755 index 06c39f4c3ed1892231a80dd892c3bfa7e780be00..0000000000000000000000000000000000000000 Binary files a/img/line_horizontal.png and /dev/null differ diff --git a/img/line_vertical.png b/img/line_vertical.png deleted file mode 100755 index 7bf11fb97b4d1dfac2d69e620fd7ff65ba38cbd3..0000000000000000000000000000000000000000 Binary files a/img/line_vertical.png and /dev/null differ diff --git a/img/mandrill.png b/img/mandrill.png deleted file mode 100755 index 58b49fbc2e57cde1124fbfaae6be90a24ca2e7c1..0000000000000000000000000000000000000000 Binary files a/img/mandrill.png and /dev/null differ diff --git a/img/mosque.png b/img/mosque.png deleted file mode 100644 index 32b02fc039467d5d8b89c7e65502332b3850dc48..0000000000000000000000000000000000000000 Binary files a/img/mosque.png and /dev/null differ diff --git a/img/mosque_color.png b/img/mosque_color.png deleted file mode 100644 index 9048f0b009617c6d7e0a65834541cf4164261616..0000000000000000000000000000000000000000 Binary files a/img/mosque_color.png and /dev/null differ diff --git a/img/parrot-mask.png b/img/parrot-mask.png deleted file mode 100755 index 831ab3775bddc0ab8e387f1268a8276905e0425d..0000000000000000000000000000000000000000 Binary files a/img/parrot-mask.png and /dev/null differ diff --git a/img/parrot.png b/img/parrot.png deleted file mode 100755 index 4cfdf35ee9ef9f172f83c0f889fc5b778b6f7498..0000000000000000000000000000000000000000 Binary files a/img/parrot.png and /dev/null differ diff --git a/img/parrotgray.png b/img/parrotgray.png deleted file mode 100644 index 57551b3e4b374e7cc8378140623c4cccad2e0a6d..0000000000000000000000000000000000000000 Binary files a/img/parrotgray.png and /dev/null differ diff --git a/img/periodic_bumps.png b/img/periodic_bumps.png deleted file mode 100755 index fb4fe1b2facdb8dd9d82a1a094644898878eabaf..0000000000000000000000000000000000000000 Binary files a/img/periodic_bumps.png and /dev/null differ diff --git a/img/rubik1.png b/img/rubik1.png deleted file mode 100755 index bbe824072a2feecaea4ee742967e921c8a938190..0000000000000000000000000000000000000000 Binary files a/img/rubik1.png and /dev/null differ diff --git a/img/rubik2.png b/img/rubik2.png deleted file mode 100755 index d56bd3e11c12cd2ff683dcb68458a7d759b5020a..0000000000000000000000000000000000000000 Binary files a/img/rubik2.png and /dev/null differ diff --git a/img/rubik3.png b/img/rubik3.png deleted file mode 100755 index e2a140af7b9cb89af5b15babaacf9209833b8228..0000000000000000000000000000000000000000 Binary files a/img/rubik3.png and /dev/null differ diff --git a/img/taxi1.png b/img/taxi1.png deleted file mode 100755 index d3859086df432710c74fd0408e61faead55e936b..0000000000000000000000000000000000000000 Binary files a/img/taxi1.png and /dev/null differ diff --git a/img/taxi2.png b/img/taxi2.png deleted file mode 100755 index 618840842b18f53f79af146b670684d4ce422d2e..0000000000000000000000000000000000000000 Binary files a/img/taxi2.png and /dev/null differ diff --git a/img/taxi3.png b/img/taxi3.png deleted file mode 100755 index 583a76f197f5e9e537033e70b54bf7b4949ceaf6..0000000000000000000000000000000000000000 Binary files a/img/taxi3.png and /dev/null differ diff --git a/img/vessels.png b/img/vessels.png deleted file mode 100755 index 3998b37cd71df9869cba1aece77c19e130ab05b1..0000000000000000000000000000000000000000 Binary files a/img/vessels.png and /dev/null differ diff --git a/src/dsgs/main_metrics_serial.py b/src/dsgs/main_metrics_serial.py index 539d242b7e48c5c5044f78d2c34107e20e5b2ec7..29fb87747dc19968496ac0d038752637bca291d0 100644 --- a/src/dsgs/main_metrics_serial.py +++ b/src/dsgs/main_metrics_serial.py @@ -76,14 +76,8 @@ def main_metrics( f = h5py.File(filename + ".h5", "r+", driver="mpio", comm=MPI.COMM_WORLD) N = f["N"][()] - # M = f["M"][()] - # h = f["h"][()] f.close() - # overlap_size = np.array(h.shape, dtype="i") - # overlap_size -= 1 - # data_size = N + overlap_size - # slice to extract image tile from full image file tile_pixels = local_split_range_nd(grid_size, N, ranknd) tile_size = tile_pixels[:, 1] - tile_pixels[:, 0] + 1 diff --git a/src/dsgs/main_metrics_spmd.py b/src/dsgs/main_metrics_spmd.py index bc2259ac0b35a517438591ef03f28b7b5e03fcf9..32a76ca469e52d1568b894d5d937e57ef7e0e437 100644 --- a/src/dsgs/main_metrics_spmd.py +++ b/src/dsgs/main_metrics_spmd.py @@ -76,14 +76,8 @@ def main_metrics( f = h5py.File(filename + ".h5", "r+", driver="mpio", comm=MPI.COMM_WORLD) N = f["N"][()] - # M = f["M"][()] - # h = f["h"][()] f.close() - # overlap_size = np.array(h.shape, dtype="i") - # overlap_size -= 1 - # data_size = N + overlap_size - # slice to extract image tile from full image file tile_pixels = local_split_range_nd(grid_size, N, ranknd) tile_size = tile_pixels[:, 1] - tile_pixels[:, 0] + 1 @@ -246,7 +240,6 @@ def main_metrics( # MMSE and MAP (need all processes) saver.save("", [N], [global_slice_tile], [None], mode="w", x_mmse=local_x_mmse) - # ! seems to create a segfault for a large number of processes: why? saver.save("", [N], [global_slice_tile], [None], mode="a", x_map=local_x_map) snr_mmse = 0.0 diff --git a/src/dsgs/main_serial_poisson_deconvolution.py b/src/dsgs/main_serial_poisson_deconvolution.py index 9defb8028a135533b378d1377856d6f4da16cbae..c807ea51d448a6d0ebfa22aff0144f3a1098a18b 100644 --- a/src/dsgs/main_serial_poisson_deconvolution.py +++ b/src/dsgs/main_serial_poisson_deconvolution.py @@ -165,10 +165,6 @@ def main( f.close() logger.info("End: loading data and images") - # parameters of the Gamma prior on the regularization parameter - # a = 1e-3 - # b = 1e-3 - logger.info("Parameters defined, setup sampler") if args.prof: @@ -245,17 +241,6 @@ def main( else: raise ValueError("Unknown sampler: {}".format(sampler)) - if args.prof: - pr.disable() - # Dump results: - # - for binary dump - pr.dump_stats("debug/cpu_serial.prof") - # - for text dump - with open("debug/cpu_serial.txt", "w") as output_file: - sys.stdout = output_file - pr.print_stats(sort="time") - sys.stdout = sys.__stdout__ - if __name__ == "__main__": import argparse @@ -291,11 +276,9 @@ if __name__ == "__main__": ] args = parser.parse_args() - # print(args) - - # args = utils.args.parse_args() # # * debugging values + # print(args) # args.imfile = "img/image_micro_8.h5" # args.datafilename = "data_image_micro_8_ds1_M30_k8" # args.rpath = "debug" @@ -379,21 +362,6 @@ if __name__ == "__main__": datafilename = join(args.dpath, args.datafilename + ".h5") checkpointname = join(args.rpath, args.checkpointname) - # tau.run("""main( - # args.data, - # args.rpath, - # args.imfile, - # datafilename, - # args.checkpointname, - # logger, - # profiling=args.prof, - # alpha=args.alpha, - # beta=args.beta, - # Nmc=args.Nmc, - # checkpoint_frequency=args.checkpoint_frequency, - # monitor_frequency=5, - # )""") - main( args.data, args.rpath, diff --git a/src/dsgs/main_spmd_poisson_deconvolution.py b/src/dsgs/main_spmd_poisson_deconvolution.py index 704fa7586cc3764bf2dbc517997f9833bb098997..9ab6998697013235c674fbe455282734243b8168 100644 --- a/src/dsgs/main_spmd_poisson_deconvolution.py +++ b/src/dsgs/main_spmd_poisson_deconvolution.py @@ -21,9 +21,6 @@ from dsgs.samplers.parallel.spmd_psgla_poisson_deconvolution import SpmdPsglaSGS from dsgs.utils.checkpoint import DistributedCheckpoint, SerialCheckpoint from dsgs.utils.communications import local_split_range_nd -# import tau -# https://forum.hdfgroup.org/t/crash-when-writing-parallel-compressed-chunks/6186 - def main( comm, @@ -282,17 +279,6 @@ def main( spmd_sampler.sample() - if args.prof: - pr.disable() - # Dump results: - # - for binary dump - pr.dump_stats("debug/cpu_%d.prof" % comm.rank) - # - for text dump - with open("debug/cpu_%d.txt" % comm.rank, "w") as output_file: - sys.stdout = output_file - pr.print_stats(sort="time") - sys.stdout = sys.__stdout__ - if __name__ == "__main__": import argparse @@ -331,11 +317,9 @@ if __name__ == "__main__": ] args = parser.parse_args() - # print(args) - - # args = utils.args.parse_args() # * debugging values + # print(args) # args.imfile = "img/cameraman.png" # args.datafilename = "conv_data_cameraman_ds1_M30_k8" # args.rpath = "results_conv_cameraman_ds1_M30_k8_h5" diff --git a/src/dsgs/operators/convolutions.py b/src/dsgs/operators/convolutions.py index f54aca143a44906371a575d8779d42da11bb3623..ffafaeb8f9c23ac90d52b8143d091e72f1a9b971 100755 --- a/src/dsgs/operators/convolutions.py +++ b/src/dsgs/operators/convolutions.py @@ -128,8 +128,6 @@ def linear_convolution(x, h, mode="constant"): rsize = np.array(h.shape, dtype="i") - 1 y = pad_array_nd(x, lsize, rsize, mode="constant") return scipy.ndimage.convolve(y, h, mode="constant", cval=0.0) - # y = pad_array_nd(x, lsize, rsize, mode=mode) - # return scipy.ndimage.convolve(y, h, mode="constant", cval=0.0) def adjoint_linear_convolution(y, h, mode="constant"): @@ -156,236 +154,3 @@ def adjoint_linear_convolution(y, h, mode="constant"): ) # ! np.flip(h, axis=None) flips all axes return adjoint_padding(x, lsize, rsize, mode="constant") - - # ! symmetric bd condition - # ! not functional for now - # lsize = (np.array(h.shape, dtype="i") - 1) - # rsize = lsize - # yp = pad_array_nd(y, np.zeros(len(h.shape), dtype="i"), rsize, mode="constant") - # x = scipy.ndimage.convolve(yp, np.conj(np.flip(h, axis=None)), mode='constant', cval=0.0) - # # ! np.flip(h, axis=None) flips all axes - # return adjoint_padding(x, lsize, rsize, mode=mode) - - -# TODO: write a generic version -# direct: padding, conv. in valid mode -# Hy_ = sg.convolve2d(y_, h, boundary="fill", mode="full") -# Hadj_y_ = adjoint_padding(Hy_, ext_size, ext_size, mode="symmetric") - -# same kind for symmetric when not based on fft (overlap-add for distributed version) - - -# ! to be made more generic (quick test for now) -# TODO: adjoint of a convolution operator with any boundary extension involves -# the adjoint of the circular convolution and the adjoint of the padding -# operator -# def adjoint_conv(x, h, shape): -# """Adjoint of the linear convolution operator. - -# Parameters -# ---------- -# x : _type_ -# _description_ -# h : _type_ -# _description_ -# shape : _type_ -# _description_ - -# Returns -# ------- -# _type_ -# _description_ -# """ -# Hx = sg.convolve2d(x, np.flip(h, axis=None), boundary="fill", mode="full") -# # ! np.flip(m, axis=None) flips all axes -# s = tuple([np.s_[: shape[k]] for k in range(len(shape))]) -# return Hx[s] - - -if __name__ == "__main__": - # # TODO: structure the example better, make sure this is included in a - # # TODO: unit-test - # import matplotlib.pyplot as plt - # import scipy.signal as sg - # from PIL import Image - - # from dsgs.operators.linear_convolution import SerialConvolution - # from dsgs.operators.padding import adjoint_padding, pad_array, pad_array_nd - - # # Generate 2D Gaussian convolution kernel - # vr = 1 - # M = 7 - # if np.mod(M, 2) > 0: # M odd - # n = np.arange(-(M - 1) // 2, (M - 1) // 2 + 1) - # else: - # n = np.arange(-M // 2, M // 2) - # h = np.exp(-(n**2 + n[:, np.newaxis] ** 2) / (2 * vr)) - - # # plt.imshow(h, cmap=plt.cm.gray) - # # plt.show() - - # x = np.asarray(Image.open("img/cameraman.png", "r"), dtype="d") - # N = x.shape - # M = h.shape - - # # * version 1: circular convolution - # K = N - # hpad = pad_array(h, K, padmode="after") # after, using fft convention for center - # yc, H = fft2_conv(x, hpad, K) - - # circ_conv = SerialConvolution(np.array(K, dtype="i"), h, np.array(K, dtype="i")) - # yc2 = circ_conv.forward(x) - # print("yc2 == yc ? {0}".format(np.allclose(yc2, yc))) - - # # plt.imshow(yc, cmap=plt.cm.gray) - # # plt.show() - - # # check adjoint operator (circular convolution) - # rng = np.random.default_rng(1234) - # x_ = rng.standard_normal(N) - # Hx_ = circ_conv.forward(x_) - # y_ = rng.standard_normal(K) - # Hadj_y_ = circ_conv.adjoint(y_) - # hp1 = np.sum(Hx_ * y_) - # hp2 = np.sum(x_ * Hadj_y_) - - # print( - # "Correct adjoint operator (circular convolution)? {}".format( - # np.isclose(hp1, hp2) - # ) - # ) - - # # * version 1.2: circular convolution w/o Fourier - # K = N - # # with fft - # hpad = pad_array(h, K, padmode="after") # after, using fft convention for center - # yc, H = fft2_conv(x, hpad, K) - - # # w/o fft (pad signal, linear convolution with x w/o additional border effect) -> ok - # shift = [[M[n] - 1, 0] for n in range(len(M))] - # # xp = np.pad(x, shift, mode="wrap") - # # yc2 = sg.convolve2d(xp, h, boundary="fill", mode="valid") - # # print("yc2 == yc? {0}".format(np.allclose(yc2, yc))) - # # yc3 = sg.convolve2d(xp, h, boundary="fill", mode="full") - # # yc3 = yc3[M[0]-1:-(M[0]-1), M[1]-1:-(M[1]-1)] - # # print("yc3 == yc2? {0}".format(np.allclose(yc2, yc3))) - - # # check adjoint operator (circular convolution, w/o Fourier) - # rng = np.random.default_rng(1234) - # x_ = rng.standard_normal(N) - # Hx_ = np.pad(x_, shift, mode="wrap") - # Hx_ = sg.convolve2d(Hx_, h, boundary="fill", mode="valid") - - # y_ = rng.standard_normal(K) - # Hadj_y_ = sg.convolve2d(y_, np.conj(np.flip(h)), boundary="fill", mode="full") - # # ! manual adjoint padding (wrap condition) - # # Hadj_y_[-(M[0]-1):, :] += Hadj_y_[:M[0]-1, :] - # # Hadj_y_[:, -(M[1]-1):] += Hadj_y_[:, :M[1]-1] - # # Hadj_y_ = Hadj_y_[M[0]-1:, M[1]-1:] - # Hadj_y_ = adjoint_padding( - # Hadj_y_, [M[n] - 1 for n in range(len(M))], 2 * [None], mode="wrap" - # ) - - # hp1 = np.sum(Hx_ * y_) - # hp2 = np.sum(x_ * Hadj_y_) - # print( - # "Correct adjoint operator (circular convolution)? {}".format( - # np.isclose(hp1, hp2) - # ) - # ) - - # # * version 2: linear convolution - # # linear convolution - # K = [N[n] + M[n] - 1 for n in range(len(N))] - # H = np.fft.rfft2(h, K) - # yl = np.fft.irfft2(H * np.fft.rfft2(x, K), K) # zeros appear around - - # yl2 = fft_conv(x, H, K) - # print("yl2 == yl ? {0}".format(np.allclose(yl2, yl))) - - # linear_conv = SerialConvolution(np.array(N, dtype="i"), h, np.array(K, dtype="i")) - # yl3 = linear_conv.forward(x) - # print("yl3 == yl ? {0}".format(np.allclose(yl3, yl))) - - # # plt.imshow(yl, cmap=plt.cm.gray) - # # plt.show() - - # # check adjoint operator (linear convolution) - # rng = np.random.default_rng(1234) - # x_ = rng.standard_normal(N) - # Hx_ = linear_conv.forward(x_) - # y_ = rng.standard_normal(K) - # Hadj_y_ = linear_conv.adjoint(y_) - # hp1 = np.sum(Hx_ * y_) - # hp2 = np.sum(x_ * Hadj_y_) - - # print( - # "Correct adjoint operator (linear convolution)? {}".format(np.isclose(hp1, hp2)) - # ) - - # # adjoint operator (w/o fft) - # Hadj_y = adjoint_conv(x_, h, N) - # hp2 = np.sum(np.conj(x_) * Hadj_y_) - - # print( - # "Correct adjoint convolution operator (linear convolution, no fft)? {}".format( - # np.isclose(hp1, hp2) - # ) - # ) - - # # Notes: - # # 5::4 -> slice(5, None, 4) - # # 5:-1 -> slice(5, -1) -> np.s_[5:-1] - # # s = alpha[5::4] == t = alpha[slice(5, None, 4)] - - # * convolution with symmetric boundary conditions - # rng = np.random.default_rng(1234) - # N_ = np.array(N, dtype="i") - # M_ = np.array(M, dtype="i") - # ext_size = M_ - 1 - # K_ = N_ + ext_size - # Fh_d = np.fft.rfft2(h, s=N_ + 3 * M_ - 3) - # Fh_a = np.fft.rfft2(h, s=K_ + M_ - 1) - - # x_ = rng.standard_normal(N) - # y_ = rng.standard_normal(K_) - - # # direct operator (with sg.convolve2d) - # Hx_ = sg.convolve2d(x_, h, boundary="symm", mode="full") - - # # direct operator (alternative using manual padding) - # # xp = pad_array_nd(x_, ext_size, ext_size, mode="symmetric") - # # Hx2 = sg.convolve(xp, h, mode="valid") - - # # direct operator (alternative based on fft) - # xp = pad_array_nd(x_, ext_size, ext_size, mode="symmetric") - # Hx2_ = fft_conv(xp, Fh_d, K_ + 2 * M_ - 2) - # Hx2 = adjoint_padding( - # Hx2_, ext_size, ext_size, mode="constant" - # ) # equivalent to valid - # print( - # "Correct direct convolution (symmetric extension)? {}".format( - # np.allclose(Hx2, Hx_) - # ) - # ) - # # -> y = CHPx, C cropping (valid, remove M-1 on each side), P boundary extension, H convolution - - # # adjoint operator (using sg.convolve2d) - # Hy_ = sg.convolve2d(y_, h, boundary="fill", mode="full") - # Hadj_y_ = adjoint_padding(Hy_, ext_size, ext_size, mode="symmetric") - - # # adjoint operator (using fft) - # # ! correct, but no conjugate of Fh_a here! why? - # Hy0_ = fft_conv(y_, Fh_a, K_ + M_ - 1) - # Hadj_y0 = adjoint_padding(Hy0_, ext_size, ext_size, mode="symmetric") - - # hp1 = np.sum(np.conj(Hx_) * y_) - # hp2 = np.sum(np.conj(x_) * Hadj_y_) - - # print( - # "Correct adjoint convolution operator (symmetric extension)? {}".format( - # np.isclose(hp1, hp2) - # ) - # ) - - pass diff --git a/src/dsgs/operators/data.py b/src/dsgs/operators/data.py index e0f907817e161f54f3629d8a0f415b7a8873f163..1c546a37e5e0cd3dd19ed414ccf221f3ee7243da 100644 --- a/src/dsgs/operators/data.py +++ b/src/dsgs/operators/data.py @@ -8,9 +8,6 @@ MPI processes. # Sampler with Hypergraph Structure for High-Dimensional Inverse Problems**, # [arxiv preprint 2210.02341](http://arxiv.org/abs/2210.02341), October 2022. -# TODO: to be simplified, e.g. with checkpoint and model objects, or even -# TODO: removed - from os.path import splitext import h5py @@ -117,7 +114,7 @@ def generate_random_mask(image_size, percent, rng): "Fraction of observed pixels percent should be such that: 0 <= percent <= 1." ) - N = np.prod(image_size) # total number of pixels + N = np.prod(image_size) masked_id = np.unravel_index( rng.choice(N, (percent * N).astype(int), replace=False), image_size ) @@ -430,16 +427,11 @@ def generate_local_data( # local convolution fft_h = np.fft.rfftn(h, local_conv_size) - # H = np.fft.rfft2( - # np.fft.fftshift(padding.pad_array(h, N, padmode="around")) - # ) # doing as in Vono's reference code local_coeffs = ucomm.slice_valid_coefficients(ranknd, grid_size, overlap_size) # ! issue: need to make sure Hx >= 0, not necessarily the case numerically # ! with a fft-based convolution # https://github.com/pytorch/pytorch/issues/30934 - # Hx = scipy.ndimage.convolve(facet, h, output=Hx, mode='constant', cval=0.0) - # Hx = convolve2d(facet, h, mode='full')[local_coeffs] Hx = fft_conv(facet, fft_h, local_conv_size)[local_coeffs] prox_nonegativity(Hx) local_data = local_rng.poisson(Hx) @@ -483,14 +475,6 @@ def mpi_load_data_from_h5( ): ndims = data_size.size - # slice for indexing into global arrays - # global_slice_data = tuple( - # [ - # np.s_[tile_pixels[d, 0] : tile_pixels[d, 0] + local_data_size[d]] - # for d in range(ndims) - # ] - # ) - local_slice = tuple(ndims * [np.s_[:]]) global_slice_data = create_local_to_global_slice( tile_pixels, ranknd, overlap_size, local_data_size, backward=backward @@ -508,7 +492,3 @@ def mpi_load_data_from_h5( f.close() return local_data - - -# if __name__ == "__main__": -# pass diff --git a/src/dsgs/operators/jtv.py b/src/dsgs/operators/jtv.py index 3ef00f74430a2b1bebf609a0b7d18d604bab2a3e..77792debd47b6549ccb4de29288e3cc0dedcc21d 100755 --- a/src/dsgs/operators/jtv.py +++ b/src/dsgs/operators/jtv.py @@ -18,12 +18,6 @@ from numba import jit # ! does not support type elision # ! only jit costly parts (by decomposing function), keep flexibility of Python # ! as much as possible -# import importlib -# importlib.reload(...) - -# TODO: investigate jitted nD version for the TV (not only 2D) -# TODO: try to simplify the 2d implementation of the chunked version of the TV -# TODO (many conditions to be checked at the moment) # * Useful numba links # https://stackoverflow.com/questions/57662631/vectorizing-a-function-returning-tuple-using-numba-guvectorize @@ -352,105 +346,3 @@ def gradient_smooth_tv(x, eps): u = gradient_2d(x) w = np.sqrt(np.abs(u[0]) ** 2 + np.abs(u[1]) ** 2 + eps) return gradient_2d_adjoint(u[0] / w, u[1] / w) - - -# if __name__ == "__main__": -# import timeit -# from inspect import cleandoc # handle identation in multi-line strings - -# rng = np.random.default_rng(1234) -# x = rng.standard_normal((10, 5)) -# eps = np.finfo(float).eps - -# stmt_s = "tv.gradient_smooth_tv(x, eps)" -# setup_s = cleandoc( -# """ -# import tv -# import from __main__ import x, eps -# """ -# ) -# t_tv_np = timeit.timeit( -# stmt_s, number=100, globals=globals() -# ) # save in pandas (prettier display) -# print("TV (numpy version): ", t_tv_np) - -# _ = gradient_smooth_tv(x, eps) # trigger numba compilation -# stmt_s = "gradient_smooth_tv(x, eps)" -# setup_s = "from __main__ import gradient_smooth_tv, x, eps" -# t_tv_numba = timeit.timeit(stmt_s, setup=setup_s, number=100) -# print("TV (numba version): ", t_tv_numba) - -# # ! multiple facets (serial test) -# # uh0, uv0 = gradient_2d(x) - -# # yh0 = (1 + 1j) * rng.standard_normal(uh0.shape) -# # yv0 = (1 + 1j) * rng.standard_normal(uv0.shape) -# # z0 = gradient_2d_adjoint(yh0, yv0) - -# # ndims = 2 -# # grid_size = np.array([3, 2], dtype="i") -# # nchunks = np.prod(grid_size) -# # N = np.array(x.shape, dtype="i") - -# # overlap = (grid_size > 1).astype(int) -# # isfirst = np.empty((nchunks, ndims), dtype=bool) -# # islast = np.empty((nchunks, ndims), dtype=bool) - -# # range0 = [] -# # range_direct = [] -# # range_adjoint = [] -# # uh = np.zeros(x.shape) -# # uv = np.zeros(x.shape) -# # z = np.zeros(x.shape, dtype=complex) - -# # for k in range(nchunks): - -# # ranknd = np.array(np.unravel_index(k, grid_size), dtype="i") -# # islast[k] = ranknd == grid_size - 1 -# # isfirst[k] = ranknd == 0 -# # range0.append(ucomm.local_split_range_nd(grid_size, N, ranknd)) -# # Nk = range0[k][:,1] - range0[k][:,0] + 1 - -# # # * direct operator -# # # version with backward overlap -# # # range_direct.append(ucomm.local_split_range_nd(grid_size, N, ranknd, overlap=overlap, backward=True)) -# # # facet = x[range_direct[k][0,0]:range_direct[k][0,1]+1, \ -# # # range_direct[k][1,0]:range_direct[k][1,1]+1] -# # # uh_k, uv_k = chunk_gradient_2d(facet, islast[k]) - -# # # start = range_direct[k][:,0] -# # # stop = start + np.array(uv_k.shape, dtype="i") -# # # uv[start[0]:stop[0], start[1]:stop[1]] = uv_k -# # # stop = start + np.array(uh_k.shape, dtype="i") -# # # uh[start[0]:stop[0], start[1]:stop[1]] = uh_k - -# # # version with forward overlap -# # range_direct.append(ucomm.local_split_range_nd(grid_size, N, ranknd, overlap=overlap, backward=False)) -# # facet = x[range_direct[k][0,0]:range_direct[k][0,1]+1, \ -# # range_direct[k][1,0]:range_direct[k][1,1]+1] -# # uh_k, uv_k = chunk_gradient_2d(facet, islast[k]) - -# # start = range0[k][:,0] -# # stop = start + np.array(uv_k.shape, dtype="i") -# # uv[start[0]:stop[0], start[1]:stop[1]] = uv_k -# # stop = start + np.array(uh_k.shape, dtype="i") -# # uh[start[0]:stop[0], start[1]:stop[1]] = uh_k - -# # # * adjoint (backward overlap only, forward more difficult to encode) -# # range_adjoint.append(ucomm.local_split_range_nd(grid_size, N, ranknd, overlap=overlap, backward=True)) -# # facet_h = yh0[range_adjoint[k][0,0]:range_adjoint[k][0,1]+1, \ -# # range_adjoint[k][1,0]:range_adjoint[k][1,1]+1] -# # facet_v = yv0[range_adjoint[k][0,0]:range_adjoint[k][0,1]+1, \ -# # range_adjoint[k][1,0]:range_adjoint[k][1,1]+1] -# # x_k = np.zeros(Nk, dtype=complex) -# # chunk_gradient_2d_adjoint(facet_h, facet_v, x_k, isfirst[k], islast[k]) - -# # start = range0[k][:,0] -# # stop = start + Nk -# # z[start[0]:stop[0], start[1]:stop[1]] = x_k - -# # print(np.allclose(uh, uh0)) -# # print(np.allclose(uv, uv0)) -# # print(np.allclose(z, z0)) - -# pass diff --git a/src/dsgs/operators/linear_convolution.py b/src/dsgs/operators/linear_convolution.py index 384411d6c07c3e3aab63263626e89f4eb254f767..dbbe687118288dbb5f41aee5866c62db24da3682 100644 --- a/src/dsgs/operators/linear_convolution.py +++ b/src/dsgs/operators/linear_convolution.py @@ -16,13 +16,8 @@ from dsgs.operators.convolutions import fft_conv from dsgs.operators.distributed_convolutions import calculate_local_data_size from dsgs.operators.linear_operator import LinearOperator -# ? = Keep kernel value out of the associated model object? -# TODO: - add circular communications for distributed convolutions -# TODO: - trim-down number of attributes in the classes? - # * Convolution model class (serial and distributed) -# ! keep kernel out of the class? (e.g., for blind deconvolution) class SerialConvolution(LinearOperator): r"""Serial (FFT-based) convolution operator. @@ -39,8 +34,6 @@ class SerialConvolution(LinearOperator): - If ``data_size == image_size + kernel_size - 1``: linear convolution. """ - # TODO: make sure the interface works for both linear and circular - # TODO- convolutions def __init__( self, image_size, diff --git a/src/dsgs/operators/padding.py b/src/dsgs/operators/padding.py index 900166df00b7219a28642a019635cc0bc6cd9c25..d6a6683f3b0b5b41929401937a71c98a732cca76 100755 --- a/src/dsgs/operators/padding.py +++ b/src/dsgs/operators/padding.py @@ -322,7 +322,6 @@ def adjoint_padding(y, lsize, rsize, mode="constant"): ) ) - # sel_core = tuple([np.s_[lsize[d] : rsize[d]] for d in range(ndims)]) sel_core = [] x_ = np.copy(y) @@ -373,7 +372,6 @@ def adjoint_padding(y, lsize, rsize, mode="constant"): ) ) - # sel_core = tuple([np.s_[lsize[d] : rsize[d]] for d in range(ndims)]) sel_core = [] x_ = np.copy(y) @@ -413,7 +411,6 @@ def adjoint_padding(y, lsize, rsize, mode="constant"): x = x_[tuple(sel_core)] elif mode == "wrap": - # sel_core = tuple([np.s_[lsize[d] : rsize[d]] for d in range(ndims)]) sel_core = [] x_ = np.copy(y) @@ -481,8 +478,6 @@ class Padding(LinearOperator): "constant". """ - # TODO: include tests at the level of the object (not the application of - # TODO: - the functions / method) def __init__(self, lsize, rsize, mode="constant"): """Implementation of padding as a linear operator. @@ -575,56 +570,3 @@ class Padding(LinearOperator): Output of the adjoint padding operator. """ return adjoint_padding(y, self.lsize, self.rsize, self.mode) - - -# if __name__ == "__main__": -# # import matplotlib.pyplot as plt -# # from imageio import imread - -# # Generate 2D Gaussian convolution kernel -# vr = 1 -# M = 7 -# # if np.mod(M, 2) > 0: # M odd -# # n = np.arange(-(M - 1) // 2, (M - 1) // 2 + 1) -# # else: -# # n = np.arange(-M // 2, M // 2) -# # h = np.exp(-(n ** 2 + n[:, np.newaxis] ** 2) / (2 * vr)) - -# # plt.imshow(h, cmap=plt.cm.gray) -# # plt.show() - -# # x = imread("img/cameraman.png") -# # N = x.shape -# # M = h.shape - -# # # version 1: circular convolution -# # # circular convolution: pad around and fftshift kernel for nicer results -# # input_size = N -# # hpad = pad_array(h, input_size, padmode="after") # around - -# # plt.imshow(hpad, cmap=plt.cm.gray) -# # plt.show() - -# # testing symmetric padding (with adjoint) (constant is fine) -# rng = np.random.default_rng(1234) -# x = rng.standard_normal(size=(M,)) -# y = rng.standard_normal(size=(M + 4,)) - -# Hx = pad_array( -# x, [M + 4], padmode="around", mode="symmetric" -# ) # around, 2 on both sides -# Hadj_y = adjoint_padding((y, np.array([2], dtype="i"), np.array([2], dtype="i"), mode="symmetric") - -# hp1 = np.sum(Hx * y) -# hp2 = np.sum(x * Hadj_y) - -# print("Correct adjoint operator? {}".format(np.isclose(hp1, hp2))) - -# # 2D: ok -# # hpad = pad_array( -# # h, [M[0] + 3, M[1] + 1], padmode="around", mode="symmetric" -# # ) # around -# # plt.imshow(hpad, cmap=plt.cm.gray) -# # plt.show() - -# pass diff --git a/src/dsgs/operators/tv.py b/src/dsgs/operators/tv.py index 0695912339e88f19460028af63e3c029362713ba..b4c7eb4a0b911d169b9d0bbf07a506f94b24291c 100755 --- a/src/dsgs/operators/tv.py +++ b/src/dsgs/operators/tv.py @@ -95,8 +95,6 @@ def tv(x): return np.sum(np.sqrt(np.sum(np.abs(u) ** 2, axis=0))) -# TODO: see of the n-D version of the TV can be possibly simplified (i.e., to -# TODO: enable jit support) def gradient_nd(x): r"""Nd discrete gradient operator. @@ -253,23 +251,3 @@ def gradient_smooth_tv(x, eps): u = gradient_2d(x) v = gradient_2d_adjoint(u / np.sqrt(np.sum(np.abs(u) ** 2, axis=0) + eps)) return v - - -# if __name__ == "__main__": -# rng = np.random.default_rng(1234) -# x = rng.standard_normal((5, 5)) -# eps = np.finfo(float).eps - -# u2 = gradient_2d(x) -# y2 = gradient_2d_adjoint(u2) - -# u = gradient_nd(x) -# y = gradient_nd_adjoint(u) - -# err_ = np.linalg.norm(y - y2) -# print("Error: {0:1.5e}".format(err_)) - -# tv_x = tv_nd(x) -# tv_x_2d = tv(x) -# err = np.linalg.norm(tv_x - tv_x_2d) -# print("Error: {0:1.5e}".format(err)) diff --git a/src/dsgs/samplers/base_sampler.py b/src/dsgs/samplers/base_sampler.py index 4d873429304e9b2ff86a133c36535f94f3891e38..853daa8c2fa62c57525d1563b7d311e3eeaefae2 100644 --- a/src/dsgs/samplers/base_sampler.py +++ b/src/dsgs/samplers/base_sampler.py @@ -160,10 +160,6 @@ class BaseSampler(ABC): k for k in (self.hyperparameters.keys() - self.updated_hyperparameters) ] - # ! this is specific to a serial sampler (only needed on the root - # ! worker) - # ! create a mock "rank parameter" to reuse the base sampler for the distributed version? - # monitoring variables (iteration counter + timing) self.atime = np.zeros((1,), dtype="d") self.asqtime = np.zeros((1,), dtype="d") @@ -257,15 +253,12 @@ class BaseSampler(ABC): potential_ : float Current value of the potential function. """ - # ! could be created somewhere else (just need to return the value of - # ! the cost, and the iteration self._it to set up value in the right - # ! position on the buffer) potential = np.zeros((self.checkpointfreq,), dtype="d") self._setup_rng() if self.warmstart: self._load() self.start_iter = self.warmstart_it - self._it = -1 # iteration from which aux. params are initialized + self._it = -1 # iteration from which auxiliary parameters are initialized aux = self._initialize_auxiliary_parameters(self._it) potential_ = self._compute_potential(aux, self._it) potential[-1] = potential_ diff --git a/src/dsgs/samplers/parallel/spmd_psgla_poisson_deconvolution.py b/src/dsgs/samplers/parallel/spmd_psgla_poisson_deconvolution.py index 5b5d9f87cbd8ae2d77b5f1ee4c374de2d396f958..2fa6e6e5973e51e7441f5f82ef5770baa058a58f 100644 --- a/src/dsgs/samplers/parallel/spmd_psgla_poisson_deconvolution.py +++ b/src/dsgs/samplers/parallel/spmd_psgla_poisson_deconvolution.py @@ -4,15 +4,12 @@ # Sampler with Hypergraph Structure for High-Dimensional Inverse Problems**, # [arxiv preprint 2210.02341](http://arxiv.org/abs/2210.02341), October 2022. -# based on main_s.py, spa_psgla_sync_s and spa_psgla_sync_mml.py - from logging import Logger import numpy as np from mpi4py import MPI from numba import jit -# import dsgs.utils.checkpoint_parallel as chkpt import dsgs.utils.communications as ucomm from dsgs.functionals.prox import ( kullback_leibler, @@ -57,16 +54,6 @@ def loading( numpy.ndarray, int, float Variables required to restart the sampler. """ - # ! version when saving all iterations for all variables - # [(np.s_[-1], *global_slice_tile)] - # + 2 * [(np.s_[-1], *global_slice_data)] - # + 2 * [(np.s_[-1], np.s_[:], *global_slice_tile)] - # + 3 * [np.s_[-1]], - # ! saving all iterations only for x - # [(np.s_[-1], *global_slice_tile)] - # + 2 * [global_slice_data] - # + 2 * [(np.s_[:], *global_slice_tile)] - # + 3 * [np.s_[-1]], # ! saving only estimator and last state of each variable dic = checkpointer.load( warmstart_iter, @@ -170,111 +157,6 @@ def loading_per_process( ) -# ! ISSUE: code hanging forever in MPI, to be revised.. -def saving( - checkpointer: DistributedCheckpoint, - iter_mc: int, - rng: np.random.Generator, - local_x, - local_u1, - local_u2, - local_z1, - local_z2, - beta, - rho1, - rho2, - alpha1, - alpha2, - potential, - counter, - atime, - asqtime, - nsamples, - global_slice_tile, - global_slice_data, - image_size, - data_size, -): - root_id = 0 - it = iter_mc + 1 - - # find index of the candidate MAP within the current checkpoint - id_map = np.empty(1, dtype="i") - if checkpointer.rank == root_id: - id_map[0] = np.argmin(potential[:nsamples]) - - checkpointer.comm.Bcast([id_map, 1, MPI.INT], root=0) - - # save MMSE, MAP (only for x) - select_ = ( - [(np.s_[:], *global_slice_tile)] - + 2 * [(*global_slice_tile,)] - + 2 * [(*global_slice_data,)] - + 2 * [(np.s_[:], *global_slice_tile)] - ) - shape_ = ( - [(nsamples, *image_size)] - + 2 * [(*image_size,)] - + 2 * [(*data_size,)] - + 2 * [(2, *image_size)] - ) - chunk_sizes_ = ( - [(1, *local_x.shape[1:])] # ! see if saving only the last point or more... - + 2 * [(*local_x.shape[1:],)] - + 2 * [(*local_z1.shape,)] - + 2 * [(1, *local_x.shape[1:])] - ) - - x_mmse = np.mean(local_x[:nsamples, ...], axis=0) - x_map = local_x[id_map[0]] - checkpointer.save( - it, - shape_, - select_, - chunk_sizes_, - rng=rng, - mode="w", - rdcc_nbytes=1024**2 * 200, # 200 MB cache - x=local_x[:nsamples], - x_map=x_map, - x_mmse=x_mmse, - z1=local_z1, - u1=local_u1, # u1 and/or z1 create an issue: check the data selection + size fpr these 2 variables - z2=local_z2, - u2=local_u2, - ) - - # ! z1 and u1 create a problem when code run from MPI (code hanging - # ! forever...): why? - # z1=local_z1, - # u1=local_u1, - - # ! saving from process 0 only (beta, potential, iter, time) - if checkpointer.rank == root_id: - select_ = 10 * [np.s_[:]] - chunk_sizes_ = 10 * [None] - checkpointer.save_from_process( - root_id, - it, - select_, - chunk_sizes_, - mode="a", - rdcc_nbytes=1024**2 * 200, # 200 MB cache - asqtime=np.array(asqtime), - atime=np.array(atime), - iter=it, - potential=potential[:nsamples], - beta=beta[:nsamples], - rho1=rho1[:nsamples], - rho2=rho2[:nsamples], - alpha1=alpha1[:nsamples], - alpha2=alpha2[:nsamples], - counter=counter, - ) - - pass - - def saving_per_process( rank, comm, @@ -348,7 +230,7 @@ def saving_per_process( # checkpointer configuration chunk_sizes_ = ( # [(1, *local_x.shape[1:])] # ! see if saving only the last point or more... - [local_x.shape[1:]] # ! see if saving only the last point or more... + [local_x.shape[1:]] + 3 * [(*local_x.shape[1:],)] + 2 * [local_z1.shape] + 2 * [(1, *local_x.shape[1:])] @@ -360,7 +242,7 @@ def saving_per_process( rng=rng, mode="w", rdcc_nbytes=1024**2 * 200, # 200 MB cache - x=local_x[-1], # [:nsamples] + x=local_x[-1], x_map=x_map, x_m=x_mmse, x_sq_m=x_sq_m, @@ -672,32 +554,21 @@ def potentials_function(y, Hx, Gx, z1, u1, z2, u2): return np.array([data_fidelity, lrho1, lrho2, lalpha1, lalpha2, prior], dtype="d") -# ! TO investigate: -# - where to define local parameter sizes? global sizes? -# - sampling step / initialization step -# - checkpointing modes - -# - observations / buffers are all chuncked, except for the hyperparameters (known only on master process) - - class SpmdPsglaSGS(BaseSPMDSampler): def __init__( self, xmax, observations, - model, #: SyncLinearConvolution, + model, hyperparameters, Nmc: int, seed: int, - checkpointer, #: DistributedCheckpoint, + checkpointer, checkpointfreq: int, logger: Logger, warmstart_it: int = -1, warmstart: bool = False, ): - # TODO: encapsulate all this into a "partition" object (to be defined - # in relation to the communicator...) - # * model parameters (AXDA regularization parameter) self.max_sq_kernel = np.max(np.abs(model.fft_kernel)) ** 2 @@ -718,7 +589,6 @@ class SpmdPsglaSGS(BaseSPMDSampler): # data # ! larger number of data points on the border (forward overlap) - # local_data_size = tile_size + (ranknd == 0) * overlap_size ( self.local_data_size, self.facet_size, @@ -756,30 +626,6 @@ class SpmdPsglaSGS(BaseSPMDSampler): model.ranknd, grid_size, self.offset_tv_adj, backward=True ) - # indexing into global arrays - # if not save_mode == "process": - # global_slice_tile = tuple( - # [np.s_[tile_pixels[d, 0] : tile_pixels[d, 1] + 1] for d in range(ndims)] - # ) - # global_slice_data = create_local_to_global_slice( - # tile_pixels, - # ranknd, - # sync_conv_model.overlap_size, - # local_data_size, - # backward=False, - # ) - - # * slice selection for checkpoints - # single file checkpoint - # "x", "z1", "u1", "z2", "u2", "beta", "rho1", "rho2", "alpha1", - # "alpha2", "potential", "iter", - # checkpoint_select = [(*self.global_slice_tile)] + 2 * [(*self.global_slice_data,)] + 2 * [(np.s_[:], *self.global_slice_tile)] - # chunk_sizes = ( - # [(*self.tile_size)] # ! see if saving only the last point or more... - # + 2 * [(*self.local_data_size,)] - # + 2 * [(1, *self.tile_size)] - # ) - # per-process checkpoints checkpoint_select = [np.s_[:]] + 4 * [np.s_[:]] chunk_sizes = ( @@ -797,17 +643,6 @@ class SpmdPsglaSGS(BaseSPMDSampler): "z2": (2, *self.tile_size), } - # TODO: to be double checked - # TODO: check is properly loaded into auxiliary buffer for x (main - # parameter) - # chunk_sizes = 5 * [None] # chunk sizes for x, z1, u1, z2, u2 - # checkpoint_select = len(parameter_sizes) * [ - # np.s_[:] # ! saving only last sample to disk for all these variables - # ] - - # create dict local_parameter_size - # with, in order, data + all variables/parameters involved? - self.xmax = xmax super(SpmdPsglaSGS, self).__init__( @@ -827,9 +662,6 @@ class SpmdPsglaSGS(BaseSPMDSampler): updated_hyperparameters=[], ) - # TODO: move these buffers to the initialization? (auxiliary vars, ...) - # TODO: check if using hyperparameter_batch from the beginning or not - # TODO - (even when the hyperparameters are fixed...) # ! auxiliary buffer for in-place updates / communications of the main # ! - variable x self.local_x = np.empty(self.facet_size, dtype="d") @@ -872,7 +704,6 @@ class SpmdPsglaSGS(BaseSPMDSampler): def _initialize_auxiliary_parameters(self, it: int): # called after initialize_parameters # it: position of the element initialized in the larger x buffer - # non need for the splitting and augmentation variables self.local_x[self.local_slice_tile] = self.parameter_batch["x"][it] @@ -897,7 +728,7 @@ class SpmdPsglaSGS(BaseSPMDSampler): # * PSGLA step-sizes # TODO: replace by self.hyperparameter_batch["..."][it] when - # TODO - hyperparameters are updates + # TODO - hyperparameters are updated stepsize_x = 0.99 / ( self.max_sq_kernel / self.hyperparameters["rho1"] + 8 / self.hyperparameters["rho2"] @@ -912,18 +743,9 @@ class SpmdPsglaSGS(BaseSPMDSampler): "stepsize_z1": stepsize_z1, "stepsize_z2": stepsize_z2, } - - # TODO: see if computation of ojective is needed from here (only - # TODO - for MML version) - - # compute parts of the potential for later on - # self._compute_potentials(aux, local_potentials, global_potentials) - return aux - # TODO: add _compute_potentials? (only for mml version, debug initial one first) - # TODO: version based on local_potentials / global_potentials? def _compute_potential(self, aux, local_potential, global_potential): local_potential[0] = potential_function( self.observations, @@ -947,37 +769,8 @@ class SpmdPsglaSGS(BaseSPMDSampler): root=0, ) - # potential = ( - # aux["potentials"][0] - # + aux["potentials"][1] / self.hyperparameter_batch["rho1"][it] - # + aux["potentials"][2] / self.hyperparameter_batch["rho2"][it] - # + aux["potentials"][3] / self.hyperparameter_batch["alpha1"][it] - # + aux["potentials"][4] / self.hyperparameter_batch["alpha2"][it] - # + aux["potentials"][5] * self.hyperparameter_batch["beta"][it] - # ) - pass - # def _compute_potentials(self, aux, local_potentials, global_potentials): - # # data_fidelity, lrho1, lrho2, lalpha1, lalpha2, prior - # local_potentials = potentials_function( - # self.observations, - # aux["Hx"][self.local_slice_conv_adj], - # aux["Gx"][tuple((np.s_[:], *self.local_slice_tv_adj))], - # self.parameter_batch["z1"], - # self.parameter_batch["u1"], - # self.parameter_batch["z2"], - # self.parameter_batch["u2"], - # ) - - # self.model.comm.Reduce( - # [local_potentials, MPI.DOUBLE], - # [global_potentials, MPI.DOUBLE], - # op=MPI.SUM, - # root=0, - # ) - - # pass def _sample_step(self, iter_mc: int, current_iter: int, past_iter: int, aux): # notational shortcuts (for in-place assignments) @@ -1065,402 +858,4 @@ class SpmdPsglaSGS(BaseSPMDSampler): self.local_rng, ) - # ! sample TV regularization parameter beta (to be debugged) - # local_l21_z2 = np.sum(np.sqrt(np.sum(local_z2_mc ** 2, axis=0))) - # ... - - pass - - -class SpmdPsglaSGSmml(BaseSPMDSampler): - def __init__( - self, - xmax, - observations, - model, #: SyncLinearConvolution, - hyperparameters, - mml_dict, - Nmc: int, - seed: int, - checkpointer, #: DistributedCheckpoint, - checkpointfreq: int, - logger: Logger, - warmstart_it: int = -1, - warmstart: bool = False, - ): - # TODO: encapsulate all this into a "partition" object (to be defined - # in relation to the communicator...) - - # ! only need to be defined on worker 0 (broadcast updates afterwards) - self.mml_dict = mml_dict - - # * model parameters (AXDA regularization parameter) - self.max_sq_kernel = np.max(np.abs(model.fft_kernel)) ** 2 - - # * auxiliary quantities for stochastic gradient updates - # dimensions - self.d_N = np.prod(model.image_size) - # dimension of the proper space (removing 0s) - # d_tv = (N[0] - 1) * N[1] + (N[1] - 1) * N[0] - self.d_M = np.prod(model.data_size) - - # tile - grid_size = MPI.Compute_dims(model.comm.Get_size(), model.ndims) - grid_size = np.array(grid_size, dtype="i") - self.tile_pixels = ucomm.local_split_range_nd( - grid_size, model.image_size, model.ranknd - ) - self.tile_size = self.tile_pixels[:, 1] - self.tile_pixels[:, 0] + 1 - - # data - # ! larger number of data points on the border (forward overlap) - # local_data_size = tile_size + (ranknd == 0) * overlap_size - ( - self.local_data_size, - self.facet_size, - self.facet_size_adj, - ) = calculate_local_data_size( - self.tile_size, model.ranknd, model.overlap_size, grid_size, backward=False - ) - - # facet (convolution) - self.offset = self.facet_size - self.tile_size - self.offset_adj = self.facet_size_adj - self.tile_size - - # facet (tv) - self.offset_tv = self.offset - (self.offset > 0).astype("i") - self.offset_tv_adj = np.logical_and(model.ranknd > 0, grid_size > 1).astype("i") - self.tv_facet_size_adj = self.tile_size + self.offset_tv_adj - - # * Useful slices (direct operators) - # extract tile from local facet (direct conv. operator) - self.local_slice_tile = ucomm.get_local_slice( - model.ranknd, grid_size, self.offset, backward=False - ) - # extract values from local conv facet to apply local gradient operator - self.local_slice_tv = ucomm.get_local_slice( - model.ranknd, grid_size, self.offset_tv, backward=False - ) - - # * Useful slices (adjoint operators) - # set value of local convolution in the adjoint buffer - self.local_slice_conv_adj = ucomm.get_local_slice( - model.ranknd, grid_size, self.offset_adj, backward=True - ) - # set value of local discrete gradient into the adjoint gradient buffer - self.local_slice_tv_adj = ucomm.get_local_slice( - model.ranknd, grid_size, self.offset_tv_adj, backward=True - ) - - # * slice selection for checkpoints - # single file checkpoint - # "x", "z1", "u1", "z2", "u2", "beta", "rho1", "rho2", "alpha1", - # "alpha2", "potential", "iter", - # checkpoint_select = [(*self.global_slice_tile)] + 2 * [(*self.global_slice_data,)] + 2 * [(np.s_[:], *self.global_slice_tile)] - # chunk_sizes = ( - # [(*self.tile_size)] # ! see if saving only the last point or more... - # + 2 * [(*self.local_data_size,)] - # + 2 * [(1, *self.tile_size)] - # ) - - # per-process checkpoints - checkpoint_select = [np.s_[:]] + 4 * [np.s_[:]] - chunk_sizes = ( - [tuple(self.tile_size)] - + 2 * [tuple(self.local_data_size)] - + 2 * [(1, *self.tile_size)] - ) - - # ! local parameter sizes - parameter_sizes = { - "x": self.tile_size, - "u1": self.local_data_size, - "z1": self.local_data_size, - "u2": (2, *self.tile_size), - "z2": (2, *self.tile_size), - } - - self.xmax = xmax - - super(SpmdPsglaSGSmml, self).__init__( - observations, - model, - parameter_sizes, - hyperparameters, - Nmc, - seed, - checkpointer, - checkpointfreq, - checkpoint_select, - chunk_sizes, - logger, - warmstart_it=warmstart_it, - warmstart=warmstart, - updated_hyperparameters=["rho1", "rho2", "alpha1", "alpha2", "beta"], - ) - - # TODO: move these buffers to the initialization? (auxiliary vars, ...) - # TODO: check if using hyperparameter_batch from the beginning or not - # TODO - (even when the hyperparameters are fixed...) - # ! auxiliary buffer for in-place updates / communications of the main - # ! - variable x - self.local_x = np.empty(self.facet_size, dtype="d") - - # * setup communication scheme - # ! convolution (direct + adjoint) + direct TV covered by self.model - # adjoint TV communicator - # ! need a different object for the moment (because of the size required...) - self.adjoint_tv_communicator = SyncCartesianCommunicatorTV( - self.model.comm, - self.model.cartcomm, - self.model.grid_size, - self.local_x.itemsize, - self.tv_facet_size_adj, - direction=True, - ) - - def _initialize_parameters(self): - for key in self.parameter_batch.keys(): - if key == "x": # ! batch only kept for the main variable x - self.parameter_batch[key][0] = self.local_rng.integers( - 0, high=self.xmax, size=self.parameter_sizes[key], endpoint=True - ).astype(float) - else: - self.parameter_batch[key] = self.local_rng.integers( - 0, high=self.xmax, size=self.parameter_sizes[key], endpoint=True - ).astype(float) - - # ! need to broadcast latest value of all the hyperparameters later on - # ! in _initialize_auxiliary_parameters - if self.rank == 0: - for key in ["rho1", "rho2", "alpha1", "alpha2", "beta"]: - # in self.hyperparameters.keys(): - self.hyperparameter_batch[key][0] = self.hyperparameters[key] - - pass - - def _initialize_auxiliary_parameters(self, it: int): - # called after initialize_parameters - # it: position of the element initialized in the larger x buffer - # non need for the splitting and augmentation variables - - self.local_x[self.local_slice_tile] = self.parameter_batch["x"][it] - - # * setup auxiliary buffers - # ! communicating facet borders to neighbours already done in-place - # ! with the direction convolution operator - # ! Hx, Gx updated whenever buffer_Hx and buffer_Gx are - buffer_Hx = np.empty(self.facet_size_adj) - buffer_Hx[self.local_slice_conv_adj] = self.model.forward(self.local_x) - - buffer_Gx = np.empty((2, *self.tv_facet_size_adj)) - ( - buffer_Gx[tuple((0, *self.local_slice_tv_adj))], - buffer_Gx[tuple((1, *self.local_slice_tv_adj))], - ) = chunk_gradient_2d(self.local_x[self.local_slice_tv], self.islast) - - # communicate facet borders to neighbours (Hx, Gx) - # ! Hx updated in place - self.model.adjoint_communicator.update_borders(buffer_Hx) - # ! Gx updated in place - self.adjoint_tv_communicator.update_borders(buffer_Gx) - - # * create hyperparameter vector (for comms between workers) - self.hyperparameter_vector = np.empty((len(self.hyperparameters)), dtype="d") - if self.rank == 0: - c_ = 0 - for key in ["rho1", "rho2", "alpha1", "alpha2", "beta"]: - self.hyperparameter_vector[c_] = self.hyperparameter_batch[key][it] - c_ += 1 - self.model.comm.Bcast( - [self.hyperparameter_vector, len(self.hyperparameters), MPI.DOUBLE], root=0 - ) - - # * PSGLA step-sizes - stepsize_x = 0.99 / ( - self.max_sq_kernel / self.hyperparameter_vector[0] - + 8 / self.hyperparameter_vector[1] - ) - stepsize_z1 = 0.99 * self.hyperparameter_vector[0] - stepsize_z2 = 0.99 * self.hyperparameter_vector[1] - - aux = { - "Hx": buffer_Hx, - "Gx": buffer_Gx, - "stepsize_x": stepsize_x, - "stepsize_z1": stepsize_z1, - "stepsize_z2": stepsize_z2, - } - - # compute parts of the potential for later on (MML updates) - if self.rank == 0: - aux.update({"potentials": np.empty((len(self.hyperparameters) + 1), "d")}) - else: - aux.update({"potentials": None}) - - self._compute_potentials(aux) - - return aux - - # TODO: version based on local_potentials / global_potentials? - # ! revise local_potential / global potential (not needed if modifying the algo to - # ! - accommodate hyperparameter updates) - def _compute_potential(self, aux, local_potential, global_potential): - if self.rank == 0: - global_potential[0] = ( - aux["potentials"][0] - + aux["potentials"][1] / self.hyperparameter_vector[0] # rho1 - + aux["potentials"][2] / self.hyperparameter_vector[1] # rho2 - + aux["potentials"][3] / self.hyperparameter_vector[2] # alpha1 - + aux["potentials"][4] / self.hyperparameter_vector[3] # alpha2 - + aux["potentials"][5] * self.hyperparameter_vector[4] # beta - ) - - pass - - def _compute_potentials(self, aux): - # data_fidelity, lrho1, lrho2, lalpha1, lalpha2, prior - local_potentials = potentials_function( - self.observations, - aux["Hx"][self.local_slice_conv_adj], - aux["Gx"][tuple((np.s_[:], *self.local_slice_tv_adj))], - self.parameter_batch["z1"], - self.parameter_batch["u1"], - self.parameter_batch["z2"], - self.parameter_batch["u2"], - ) - - self.model.comm.Reduce( - [local_potentials, MPI.DOUBLE], - [aux["potentials"], MPI.DOUBLE], - op=MPI.SUM, - root=0, - ) - - pass - - def _sample_step(self, iter_mc: int, current_iter: int, past_iter: int, aux): - # * PSGLA step-sizes - aux["stepsize_x"] = 0.99 / ( - self.max_sq_kernel / self.hyperparameter_vector[0] - + 8 / self.hyperparameter_vector[1] - ) - aux["stepsize_z1"] = 0.99 * self.hyperparameter_vector[0] - aux["stepsize_z2"] = 0.99 * self.hyperparameter_vector[1] - - # notational shortcuts (for in-place assignments) - # ? can be defined just once? (to be double checked) - # ! Hx, Gx updated whenever buffer_Hx and buffer_Gx are - Hx = aux["Hx"][self.local_slice_conv_adj] - Gx = aux["Gx"][tuple((np.s_[:], *self.local_slice_tv_adj))] - - # sample image x (update local tile) - grad_x = gradient_x( - self.model, - self.adjoint_tv_communicator, - self.isfirst, - self.islast, - aux["Hx"], - aux["Gx"], - self.parameter_batch["z1"], - self.parameter_batch["u1"], - self.parameter_batch["z2"], - self.parameter_batch["u2"], - self.hyperparameter_vector[0], # rho1 - self.hyperparameter_vector[1], # rho2 - self.local_slice_tv, - self.local_slice_tv_adj, - ) - - sample_x( - self.local_x[self.local_slice_tile], - aux["stepsize_x"], - grad_x, - self.local_rng, - ) - self.parameter_batch["x"][current_iter] = self.local_x[self.local_slice_tile] - - # communicate borders of each facet to appropriate neighbours - # ! synchronous case: need to update Hx and Gx for the next step - # ! local_x updated in-place here (border communication involved in - # ! direct operator) - aux["Hx"][self.local_slice_conv_adj] = self.model.forward(self.local_x) - # ! Hx and Gx updated in-place whenever buffer_Hx, buffer_Gx are - ( - aux["Gx"][tuple((0, *self.local_slice_tv_adj))], - aux["Gx"][tuple((1, *self.local_slice_tv_adj))], - ) = chunk_gradient_2d(self.local_x[self.local_slice_tv], self.islast) - - # communicate borders of buffer_Hx, buffer_Gx (adjoint op) for the - # next iteration - self.model.adjoint_communicator.update_borders(aux["Hx"]) - self.adjoint_tv_communicator.update_borders(aux["Gx"]) - - # * sample auxiliary variables (z1, u1) - self.parameter_batch["z1"] = sample_z1( - self.parameter_batch["z1"], - self.observations, - Hx, - self.parameter_batch["u1"], - self.hyperparameter_vector[0], # rho1 - aux["stepsize_z1"], - self.local_rng, - ) - self.parameter_batch["u1"] = sample_u( - self.parameter_batch["z1"], - Hx, - self.hyperparameter_vector[0], # rho1 - self.hyperparameter_vector[2], # alpha1 - self.local_rng, - ) - - # * sample auxiliary variables (z2, u2) - self.parameter_batch["z2"] = sample_z2( - self.parameter_batch["z2"], - Gx, - self.parameter_batch["u2"], - self.hyperparameter_vector[1], # rho2 - aux["stepsize_z2"], - self.hyperparameter_vector[4] * aux["stepsize_z2"], # beta - self.local_rng, - ) - - self.parameter_batch["u2"] = sample_u( - self.parameter_batch["z2"], - Gx, - self.hyperparameter_vector[1], # rho2 - self.hyperparameter_vector[3], # alpha2 - self.local_rng, - ) - - # * update hyperparameters (using MML) - # compute auxiliary potentials for MML updates + computation of the - # overall potential - # ! need to define local_potentials, global_potentials - self._compute_potentials(aux) - - # TODO: revise and simplify this part - # order: [data_fidelity, lrho1, lrho2, lalpha1, lalpha2, prior] - # check if this is fine in practice - if self.rank == 0: - c_ = 0 - - # if iter_mc == 20: - # breakpoint() - - for key in ["rho1", "rho2", "alpha1", "alpha2", "beta"]: - self.hyperparameter_batch[key][current_iter] = self.mml_dict[ - key - ].update( - self.hyperparameter_batch[key][past_iter], - iter_mc, - aux["potentials"][c_ + 1], # ! data-fidelity term in position 0 - ) - self.hyperparameter_vector[c_] = self.hyperparameter_batch[key][ - current_iter - ] - c_ += 1 - - self.model.comm.Bcast([self.hyperparameter_vector, MPI.DOUBLE], root=0) - pass diff --git a/src/dsgs/samplers/serial/serial_pmyula_poisson_deconvolution.py b/src/dsgs/samplers/serial/serial_pmyula_poisson_deconvolution.py index c9624ee39a05ddceaff91777f6611d1ccf4be5e9..a15292f581b8bc0d7edec9fbdf096f23d1599479 100644 --- a/src/dsgs/samplers/serial/serial_pmyula_poisson_deconvolution.py +++ b/src/dsgs/samplers/serial/serial_pmyula_poisson_deconvolution.py @@ -505,226 +505,3 @@ class MyulaSGS(SerialSampler): ) pass - - -class MyulaSGSmml(SerialSampler): - def __init__( - self, - xmax, - observations, - model: SerialConvolution, - hyperparameters, - mml_dict, - Nmc: int, - seed: int, - checkpointer: SerialCheckpoint, - checkpointfreq: int, - logger: Logger, - warmstart_it=-1, - warmstart=False, - save_batch=False, - ): - data_size = model.data_size - N = model.image_size - parameter_sizes = { - "x": N, - "u1": data_size, - "z1": data_size, - "u2": N, - "z2": N, - "u3": N, - "z3": N, - } - chunk_sizes = 7 * [None] # chunk sizes for x, z1, u1, z2, u2, z3, u3 - checkpoint_select = len(parameter_sizes) * [ - np.s_[:] # ! saving only last sample to disk for all these variables - ] - - self.xmax = xmax - self.mml_dict = mml_dict - - super(MyulaSGSmml, self).__init__( - observations, - model, - parameter_sizes, - hyperparameters, - Nmc, - seed, - checkpointer, - checkpointfreq, - checkpoint_select, - chunk_sizes, - logger, - warmstart_it=warmstart_it, - warmstart=warmstart, - save_batch=save_batch, - updated_hyperparameters=[ - "rho1", - "rho2", - "rho3", - "alpha1", - "alpha2", - "alpha3", - "beta", - ], - ) - - def _initialize_parameters(self): - for key in self.parameter_batch.keys(): - self.parameter_batch[key][0] = self.rng.integers( - 0, high=self.xmax, size=self.parameter_sizes[key], endpoint=True - ).astype(float) - - for key in self.hyperparameter_batch.keys(): - self.hyperparameter_batch[key][0] = self.hyperparameters[key] - - pass - - def _initialize_auxiliary_parameters(self, it: int): - Hx = self.model.forward(self.parameter_batch["x"][it]) - - aux = { - "Hx": Hx, - "myula1": MYULA(1 / self.hyperparameter_batch["rho1"][it]), - "myula2": MYULA(1 / self.hyperparameter_batch["rho2"][it]), - "myula3": MYULA(1 / self.hyperparameter_batch["rho3"][it]), - } - - # compute parts of the potential for later on - self._compute_potentials(aux, it) - - return aux - - def _compute_potentials(self, aux, it): - aux["potentials"] = potentials_function( - self.observations, - aux["Hx"], - self.parameter_batch["x"][it], - self.parameter_batch["z1"][it], - self.parameter_batch["u1"][it], - self.parameter_batch["z2"][it], - self.parameter_batch["u2"][it], - self.parameter_batch["z3"][it], - self.parameter_batch["u3"][it], - ) - - pass - - def _compute_potential(self, aux, it): - # np.array([data_fidelity, lrho1, lrho2, lrho3, lalpha1, lalpha2, lalpha3, prior], - # dtype="d", - # ) - - potential = ( - aux["potentials"][0] - + aux["potentials"][1] / self.hyperparameter_batch["rho1"][it] - + aux["potentials"][2] / self.hyperparameter_batch["rho2"][it] - + aux["potentials"][3] / self.hyperparameter_batch["rho3"][it] - + aux["potentials"][4] / self.hyperparameter_batch["alpha1"][it] - + aux["potentials"][5] / self.hyperparameter_batch["alpha2"][it] - + aux["potentials"][6] / self.hyperparameter_batch["alpha3"][it] - + aux["potentials"][7] * self.hyperparameter_batch["beta"][it] - ) - - return potential - - def _sample_step(self, iter_mc, current_iter, past_iter, aux): - # * update MYULA step-sizes - aux["myula1"].reset_lipschitz_constant( - 1 / self.hyperparameter_batch["rho1"][past_iter] - ) - aux["myula2"].reset_lipschitz_constant( - 1 / self.hyperparameter_batch["rho2"][past_iter] - ) - aux["myula3"].reset_lipschitz_constant( - 1 / self.hyperparameter_batch["rho3"][past_iter] - ) - - # * sample image x - self.parameter_batch["x"][current_iter], aux["Hx"] = sample_x( - self.model, - self.parameter_batch["u1"][past_iter], - self.parameter_batch["z1"][past_iter], - self.parameter_batch["u2"][past_iter], - self.parameter_batch["z2"][past_iter], - self.parameter_batch["u3"][past_iter], - self.parameter_batch["z3"][past_iter], - self.hyperparameter_batch["rho1"][past_iter], - self.hyperparameter_batch["rho2"][past_iter], - self.hyperparameter_batch["rho3"][past_iter], - self.rng, - ) - - # * sample auxiliary variables (z1, u1) - self.parameter_batch["z1"][current_iter] = sample_z1( - self.parameter_batch["z1"][past_iter], - self.observations, - aux["Hx"], - self.parameter_batch["u1"][past_iter], - self.hyperparameter_batch["rho1"][past_iter], - aux["myula1"], - self.rng, - ) - self.parameter_batch["u1"][current_iter] = sample_u( - self.parameter_batch["z1"][current_iter], - aux["Hx"], - self.hyperparameter_batch["rho1"][past_iter], - self.hyperparameter_batch["alpha1"][past_iter], - self.rng, - ) - - # * sample auxiliary variables (z2, u2) - self.parameter_batch["z2"][current_iter] = sample_z2( - self.parameter_batch["z2"][past_iter], - self.parameter_batch["x"][current_iter], - self.parameter_batch["u2"][past_iter], - self.hyperparameter_batch["rho2"][past_iter], - self.hyperparameter_batch["beta"][past_iter], - aux["myula2"], - self.rng, - ) - - self.parameter_batch["u2"][current_iter] = sample_u( - self.parameter_batch["z2"][current_iter], - self.parameter_batch["x"][current_iter], - self.hyperparameter_batch["rho2"][past_iter], - self.hyperparameter_batch["alpha2"][past_iter], - self.rng, - ) - - # * sample auxiliary variables (z3, u3) - self.parameter_batch["z3"][current_iter] = sample_z3( - self.parameter_batch["z3"][past_iter], - self.parameter_batch["x"][current_iter], - self.parameter_batch["u3"][past_iter], - self.hyperparameter_batch["rho3"][past_iter], - aux["myula3"], - self.rng, - ) - - self.parameter_batch["u3"][current_iter] = sample_u( - self.parameter_batch["z3"][current_iter], - self.parameter_batch["x"][current_iter], - self.hyperparameter_batch["rho3"][past_iter], - self.hyperparameter_batch["alpha3"][past_iter], - self.rng, - ) - - # * update hyperparameters (using MML) - # compute auxiliary potentials for MML updates + computation of the - # overall potential - self._compute_potentials(aux, current_iter) - - # TODO: revise and simplify this part - # order: [data_fidelity, lrho1, lrho2, lrho3, lalpha1, lalpha2, lalpha3, prior] - # check if this is fine in practice - c_ = 1 # assuming entries are on the rights order in the vector - for key in ["rho1", "rho2", "rho3", "alpha1", "alpha2", "alpha3", "beta"]: - self.hyperparameter_batch[key][current_iter] = self.mml_dict[key].update( - self.hyperparameter_batch[key][past_iter], - iter_mc, - aux["potentials"][c_], - ) - c_ += 1 - - pass diff --git a/src/dsgs/samplers/serial/serial_psgla_poisson_deconvolution.py b/src/dsgs/samplers/serial/serial_psgla_poisson_deconvolution.py index 723895b13794810bd2b5dba46d654c16453ede5c..e37fe6b5c03f1402d5b9eb700a203f0337a5e78a 100644 --- a/src/dsgs/samplers/serial/serial_psgla_poisson_deconvolution.py +++ b/src/dsgs/samplers/serial/serial_psgla_poisson_deconvolution.py @@ -292,15 +292,6 @@ class PsglaSGS(SerialSampler): self.xmax = xmax - # gamma_x = 0.99 / ( - # np.max(np.abs(model.fft_kernel)) ** 2 / hyperparameters["rho1"] - # + 8 / hyperparameters["rho2"] - # ) - # gamma1 = 0.99 * hyperparameters["rho1"] - # gamma2 = 0.99 * hyperparameters["rho2"] - - # kernels = {"x": PSGLA(gamma_x), "z1": PSGLA(gamma1), "z2": PSGLA(gamma2)} - super(PsglaSGS, self).__init__( observations, model, @@ -341,7 +332,6 @@ class PsglaSGS(SerialSampler): # PSGLA step-sizes # ! only if the hyperparameters are kept fixed - # ! see what to do if hyperparameters have been updated in the iterations stepsize_x = 0.99 / ( np.max(np.abs(self.model.fft_kernel)) ** 2 / self.hyperparameter_batch["rho1"][0] @@ -377,13 +367,6 @@ class PsglaSGS(SerialSampler): return potential def _sample_step(self, iter_mc, current_iter, past_iter, aux): - # PSGLA step-sizes (-> now defined earlier) - # gamma_x = 0.99 / ( - # np.max(np.abs(conv_model.fft_kernel)) ** 2 / rho1_mc[past_iter] - # + 8 / rho2_mc[past_iter] - # ) - # gamma1 = 0.99 * rho1_mc[past_iter] - # gamma2 = 0.99 * rho2_mc[past_iter] # * sample image x grad_x = gradient_x( diff --git a/src/dsgs/samplers/serial_sampler.py b/src/dsgs/samplers/serial_sampler.py index c21533ffb7eada88328d59d314d057383f196c44..6b6f45456feba1e9599ebdb13cc58ca79228dcd1 100644 --- a/src/dsgs/samplers/serial_sampler.py +++ b/src/dsgs/samplers/serial_sampler.py @@ -196,13 +196,6 @@ class SerialSampler(BaseSampler): potential : numpy.ndarray Evolution of the log-posterior over the checkpoint window. """ - # - # f2(**d) # turning a dict into keyword parameters - # f(**d1, **d2) # need to make sure no duplicate keys, otherwise use collections.ChainMap (https://subscription.packtpub.com/book/application-development/9781788830829/1/ch01lvl1sec14/unpacking-multiple-keyword-arguments) - # - # TODO: see if copies can be avoided (not sure this occurs with the - # dictionaries) - # ! if the full batch for the parameter "x" needs to be saved to disk chunk_sizes = [] last_sample = {} diff --git a/src/dsgs/samplers/spmd_sampler.py b/src/dsgs/samplers/spmd_sampler.py index 3d34b4302ab083eaa297dd44946851c6093b840b..d36a2c7530e24dbdf07d8df4a0c4b813a783c43d 100644 --- a/src/dsgs/samplers/spmd_sampler.py +++ b/src/dsgs/samplers/spmd_sampler.py @@ -443,10 +443,6 @@ class BaseSPMDSampler(ABC): potential : numpy.ndarray Evolution of the log-posterior over the checkpoint window. """ - # TODO: give the option to switch from one config to another: single - # TODO: save file or one per process. - - # TODO: save MMSE, MAP and last sample only (not mean of squares) # ! save last sample for all variables, compute MMSE and MAP for x only chunk_sizes = self.chunk_sizes + 2 * [self.chunk_sizes[0]] diff --git a/src/dsgs/samplers/transition_kernels/metropolis_hastings.py b/src/dsgs/samplers/transition_kernels/metropolis_hastings.py index 892be0c3800936d41b10d32a1a9f1cfd95044bbd..ed87eb1daf1f21c6f7a7a0df734c05c31376cf6d 100644 --- a/src/dsgs/samplers/transition_kernels/metropolis_hastings.py +++ b/src/dsgs/samplers/transition_kernels/metropolis_hastings.py @@ -146,10 +146,6 @@ class MetropolisHastings(BaseMCMCKernel): return current_state -# TODO: 1. take a lambda function to compute the relevant potential? -# TODO: 2. check if a BaseMCMCKernel can be plugged in there or not! -# TODO: 3. pass a lambda function to pass the expression for the potential of -# TODO the target distribution class MetropolisedKernel(MetropolisHastings): """Generic implementation of a Metropolis-Hastings transition kernel, where the proposal distribution is generated by another kernel.""" diff --git a/src/dsgs/samplers/transition_kernels/myula.py b/src/dsgs/samplers/transition_kernels/myula.py index d713b93730454b22eb46048b36e85bf087be37d3..2f382fc2603c944bc8534f2e68f426b3cadc5b04 100644 --- a/src/dsgs/samplers/transition_kernels/myula.py +++ b/src/dsgs/samplers/transition_kernels/myula.py @@ -59,7 +59,6 @@ def default_myula_parameters( return np.array([stepsize, regularization]), np.array([c1, c2]) -# TODO: implement prox as a lambda function? class MYULA(BaseMCMCKernel): r"""MYULA transition kernel :cite:p:`Durmus2018` associated with a target density of the form