Source code for nnfwtbn.interface

"""
This module provides classes to interface between classifiers from other
frameworks.
"""

from abc import ABC, abstractmethod

from lxml import etree
import numpy as np

[docs]class Classifier(ABC): """ Abstract classifier train with another framework and loaded into nnfwtbn. """
[docs] @abstractmethod def predict(dataframe): """ Returns an array with the predicted values. """
[docs]class TmvaBdt(Classifier): """ Experimental class to use BDT's from TMVA. The class has the following limitations:. - The XML file must contain exactly one classifier. - The boosting method must be AdaBoost. - Fisher cuts cannot be used. """
[docs] def __init__(self, filename): """ Loads the BDT from an XML file. """ with open(filename) as xml_file: xml = etree.parse(xml_file) # Checks against unsupported features boost_type = xml.xpath("//Option[@name='BoostType']")[0].text if boost_type != "AdaBoost": raise Exception("Cannot handle boost type %r." % boost_type) fisher_cuts = xml.xpath("//Option[@name='UseFisherCuts']")[0].text if fisher_cuts != "False": raise Exception("Cannot handle Fisher cuts.") self.xml = xml
[docs] def predict(self, dataframe): """ Evaluate the BDT on the given dataframe. The method returns an array with the BDT scores. """ # Prepare input variables variables = {int(_.get("VarIndex")): dataframe[_.get("Expression")] for _ in self.xml.xpath("//Variable")} # Prepare result array response = np.zeros(len(dataframe)) sum_weights = 0 # Loop over trees trees = self.xml.xpath("//BinaryTree") for tree in trees: tree_weight = float(tree.get("boostWeight")) sum_weights += tree_weight # Loop over terminal notes of tree leafs = tree.xpath(".//Node[@nType!=0]") for leaf in leafs: ancestors = leaf.xpath("ancestor::Node") mask = np.ones(len(dataframe), dtype='bool') # Trace path from root to leaf and record surviving events for node, next_node in zip(ancestors, ancestors[1:] + [leaf]): variable = variables[int(node.get("IVar"))] cut = float(node.get("Cut")) cut_type = int(node.get("cType")) next_type = {"l": 0, "r": 1}[next_node.get("pos")] # Actual evaluation of node cut mask &= (next_type == cut_type) ^ (variable < cut) leaf_type = int(leaf.get("nType")) # Record prediction of tree response[mask] += tree_weight * leaf_type return response / sum_weights