Skip to content
Snippets Groups Projects
Commit 316e9a94 authored by Levecque Etienne's avatar Levecque Etienne
Browse files

feat: add mask and npy outputs

parent 00825b59
Branches
No related tags found
No related merge requests found
......@@ -16,9 +16,10 @@ from antecedent.pipeline import create_pipeline
def main(config, config_path, verbose=False):
validate_config(config)
path_to_csv = create_output_folder(config, config_path, verbose)
path_to_csv, image_mask_output_dir, npy_mask_output_dir = create_output_folder(config, config_path, verbose)
input_path = config['data']['input_path']
filenames = sorted(os.listdir(config['data']['input_path']))
filenames = sorted(os.listdir(input_path))
start, stop = config['data']['starting_index'], config['data']['ending_index']
if stop == -1:
stop = len(filenames)
......@@ -31,7 +32,7 @@ def main(config, config_path, verbose=False):
"[progress.percentage]{task.percentage:>3.0f}%",
progress.TimeRemainingColumn(),
progress.TimeElapsedColumn(),
refresh_per_second=2) as progress_bar:
refresh_per_second=10) as progress_bar:
with concurrent.futures.ProcessPoolExecutor(max_workers=config['antecedent_search']['max_workers']) as executor:
for filename in filenames[start:stop]:
......@@ -48,14 +49,16 @@ def main(config, config_path, verbose=False):
overall_task = progress_bar.add_task("Overall", visible=True)
while sum([future.done() for future in written_to_csv]) < len(written_to_csv):
update_bar(progress_bar, written_to_csv, overall_task, shared_dict)
try:
for future in concurrent.futures.as_completed(written_to_csv, timeout=1):
if not written_to_csv[future]:
to_csv(path_to_csv, future.result())
written_to_csv[future] = True
except concurrent.futures.TimeoutError:
continue
update_bar(progress_bar, written_to_csv, overall_task, shared_dict)
for future in concurrent.futures.as_completed(written_to_csv):
# FIXME: When the writing process is longer than timeout, some image can be skipped entirely!
#if not written_to_csv[future]:
to_csv(path_to_csv, future.result())
#written_to_csv[future] = True
# except concurrent.futures.TimeoutError:
# continue
get_image_from_output(path_to_csv, image_mask_output_dir, npy_mask_output_dir, input_path)
def update_bar(bar, futures, overall_task, shared_dict):
......@@ -105,7 +108,7 @@ def handle_image(filename, config, shared_dict, task_id, verbose):
return img
def to_csv(path, image):
def to_csv(path_to_csv, image):
"""
Write data to a csv file.
......@@ -115,7 +118,7 @@ def to_csv(path, image):
"""
filename = image.filename
pipeline = image.pipeline
with open(path, 'a') as f:
with open(path_to_csv, 'a') as f:
writer = csv.writer(f)
for pos, block in image.block_collection.items():
......@@ -172,8 +175,12 @@ def create_output_folder(config, config_path, verbose):
output_dir = os.path.join(exp_dir, output_dir_name)
output_file = os.path.join(output_dir, "outputs.csv")
image_mask_output_dir = os.path.join(exp_dir, output_dir_name, 'image_mask')
npy_mask_output_dir = os.path.join(exp_dir, output_dir_name, 'npy_mask')
os.makedirs(exp_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
os.makedirs(image_mask_output_dir, exist_ok=True)
os.makedirs(npy_mask_output_dir, exist_ok=True)
shutil.copy2(config_path.name, os.path.join(output_dir, "config.yaml"))
with open(output_file, "w") as f:
......@@ -182,7 +189,7 @@ def create_output_folder(config, config_path, verbose):
['filename', 'pos', 'pipeline', 'status', 'iteration_heuristic', 'iteration_gurobi', 'solution'])
if verbose:
print(f'Output folder successfully created at: \n{output_dir}')
return output_file
return output_file, image_mask_output_dir, npy_mask_output_dir
def parse_args():
......@@ -212,6 +219,23 @@ def parse_args():
return config, args.verbose, args.config_file, args.job_id
def get_image_from_output(path_to_csv, image_mask_output_dir, npy_mask_output_dir, input_path):
df = pd.read_csv(path_to_csv)
for filename, group in df.groupby('filename'):
with pimg.open(os.path.join(input_path, filename)) as img:
n, m = img.height, img.width
mask = np.zeros((n, m), dtype=np.uint8)
for i in range(group.shape[0]):
pos = eval(group.iloc[i]['pos'])
status = int(group.iloc[i]['status'])
is_clipped = bool(group.iloc[i]['is_clipped'])
if (status == -1 and not is_clipped) or (status == 0):# and not is_clipped):
mask[pos[0] * 8: (pos[0] + 1) * 8, pos[1] * 8: (pos[1] + 1) * 8] = 255
pimg.fromarray(mask).save(os.path.join(image_mask_output_dir, filename) + '.png')
np.save(os.path.join(npy_mask_output_dir, filename), mask)
if __name__ == "__main__":
config, verbose, config_path, _ = parse_args()
main(config, config_path, verbose)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment