Skip to content

Core API

This section documents the core components of OKAPI.

Okapi Class

The Okapi class is the main entry point for evolutionary model ensemble optimization.

from okapi import Okapi

Main class for evolutionary model ensemble optimization.

Okapi uses genetic programming to evolve tree-based ensembles of machine learning models. The algorithm creates a population of trees where each tree represents a different way of combining model predictions. Through evolution (crossover and mutation), it searches for optimal ensemble structures that maximize a fitness function.

Each tree has ValueNodes that contain tensor predictions from individual models, and OperatorNodes that define how to combine these predictions (e.g., mean, min, max, weighted mean). The evolution process selects and combines high-performing trees to produce better ensembles.

Attributes:

Name Type Description
population_size

Number of individuals in the population

population_multiplier

Factor determining how many additional trees to generate in each iteration

tournament_size

Number of trees to consider in tournament selection

fitness_function

Function used to evaluate the fitness of each tree

callbacks

Collection of callbacks for monitoring/modifying the evolution process

allowed_ops

Operator node types allowed in tree construction

train_tensors

Dictionary mapping model names to their prediction tensors

gt_tensor

Ground truth tensor for comparison

population

Current population of trees

additional_population List[Tree]

Additional trees generated during evolution

Source code in okapi/okapi.py
class Okapi:
    """
    Main class for evolutionary model ensemble optimization.

    Okapi uses genetic programming to evolve tree-based ensembles of machine learning models.
    The algorithm creates a population of trees where each tree represents a different way of
    combining model predictions. Through evolution (crossover and mutation), it searches for
    optimal ensemble structures that maximize a fitness function.

    Each tree has ValueNodes that contain tensor predictions from individual models, and
    OperatorNodes that define how to combine these predictions (e.g., mean, min, max, weighted mean).
    The evolution process selects and combines high-performing trees to produce better ensembles.

    Attributes:
        population_size: Number of individuals in the population
        population_multiplier: Factor determining how many additional trees to generate in each iteration
        tournament_size: Number of trees to consider in tournament selection
        fitness_function: Function used to evaluate the fitness of each tree
        callbacks: Collection of callbacks for monitoring/modifying the evolution process
        allowed_ops: Operator node types allowed in tree construction
        train_tensors: Dictionary mapping model names to their prediction tensors
        gt_tensor: Ground truth tensor for comparison
        population: Current population of trees
        additional_population: Additional trees generated during evolution
    """

    def __init__(
        self,
        preds_source: Union[Path, str, Iterable[Path], Iterable[str]],
        gt_path: Union[Path, str, Iterable[Path], Iterable[str]],
        population_size: int,
        population_multiplier: int,
        tournament_size: int,
        minimize_node_count: bool = True,
        objective_functions: Sequence[Callable[[Tree, lib_types.Tensor], float]] = (average_precision_fitness,),
        objectives: Sequence[Callable[[float, float], bool]] = (maximize,),
        allowed_ops: Sequence[Type[OperatorNode]] = (MEAN, MIN, MAX, WEIGHTED_MEAN, FAR_THRESHOLD, CLOSE_THRESHOLD),
        callbacks: Iterable[Callback] = tuple(),
        backend: Union[str, None] = None,
        seed: int = 0,
        postprocessing_function=None,
    ):
        """
        Initialize the Okapi evolutionary algorithm.

        Args:
            preds_source: Source of model predictions, can be a path to directory or iterable of paths
            gt_path: Path to ground truth data, can be a single path or iterable of paths. Should match preds_source by order
            population_size: Size of the population to evolve
            population_multiplier: Factor determining how many additional trees to generate
            tournament_size: Number of trees to consider in tournament selection
            minimize_node_count: Whether the pareto frontier models should also consider node count.
            objective_functions: Functions that calculate the fitnesses that are to be optimized
            objectives: Functions that copare two fitnesses and return True if first is better than second. Usually maximize or minimize
            allowed_ops: Sequence of operator node types that can be used in trees
            callbacks: Iterable of callback objects for monitoring/modifying evolution
            backend: Optional backend implementation for tensor operations
            seed: Random seed for reproducibility
            postprocessing_function: Function applied after each Op Node.
            Most of the operations may break some data characteristics, for example vector summing to one. This can be used to fix that.
        """
        # Validate parameters
        if population_size <= 0:
            raise ValueError(f"population_size must be positive, got {population_size}")
        if population_multiplier <= 0:
            raise ValueError(f"population_multiplier must be positive, got {population_multiplier}")
        if tournament_size <= 0:
            raise ValueError(f"tournament_size must be positive, got {tournament_size}")
        if tournament_size > population_size:
            raise ValueError(
                f"tournament_size ({tournament_size}) cannot be larger than population_size ({population_size})"
            )

        if backend is not None:
            Backend.set_backend(backend)
        if seed is not None:
            np.random.seed(seed)
        if postprocessing_function:
            set_postprocessing_function(postprocessing_function)

        self.population_size = population_size
        self.population_multiplier = population_multiplier
        self.tournament_size = tournament_size
        self.minimize_node_count = minimize_node_count
        self.seed = seed

        self.objective_functions = objective_functions
        self.objectives = objectives
        assert len(objectives) == len(objective_functions), "The number of (optimization) objectives and objective functions is not the same"
        self.optimal_point = _get_optimal_point_based_on_list_of_objective_functions(self.objectives)

        self.callbacks = callbacks
        self.allowed_ops = allowed_ops

        self.train_tensors, self.gt_tensor = self._build_train_tensors(preds_source, gt_path)
        self.ids, self.models = list(self.train_tensors.keys()), list(self.train_tensors.values())
        self._validate_input()

        # state
        self.should_stop = False

        self.population = self._initialize_population()
        self.additional_population: List[Tree] = []  # for potential callbacks
        self.fitnesses: None | npt.NDArray[np.float64] = None

    def _call_hook(self, hook_name):
        """
        Call a specific hook on all registered callbacks.

        Args:
            hook_name: Name of the hook to call
        """
        for callback in self.callbacks:
            getattr(callback, hook_name)(self)

    def _initialize_population(self):
        """
        Initialize the population of trees.

        Creates simple trees using available prediction tensors.

        Returns:
            List of initialized Tree objects
        """
        logger.info(f"Initializing population with size {self.population_size}")
        population = initialize_individuals(self.train_tensors, self.population_size)
        logger.debug(f"Population initialized with {len(population)} individuals")
        return population

    def _calculate_fitnesses(self, trees: None | List[Tree] = None) -> npt.NDArray[np.float64]:
        """
        Calculate fitness values for the given trees.

        Args:
            trees: List of trees to evaluate. If None, uses the current population.

        Returns:
            NumPy array of fitness values
        """
        if trees is None:
            trees = self.population
        logger.debug(f"Calculating fitness for {len(trees)} trees")
        fitnesses = np.zeros(shape=(len(trees), len(self.objective_functions)))
        for ix, objective_function in enumerate(self.objective_functions):
            fitnesses[:, ix] = np.array([objective_function(tree, self.gt_tensor) for tree in trees])
        return fitnesses

    def run_iteration(self):
        """
        Run a single iteration of the evolutionary algorithm.

        This method:
        1. Calculates fitness values for the current population
        2. Performs tournament selection and crossover to create new trees
        3. Applies mutations to some of the new trees
        4. Removes duplicate trees from the population
        """
        logger.info("Starting evolution iteration")
        if self.fitnesses is None:
            self.fitnesses = self._calculate_fitnesses(self.population).round(
                3
            )  # this generally unnecessarily happens again > probably not with the if

        logger.debug("Performing tournament selection and crossover")
        assert self.fitnesses.shape[0] == len(self.population)
        crossover_count = self._perform_crossovers(self.fitnesses)
        assert self.fitnesses.shape[0] == len(self.population)
        logger.debug(f"Performed {crossover_count} crossover operations")

        logger.debug("Applying mutations")
        mutation_count = self._mutate_additional_population()
        assert self.fitnesses.shape[0] == len(self.population)
        logger.info(f"Applied {mutation_count} mutations")

        joined_population = np.array(self.population + self.additional_population)  # maybe worth it to calculated fitnesses first?
        codes = np.array([tree.__repr__() for tree in joined_population])
        mask = first_uniques_mask(codes)
        self.population = list(joined_population[mask])
        self.fitnesses = self._calculate_fitnesses(self.population).round(3)
        assert self.fitnesses.shape[0] == len(self.population)

        logger.debug(f"Removed {len(joined_population) - sum(mask)} duplicate trees")
        logger.debug(f"New population size: {len(self.population)}")

        self.population, self.fitnesses = choose_pareto_then_proximity(
            self.population, self.fitnesses, self.objectives, self.population_size, self.minimize_node_count
        )

        assert self.fitnesses.shape[0] == len(self.population)

        self.additional_population = []

    def _perform_crossovers(self, fitnesses: npt.NDArray[np.float64]):
        crossover_count = 0
        while len(self.additional_population) < (self.population_multiplier * self.population_size):
            idx1, idx2 = tournament_selection_indexes(fitnesses, self.tournament_size, self.optimal_point)
            parent_1, parent_2 = self.population[idx1], self.population[idx2]
            new_tree_1, new_tree_2 = crossover(parent_1, parent_2)
            self.additional_population += [new_tree_1, new_tree_2]
            crossover_count += 1
        return crossover_count

    def _mutate_additional_population(self) -> int:
        mutation_count = 0
        for tree in self.additional_population:
            mutation_chance = np.random.rand()
            if mutation_chance < tree.mutation_chance:
                allowed_mutations = np.array(get_allowed_mutations(tree))
                chosen_mutation = np.random.choice(allowed_mutations)
                logger.trace(f"Applying mutation: {chosen_mutation.__name__}")
                mutated_tree = chosen_mutation(
                    tree,
                    models=self.models,
                    ids=self.ids,
                    allowed_ops=self.allowed_ops,
                )
                self.additional_population.append(mutated_tree)
                mutation_count += 1
        return mutation_count

    def train(self, iterations: int):
        """
        Run the evolutionary algorithm for a specified number of iterations.

        Args:
            iterations: Number of evolution iterations to run
        """
        logger.info(f"Starting evolution with {iterations} iterations")
        self._call_hook("on_evolution_start")

        for i in tqdm.tqdm(range(iterations)):
            logger.info(f"Generation {i + 1}/{iterations}")
            self._call_hook("on_generation_start")  # possibly move to run_iteration instead
            self.run_iteration()
            self._call_hook("on_generation_end")

            if self.should_stop:
                logger.info("Early stopping triggered")
                break

        logger.info("Evolution complete")
        self._call_hook("on_evolution_end")

    def _build_train_tensors(self, preds_source, gt_path):
        """
        Load prediction tensors and ground truth from files.

        Args:
            preds_source: Source of model predictions (path or iterable of paths)
            gt_path: Path to ground truth data

        Returns:
            Tuple of (train_tensors dictionary, ground truth tensor)
        """
        logger.info("Loading prediction tensors and ground truth")
        tensor_paths = []
        if isinstance(preds_source, str):
            preds_source = Path(preds_source)
        if isinstance(preds_source, Path):
            logger.debug(f"Scanning directory for tensors: {preds_source}")
            tensor_paths = list(preds_source.glob("*"))
        elif hasattr(preds_source, "__iter__"):
            marked_paths, all_same = mark_paths(preds_source)
            if all_same:
                if marked_paths[0] == "dir":
                    for pred_source in preds_source:
                        pred_source = Path(pred_source)
                        tensor_paths += list(pred_source.glob("*"))
                elif marked_paths[0] == "file":
                    tensor_paths = list(preds_source)
            else:
                raise ValueError(
                    "preds source must be either path to directory with predictions,"
                    " list of paths to directories with predictions, or list of paths to predictions"
                )

        train_tensors = {}
        for tensor_path in tensor_paths:
            logger.debug(f"Loading tensor: {tensor_path}")
            tensor_id = Path(tensor_path).name
            if tensor_id not in train_tensors:
                train_tensors[tensor_id] = B.load(tensor_path, device=DEVICE)
            else:
                train_tensors[tensor_id] = B.concat([train_tensors[tensor_id], B.load(tensor_path, device=DEVICE)])

        logger.debug(f"Loaded {len(train_tensors)} prediction tensors")
        logger.debug(f"Loading ground truth from: {gt_path}")

        gt_tensor: None | Tensor = None
        if isinstance(gt_path, str):
            gt_path = Path(gt_path)
        if isinstance(gt_path, Path):
            if os.path.isdir(gt_path):
                for path in gt_path.glob("*"):
                    if gt_tensor is None:
                        gt_tensor = B.load(path, device=DEVICE)
                    else:
                        gt_tensor = B.concat([gt_tensor, B.load(path, device=DEVICE)])  # type: ignore
            else:
                gt_tensor = B.load(gt_path, device=DEVICE)
        elif hasattr(gt_path, "__iter__"):
            for path in gt_path:
                if gt_tensor is None:
                    gt_tensor = B.load(path, device=DEVICE)
                else:
                    gt_tensor = B.concat([gt_tensor, B.load(path, device=DEVICE)])  # type: ignore
        else:
            raise ValueError(f"{gt_path} is not valid for loading gt")

        logger.info("Tensors loaded successfully")
        return train_tensors, gt_tensor

    def _validate_input(self, fix_swapped=True):  # no way to change this argument for now TODO
        """
        Validate that all input tensors have compatible shapes.

        Checks if all prediction tensors have the same shape and if the ground truth
        tensor has a compatible shape. Can optionally fix swapped dimensions in the
        ground truth tensor.

        Args:
            fix_swapped: If True, attempts to fix swapped dimensions in ground truth tensor

        Raises:
            ValueError: If tensor shapes are incompatible and cannot be fixed
        """
        logger.info("Validating input tensors")
        # check if all tensors have the same shape
        shapes = [B.shape(tensor) for tensor in self.train_tensors.values()]

        if len(set(shapes)) > 1:
            logger.error(f"Tensors have different shapes: {shapes}")
            raise ValueError(f"Tensors have different shapes: {shapes}")

        logger.debug(f"All prediction tensors have shape: {shapes[0]}")
        logger.debug(f"Ground truth tensor has shape: {B.shape(self.gt_tensor)}")

        if B.shape(self.gt_tensor) != shapes[0]:
            gt_shape = B.shape(self.gt_tensor)
            if len(shapes[0]) > 1 and (len(gt_shape) == 1 or gt_shape[-1] == 1):
                pass
            elif fix_swapped:
                if (shapes[0] == B.shape(self.gt_tensor)[::-1]) and (len(shapes[0]) == 2):
                    logger.warning(f"Ground truth tensor dimensions appear to be swapped. Reshaping from {B.shape(self.gt_tensor)} to {shapes[0]}")
                    self.gt_tensor = B.reshape(self.gt_tensor, shapes[0])
                    logger.info("Tensor shapes fixed successfully")
                else:
                    logger.error(f"Ground truth tensor shape {B.shape(self.gt_tensor)} incompatible with prediction tensor shape {shapes[0]}")
                    raise ValueError(f"Ground truth tensor has incompatible shape: {B.shape(self.gt_tensor)} vs {shapes[0]}")
            else:
                logger.error(f"Ground truth tensor shape {B.shape(self.gt_tensor)} does not match prediction tensor shape {shapes[0]}")
                raise ValueError(f"Ground truth tensor has different shape than input tensors: {shapes[0]} != {B.shape(self.gt_tensor)}")

        logger.info("Input validation successful")

    @property
    def pareto_trees(self) -> List[Tree]:
        assert isinstance(self.fitnesses, np.ndarray), "Fitnesses not yet initialized. Did you run any iteration?"
        all_pareto_trees, _ = choose_pareto(self.population, self.fitnesses, len(self.population), self.objectives, self.minimize_node_count)
        return all_pareto_trees

    @property
    def pareto_fitnesses(self) -> np.ndarray:
        assert isinstance(self.fitnesses, np.ndarray), "Fitnesses not yet initialized. Did you run any iteration?"
        _, pareto_fitnesses = choose_pareto(self.population, self.fitnesses, len(self.population), self.objectives, self.minimize_node_count)
        return pareto_fitnesses

