Skip to content

Utilities API

This section documents the utility components of OKAPI.

Fitness Functions

Functions for evaluating the fitness of trees.

from okapi.fitness import average_precision_fitness, roc_auc_score_fitness

accuracy_fitness(tree, gt, task='binary')

Calculate the Accuracy score as a fitness measure using torchmetrics.

Accuracy is the proportion of correct predictions among the total number of cases processed. This implementation supports binary, multiclass, and multilabel classification.

Parameters:

Name Type Description Default
tree Tree

The tree whose evaluation will be compared against ground truth

required
gt Tensor

Ground truth tensor containing labels

required
task Literal['binary', 'multiclass', 'multilabel']

Classification task type: - 'binary': Binary classification (default) - 'multiclass': Multiclass classification - 'multilabel': Multilabel classification

'binary'

Returns:

Type Description
float

Accuracy score as a float between 0 and 1 (higher is better)

Source code in okapi/fitness.py
def accuracy_fitness(tree: Tree, gt: Tensor, task: Literal["binary", "multiclass", "multilabel"] = "binary") -> float:
    """
    Calculate the Accuracy score as a fitness measure using torchmetrics.

    Accuracy is the proportion of correct predictions among the total number of cases processed.
    This implementation supports binary, multiclass, and multilabel classification.

    Args:
        tree: The tree whose evaluation will be compared against ground truth
        gt: Ground truth tensor containing labels
        task: Classification task type:
            - 'binary': Binary classification (default)
            - 'multiclass': Multiclass classification
            - 'multilabel': Multilabel classification

    Returns:
        Accuracy score as a float between 0 and 1 (higher is better)
    """
    from torchmetrics.classification import Accuracy

    # Ensure inputs are PyTorch tensors
    if not isinstance(tree.evaluation, torch.Tensor):
        pred = torch.tensor(B.to_numpy(tree.evaluation))
    else:
        pred = tree.evaluation

    if not isinstance(gt, torch.Tensor):
        gt = torch.tensor(B.to_numpy(gt))

    # Infer number of classes from ground truth
    num_classes = _infer_num_classes(gt, task)
    gt = gt.squeeze().int()

    # Create metric with appropriate parameters based on task
    if task == "multiclass":
        metric = Accuracy(task=task, num_classes=num_classes)
    elif task == "multilabel":
        metric = Accuracy(task=task, num_labels=num_classes)
    else:  # binary
        metric = Accuracy(task=task)

    # Calculate and return the score
    return metric(pred, gt).item()

average_precision_fitness(tree, gt, task='binary')

Calculate the Average Precision (AP) score as a fitness measure using torchmetrics.

Average Precision summarizes a precision-recall curve as the weighted mean of precisions achieved at each threshold, with the increase in recall from the previous threshold used as the weight. This implementation supports binary, multiclass, and multilabel classification.

Parameters:

Name Type Description Default
tree Tree

The tree whose evaluation will be compared against ground truth

required
gt Tensor

Ground truth tensor containing labels

required
task Literal['binary', 'multiclass', 'multilabel']

Classification task type: - 'binary': Binary classification (default) - 'multiclass': Multiclass classification - 'multilabel': Multilabel classification

'binary'

Returns:

Type Description
float

Average Precision score as a float between 0 and 1 (higher is better)

Source code in okapi/fitness.py
def average_precision_fitness(tree: Tree, gt: Tensor, task: Literal["binary", "multiclass", "multilabel"] = "binary") -> float:
    """
    Calculate the Average Precision (AP) score as a fitness measure using torchmetrics.

    Average Precision summarizes a precision-recall curve as the weighted mean of precisions
    achieved at each threshold, with the increase in recall from the previous threshold used
    as the weight. This implementation supports binary, multiclass, and multilabel classification.

    Args:
        tree: The tree whose evaluation will be compared against ground truth
        gt: Ground truth tensor containing labels
        task: Classification task type:
            - 'binary': Binary classification (default)
            - 'multiclass': Multiclass classification
            - 'multilabel': Multilabel classification

    Returns:
        Average Precision score as a float between 0 and 1 (higher is better)
    """
    from torchmetrics.classification import AveragePrecision

    # Ensure inputs are PyTorch tensors
    if not isinstance(tree.evaluation, torch.Tensor):
        pred = torch.tensor(B.to_numpy(tree.evaluation))
    else:
        pred = tree.evaluation

    if not isinstance(gt, torch.Tensor):
        gt = torch.tensor(B.to_numpy(gt))

    # Infer number of classes from ground truth
    num_classes = _infer_num_classes(gt, task)
    gt = gt.squeeze().int()

    # Create metric with appropriate parameters based on task
    if task == "multiclass":
        metric = AveragePrecision(task=task, num_classes=num_classes)
    elif task == "multilabel":
        metric = AveragePrecision(task=task, num_labels=num_classes)
    else:  # binary
        if pred.shape[1] == 2:
            pred = pred[:, 1]
        elif pred.shape[1] == 1:
            pred = pred.squeeze()
        else:
            pass
        metric = AveragePrecision(task=task)

    # Calculate and return the score
    return metric(pred, gt).item()

