diff --git a/metaflow/graph.py b/metaflow/graph.py index 56948d2e8be..346b2fc5589 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -463,7 +463,9 @@ def populate_block(start_name, end_name): cur_name = cur_node.matching_join elif node_type == "split-switch": all_paths = [ - populate_block(s, end_name) for s in cur_node.out_funcs + populate_block(s, end_name) + for s in cur_node.out_funcs + if s != cur_name ] resulting_list.append(all_paths) cur_name = end_name diff --git a/metaflow/lint.py b/metaflow/lint.py index 1b450544594..8467cbe0d05 100644 --- a/metaflow/lint.py +++ b/metaflow/lint.py @@ -175,6 +175,9 @@ def check_for_acyclicity(graph): def check_path(node, seen): for n in node.out_funcs: + if node.type == "split-switch" and n == node.name: + continue + if n in seen: path = "->".join(seen + [n]) raise LintWarn( @@ -241,6 +244,8 @@ def traverse(node, split_stack): elif node.type == "split-switch": # For a switch, continue traversal down each path with the same stack for n in node.out_funcs: + if node.type == "split-switch" and n == node.name: + continue traverse(graph[n], split_stack) return elif node.type == "end": @@ -297,6 +302,8 @@ def parents(n): new_stack = split_stack for n in node.out_funcs: + if node.type == "split-switch" and n == node.name: + continue traverse(graph[n], new_stack) traverse(graph["start"], []) diff --git a/metaflow/runtime.py b/metaflow/runtime.py index b099401caac..13e629f38a9 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -698,7 +698,7 @@ def _translate_index(self, task, next_step, type, split_index=None): # Store the parameters needed for task creation, so that pushing on items # onto the run_queue is an inexpensive operation. - def _queue_push(self, step, task_kwargs, index=None): + def _queue_push(self, step, task_kwargs, index=None, is_self_loop_switch=False): # In the case of cloning, we set all the cloned tasks as the # finished tasks when pushing tasks using _queue_tasks. This means that we # could potentially try to push the same task multiple times (for example @@ -706,7 +706,7 @@ def _queue_push(self, step, task_kwargs, index=None): # has executed (been cloned) or what has been scheduled and avoid scheduling # it again. if index: - if index in self._ran_or_scheduled_task_index: + if index in self._ran_or_scheduled_task_index and not is_self_loop_switch: # It has already run or been scheduled return # Note that we are scheduling this to run @@ -902,8 +902,14 @@ def _queue_task_switch(self, task, next_steps): raise Exception(msg.format(step=task.step, actual=len(next_steps))) chosen_step = next_steps[0] + is_self_loop_switch = task.step == chosen_step index = self._translate_index(task, chosen_step, "linear") - self._queue_push(chosen_step, {"input_paths": [task.path]}, index) + self._queue_push( + chosen_step, + {"input_paths": [task.path]}, + index, + is_self_loop_switch, + ) def _queue_task_foreach(self, task, next_steps): # CHECK: this condition should be enforced by the linter but