__init__(preds_source, gt_path, population_size, population_multiplier, tournament_size, minimize_node_count=True, objective_functions=(average_precision_fitness,), objectives=(maximize,), allowed_ops=(MEAN, MIN, MAX, WEIGHTED_MEAN, FAR_THRESHOLD, CLOSE_THRESHOLD), callbacks=tuple(), backend=None, seed=0, postprocessing_function=None)

Initialize the Okapi evolutionary algorithm.

Parameters:

Name Type Description Default
preds_source Union[Path, str, Iterable[Path], Iterable[str]]

Source of model predictions, can be a path to directory or iterable of paths

required
gt_path Union[Path, str, Iterable[Path], Iterable[str]]

Path to ground truth data, can be a single path or iterable of paths. Should match preds_source by order

required
population_size int

Size of the population to evolve

required
population_multiplier int

Factor determining how many additional trees to generate

required
tournament_size int

Number of trees to consider in tournament selection

required
minimize_node_count bool

Whether the pareto frontier models should also consider node count.

True
objective_functions Sequence[Callable[[Tree, Tensor], float]]

Functions that calculate the fitnesses that are to be optimized

(average_precision_fitness,)
objectives Sequence[Callable[[float, float], bool]]

Functions that copare two fitnesses and return True if first is better than second. Usually maximize or minimize

