Skip to content
Snippets Groups Projects
Commit 6b909cae authored by Rony Abecidan's avatar Rony Abecidan
Browse files

[Update] : Update of the demo notebook and the readme

parent 403740f3
Branches
No related tags found
No related merge requests found
%% Cell type:code id: tags:
``` python
from utils import *
from simulations import *
```
%% Cell type:markdown id: tags:
### Before reading that notebook please follow the instructions of the file [INSTALL.md](./INSTALL.md)
%% Cell type:markdown id: tags:
## I - How to launch a simulation ?
- All you need is to precise some hyperparameters relative to the experiment. Please find below the list of hyperparameters you have to give according to the setup (note that the function `initialize_hyperparameters` can simplify this task) :
| Name of the hyperparameter | Description | Default value (ours) |
|----------------------------|--------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------|
| seed | The seed used to make the training reproducible | 2021 |
| N_fold | The number of fold you use for your cross validation | 3 |
| im_size | The size of the patches used for the training phases | 128 |
| max_epochs | The maximal number of epochs for the training phases | 30 |
| earlystop_patience | The maximal number of epochs we wait before the earlystopping | 5 |
| lr | The initial learning rate for our training phases | 0.0001 |
| train_batch_size | The size of the batch size used during the training phases | 128 |
| eval_batch_size | The size of the batch size used during the evaluation phases | 512 |
| detector_name | The name of the forgery detector you use | 'Bayar' |
| source_path | The filename of your source domain | 'source-none.hdf5' |
| target_path | The filename of your target domain | 'target-qf(5).hdf5' |
| source_name | Name of the source domain (deduced from source_path) | 'source-none' |
| target_name | Name of the target domain (deduce from target_path) | 'target-qf(5)' |
| setup | The setup you consider for your experiment | 'SrcOnly' |
| domain paths | The filenames of the domains used for the evaluation phases | ["target-qf(5).hdf5", "target-qf(10).hdf5", "target-qf(20).hdf5", "target-qf(50).hdf5", "target-qf(100).hdf5", "target-none.hdf5"] |
| domain_names | The name of the domains for the evaluation phases (deduced from domain_path) | ["qf(5)", "qf(10)", "qf(20)", "qf(50)", "qf(100)", "none"] |
| nb_source_max | The maximal number of patches you want to use for the source during training | 10**(8) |
| nb_target_max | The maximal number of patches you want to use for the target during training | 10**(8) |
| save_at_each_epoch | if True, for your first fold only, the weights of the detector will be saved at each epoch | true |
| precisions | Some precisions about the experiments (deduced from source_path and target_path) | s=none_t=qf(5) |
For what follows, note that the source and target filenames are stored in the list `sources` and `targets` implicitly imported above
%% Cell type:code id: tags:
``` python
print(sources)
print(targets)
```
%% Cell type:markdown id: tags:
- Example 1 : We want to test the Experiment `SrcOnly_s=none_t=qf(5)`
%% Cell type:code id: tags:
``` python
hyperparameters=initialize_hyperparameters(source_path=sources[2],target_path=target[0],eval_domains=targets,details=None,save_at_each_epoch=False,setup='SrcOnly')
simulate(hyperparameters)
```
%% Cell type:markdown id: tags:
- Example 2 : We want to test the Experiment `TgtOnly_s=qf(5)_t=qf(5)`
*Technically, the TgtOnly setup is just a SrcOnly setup with an other source. Hence, we didn't explicitly considered a TgtOnly setup in our code*
%% Cell type:code id: tags:
``` python
hyperparameters=initialize_hyperparameters(source_path=sources[0],target_path=target[0],eval_domains=targets,details=None,save_at_each_epoch=False,setup='SrcOnly')
simulate(hyperparameters)
```
%% Cell type:markdown id: tags:
- Example 3 : We want to test the Experiment `Update(sigma=8)_s=None_t=qf(5)`
*For that we need to precise also the bandwiths parameter at the level of each final dense layer. This is possible with an extra hyperparameters 'sigmas' that you need to add*
*You can also precise in 'details' that you choose a specific bandwith for your experiment so that it appeared in the name of the file containing the results*
%% Cell type:code id: tags:
``` python
hyperparameters=initialize_hyperparameters(source_path=sources[2],target_path=target[0],eval_domains=targets,details='sigma=8',save_at_each_epoch=False,setup='Update')
hyperparameters['sigmas']=[8,8,8]
simulate(hyperparameters)
```
%% Cell type:markdown id: tags:
- Example 4 : We want to test the Experiment `Update(sigmas=[2,3,4])_s=None_t=qf(5)_N_t=1000 `
*For that we need to precise also the bandwiths parameter at the level of each final dense layer and change the default value of nb_target_max.*
*You can also precise in 'details' that you choose specific bandwiths for your experiment so that it appeared in the name of the file containing the results*
%% Cell type:code id: tags:
``` python
hyperparameters=initialize_hyperparameters(source_path=sources[2],target_path=target[0],eval_domains=targets,details='sigma=8',save_at_each_epoch=False,setup='Update')
hyperparameters['sigmas']=[2,3,4]
hyperparameters['nb_target_max']=1000
simulate(hyperparameters)
```
%% Cell type:markdown id: tags:
## II - Can I reproduce the nice gif you gave in the Readme to see what is going one for each experiment ?
%% Cell type:code id: tags:
``` python
import torch
import imageio
```
%% Cell type:markdown id: tags:
Of course ! Setting the key `save_at_each_epoch` to True enables to save the weights of your detector at each epoch for the first training phase (first fold).
When you have all the weights, you can use the function below.
It requires use to install imageio doing `pip install imageio`.
Moreover, you need before to obtain a batch and its associated labels from your domain
To do so you can simply do something like below :
```
my_set=MyDataset(f'{your_domain_path}',key1=f'test_0',key2=f'l_test_0')
my_dataloader=DataLoader(my_set, batch_size=512, shuffle=True)
torch.manual_seed(10)
it=iter(my_dataloader)
batch,labels=next(it)
```
Pay attention that you also need to precise again the hyperparameters that you used for your experiment
%% Cell type:code id: tags:
``` python
def create_gif(hyperparameters,batch,labels):
my_detector=ForgeryDetector(hyperparameters)
my_detector.to(device)
for i in range(0,25):
my_detector.load_state_dict(torch.load(f'./Results/{my_detector.folder_path}/{hyperparameters['setup']}-{i+1}.pt'))
my_detector.eval()
embedding=(my_detector(batch)).cpu().detach().numpy()
plt.figure(figsize=(24,8))
norm0=(my_detector(batch[labels==0]).view(-1)).cpu().std().detach().numpy()
norm1=(my_detector(batch[labels==1]).view(-1)).cpu().std().detach().numpy()
plt.hist((target_embedding[labels==0]).reshape(-1)/norm0,alpha=0.5,label='real',bins=50,color='#1ABC9C',density=True);
plt.hist((target_embedding[labels==1]).reshape(-1)/norm1,alpha=0.5,label='forged',bins=50,color='#186A3B',density=True)
plt.plot([0,0],[0,1],color='black',lw=5,linestyle='--',alpha=0.5)
plt.title(f'Distribution of the final embeddings from your domain ({hyperparameters['setup']},epoch {i})');
plt.legend()
plt.xlim(-5,5)
plt.ylim(0,1)
with imageio.get_writer(f'Evolution.gif', mode='I') as writer:
for filename in np.array([10*[f'{i}.png'] for i in range(0,30)]).reshape(-1):
image = imageio.imread(filename)
writer.append_data(image)
```
......
......@@ -24,6 +24,8 @@ To be able to reproduce our experiments and do your own ones, please follow our
To have a quick idea of the adaptation impact on the training phase, we selected a batch of size 512 from the target and, we represented the evolution of the final embeddings distributions from this batch during the training according to the setups **SrcOnly** and **Update($`\sigma=8`$)**
described in the paper. The training relative to the SrcOnly setup is on the left meanwhile the one relative to **Update($`\sigma=8`$)** is on the right.
**Don't hesitate to click on the gif below to see it better !**
![](https://s10.gifyu.com/images/Adaptationf80f69ab9e1dfcaa.gif)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment