Source code for ferrmion.optimize.bonsai

"""Bonsai Algorithm."""

import logging

import numpy as np
import rustworkx as rx

from ferrmion import TernaryTree
from ferrmion.encode.ternary_tree_node import TTNode

logger = logging.getLogger(__name__)


[docs] def bonsai_algorithm( graph: rx.PyGraph, homogenous: bool = False, max_nodes: None | int = None, ) -> TernaryTree: """Create a TernaryTree encoding using the Bonsai Algorithm. Args: graph (rx.PyGraph): A RustworkX graph of device qubit-connectivity. homogenous (bool): "homogenous" labelling if true, else "heterogenous" max_nodes (int): Maximum number of nodes to include in output tree. Returns: TernaryTree: A ternary tree encoding. """ logger.debug("Starting Bonsai Algorithm.") if homogenous: chars = ["x", "y", "z"] else: chars = ["z", "x", "y"] distances = rx.distance_matrix(graph) root_index = int(np.argmin(np.max(distances, axis=1))) node_queue = [root_index] used_indices = {root_index} nodes = [TTNode(parent=None) for _ in range(graph.num_nodes())] nodes[root_index].root_path = "" nodes[root_index].qubit_label = root_index while len(node_queue) > 0: logger.debug(f"{node_queue=}") logger.debug(f"{used_indices=}") node: int = node_queue.pop(0) parent = nodes[node] logger.debug(node) neighbors = sorted(list(set(graph.neighbors(node)).difference(used_indices))) logger.debug(f"{neighbors=}") n_neighbors = len(neighbors) for neighbor, char in zip(neighbors[:3], chars[:n_neighbors]): node_queue.append(neighbor) used_indices.add(neighbor) parent.add_child( char, child_node=nodes[neighbor], root_path=f"{parent.root_path}{char}", qubit_label=neighbor, ) nodes[neighbor].parent = parent if max_nodes is not None and len(used_indices) == max_nodes: break if max_nodes is not None and len(used_indices) == max_nodes: break logger.debug(node_queue) logger.debug("") if max_nodes is not None and len(used_indices) == max_nodes: logger.debug(f"Found {max_nodes} nodes.") elif len(used_indices) == graph.num_nodes(): logger.debug("Found spanning tree") else: logger.debug("Tree does not span the graph.") unused_indices = set(range(graph.num_nodes())).difference(used_indices) for unused in unused_indices: closest = np.argsort(distances[unused]) for used in closest: used_node: TTNode = nodes[used] for child_branch in chars: if getattr(used_node, child_branch) is None: used_node.add_child( child_branch, child_node=nodes[unused], root_path=f"{used_node.root_path}{child_branch}", qubit_label=unused, ) unused_indices.remove(used) break if len(unused_indices) > 0: logger.debug("Error, not all qubits assigned to nodes.") else: logger.debug("All graph nodes assigned to tree.") logger.debug("Creating encoding.") n_nodes = graph.num_nodes() if max_nodes is None else max_nodes tree = TernaryTree( n_modes=n_nodes, n_qubits=graph.num_nodes(), root_node=nodes[root_index] ) tree.enumeration_scheme = tree.default_enumeration_scheme() return tree