(maximize,)
allowed_ops Sequence[Type[OperatorNode]]

Sequence of operator node types that can be used in trees

(MEAN, MIN, MAX, WEIGHTED_MEAN, FAR_THRESHOLD, CLOSE_THRESHOLD)
callbacks Iterable[Callback]

Iterable of callback objects for monitoring/modifying evolution

tuple()
backend Union[str, None]

Optional backend implementation for tensor operations

None
seed int

Random seed for reproducibility

0
postprocessing_function

Function applied after each Op Node.

None
Source code in okapi/okapi.py
def __init__(
    self,
    preds_source: Union[Path, str, Iterable[Path], Iterable[str]],
    gt_path: Union[Path, str, Iterable[Path], Iterable[str]],
    population_size: int,
    population_multiplier: int,
    tournament_size: int,
    minimize_node_count: bool = True,
    objective_functions: Sequence[Callable[[Tree, lib_types.Tensor], float]] = (average_precision_fitness,),
    objectives: Sequence[Callable[[float, float], bool]] = (maximize,),
    allowed_ops: Sequence[Type[OperatorNode]] = (MEAN, MIN, MAX, WEIGHTED_MEAN, FAR_THRESHOLD, CLOSE_THRESHOLD),
    callbacks: Iterable[Callback] = tuple(),
    backend: Union[str, None] = None,
    seed: int = 0,
    postprocessing_function=None,
):
    """
    Initialize the Okapi evolutionary algorithm.

    Args:
        preds_source: Source of model predictions, can be a path to directory or iterable of paths
        gt_path: Path to ground truth data, can be a single path or iterable of paths. Should match preds_source by order
        population_size: Size of the population to evolve
        population_multiplier: Factor determining how many additional trees to generate
        tournament_size: Number of trees to consider in tournament selection
        minimize_node_count: Whether the pareto frontier models should also consider node count.
        objective_functions: Functions that calculate the fitnesses that are to be optimized
        objectives: Functions that copare two fitnesses and return True if first is better than second. Usually maximize or minimize
        allowed_ops: Sequence of operator node types that can be used in trees
        callbacks: Iterable of callback objects for monitoring/modifying evolution
        backend: Optional backend implementation for tensor operations
        seed: Random seed for reproducibility
        postprocessing_function: Function applied after each Op Node.
        Most of the operations may break some data characteristics, for example vector summing to one. This can be used to fix that.
    """
    # Validate parameters
    if population_size <= 0:
        raise ValueError(f"population_size must be positive, got {population_size}")
    if population_multiplier <= 0:
        raise ValueError(f"population_multiplier must be positive, got {population_multiplier}")
    if tournament_size <= 0:
        raise ValueError(f"tournament_size must be positive, got {tournament_size}")
    if tournament_size > population_size:
        raise ValueError(
            f"tournament_size ({tournament_size}) cannot be larger than population_size ({population_size})"
        )

    if backend is not None:
        Backend.set_backend(backend)
    if seed is not None:
        np.random.seed(seed)
    if postprocessing_function:
        set_postprocessing_function(postprocessing_function)

    self.population_size = population_size
    self.population_multiplier = population_multiplier
    self.tournament_size = tournament_size
    self.minimize_node_count = minimize_node_count
    self.seed = seed

    self.objective_functions = objective_functions
    self.objectives = objectives
    assert len(objectives) == len(objective_functions), "The number of (optimization) objectives and objective functions is not the same"
    self.optimal_point = _get_optimal_point_based_on_list_of_objective_functions(self.objectives)

    self.callbacks = callbacks
    self.allowed_ops = allowed_ops

    self.train_tensors, self.gt_tensor = self._build_train_tensors(preds_source, gt_path)
    self.ids, self.models = list(self.train_tensors.keys()), list(self.train_tensors.values())
    self._validate_input()

    # state
    self.should_stop = False

    self.population = self._initialize_population()
    self.additional_population: List[Tree] = []  # for potential callbacks
    self.fitnesses: None | npt.NDArray[np.float64] = None

