From 119f42e5df1da064f98565a646aecdc41931abba Mon Sep 17 00:00:00 2001
From: labaro <etienne.levecque@univ-lille.fr>
Date: Mon, 10 Oct 2022 14:39:39 +0200
Subject: [PATCH] feature: ks_test and bonferroni correction

---
 data.py       |  4 ++--
 embed_juni.py | 10 +++-------
 main.py       | 55 ++++++++++++++++++++++++++++++++++++++++++++++-----
 utils.py      |  5 +++++
 4 files changed, 60 insertions(+), 14 deletions(-)

diff --git a/data.py b/data.py
index 8565878..d424bce 100644
--- a/data.py
+++ b/data.py
@@ -3,7 +3,7 @@ import jpegio as jio
 import numpy as np
 import multiprocessing as mp
 
-from skimage import view_as_block
+from skimage.util import view_as_blocks
 
 from utils import decompress_structure
 from embed_juni import embed_img
@@ -76,7 +76,7 @@ def variance_filter(img_generator, variance_threshold, block_per_threshold):
     ignored = 0
     try:
         for img in img_generator:
-            view = view_as_block(img, (8, 8))
+            view = view_as_blocks(img, (8, 8))
             mask_var = np.var(view, axis=(2, 3)) >= variance_threshold
             mask_saturated = np.any(view == 255, axis=(2,3)) | np.any(view == 0, axis=(2,3))
             remaining_blocks = view.reshape((-1, 8, 8))[mask_var.flatten() & ~mask_saturated.flatten()]
diff --git a/embed_juni.py b/embed_juni.py
index 57e2d14..3405809 100644
--- a/embed_juni.py
+++ b/embed_juni.py
@@ -1,13 +1,9 @@
 import os
 import scipy.signal
-import scipy.fftpack
+from scipy.fftpack import dct, idct
 import numpy as np
-from tqdm import tqdm as tqdm
-import multiprocessing
-from multiprocessing import Pool
 import jpegio as jio
 import cv2
-import pickle
 
 from utils import decompress_structure
 
@@ -17,11 +13,11 @@ os.environ['OPENBLAS_NUM_THREADS'] = '1'
 
 
 def dct2(a):
-    return scipy.fftpack.dct(scipy.fftpack.dct(a, axis=0, norm='ortho'), axis=1, norm='ortho')
+    return dct(dct(a, axis=0, norm='ortho'), axis=1, norm='ortho')
 
 
 def idct2(a):
-    return scipy.fftpack.idct(scipy.fftpack.idct(a, axis=0, norm='ortho'), axis=1, norm='ortho')
+    return idct(idct(a, axis=0, norm='ortho'), axis=1, norm='ortho')
 
 
 def entropy_ternary(pP1, pM1):
diff --git a/main.py b/main.py
index 2977fe0..8a0f845 100644
--- a/main.py
+++ b/main.py
@@ -1,13 +1,27 @@
+import numpy as np
+import matplotlib.pyplot as plt
+import multiprocessing as mp
+
+from scipy.stats import kstest
+
 from data import get_train_test_generator, embed_generator, variance_filter, feature_extractor, img_generator
+from utils import bonferroni_correction
 
 cover_dir = "/home/labaro/Documents/These/datasets/images/alaska/jpeg/qf100"
 stego_dir = "/home/labaro/Documents/These/datasets/images/alaska/jpeg/embedded"
-compute_stego = False
-train_size = 0.1
-payload = 0.0
+compute_stego = True
+train_size = 0.01
+payload = 0.1
 stego_percentage = 0.1
-variance_threshold = 0.0
-block_per_threshold = 0.0
+variance_threshold = 20
+block_per_threshold = 0.5
+threshold = np.geomspace(1e-4, 1, 100)
+
+
+def ks_test(error_tuple):
+    ref_cdf, error = error_tuple
+    return [kstest(ref_cdf[:, pos], error.reshape((-1, 64))[:, pos])[1] for pos in range(64)]
+
 
 if __name__ == "__main__":
     if compute_stego:
@@ -28,3 +42,34 @@ if __name__ == "__main__":
     test_cover_features = feature_extractor(variance_filter(test_cover_gen,
                                                             variance_threshold,
                                                             block_per_threshold))
+
+    ref_cdf = np.concatenate([error for error in train_features]).reshape((-1, 64))
+
+    y = []
+    label = []
+
+    with mp.Pool() as p:
+        try:
+            for p_val in p.imap_unordered(ks_test, ((ref_cdf, error) for error in test_stego_features)):
+                corrected_p = bonferroni_correction(p_val)
+                y.append(np.min(corrected_p))
+                print(np.min(corrected_p))
+                label.append(1)
+        except StopIteration as ex:
+            stats = ex.value
+            print(stats)
+
+        try:
+            for p_val in p.imap_unordered(ks_test, ((ref_cdf, error) for error in test_cover_features)):
+                corrected_p = bonferroni_correction(p_val)
+                y.append(np.min(corrected_p))
+                label.append(0)
+                print(np.min(corrected_p))
+        except StopIteration as ex:
+            stats = ex.value
+            print(stats)
+
+    y = np.array(y)
+    label = np.array(label)
+
+    plt.plot(threshold, [np.mean(y[label == 0] < t) for t in threshold])
diff --git a/utils.py b/utils.py
index 8885a6c..d19a996 100644
--- a/utils.py
+++ b/utils.py
@@ -40,3 +40,8 @@ def decompress_structure(S, grayscale=True):
         fun = lambda x: fftpack.idct(fftpack.idct(x, norm='ortho', axis=2), norm='ortho', axis=3) + 128
         I[:, :, i] = segmented_stride(C, fun)
     return I
+
+
+def bonferroni_correction(p):
+    p = np.array(p)
+    return np.clip(p * p.shape[0], 0, 1)
-- 
GitLab