Source code for ferrmion.optimize.rett

"""Reduced entanglement Ternary Tree."""

import logging

import numpy as np
from numpy.typing import NDArray

from ..encode.ternary_tree import TernaryTree
from ..encode.ternary_tree_node import TTNode

logger = logging.getLogger(__name__)


[docs] def reduced_entanglement_ternary_tree( mutual_information: NDArray, cutoff: float = 0.5, max_branches: int | None = None, squash: bool = False, ) -> TernaryTree: """Creates the reduced entanglement TernaryTree. Args: mutual_information (NDArray): A 2D array of mode mutual information. cutoff (float | None): The average MI between spatial orbitals. max_branches (int): The maximum allowed number of Parity branches. squash (bool): Whether to squash the mutual_information from spin-orbit form to spinless. Returns: TernaryTree: A new ternary tree. Note: Assumes that the MI matrix gives MI between spinless spatial orbitals So that each block of four contains [[aa, ab], [ba,bb]] Example: >>> import numpy as np >>> from ferrmion.optimize.rett import reduced_entanglement_tree >>> mi = 0.5 * np.random.random((6,6)) >>> mi = mi + mi.T >>> tree = reduced_entanglement_tree(mi) >>> tree.as_dict() Advanced example (with options): >>> tree = reduced_entanglement_tree(mi, cutoff=0.1, max_branches=2, squash=False) """ logger.debug("Creating Reduced entanglement TT.") enumeration_scheme = {} n_modes = mutual_information.shape[0] n_modes *= 1 if squash else 2 new_tree = TernaryTree(n_modes, root_node=TTNode()) if squash: # First combine the MI information for alpha and beta spins squash_rows = mutual_information[::2] + mutual_information[1::2] squash_matrix = squash_rows[:, ::2] + squash_rows[:, 1::2] squash_matrix *= 0.25 else: squash_matrix = mutual_information mi_rank = np.triu(squash_matrix).flatten().argsort()[::-1] # Convert back to square format from flattened sorted_indices = [np.unravel_index(index, squash_matrix.shape) for index in mi_rank] sorted_indices = [(int(i[0]), int(i[1])) for i in sorted_indices] logger.debug(f"Matrix indices sorted by decreasing MI: {sorted_indices}") branches: list[tuple[int, int, int, int]] = [] unused_indices = {i for i in range(squash_matrix.shape[0])} for squash_index in sorted_indices: if len(set(squash_index)) == 1: logger.warning("MI Matrix contains non-zero diagonal elements, skipping.") continue if max_branches is not None and len(branches) >= max_branches: break if not unused_indices.issuperset(squash_index): logger.debug("Indices %s previously assigned to branch.", squash_index) continue if squash_matrix[squash_index] >= cutoff: branch = ( 2 * squash_index[0], 2 * squash_index[0] + 1, 2 * squash_index[1], 2 * squash_index[1] + 1, ) logger.debug("Adding branch %s", branch) unused_indices.remove(squash_index[0]) unused_indices.remove(squash_index[1]) branches.append(branch) if len(unused_indices) <= 1: break unused_modes = {i for i in range(new_tree.n_qubits)} for i, branch in enumerate(branches): for j, mode in enumerate(branch): node_path = "z" * i + "x" * j new_tree.add_node(node_path) enumeration_scheme[node_path] = (mode, mode) unused_modes.remove(mode) remaining_modes = new_tree.n_qubits - (4 * len(branches)) new_tree.add_node("z" * (remaining_modes + len(branches) - 1)) for node_path in new_tree.root_node.child_strings: if enumeration_scheme.get(node_path, None) is None: mode = unused_modes.pop() enumeration_scheme[node_path] = (mode, mode) logger.debug("Setting enumeration scheme") logger.debug(enumeration_scheme) new_tree.enumeration_scheme = enumeration_scheme return new_tree