run_iteration()

Run a single iteration of the evolutionary algorithm.

This method: 1. Calculates fitness values for the current population 2. Performs tournament selection and crossover to create new trees 3. Applies mutations to some of the new trees 4. Removes duplicate trees from the population

Source code in okapi/okapi.py
def run_iteration(self):
    """
    Run a single iteration of the evolutionary algorithm.

    This method:
    1. Calculates fitness values for the current population
    2. Performs tournament selection and crossover to create new trees
    3. Applies mutations to some of the new trees
    4. Removes duplicate trees from the population
    """
    logger.info("Starting evolution iteration")
    if self.fitnesses is None:
        self.fitnesses = self._calculate_fitnesses(self.population).round(
            3
        )  # this generally unnecessarily happens again > probably not with the if

    logger.debug("Performing tournament selection and crossover")
    assert self.fitnesses.shape[0] == len(self.population)
    crossover_count = self._perform_crossovers(self.fitnesses)
    assert self.fitnesses.shape[0] == len(self.population)
    logger.debug(f"Performed {crossover_count} crossover operations")

    logger.debug("Applying mutations")
    mutation_count = self._mutate_additional_population()
    assert self.fitnesses.shape[0] == len(self.population)
    logger.info(f"Applied {mutation_count} mutations")

    joined_population = np.array(self.population + self.additional_population)  # maybe worth it to calculated fitnesses first?
    codes = np.array([tree.__repr__() for tree in joined_population])
    mask = first_uniques_mask(codes)
    self.population = list(joined_population[mask])
    self.fitnesses = self._calculate_fitnesses(self.population).round(3)
    assert self.fitnesses.shape[0] == len(self.population)

    logger.debug(f"Removed {len(joined_population) - sum(mask)} duplicate trees")
    logger.debug(f"New population size: {len(self.population)}")

    self.population, self.fitnesses = choose_pareto_then_proximity(
        self.population, self.fitnesses, self.objectives, self.population_size, self.minimize_node_count
    )

    assert self.fitnesses.shape[0] == len(self.population)

    self.additional_population = []

train(iterations)

Run the evolutionary algorithm for a specified number of iterations.

Parameters:

Name Type Description Default
iterations int

Number of evolution iterations to run

required
Source code in okapi/okapi.py
def train(self, iterations: int):
    """
    Run the evolutionary algorithm for a specified number of iterations.

    Args:
        iterations: Number of evolution iterations to run
    """
    logger.info(f"Starting evolution with {iterations} iterations")
    self._call_hook("on_evolution_start")

    for i in tqdm.tqdm(range(iterations)):
        logger.info(f"Generation {i + 1}/{iterations}")
        self._call_hook("on_generation_start")  # possibly move to run_iteration instead
        self.run_iteration()
        self._call_hook("on_generation_end")

        if self.should_stop:
            logger.info("Early stopping triggered")
            break

    logger.info("Evolution complete")
    self._call_hook("on_evolution_end")

Tree Class

The Tree class represents a computational tree structure for model ensemble composition.

from okapi.tree import Tree

Represents a computational tree structure for model ensemble composition.

The Tree class is a central component in OKAPI, representing a hierarchical structure of nodes that define how different models are combined. Each tree has a ValueNode as its root, and may contain multiple ValueNodes and OperatorNodes arranged in a tree structure.

ValueNodes contain tensor data (model predictions), while OperatorNodes define operations to combine these predictions (such as mean, min, max, weighted mean). The tree's evaluation produces a combined prediction by recursively applying these operations.

Trees can be manipulated through various operations like pruning, appending, and replacing nodes, making them suitable for evolutionary algorithms where trees evolve over generations.

Attributes:

Name Type Description
root

The root node of the tree (must be a ValueNode)

nodes dict[str, list]

Dictionary containing lists of all value nodes and operator nodes in the tree

mutation_chance

Probability of mutation for this tree during evolution