roc_auc_score_fitness(tree, gt, task='binary')

Calculate the Area Under the ROC Curve (AUC-ROC) score as a fitness measure using torchmetrics.

The AUC-ROC score represents the probability that a randomly chosen positive instance is ranked higher than a randomly chosen negative instance. This implementation supports binary, multiclass, and multilabel classification.

Parameters:

Name Type Description Default
tree Tree

The tree whose evaluation will be compared against ground truth

required
gt Tensor

Ground truth tensor containing labels

required
task Literal['binary', 'multiclass', 'multilabel']

Classification task type: - 'binary': Binary classification (default) - 'multiclass': Multiclass classification - 'multilabel': Multilabel classification

'binary'

Returns:

Type Description
float

ROC AUC score as a float between 0 and 1 (higher is better)

Source code in okapi/fitness.py
def roc_auc_score_fitness(tree: Tree, gt: Tensor, task: Literal["binary", "multiclass", "multilabel"] = "binary") -> float:
    """
    Calculate the Area Under the ROC Curve (AUC-ROC) score as a fitness measure using torchmetrics.

    The AUC-ROC score represents the probability that a randomly chosen positive instance
    is ranked higher than a randomly chosen negative instance. This implementation supports
    binary, multiclass, and multilabel classification.

    Args:
        tree: The tree whose evaluation will be compared against ground truth
        gt: Ground truth tensor containing labels
        task: Classification task type:
            - 'binary': Binary classification (default)
            - 'multiclass': Multiclass classification
            - 'multilabel': Multilabel classification

    Returns:
        ROC AUC score as a float between 0 and 1 (higher is better)
    """
    from torchmetrics.classification import AUROC

    # Ensure inputs are PyTorch tensors
    if not isinstance(tree.evaluation, torch.Tensor):
        pred = torch.tensor(B.to_numpy(tree.evaluation))
    else:
        pred = tree.evaluation

    if not isinstance(gt, torch.Tensor):
        gt = torch.tensor(B.to_numpy(gt))

    # Infer number of classes from ground truth
    num_classes = _infer_num_classes(gt, task)
    gt = gt.squeeze().int()

    # Create metric with appropriate parameters based on task
    if task == "multiclass":
        metric = AUROC(task=task, num_classes=num_classes)
    elif task == "multilabel":
        metric = AUROC(task=task, num_labels=num_classes)
    else:  # binary
        metric = AUROC(task=task)

    # Calculate and return the score
    return metric(pred, gt).item()

Callbacks

The callback system allows customizing the evolutionary process.

from okapi.callback import Callback

Base class for callbacks that can be triggered during the evolutionary process.

Callbacks allow monitoring and potentially modifying the evolution process at specific points: before/after each generation, and at the start/end of the entire evolution. Custom callbacks should inherit from this class and override the methods corresponding to the desired intervention points.

