"""
This module provides classes to interface between classifiers from other
frameworks.
"""
from abc import ABC, abstractmethod
from lxml import etree
import numpy as np
[docs]class Classifier(ABC):
"""
Abstract classifier train with another framework and loaded into nnfwtbn.
"""
[docs] @abstractmethod
def predict(dataframe):
"""
Returns an array with the predicted values.
"""
[docs]class TmvaBdt(Classifier):
"""
Experimental class to use BDT's from TMVA. The class has the following
limitations:.
- The XML file must contain exactly one classifier.
- The boosting method must be AdaBoost.
- Fisher cuts cannot be used.
"""
[docs] def __init__(self, filename):
"""
Loads the BDT from an XML file.
"""
with open(filename) as xml_file:
xml = etree.parse(xml_file)
# Checks against unsupported features
boost_type = xml.xpath("//Option[@name='BoostType']")[0].text
if boost_type != "AdaBoost":
raise Exception("Cannot handle boost type %r." % boost_type)
fisher_cuts = xml.xpath("//Option[@name='UseFisherCuts']")[0].text
if fisher_cuts != "False":
raise Exception("Cannot handle Fisher cuts.")
self.xml = xml
[docs] def predict(self, dataframe):
"""
Evaluate the BDT on the given dataframe. The method returns an array
with the BDT scores.
"""
# Prepare input variables
variables = {int(_.get("VarIndex")): dataframe[_.get("Expression")]
for _ in self.xml.xpath("//Variable")}
# Prepare result array
response = np.zeros(len(dataframe))
sum_weights = 0
# Loop over trees
trees = self.xml.xpath("//BinaryTree")
for tree in trees:
tree_weight = float(tree.get("boostWeight"))
sum_weights += tree_weight
# Loop over terminal notes of tree
leafs = tree.xpath(".//Node[@nType!=0]")
for leaf in leafs:
ancestors = leaf.xpath("ancestor::Node")
mask = np.ones(len(dataframe), dtype='bool')
# Trace path from root to leaf and record surviving events
for node, next_node in zip(ancestors, ancestors[1:] + [leaf]):
variable = variables[int(node.get("IVar"))]
cut = float(node.get("Cut"))
cut_type = int(node.get("cType"))
next_type = {"l": 0, "r": 1}[next_node.get("pos")]
# Actual evaluation of node cut
mask &= (next_type == cut_type) ^ (variable < cut)
leaf_type = int(leaf.get("nType"))
# Record prediction of tree
response[mask] += tree_weight * leaf_type
return response / sum_weights