Source code in okapi/tree.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
class Tree:
    """
    Represents a computational tree structure for model ensemble composition.

    The Tree class is a central component in OKAPI, representing a hierarchical structure
    of nodes that define how different models are combined. Each tree has a ValueNode as its root,
    and may contain multiple ValueNodes and OperatorNodes arranged in a tree structure.

    ValueNodes contain tensor data (model predictions), while OperatorNodes define operations
    to combine these predictions (such as mean, min, max, weighted mean). The tree's evaluation
    produces a combined prediction by recursively applying these operations.

    Trees can be manipulated through various operations like pruning, appending, and replacing
    nodes, making them suitable for evolutionary algorithms where trees evolve over generations.

    Attributes:
        root: The root node of the tree (must be a ValueNode)
        nodes: Dictionary containing lists of all value nodes and operator nodes in the tree
        mutation_chance: Probability of mutation for this tree during evolution
    """

    def __init__(self, root: ValueNode, mutation_chance=0.1):
        self.root = root
        logger.debug(f"Creating new tree with root: {root}")

        if isinstance(self.root, OperatorNode):
            logger.error("Cannot initialize tree with OperatorNode as root")
            raise Exception("Cannot get evaluation of tree with OpNode as root")

        self.nodes: dict[str, list] = {"value_nodes": [], "op_nodes": []}
        self.mutation_chance = mutation_chance
        self.update_nodes()
        logger.trace(f"Tree initialized with {len(self.nodes['value_nodes'])} value nodes and {len(self.nodes['op_nodes'])} operator nodes")

    def update_nodes(self):
        """
        Update the internal collections of nodes in the tree.

        This method traverses the tree and categorizes all nodes into value nodes and operator nodes,
        updating the internal `nodes` dictionary.
        """
        logger.debug("Updating tree node collections")
        self.nodes = {"value_nodes": [], "op_nodes": []}
        root_nodes = self.root.get_nodes()
        for node in root_nodes:
            if isinstance(node, ValueNode):
                self.nodes["value_nodes"].append(node)
            else:
                self.nodes["op_nodes"].append(node)
        logger.trace(f"Updated nodes: {len(self.nodes['value_nodes'])} value nodes, {len(self.nodes['op_nodes'])} operator nodes")

    @staticmethod
    def create_tree_from_root(root: ValueNode, mutation_chance=0.1):
        """
        Create a new tree with the given root node.

        Args:
            root: The ValueNode to use as the root of the new tree
            mutation_chance: Probability of mutation for the new tree

        Returns:
            A new Tree instance
        """
        logger.debug(f"Creating tree from root node with mutation chance: {mutation_chance}")
        tree = Tree(root, mutation_chance)
        return tree

    @property
    def evaluation(self):
        """
        Calculate and return the evaluation of the tree.

        The evaluation is the result of recursively applying all operations
        in the tree, starting from the root node.

        Returns:
            The tensor resulting from evaluating the tree
        """
        # WARNING: This may not make sense for cases other than binary classification (Squeezing)
        # return B.squeeze(self.root.evaluation if self.root.evaluation is not None else self.root.calculate())
        return self.root.evaluation if self.root.evaluation is not None else self.root.calculate()

    @property
    def nodes_count(self):
        """
        Count the total number of nodes in the tree.

        Returns:
            The sum of value nodes and operator nodes
        """
        return len(self.nodes["value_nodes"]) + len(self.nodes["op_nodes"])

    def _clean_evals(self):
        """
        Reset the cached evaluation results for all value nodes in the tree.

        This forces recalculation of node evaluations when the tree structure changes.
        """
        logger.debug("Clearing cached evaluations for all value nodes")
        for node in self.nodes["value_nodes"]:
            node.evaluation = None

    def _clean_values_and_evals(self):
        for value_node in self.nodes["value_nodes"]:
            value_node.value = value_node.evaluation = None

    def recalculate(self):
        """
        Force recalculation of the tree evaluation.

        This method clears any cached evaluations and triggers a fresh calculation.
        It also updates the nodes dictionary

        Returns:
            The newly calculated evaluation of the tree
        """
        logger.debug("Recalculating tree evaluation")
        self._clean_evals()
        self.update_nodes()
        evaluation = self.evaluation
        logger.trace("Tree recalculation complete")
        return evaluation

    def copy(self):
        """
        Create a deep copy of the tree.

        Returns:
            A new Tree instance that is a deep copy of the current tree
        """
        logger.debug("Creating deep copy of tree")
        root_copy: ValueNode = cast(ValueNode, self.root.copy_subtree())
        return Tree.create_tree_from_root(root_copy)

    def prune_at(self, node: Node) -> Node:
        """
        Remove a node and its subtree from the tree.

        This method removes the specified node and all its descendants from the tree.
        If the node is the only child of an operator node, that operator node will
        also be pruned.

        Args:
            node: The node to prune from the tree

        Returns:
            The pruned node (which is no longer part of the tree). If parent was pruned, the parent will be returned.

        Raises:
            ValueError: If the node is not found in the tree or if attempting to prune the root node
        """
        logger.debug(f"Pruning node from tree: {node}")

        if node not in self.nodes["value_nodes"] and node not in self.nodes["op_nodes"]:
            logger.error(f"Attempted to prune node not in tree: {node}")
            raise ValueError("Node not found in tree")

        if node.parent is None:
            logger.error("Cannot prune root node")
            raise ValueError("Cannot prune root node")

        if isinstance(node.parent, OperatorNode) and (
            len(node.parent.children) < 2
        ):  # if only child of op node is to be pruned, remove the parent instead
            logger.debug(f"Node is the only child of operator node, pruning parent: {node.parent}")
            return self.prune_at(node.parent)

        subtree_nodes = node.get_nodes()
        node_count = len(subtree_nodes)

        logger.debug(f"Removing {node_count} nodes in subtree")
        for subtree_node in subtree_nodes:
            if isinstance(subtree_node, ValueNode):
                self.nodes["value_nodes"].remove(subtree_node)
            else:
                self.nodes["op_nodes"].remove(subtree_node)

        node.parent.remove_child(node)
        logger.debug("Pruning complete, clearing cached evaluations")
        self._clean_evals()
        return node

    def append_after(self, node: Node, new_node: Node):
        """
        Append a new node as a child of an existing node.

        The new node must be of a different type than the existing node
        (i.e., value nodes can only append operator nodes and vice versa).

        Args:
            node: The existing node to which the new node will be appended
            new_node: The new node to append

        Raises:
            ValueError: If the node is not found in the tree or if attempting to append
                       a node of the same type
        """
        logger.debug(f"Appending node {new_node} after {node}")

        if node not in self.nodes["value_nodes"] and node not in self.nodes["op_nodes"]:
            logger.error(f"Attempted to append to node not in tree: {node}")
            raise ValueError("Node not found in tree")

        if check_if_both_types_same_node_variant(type(node), type(new_node)):
            logger.error(f"Cannot append node of same type: {type(node).__name__} and {type(new_node).__name__}")
            raise ValueError("Cannot append node of the same type")

        subtree_nodes = new_node.get_nodes()
        logger.debug(f"Adding {len(subtree_nodes)} nodes from subtree")

        for subtree_node in subtree_nodes:
            if isinstance(subtree_node, ValueNode):
                self.nodes["value_nodes"].append(subtree_node)
            else:
                self.nodes["op_nodes"].append(subtree_node)

        node.add_child(new_node)
        logger.debug("Append complete, clearing cached evaluations")
        self._clean_evals()

    def replace_at(self, at: Node, replacement: Node) -> Self:
        """
        Replace a node in the tree with another node.

        The replacement node must be of the same type as the node being replaced.
        This operation preserves the parent-child relationships.

        Args:
            at: The node to be replaced
            replacement: The new node that will replace the existing node

        Returns:
            Self reference to allow method chaining

        Raises:
            AssertionError: If the replacement node is not of the same type as the node being replaced
        """
        assert (isinstance(replacement, ValueNode) and isinstance(at, ValueNode)) or (
            isinstance(replacement, OperatorNode) and isinstance(at, OperatorNode)
        ), "Replacement node must be of the same parent type (ValueNode or OperatorNode) as the node being replaced"
        at_parent = at.parent

        if at_parent is None:
            assert isinstance(self.root, ValueNode), "Root must be a value node"
            assert isinstance(replacement, ValueNode), "Replacement for root must be a value node"
            logger.warning("Node at replacement is root node")
            self.root = replacement
        else:
            at_parent.replace_child(at, replacement)

        if isinstance(at, ValueNode):
            self.nodes["value_nodes"].remove(at)
            self.nodes["value_nodes"].append(replacement)
        else:
            self.nodes["op_nodes"].remove(at)
            self.nodes["op_nodes"].append(replacement)

        self._clean_evals()
        return self

    def get_random_node(self, nodes_type: str | None = None, allow_root=True, allow_leaves=True):
        """
        Get a random node from the tree based on specified constraints.

        Args:
            nodes_type: Optional type of nodes to consider ('value_nodes' or 'op_nodes')
                       If None, a random type will be chosen
            allow_root: Whether to allow selecting the root node
            allow_leaves: Whether to allow selecting leaf nodes

        Returns:
            A randomly selected node that satisfies the constraints

        Raises:
            ValueError: If no node satisfying the constraints is found
        """
        if self.root.children == []:
            if allow_root:
                if nodes_type is None or nodes_type == "value_nodes":
                    return self.root
                else:
                    raise ValueError("Tree has only root node and nodes_type is not value_nodes")
            else:
                raise ValueError("Tree has only root node and allow_root is set to False")

        if nodes_type is not None:
            assert nodes_type in ("value_nodes", "op_nodes"), f'Unsupported node type "{nodes_type}" selected.'
            nodes_types = [
                nodes_type,
            ]
        else:
            nodes_types = list(np.random.permutation(["op_nodes", "value_nodes"]))

        for nodes_type in nodes_types:
            assert nodes_type is not None, "Nodes type cannot be None"
            order = np.arange(len(self.nodes[nodes_type]))
            for i in order:
                node = self.nodes[nodes_type][i]
                if (allow_leaves or node.children != []) and (allow_root or node != self.root):
                    return node
        raise ValueError("No node found that complies to the constraints")

    @property
    def unique_value_node_ids(self):
        """
        Get the unique IDs of all value nodes in the tree.

        Returns:
            A list of unique IDs from all value nodes
        """
        return list(set([node.id for node in self.nodes["value_nodes"]]))

    @property
    def value_nodes(self):
        return self.nodes["value_nodes"]

    @property
    def op_nodes(self):
        return self.nodes["op_nodes"]

    def save_tree_architecture(self, output_path):  # TODO: needs adjustment for weighted node
        """
        Save the tree's architecture to a file.

        This method creates a copy of the tree with tensor values removed
        and saves it to the specified path using pickle serialization.

        Args:
            output_path: Path where the tree architecture will be saved
        """
        logger.info(f"Saving tree architecture to {output_path}")
        copy_tree = self.copy()
        copy_tree._clean_values_and_evals()

        Pickle.save(output_path, copy_tree)
        logger.debug("Tree architecture saved successfully")

    @staticmethod
    def load_tree_architecture(architecture_path) -> "Tree":  # TODO: needs adjusted for weighted node
        """
        Load a tree architecture from a file.

        Args:
            architecture_path: Path to the saved tree architecture file

        Returns:
            The loaded Tree object without tensor values
        """
        logger.info(f"Loading tree architecture from {architecture_path}")
        tree = Pickle.load(architecture_path)
        logger.debug("Tree architecture loaded successfully")
        return tree

    def _load_tensors_from_path(self, preds_directory):
        current_tensors = {}
        preds_directory = Path(preds_directory)
        for value_node in self.nodes["value_nodes"]:
            node_id = value_node.id
            if node_id not in current_tensors:
                logger.debug(f"Loading tensor for node ID: {node_id}")
                current_tensors[node_id] = B.load(preds_directory / str(node_id), DEVICE)
            else:
                logger.trace(f"Using pre-loaded tensor for node ID: {node_id}")
        return current_tensors

    def _load_tensors_to_tree(self, preds_directory, current_tensors):
        if preds_directory is not None:
            preds_directory = Path(preds_directory)
            loaded_tensors = self._load_tensors_from_path(preds_directory)
            current_tensors.update(loaded_tensors)
        for value_node in self.nodes["value_nodes"]:
            node_id = value_node.id
            value_node.value = current_tensors[node_id]
        return current_tensors

    def do_pred_on_another_tensors(
        self, preds_directory: None | str | Path = None, current_tensors: None | dict = None, return_tree=False
    ) -> tuple[None | Tensor, "Tree"] | Tensor:
        assert not all(
            [current_tensors is not None, preds_directory is not None]
        ), "Either preds directory or current tensors needs to be set, not both"
        assert any(
            [current_tensors is not None, preds_directory is not None]
        ), "Either preds directory or current tensors needs to be set, none was set"

        current_tensors = {}
        copy_tree = self.copy()
        copy_tree._clean_values_and_evals()
        current_tensors = copy_tree._load_tensors_to_tree(preds_directory, current_tensors)
        if return_tree:
            return copy_tree.evaluation, copy_tree

        return copy_tree.evaluation

    @staticmethod
    def load_tree(architecture_path, preds_directory, tensors=None) -> Tuple["Tree", dict]:
        """
        Load a complete tree with tensor values from files.

        This method loads a tree architecture and then loads the associated tensor
        values for each value node from the specified directory.

        Args:
            architecture_path: Path to the saved tree architecture file
            preds_directory: Directory containing the tensor files
            tensors: Optional dictionary of pre-loaded tensors

        Returns:
            A tuple containing:
            - The loaded Tree object with tensor values
            - A dictionary of all tensors used in the tree
        """
        if tensors is None:
            tensors = {}

        logger.info(f"Loading complete tree from {architecture_path} with tensors from {preds_directory}")
        logger.debug(f"Starting with {len(tensors)} pre-loaded tensors")

        current_tensors = {}
        current_tensors.update(tensors)  # tensors argument is mutable and we do not want to modify it

        loaded = Tree.load_tree_architecture(architecture_path)
        current_tensors = loaded._load_tensors_to_tree(preds_directory, current_tensors)

        logger.info(
            f"Tree loaded successfully with {len(loaded.nodes['value_nodes'])} value nodes and {len(loaded.nodes['op_nodes'])} operator nodes"
        )
        return loaded, current_tensors

    def __repr__(self):
        """
        Get a string representation of the tree.

        Returns:
            A string representation formed by concatenating the code of all nodes
        """
        return "_".join(node.code for node in self.root.get_nodes())