Source code in okapi/callback.py
class Callback:
    """
    Base class for callbacks that can be triggered during the evolutionary process.

    Callbacks allow monitoring and potentially modifying the evolution process at
    specific points: before/after each generation, and at the start/end of the
    entire evolution. Custom callbacks should inherit from this class and override
    the methods corresponding to the desired intervention points.
    """

    def __init__(self) -> None:
        pass

    def on_generation_end(self, okapi: "Okapi") -> None:
        """
        Called at the end of each generation.

        This hook is triggered after a generation (iteration) of evolution has completed,
        including selection, crossover, and mutation operations.

        Args:
            okapi: The Okapi instance running the evolution
        """
        pass

    def on_evolution_end(self, okapi: "Okapi") -> None:
        """
        Called at the end of the entire evolution process.

        This hook is triggered when all generations have been completed or
        when the evolution process is manually stopped.

        Args:
            okapi: The Okapi instance running the evolution
        """
        pass

    def on_evolution_start(self, okapi: "Okapi") -> None:
        """
        Called at the start of the evolution process.

        This hook is triggered before any generations are run, after the
        initial population has been created.

        Args:
            okapi: The Okapi instance running the evolution
        """
        pass

    def on_generation_start(self, okapi: "Okapi") -> None:
        """
        Called at the start of each generation.

        This hook is triggered before a generation (iteration) of evolution begins,
        before any selection, crossover, or mutation operations.

        Args:
            okapi: The Okapi instance running the evolution
        """
        pass

on_evolution_end(okapi)

Called at the end of the entire evolution process.

This hook is triggered when all generations have been completed or when the evolution process is manually stopped.

Parameters:

Name Type Description Default
okapi Okapi

The Okapi instance running the evolution

required
Source code in okapi/callback.py
def on_evolution_end(self, okapi: "Okapi") -> None:
    """
    Called at the end of the entire evolution process.

    This hook is triggered when all generations have been completed or
    when the evolution process is manually stopped.

    Args:
        okapi: The Okapi instance running the evolution
    """
    pass

on_evolution_start(okapi)

Called at the start of the evolution process.

This hook is triggered before any generations are run, after the initial population has been created.

Parameters:

Name Type Description Default
okapi Okapi

The Okapi instance running the evolution

required
Source code in okapi/callback.py
def on_evolution_start(self, okapi: "Okapi") -> None:
    """
    Called at the start of the evolution process.

    This hook is triggered before any generations are run, after the
    initial population has been created.

    Args:
        okapi: The Okapi instance running the evolution
    """
    pass

on_generation_end(okapi)

Called at the end of each generation.

This hook is triggered after a generation (iteration) of evolution has completed, including selection, crossover, and mutation operations.

Parameters:

Name Type Description Default
okapi Okapi

The Okapi instance running the evolution

required
Source code in okapi/callback.py
def on_generation_end(self, okapi: "Okapi") -> None:
    """
    Called at the end of each generation.

    This hook is triggered after a generation (iteration) of evolution has completed,
    including selection, crossover, and mutation operations.

    Args:
        okapi: The Okapi instance running the evolution
    """
    pass

on_generation_start(okapi)

Called at the start of each generation.

This hook is triggered before a generation (iteration) of evolution begins, before any selection, crossover, or mutation operations.

Parameters:

Name Type Description Default
okapi Okapi

The Okapi instance running the evolution

required
Source code in okapi/callback.py
def on_generation_start(self, okapi: "Okapi") -> None:
    """
    Called at the start of each generation.

    This hook is triggered before a generation (iteration) of evolution begins,
    before any selection, crossover, or mutation operations.

    Args:
        okapi: The Okapi instance running the evolution
    """
    pass

Visualization

Functions for visualizing trees.

from okapi.draw import draw_tree

Create a visual representation of a tree structure using Graphviz.

This function generates a directed graph visualization of a tree structure, showing nodes and their hierarchical relationships. For ValueNodes, it can optionally display the tensor values and evaluations if they're small enough.

Parameters:

Name Type Description Default
to_draw Node | OperatorNode | Tree | ValueNode

The object to visualize (can be a Tree, Node, OperatorNode, or ValueNode)

required
dot

Optional existing Digraph object to add to. If None, a new one is created.

None
add_val_eval

If True, include value and evaluation information for ValueNodes

True

Returns:

Type Description

A Graphviz Digraph object representing the tree structure

