Skip to content
Snippets Groups Projects
Commit 57ccb6ec authored by Levecque Etienne's avatar Levecque Etienne
Browse files

feat: add verbose support without multiprocessing

parent 1f91a978
No related branches found
No related tags found
No related merge requests found
......@@ -7,10 +7,14 @@ from antecedent.pipeline import ComposedPipeline
import antecedent.utils as utils
def send_logs(verbose, shared_dict, task_id, iteration, max_iteration):
if verbose:
def send_logs(verbose, shared_dict, task_id, iteration, max_iteration, done):
if verbose and shared_dict is not None and iteration % 100 == 0:
tmp_dict = shared_dict[task_id]
shared_dict[task_id] = tmp_dict | {'completed': iteration, 'total': max_iteration}
elif verbose and done:
print(f"DONE {iteration}/{max_iteration}", end="\r")
elif verbose:
print(f"RUNNING {iteration}/{max_iteration}", end="\r")
class Block:
......@@ -50,13 +54,14 @@ class Block:
iter_count: iteration counter
"""
np.random.seed(123)
iter_counter = 0
c, n, m = self.value.shape
target = np.array([self.value], dtype=np.int16)
float_start = pipeline.backward(target)
start = utils.round(float_start).astype(np.int16)
if np.allclose(pipeline.forward(start), target):
send_logs(verbose, shared_dict, task_id, max_iter, max_iter)
send_logs(verbose, shared_dict, task_id, max_iter, max_iter, True)
self.status[pipeline] = 1
self.antecedents[pipeline] = (start, 0)
return self.antecedents[pipeline]
......@@ -81,8 +86,7 @@ class Block:
if not queue:
break
if iter_counter % 100 == 0:
send_logs(verbose, shared_dict, task_id, iter_counter, max_iter)
send_logs(verbose, shared_dict, task_id, iter_counter, max_iter, False)
_, _, current = heappop(queue) # shape (1, 1 for grayscale or 3, 8, 8)
children = current + changes
......@@ -104,7 +108,7 @@ class Block:
error_idx = np.argsort(error)
if error[error_idx[0]] == 0: # check only the first element which is the smallest error
send_logs(verbose, shared_dict, task_id, max_iter, max_iter)
send_logs(verbose, shared_dict, task_id, max_iter, max_iter, True)
self.antecedents[pipeline] = (children[not_ignored][error_idx[0]], iter_counter)
self.status[pipeline] = 1
......@@ -119,7 +123,7 @@ class Block:
heappush(queue, to_enqueue)
open_set.add(children_hash[not_ignored][idx])
send_logs(verbose, shared_dict, task_id, max_iter, max_iter)
send_logs(verbose, shared_dict, task_id, max_iter, max_iter, True)
if iter_counter < max_iter:
self.antecedents[pipeline] = (False, iter_counter + 1)
self.status[pipeline] = -1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment