
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/decomposition/plot_image_denoising.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_decomposition_plot_image_denoising.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_decomposition_plot_image_denoising.py:


=========================================
Image denoising using dictionary learning
=========================================

An example comparing the effect of reconstructing noisy fragments
of a raccoon face image using firstly online :ref:`DictionaryLearning` and
various transform methods.

The dictionary is fitted on the distorted left half of the image, and
subsequently used to reconstruct the right half. Note that even better
performance could be achieved by fitting to an undistorted (i.e.
noiseless) image, but here we start from the assumption that it is not
available.

A common practice for evaluating the results of image denoising is by looking
at the difference between the reconstruction and the original image. If the
reconstruction is perfect this will look like Gaussian noise.

It can be seen from the plots that the results of :ref:`omp` with two
non-zero coefficients is a bit less biased than when keeping only one
(the edges look less prominent). It is in addition closer from the ground
truth in Frobenius norm.

The result of :ref:`least_angle_regression` is much more strongly biased: the
difference is reminiscent of the local intensity value of the original image.

Thresholding is clearly not useful for denoising, but it is here to show that
it can produce a suggestive output with very high speed, and thus be useful
for other tasks such as object classification, where performance is not
necessarily related to visualisation.

.. GENERATED FROM PYTHON SOURCE LINES 34-38

.. code-block:: Python


    # Authors: The scikit-learn developers
    # SPDX-License-Identifier: BSD-3-Clause








.. GENERATED FROM PYTHON SOURCE LINES 39-41

Generate distorted image
------------------------

.. GENERATED FROM PYTHON SOURCE LINES 41-66