Source code in okapi/draw.py
def draw_tree(to_draw: Node | OperatorNode | Tree | ValueNode, dot=None, add_val_eval=True, remove_after_dot=False):
    """
    Create a visual representation of a tree structure using Graphviz.

    This function generates a directed graph visualization of a tree structure,
    showing nodes and their hierarchical relationships. For ValueNodes, it can
    optionally display the tensor values and evaluations if they're small enough.

    Args:
        to_draw: The object to visualize (can be a Tree, Node, OperatorNode, or ValueNode)
        dot: Optional existing Digraph object to add to. If None, a new one is created.
        add_val_eval: If True, include value and evaluation information for ValueNodes

    Returns:
        A Graphviz Digraph object representing the tree structure
    """
    node: None | Node | OperatorNode | ValueNode
    if isinstance(to_draw, Tree):
        node = to_draw.root
    else:
        node = to_draw

    if dot is None:
        dot = Digraph(comment="Tree")
        dot.attr('node', shape='box', width='1.0', margin='0.1')

    if isinstance(node, ValueNode):
        if node.value is not None:
            value = B.to_numpy(node.value) if (np.prod(node.value.shape) <= 9) else f"Tensor with memory adress: {hex(id(node.value))}"
        else:
            value = None

        if node.evaluation is not None:
            evaluation = (
                B.to_numpy(node.evaluation) if (np.prod(B.shape(node.evaluation)) <= 9) else f"Tensor with memory adress: {hex(id(node.evaluation))}"
            )
        else:
            evaluation = None

        display_string = "Value Node\n"

        if node.id is not None:
            node_id = str(node.id)
            if remove_after_dot:
                node_id = node_id.split('.')[0]
            display_string += f"Model ID:\n {node_id}\n"

        if add_val_eval:
            display_string += f"Value: {value} | Eval: {evaluation}"

        dot.node(
            f"{hex(id(node))}",
            display_string,
        )
    else:
        dot.node(f"{hex(id(node))}", f"Op\n{str(node)}")

    for child in node.children:
        draw_tree(child, dot, add_val_eval, remove_after_dot)
        dot.edge(f"{hex(id(node))}", f"{hex(id(child))}")

    return dot

Postprocessing Functions

Functions for postprocessing tree outputs.

from okapi.functions import scale_vector_to_sum_1, set_multiclass_postprocessing

scale_vector_to_sum_1(tensor)

Normalize a tensor so that each vector sums to 1.

This function scales each vector in the tensor along the last dimension such that its elements sum to 1, making it suitable for representing probability distributions.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor to be normalized

required

Returns:

Type Description

Normalized tensor where each vector sums to 1

Source code in okapi/functions.py
def scale_vector_to_sum_1(tensor: Tensor):
    """
    Normalize a tensor so that each vector sums to 1.

    This function scales each vector in the tensor along the last dimension
    such that its elements sum to 1, making it suitable for representing
    probability distributions.

    Args:
        tensor: Input tensor to be normalized

    Returns:
        Normalized tensor where each vector sums to 1
    """
    return tensor / B.unsqueeze(B.sum(tensor, axis=-1), -1)

set_multiclass_postprocessing()

Configure the global postprocessing function for multiclass classification.

This function sets the global postprocessing to normalize tensor outputs so that they can be interpreted as probability distributions, which is required for multiclass classification tasks. This is important if probas and not logits are an output.

Source code in okapi/functions.py
def set_multiclass_postprocessing():
    """
    Configure the global postprocessing function for multiclass classification.

    This function sets the global postprocessing to normalize tensor outputs so that
    they can be interpreted as probability distributions, which is required for
    multiclass classification tasks. This is important if probas and not logits are an output.
    """
    set_postprocessing_function(scale_vector_to_sum_1)

Other Utilities

Additional utility functions.

from okapi.utils import Pickle, first_uniques_mask

Pickle

A utility class for serializing and deserializing Python objects using pickle.

This class provides static methods for saving objects to files and loading them back, which is particularly useful for persisting tree architectures.