evaluation property

Calculate and return the evaluation of the tree.

The evaluation is the result of recursively applying all operations in the tree, starting from the root node.

Returns:

Type Description

The tensor resulting from evaluating the tree

nodes_count property

Count the total number of nodes in the tree.

Returns:

Type Description

The sum of value nodes and operator nodes

unique_value_node_ids property

Get the unique IDs of all value nodes in the tree.

Returns:

Type Description

A list of unique IDs from all value nodes

__repr__()

Get a string representation of the tree.

Returns:

Type Description

A string representation formed by concatenating the code of all nodes

Source code in okapi/tree.py
def __repr__(self):
    """
    Get a string representation of the tree.

    Returns:
        A string representation formed by concatenating the code of all nodes
    """
    return "_".join(node.code for node in self.root.get_nodes())

append_after(node, new_node)

Append a new node as a child of an existing node.

The new node must be of a different type than the existing node (i.e., value nodes can only append operator nodes and vice versa).

Parameters:

Name Type Description Default
node Node

The existing node to which the new node will be appended

required
new_node Node

The new node to append

required

Raises:

Type Description
ValueError

If the node is not found in the tree or if attempting to append a node of the same type

Source code in okapi/tree.py
def append_after(self, node: Node, new_node: Node):
    """
    Append a new node as a child of an existing node.

    The new node must be of a different type than the existing node
    (i.e., value nodes can only append operator nodes and vice versa).

    Args:
        node: The existing node to which the new node will be appended
        new_node: The new node to append

    Raises:
        ValueError: If the node is not found in the tree or if attempting to append
                   a node of the same type
    """
    logger.debug(f"Appending node {new_node} after {node}")

    if node not in self.nodes["value_nodes"] and node not in self.nodes["op_nodes"]:
        logger.error(f"Attempted to append to node not in tree: {node}")
        raise ValueError("Node not found in tree")

    if check_if_both_types_same_node_variant(type(node), type(new_node)):
        logger.error(f"Cannot append node of same type: {type(node).__name__} and {type(new_node).__name__}")
        raise ValueError("Cannot append node of the same type")

    subtree_nodes = new_node.get_nodes()
    logger.debug(f"Adding {len(subtree_nodes)} nodes from subtree")

    for subtree_node in subtree_nodes:
        if isinstance(subtree_node, ValueNode):
            self.nodes["value_nodes"].append(subtree_node)
        else:
            self.nodes["op_nodes"].append(subtree_node)

    node.add_child(new_node)
    logger.debug("Append complete, clearing cached evaluations")
    self._clean_evals()