.. code-block:: Python

    import numpy as np
    from scipy.datasets import face

    raccoon_face = face(gray=True)

    # Convert from uint8 representation with values between 0 and 255 to
    # a floating point representation with values between 0 and 1.
    raccoon_face = raccoon_face / 255.0

    # downsample for higher speed
    raccoon_face = (
        raccoon_face[::4, ::4]
        + raccoon_face[1::4, ::4]
        + raccoon_face[::4, 1::4]
        + raccoon_face[1::4, 1::4]
    )
    raccoon_face /= 4.0
    height, width = raccoon_face.shape

    # Distort the right half of the image
    print("Distorting image...")
    distorted = raccoon_face.copy()
    distorted[:, width // 2 :] += 0.075 * np.random.randn(height, width // 2)




.. rst-class:: sphx-glr-script-out

.. code-block:: pytb

    Traceback (most recent call last):
      File "$BUILD_DIR/examples/decomposition/plot_image_denoising.py", line 44, in <module>
        raccoon_face = face(gray=True)
      File "scipy/datasets/_fetchers.py", line 216, in face
        fname = fetch_data("face.dat")
      File "scipy/datasets/_fetchers.py", line 37, in fetch_data
        return data_fetcher.fetch(dataset_name, downloader=downloader)
               ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "pooch/core.py", line 598, in fetch
        stream_download(
        ~~~~~~~~~~~~~~~^
            url,
            ^^^^
        ...<4 lines>...
            retry_if_failed=self.retry_if_failed,
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        )
        ^
      File "pooch/core.py", line 823, in stream_download
        downloader(url, tmp, pooch)
        ~~~~~~~~~~^^^^^^^^^^^^^^^^^
      File "pooch/downloaders.py", line 230, in __call__
        response = requests.get(url, timeout=timeout, **kwargs)
      File "requests/api.py", line 73, in get
        return request("get", url, params=params, **kwargs)
      File "requests/api.py", line 59, in request
        return session.request(method=method, url=url, **kwargs)
               ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "requests/sessions.py", line 589, in request
        resp = self.send(prep, **send_kwargs)
      File "requests/sessions.py", line 703, in send
        r = adapter.send(request, **kwargs)
      File "requests/adapters.py", line 677, in send
        raise ConnectionError(e, request=request)
    requests.exceptions.ConnectionError: HTTPSConnectionPool(host='raw.githubusercontent.com', port=443): Max retries exceeded with url: /scipy/dataset-face/main/face.dat (Caused by NameResolutionError("HTTPSConnection(host='raw.githubusercontent.com', port=443): Failed to resolve 'raw.githubusercontent.com' ([Errno -3] Temporary failure in name resolution)"))




.. GENERATED FROM PYTHON SOURCE LINES 67-69

Display the distorted image
---------------------------

.. GENERATED FROM PYTHON SOURCE LINES 69-96

.. code-block:: Python

    import matplotlib.pyplot as plt


    def show_with_diff(image, reference, title):
        """Helper function to display denoising"""
        plt.figure(figsize=(5, 3.3))
        plt.subplot(1, 2, 1)
        plt.title("Image")
        plt.imshow(image, vmin=0, vmax=1, cmap=plt.cm.gray, interpolation="nearest")
        plt.xticks(())
        plt.yticks(())
        plt.subplot(1, 2, 2)
        difference = image - reference

        plt.title("Difference (norm: %.2f)" % np.sqrt(np.sum(difference**2)))
        plt.imshow(
            difference, vmin=-0.5, vmax=0.5, cmap=plt.cm.PuOr, interpolation="nearest"
        )
        plt.xticks(())
        plt.yticks(())
        plt.suptitle(title, size=16)
        plt.subplots_adjust(0.02, 0.02, 0.98, 0.79, 0.02, 0.2)


    show_with_diff(distorted, raccoon_face, "Distorted image")



.. GENERATED FROM PYTHON SOURCE LINES 97-99

Extract reference patches
----------------------------

.. GENERATED FROM PYTHON SOURCE LINES 99-114

.. code-block:: Python

    from time import time

    from sklearn.feature_extraction.image import extract_patches_2d

    # Extract all reference patches from the left half of the image
    print("Extracting reference patches...")
    t0 = time()
    patch_size = (7, 7)
    data = extract_patches_2d(distorted[:, : width // 2], patch_size)
    data = data.reshape(data.shape[0], -1)
    data -= np.mean(data, axis=0)
    data /= np.std(data, axis=0)
    print(f"{data.shape[0]} patches extracted in %.2fs." % (time() - t0))



.. GENERATED FROM PYTHON SOURCE LINES 115-117

Learn the dictionary from reference patches
-------------------------------------------

.. GENERATED FROM PYTHON SOURCE LINES 117-147

.. code-block:: Python

    from sklearn.decomposition import MiniBatchDictionaryLearning

    print("Learning the dictionary...")
    t0 = time()
    dico = MiniBatchDictionaryLearning(
        # increase to 300 for higher quality results at the cost of slower
        # training times.
        n_components=50,
        batch_size=200,
        alpha=1.0,
        max_iter=10,
    )
    V = dico.fit(data).components_
    dt = time() - t0
    print(f"{dico.n_iter_} iterations / {dico.n_steps_} steps in {dt:.2f}.")

    plt.figure(figsize=(4.2, 4))
    for i, comp in enumerate(V[:100]):
        plt.subplot(10, 10, i + 1)
        plt.imshow(comp.reshape(patch_size), cmap=plt.cm.gray_r, interpolation="nearest")
        plt.xticks(())
        plt.yticks(())
    plt.suptitle(
        "Dictionary learned from face patches\n"
        + "Train time %.1fs on %d patches" % (dt, len(data)),
        fontsize=16,
    )
    plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)



.. GENERATED FROM PYTHON SOURCE LINES 148-150

Extract noisy patches and reconstruct them using the dictionary
---------------------------------------------------------------

.. GENERATED FROM PYTHON SOURCE LINES 150-189

.. code-block:: Python

    from sklearn.feature_extraction.image import reconstruct_from_patches_2d

    print("Extracting noisy patches... ")
    t0 = time()
    data = extract_patches_2d(distorted[:, width // 2 :], patch_size)
    data = data.reshape(data.shape[0], -1)
    intercept = np.mean(data, axis=0)
    data -= intercept
    print("done in %.2fs." % (time() - t0))

    transform_algorithms = [
        ("Orthogonal Matching Pursuit\n1 atom", "omp", {"transform_n_nonzero_coefs": 1}),
        ("Orthogonal Matching Pursuit\n2 atoms", "omp", {"transform_n_nonzero_coefs": 2}),
        ("Least-angle regression\n4 atoms", "lars", {"transform_n_nonzero_coefs": 4}),
        ("Thresholding\n alpha=0.1", "threshold", {"transform_alpha": 0.1}),
    ]

    reconstructions = {}
    for title, transform_algorithm, kwargs in transform_algorithms:
        print(title + "...")
        reconstructions[title] = raccoon_face.copy()
        t0 = time()
        dico.set_params(transform_algorithm=transform_algorithm, **kwargs)
        code = dico.transform(data)
        patches = np.dot(code, V)

        patches += intercept
        patches = patches.reshape(len(data), *patch_size)
        if transform_algorithm == "threshold":
            patches -= patches.min()
            patches /= patches.max()
        reconstructions[title][:, width // 2 :] = reconstruct_from_patches_2d(
            patches, (height, width // 2)
        )
        dt = time() - t0
        print("done in %.2fs." % dt)
        show_with_diff(reconstructions[title], raccoon_face, title + " (time: %.1fs)" % dt)

    plt.show()


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 0.005 seconds)


.. _sphx_glr_download_auto_examples_decomposition_plot_image_denoising.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_image_denoising.ipynb <plot_image_denoising.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_image_denoising.py <plot_image_denoising.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_image_denoising.zip <plot_image_denoising.zip>`


.. include:: plot_image_denoising.recommendations


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