Source code in okapi/utils.py
class Pickle:
    """
    A utility class for serializing and deserializing Python objects using pickle.

    This class provides static methods for saving objects to files and loading them back,
    which is particularly useful for persisting tree architectures.
    """

    @staticmethod
    def load(path):
        """
        Load a Python object from a pickle file.

        Args:
            path: File path to load the object from

        Returns:
            The deserialized Python object
        """
        logger.debug(f"Loading pickle file from: {path}")
        try:
            with open(path, "rb") as file:
                obj = pickle.load(file)
            logger.debug(f"Successfully loaded object from {path}")
            return obj
        except Exception as e:
            logger.error(f"Failed to load pickle file from {path}: {str(e)}")
            raise

    @staticmethod
    def save(path, obj):
        """
        Save a Python object to a pickle file.

        Args:
            path: File path where the object will be saved
            obj: The Python object to serialize and save
        """
        logger.debug(f"Saving object to pickle file: {path}")
        try:
            with open(path, "wb") as file:
                pickle.dump(obj, file)
            logger.debug(f"Successfully saved object to {path}")
        except Exception as e:
            logger.error(f"Failed to save object to {path}: {str(e)}")
            raise

load(path) staticmethod

Load a Python object from a pickle file.

Parameters:

Name Type Description Default
path

File path to load the object from

required

Returns:

Type Description

The deserialized Python object

Source code in okapi/utils.py
@staticmethod
def load(path):
    """
    Load a Python object from a pickle file.

    Args:
        path: File path to load the object from

    Returns:
        The deserialized Python object
    """
    logger.debug(f"Loading pickle file from: {path}")
    try:
        with open(path, "rb") as file:
            obj = pickle.load(file)
        logger.debug(f"Successfully loaded object from {path}")
        return obj
    except Exception as e:
        logger.error(f"Failed to load pickle file from {path}: {str(e)}")
        raise

save(path, obj) staticmethod

Save a Python object to a pickle file.

Parameters:

Name Type Description Default
path

File path where the object will be saved

required
obj

The Python object to serialize and save

required
Source code in okapi/utils.py
@staticmethod
def save(path, obj):
    """
    Save a Python object to a pickle file.

    Args:
        path: File path where the object will be saved
        obj: The Python object to serialize and save
    """
    logger.debug(f"Saving object to pickle file: {path}")
    try:
        with open(path, "wb") as file:
            pickle.dump(obj, file)
        logger.debug(f"Successfully saved object to {path}")
    except Exception as e:
        logger.error(f"Failed to save object to {path}: {str(e)}")
        raise

first_uniques_mask(arr)

Create a boolean mask that identifies the first occurrence of each unique item in an array.

This function is useful for filtering duplicates from an array while preserving the order of first appearances.

Parameters:

Name Type Description Default
arr

An array-like object to analyze

required

Returns:

Type Description

A list of booleans where True indicates the first occurrence of a value and

False indicates a duplicate of a previously seen value

Source code in okapi/utils.py
def first_uniques_mask(arr):
    """
    Create a boolean mask that identifies the first occurrence of each unique item in an array.

    This function is useful for filtering duplicates from an array while preserving the order
    of first appearances.

    Args:
        arr: An array-like object to analyze

    Returns:
        A list of booleans where True indicates the first occurrence of a value and
        False indicates a duplicate of a previously seen value
    """
    logger.trace(f"Creating unique items mask for array of length {len(arr)}")
    mask = []
    seen = set()
    unique_count = 0

    for item in arr:
        if item not in seen:
            mask.append(True)
            seen.add(item)
            unique_count += 1
        else:
            mask.append(False)

    logger.trace(f"Found {unique_count} unique items out of {len(arr)} total items")
    return mask

mark_paths(list_of_paths)

Mark each path in the list with its type (directory or file) or None if it doesn't exist.

Parameters:

Name Type Description Default
list_of_paths

A list of paths to be marked.

required

Returns:

Type Description
tuple[list[str | None], bool]

A tuple containing: - A list of strings or None values representing the type of each path. - A boolean indicating whether all paths are of the same type.

Source code in okapi/utils.py
def mark_paths(list_of_paths) -> tuple[list[str | None], bool]:
    """
    Mark each path in the list with its type (directory or file) or None if it doesn't exist.

    Args:
        list_of_paths: A list of paths to be marked.

    Returns:
        A tuple containing:
            - A list of strings or None values representing the type of each path.
            - A boolean indicating whether all paths are of the same type.
    """
    marked_paths: list[str | None] = []

    for path in list_of_paths:
        if os.path.exists(path):
            if os.path.isdir(path):
                marked_paths.append("dir")
            else:
                marked_paths.append("file")
        else:
            marked_paths.append(None)
    all_same = all(item == marked_paths[0] for item in marked_paths)
    return marked_paths, all_same