copy()

Create a deep copy of the tree.

Returns:

Type Description

A new Tree instance that is a deep copy of the current tree

Source code in okapi/tree.py
def copy(self):
    """
    Create a deep copy of the tree.

    Returns:
        A new Tree instance that is a deep copy of the current tree
    """
    logger.debug("Creating deep copy of tree")
    root_copy: ValueNode = cast(ValueNode, self.root.copy_subtree())
    return Tree.create_tree_from_root(root_copy)

create_tree_from_root(root, mutation_chance=0.1) staticmethod

Create a new tree with the given root node.

Parameters:

Name Type Description Default
root ValueNode

The ValueNode to use as the root of the new tree

required
mutation_chance

Probability of mutation for the new tree

0.1

Returns:

Type Description

A new Tree instance

Source code in okapi/tree.py
@staticmethod
def create_tree_from_root(root: ValueNode, mutation_chance=0.1):
    """
    Create a new tree with the given root node.

    Args:
        root: The ValueNode to use as the root of the new tree
        mutation_chance: Probability of mutation for the new tree

    Returns:
        A new Tree instance
    """
    logger.debug(f"Creating tree from root node with mutation chance: {mutation_chance}")
    tree = Tree(root, mutation_chance)
    return tree

get_random_node(nodes_type=None, allow_root=True, allow_leaves=True)

Get a random node from the tree based on specified constraints.

Parameters:

Name Type Description Default
nodes_type str | None

Optional type of nodes to consider ('value_nodes' or 'op_nodes') If None, a random type will be chosen

None
allow_root

Whether to allow selecting the root node

True
allow_leaves

Whether to allow selecting leaf nodes

True

Returns:

Type Description

A randomly selected node that satisfies the constraints

Raises:

Type Description
ValueError

If no node satisfying the constraints is found

Source code in okapi/tree.py
def get_random_node(self, nodes_type: str | None = None, allow_root=True, allow_leaves=True):
    """
    Get a random node from the tree based on specified constraints.

    Args:
        nodes_type: Optional type of nodes to consider ('value_nodes' or 'op_nodes')
                   If None, a random type will be chosen
        allow_root: Whether to allow selecting the root node
        allow_leaves: Whether to allow selecting leaf nodes

    Returns:
        A randomly selected node that satisfies the constraints

    Raises:
        ValueError: If no node satisfying the constraints is found
    """
    if self.root.children == []:
        if allow_root:
            if nodes_type is None or nodes_type == "value_nodes":
                return self.root
            else:
                raise ValueError("Tree has only root node and nodes_type is not value_nodes")
        else:
            raise ValueError("Tree has only root node and allow_root is set to False")

    if nodes_type is not None:
        assert nodes_type in ("value_nodes", "op_nodes"), f'Unsupported node type "{nodes_type}" selected.'
        nodes_types = [
            nodes_type,
        ]
    else:
        nodes_types = list(np.random.permutation(["op_nodes", "value_nodes"]))

    for nodes_type in nodes_types:
        assert nodes_type is not None, "Nodes type cannot be None"
        order = np.arange(len(self.nodes[nodes_type]))
        for i in order:
            node = self.nodes[nodes_type][i]
            if (allow_leaves or node.children != []) and (allow_root or node != self.root):
                return node
    raise ValueError("No node found that complies to the constraints")

load_tree(architecture_path, preds_directory, tensors=None) staticmethod

Load a complete tree with tensor values from files.

This method loads a tree architecture and then loads the associated tensor values for each value node from the specified directory.

Parameters:

Name Type Description Default
architecture_path

Path to the saved tree architecture file

required
preds_directory

Directory containing the tensor files

required
tensors

Optional dictionary of pre-loaded tensors

None

Returns:

Type Description
Tree

A tuple containing:

dict
  • The loaded Tree object with tensor values
Tuple[Tree, dict]
  • A dictionary of all tensors used in the tree
Source code in okapi/tree.py
@staticmethod
def load_tree(architecture_path, preds_directory, tensors=None) -> Tuple["Tree", dict]:
    """
    Load a complete tree with tensor values from files.

    This method loads a tree architecture and then loads the associated tensor
    values for each value node from the specified directory.

    Args:
        architecture_path: Path to the saved tree architecture file
        preds_directory: Directory containing the tensor files
        tensors: Optional dictionary of pre-loaded tensors

    Returns:
        A tuple containing:
        - The loaded Tree object with tensor values
        - A dictionary of all tensors used in the tree
    """
    if tensors is None:
        tensors = {}

    logger.info(f"Loading complete tree from {architecture_path} with tensors from {preds_directory}")
    logger.debug(f"Starting with {len(tensors)} pre-loaded tensors")

    current_tensors = {}
    current_tensors.update(tensors)  # tensors argument is mutable and we do not want to modify it

    loaded = Tree.load_tree_architecture(architecture_path)
    current_tensors = loaded._load_tensors_to_tree(preds_directory, current_tensors)

    logger.info(
        f"Tree loaded successfully with {len(loaded.nodes['value_nodes'])} value nodes and {len(loaded.nodes['op_nodes'])} operator nodes"
    )
    return loaded, current_tensors

load_tree_architecture(architecture_path) staticmethod

Load a tree architecture from a file.

Parameters:

Name Type Description Default
architecture_path

Path to the saved tree architecture file

required

Returns:

Type Description
Tree

The loaded Tree object without tensor values

Source code in okapi/tree.py
@staticmethod
def load_tree_architecture(architecture_path) -> "Tree":  # TODO: needs adjusted for weighted node
    """
    Load a tree architecture from a file.

    Args:
        architecture_path: Path to the saved tree architecture file

    Returns:
        The loaded Tree object without tensor values
    """
    logger.info(f"Loading tree architecture from {architecture_path}")
    tree = Pickle.load(architecture_path)
    logger.debug("Tree architecture loaded successfully")
    return tree

prune_at(node)

Remove a node and its subtree from the tree.

This method removes the specified node and all its descendants from the tree. If the node is the only child of an operator node, that operator node will also be pruned.

Parameters:

Name Type Description Default
node Node

The node to prune from the tree

required

Returns:

Type Description
Node

The pruned node (which is no longer part of the tree). If parent was pruned, the parent will be returned.

Raises:

Type Description
ValueError

If the node is not found in the tree or if attempting to prune the root node

Source code in okapi/tree.py
def prune_at(self, node: Node) -> Node:
    """
    Remove a node and its subtree from the tree.

    This method removes the specified node and all its descendants from the tree.
    If the node is the only child of an operator node, that operator node will
    also be pruned.

    Args:
        node: The node to prune from the tree

    Returns:
        The pruned node (which is no longer part of the tree). If parent was pruned, the parent will be returned.

    Raises:
        ValueError: If the node is not found in the tree or if attempting to prune the root node
    """
    logger.debug(f"Pruning node from tree: {node}")

    if node not in self.nodes["value_nodes"] and node not in self.nodes["op_nodes"]:
        logger.error(f"Attempted to prune node not in tree: {node}")
        raise ValueError("Node not found in tree")

    if node.parent is None:
        logger.error("Cannot prune root node")
        raise ValueError("Cannot prune root node")

    if isinstance(node.parent, OperatorNode) and (
        len(node.parent.children) < 2
    ):  # if only child of op node is to be pruned, remove the parent instead
        logger.debug(f"Node is the only child of operator node, pruning parent: {node.parent}")
        return self.prune_at(node.parent)

    subtree_nodes = node.get_nodes()
    node_count = len(subtree_nodes)

    logger.debug(f"Removing {node_count} nodes in subtree")
    for subtree_node in subtree_nodes:
        if isinstance(subtree_node, ValueNode):
            self.nodes["value_nodes"].remove(subtree_node)
        else:
            self.nodes["op_nodes"].remove(subtree_node)

    node.parent.remove_child(node)
    logger.debug("Pruning complete, clearing cached evaluations")
    self._clean_evals()
    return node

recalculate()

Force recalculation of the tree evaluation.

This method clears any cached evaluations and triggers a fresh calculation. It also updates the nodes dictionary

Returns:

Type Description

The newly calculated evaluation of the tree

Source code in okapi/tree.py
def recalculate(self):
    """
    Force recalculation of the tree evaluation.

    This method clears any cached evaluations and triggers a fresh calculation.
    It also updates the nodes dictionary

    Returns:
        The newly calculated evaluation of the tree
    """
    logger.debug("Recalculating tree evaluation")
    self._clean_evals()
    self.update_nodes()
    evaluation = self.evaluation
    logger.trace("Tree recalculation complete")
    return evaluation

replace_at(at, replacement)

Replace a node in the tree with another node.

The replacement node must be of the same type as the node being replaced. This operation preserves the parent-child relationships.

Parameters:

Name Type Description Default
at Node

The node to be replaced

required
replacement Node

The new node that will replace the existing node

required

Returns:

Type Description
Self

Self reference to allow method chaining

Raises:

Type Description
AssertionError

If the replacement node is not of the same type as the node being replaced

Source code in okapi/tree.py
def replace_at(self, at: Node, replacement: Node) -> Self:
    """
    Replace a node in the tree with another node.

    The replacement node must be of the same type as the node being replaced.
    This operation preserves the parent-child relationships.

    Args:
        at: The node to be replaced
        replacement: The new node that will replace the existing node

    Returns:
        Self reference to allow method chaining

    Raises:
        AssertionError: If the replacement node is not of the same type as the node being replaced
    """
    assert (isinstance(replacement, ValueNode) and isinstance(at, ValueNode)) or (
        isinstance(replacement, OperatorNode) and isinstance(at, OperatorNode)
    ), "Replacement node must be of the same parent type (ValueNode or OperatorNode) as the node being replaced"
    at_parent = at.parent

    if at_parent is None:
        assert isinstance(self.root, ValueNode), "Root must be a value node"
        assert isinstance(replacement, ValueNode), "Replacement for root must be a value node"
        logger.warning("Node at replacement is root node")
        self.root = replacement
    else:
        at_parent.replace_child(at, replacement)

    if isinstance(at, ValueNode):
        self.nodes["value_nodes"].remove(at)
        self.nodes["value_nodes"].append(replacement)
    else:
        self.nodes["op_nodes"].remove(at)
        self.nodes["op_nodes"].append(replacement)

    self._clean_evals()
    return self

save_tree_architecture(output_path)

Save the tree's architecture to a file.

This method creates a copy of the tree with tensor values removed and saves it to the specified path using pickle serialization.

Parameters:

Name Type Description Default
output_path

Path where the tree architecture will be saved

required
Source code in okapi/tree.py
def save_tree_architecture(self, output_path):  # TODO: needs adjustment for weighted node
    """
    Save the tree's architecture to a file.

    This method creates a copy of the tree with tensor values removed
    and saves it to the specified path using pickle serialization.

    Args:
        output_path: Path where the tree architecture will be saved
    """
    logger.info(f"Saving tree architecture to {output_path}")
    copy_tree = self.copy()
    copy_tree._clean_values_and_evals()

    Pickle.save(output_path, copy_tree)
    logger.debug("Tree architecture saved successfully")

update_nodes()

Update the internal collections of nodes in the tree.

This method traverses the tree and categorizes all nodes into value nodes and operator nodes, updating the internal nodes dictionary.

Source code in okapi/tree.py
def update_nodes(self):
    """
    Update the internal collections of nodes in the tree.

    This method traverses the tree and categorizes all nodes into value nodes and operator nodes,
    updating the internal `nodes` dictionary.
    """
    logger.debug("Updating tree node collections")
    self.nodes = {"value_nodes": [], "op_nodes": []}
    root_nodes = self.root.get_nodes()
    for node in root_nodes:
        if isinstance(node, ValueNode):
            self.nodes["value_nodes"].append(node)
        else:
            self.nodes["op_nodes"].append(node)
    logger.trace(f"Updated nodes: {len(self.nodes['value_nodes'])} value nodes, {len(self.nodes['op_nodes'])} operator nodes")