Skip to content
Snippets Groups Projects
Commit b89a770d authored by Levecque Etienne's avatar Levecque Etienne
Browse files

fix: add small corrections

parent 48b2802a
Branches
No related tags found
No related merge requests found
......@@ -94,11 +94,11 @@ class Image:
self.trivial_search(blocks)
n = len(blocks)
for i, block in enumerate(blocks):
self.init_shared_dict(shared_dict, task_id, i, n)
if block.status[self.pipeline] == 0:
self.update_shared_dict(shared_dict, task_id, i, n)
if self.pipeline not in block.status:
block.search_antecedent(self.pipeline, max_iter, shared_dict, task_id, verbose)
def init_shared_dict(self, shared_dict, task_id, i, n, completed=0, total=0):
def update_shared_dict(self, shared_dict, task_id, i, n, completed=0, total=0):
if shared_dict is not None and task_id is not None:
shared_dict[task_id] = {"filename": self.filename,
"current_block": i + 1,
......@@ -119,7 +119,7 @@ class Image:
all_status = []
n = len(single_channel_blocks)
for i, block in enumerate(single_channel_blocks):
self.init_shared_dict(shared_dict, task_id, i, n)
self.update_shared_dict(shared_dict, task_id, i, n)
if block.status[grayscale_pipeline] == 1: # Solved with the trivial search
antecedents.append(block.antecedents[grayscale_pipeline])
all_status.append(1)
......@@ -148,14 +148,17 @@ class Image:
block.iterations[self.pipeline] = np.sum(iterations)
block.status[self.pipeline] = status
def trivial_search(self, blocks, pipeline=None):
def trivial_search(self, blocks, shared_dict=None, task_id=None, pipeline=None):
if pipeline is None:
pipeline = self.pipeline
block_values = np.concatenate([block.value for block in blocks])
block_values = np.stack([block.value.reshape(-1,8,8) for block in blocks])
starts = pipeline.backward(block_values)
transformed_starts = pipeline.forward(starts)
completed = 0
for i in range(len(starts)):
if np.allclose(block_values[i], transformed_starts[i]):
completed += 1
self.update_shared_dict(shared_dict, task_id, i, len(starts), completed=completed, total=len(starts))
blocks[i].antecedents[pipeline] = starts[i]
blocks[i].iterations[pipeline] = 0
blocks[i].status[pipeline] = 1
......@@ -175,7 +178,7 @@ class Image:
'NodefileStart': parameters['node_file_start'],
'Cutoff': parameters['cutoff']}
for i, block in enumerate(blocks):
self.init_shared_dict(shared_dict, task_id, i, len(blocks), 0, 0)
self.update_shared_dict(shared_dict, task_id, i, len(blocks), 0, 0)
block.search_antecedent_ilp(gurobi_parameters, shared_dict, task_id, verbose)
def classify(self, likelihood_file):
......
......@@ -123,7 +123,7 @@ def to_csv(path, image):
status = block.status[pipeline]
antecedent = block.antecedents[pipeline]
iteration_heuristic = block.iterations[pipeline]
if status != 1 and block.is_clipped[pipeline]:
if status != 1 and (pipeline in block.is_clipped and block.is_clipped[pipeline]):
status = 3
antecedent = None
iteration_heuristic = 0
......@@ -134,7 +134,7 @@ def to_csv(path, image):
sol = None
writer.writerow(
[filename, str(block.pos), str(pipeline), str(status), str(iteration_heuristic),
[filename, str(pos), str(pipeline), str(status), str(iteration_heuristic),
str(iteration_gurobi), sol])
......@@ -147,7 +147,8 @@ def validate_config(config):
raise ValueError(f"starting_index must be lower than ending_index.")
if data['preprocessing']['percentage_block_per_image'] <= 0:
raise ValueError("The percentage of block used per image must be strictly positive")
create_pipeline(antecedent_search['pipeline'], antecedent_search['quality'], True)
# TODO: improve pipeline verification
# create_pipeline(antecedent_search['pipeline'], antecedent_search['quality'], True)
def create_output_folder(config, config_path, verbose):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment