Skip to content
Snippets Groups Projects
Commit 9679f003 authored by Levecque Etienne's avatar Levecque Etienne
Browse files

feat: change clipped blocks behavior

parent 554073f8
No related merge requests found
......@@ -69,7 +69,11 @@ class Block:
open_set = {start.data.tobytes().__hash__()}
s = c * n * m # (channel * row * column)
mask = np.ravel(pipeline.pipelines[0].quant_tbl.astype(float) <= pipeline.upper_bound)
upper_bound_offset = 0
if self.is_clipped[pipeline]:
upper_bound_offset = pipeline.upper_bound
mask = np.ravel(pipeline.pipelines[0].quant_tbl.astype(float) <= pipeline.upper_bound + upper_bound_offset)
n_changes = np.sum(mask)
changes = np.stack([np.eye(s, dtype=np.int8)[mask],
-np.eye(s, dtype=np.int8)[mask]]).reshape(2 * n_changes, c, n, m)
......@@ -93,7 +97,7 @@ class Block:
children.flags.writeable = True
distance = np.abs((children - float_start) * pipeline.pipelines[0].quant_tbl)
norm_distance = np.linalg.norm(distance, axis=(2, 3), ord='fro')
not_ignored = not_ignored & np.all(norm_distance <= np.ravel(pipeline.upper_bound), axis=-1)
not_ignored = not_ignored & np.all(norm_distance <= np.ravel(pipeline.upper_bound + upper_bound_offset), axis=-1)
if np.any(not_ignored):
transformed_children = pipeline.forward(children[not_ignored])
......
......@@ -6,9 +6,11 @@ import shutil
import argparse
import multiprocessing
import concurrent.futures
import pandas as pd
from time import strftime, localtime
from rich import progress
from PIL import Image as pimg
from antecedent.image import Image
from antecedent.pipeline import create_pipeline
......@@ -126,10 +128,7 @@ def to_csv(path_to_csv, image):
status = block.status[pipeline]
antecedent = block.antecedents[pipeline]
iteration_heuristic = block.iterations[pipeline]
if status != 1 and (pipeline in block.is_clipped and block.is_clipped[pipeline]):
status = 3
antecedent = None
iteration_heuristic = 0
is_clipped = block.is_clipped[pipeline]
iteration_gurobi = 0
if antecedent is not None:
sol = str(list(np.ravel(antecedent).astype(int)))
......@@ -137,7 +136,7 @@ def to_csv(path_to_csv, image):
sol = None
writer.writerow(
[filename, str(pos), str(pipeline), str(status), str(iteration_heuristic),
[filename, str(pos), str(pipeline), str(status), str(is_clipped), str(iteration_heuristic),
str(iteration_gurobi), sol])
......@@ -186,7 +185,8 @@ def create_output_folder(config, config_path, verbose):
with open(output_file, "w") as f:
writer = csv.writer(f)
writer.writerow(
['filename', 'pos', 'pipeline', 'status', 'iteration_heuristic', 'iteration_gurobi', 'solution'])
['filename', 'pos', 'pipeline', 'status', 'is_clipped', 'iteration_heuristic', 'iteration_gurobi',
'solution'])
if verbose:
print(f'Output folder successfully created at: \n{output_dir}')
return output_file, image_mask_output_dir, npy_mask_output_dir
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment