Skip to content
Snippets Groups Projects
Commit 3cb9cbb3 authored by pilavciy's avatar pilavciy
Browse files

clean up

parent 48e4c2dc
Branches
No related tags found
No related merge requests found
%% Cell type:code id:f5605c04 tags:
``` python
import numpy as np
import scipy
import matplotlib.pyplot as plt
import quaternion # load the quaternion module
import bispy as bsp
import torch
import tqdm
import time
from sklearn.preprocessing import normalize
from utils import STIS,optimize_loop,snr_bivariate,param_search,objective_STIS,objective_KReSP,KReSP
from ray import tune,init
from tempfile import TemporaryFile
import pickle
init(num_cpus=13)
## PLOT AN EXAMPLE
np.random.seed(5)
N = 1024 # length of the signal
t = np.linspace(0, 2*np.pi/4, N) # time vector
dt = t[1]-t[0]
# ellipse parameters - AM-FM-PM polarized
theta1 = np.pi/4 - 2*t
chi1 = np.pi/16 - t
phi1 = 0
f0 = 25/N/dt
S0 = bsp.utils.windows.hanning(N)
x_quad = bsp.signals.bivariateAMFM(S0, theta1, chi1, 2*np.pi*f0*t+ phi1)
x = quaternion.as_float_array(x_quad)[:,:2]
bsp.utils.visual.plot2D(t,x_quad)
plt.savefig("clean_sig.pdf")
sigma = 0.05
n = np.zeros([N,4])
noise_complex = np.random.randn(N,2)
y = x + noise_complex
uH = np.imag(scipy.signal.hilbert(noise_complex[:,0]))
vH = np.imag(scipy.signal.hilbert(noise_complex[:,1]))
n[:,0] = noise_complex[:,0]
n[:,1] = noise_complex[:,1]
n[:,2] = uH
n[:,3] = vH
n = quaternion.from_float_array(n)
y_quad = sigma*n + x_quad # Noisy signal
bsp.utils.visual.plot2D(t,y_quad)
plt.savefig("noisy_sig.pdf")
n = np.random.randn(N,2)
y = sigma*n + x # Noisy signal
print("sigma: " + str(sigma) + " Noise SNR: "+ str(snr_bivariate(x,y) ) )
search_space = {"x":tune.grid_search([x]),"t":tune.grid_search([t]),"y":tune.grid_search([y]),"lambdax": tune.grid_search((0.1)**np.linspace(5,15,7)), "lambdaS": tune.grid_search((0.1)**np.linspace(5,10,7)) , "beta1":tune.grid_search((0.10)**np.linspace(-2,2,5)),"beta2":tune.grid_search((0.10)**np.linspace(-2,2,5)),"sigma2":tune.grid_search([sigma**2])}
config = param_search(objective_STIS,search_space)
model = STIS(t,y,lambdax=config["lambdax"],lambdaS=config["lambdaS"],beta1=config["beta1"],beta2=config["beta2"],sigma2=sigma**2,p=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model = optimize_loop(model,optimizer,numit=1000)
x_stis = quaternion.from_float_array(model.Xquad.detach().numpy())
bsp.utils.visual.plot2D(t,x_stis)
plt.savefig("denoised_via_all_terms.pdf")
print("sigma: " + str(sigma) + " STIS SNR: "+ str(snr_bivariate(x,model.X.detach().numpy())))
# NO STOKES REGULARIZATION
search_space = {"x":tune.grid_search([x]),"t":tune.grid_search([t]),"y":tune.grid_search([y]),"lambdax": tune.grid_search((0.1)**np.linspace(5,15,7)), "lambdaS":tune.grid_search([0.0]) , "beta1":tune.grid_search([0.0]),"beta2":tune.grid_search([0.0]),"sigma2":tune.grid_search([sigma**2])}
config = param_search(objective_STIS,search_space)
model = STIS(t,y,lambdax=config["lambdax"],lambdaS=config["lambdaS"],beta1=config["beta1"],beta2=config["beta2"],sigma2=sigma**2,p=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model = optimize_loop(model,optimizer,numit=1000)
print("sigma: " + str(sigma) + " STIS SNR: "+ str((snr_bivariate(x,model.X.detach().numpy()))))
x_stis = quaternion.from_float_array(model.Xquad.detach().numpy())
bsp.utils.visual.plot2D(t,x_stis)
plt.savefig("denoised_via_no_stokes.pdf")
# ONLY SMOOTH STOKES
search_space = {"x":tune.grid_search([x]),"t":tune.grid_search([t]),"y":tune.grid_search([y]),"lambdax": tune.grid_search([0.0]), "lambdaS": tune.grid_search((0.1)**np.linspace(5,10,7)) , "beta1":tune.grid_search((0.1)**np.linspace(-2,2,7)),"beta2":tune.grid_search((0.1)**np.linspace(-2,2,5)),"sigma2":tune.grid_search([sigma**2])}
config = param_search(objective_STIS,search_space)
model = STIS(t,y,lambdax=config["lambdax"],lambdaS=config["lambdaS"],beta1=config["beta1"],beta2=config["beta2"],sigma2=sigma**2,p=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model = optimize_loop(model,optimizer,numit=1000)
x_stis = quaternion.from_float_array(model.Xquad.detach().numpy())
bsp.utils.visual.plot2D(t,x_stis)
plt.savefig("denoised_via_no_signal_smoother.pdf")
print("sigma: " + str(sigma) + " STIS SNR: "+ str((snr_bivariate(x,model.X.detach().numpy()))))
# Kernel regression on normalized
search_space = {"x":tune.grid_search([x]),"t":tune.grid_search([t]),"y":tune.grid_search([y]),"alpha": tune.grid_search((0.1)**np.linspace(5,15,7)),"lambda_1": tune.grid_search((0.1)**np.linspace(5,15,7)), "lambda_s": tune.grid_search((0.1)**np.linspace(5,10,7)) , "beta":tune.grid_search((0.10)**np.linspace(-2,2,7)),"gamma":tune.grid_search([0.2]),"sigma2":tune.grid_search([sigma**2])}
config = param_search(objective_KReSP,search_space)
model = KReSP(t,y,lambda_1=config["lambda_1"],beta=config["beta"],lambda_s=config["lambda_s"],alpha=config["alpha"],gamma=config["gamma"],eps=10**-7,win_width=32,sigma2=sigma**2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model = optimize_loop(model,optimizer,numit=300)
print("sigma: " + str(sigma) + " KReSP SNR: "+ str(snr_bivariate(x,model.X.detach().numpy())))
x_kerreg = quaternion.from_float_array(model.Xquad.detach().numpy())
bsp.utils.visual.plot2D(t,x_kerreg)
plt.savefig("denoised_via_kerreg.pdf")
```
{}
\ No newline at end of file
File deleted
File deleted
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment