import os
import tempfile
import unittest
import numpy as np
import pandas as pd
from nnfwtbn.variable import Variable, RangeBlindingStrategy
[docs]class VariableTestCase(unittest.TestCase):
"""
Test the implementation of the variable class.
"""
[docs] def test_init_store(self):
"""
Check that all arguments are stored in the object.
"""
blinding = RangeBlindingStrategy(100, 125)
variable = Variable("MMC", "ditau_mmc_mlm_m", "GeV", blinding)
self.assertEqual(variable.name, "MMC")
self.assertIsNotNone(variable.definition)
self.assertEqual(variable.unit, "GeV")
self.assertEqual(variable.blinding, blinding)
[docs] def test_init_definition_string(self):
"""
Check that a string used as the variable definition is wrapped into a
lambda.
"""
variable = Variable("MMC", "ditau_mmc_mlm_m", "GeV")
self.assertTrue(callable(variable.definition))
[docs] def test_init_blinding_type(self):
"""
Check that an error is thrown if the blinding object is not an
instance of the abstract blinding class.
"""
self.assertRaises(TypeError, "MMC", "ditau_mmc_mlm_m", "GeV", "blind")
[docs] def test_repr(self):
"""
Check that the string representation contains the name of the
variable.
"""
variable = Variable("MMC", "ditau_mmc_mlm_m", "GeV")
self.assertEqual(repr(variable), "<Variable 'MMC' [GeV]>")
variable = Variable(r"$\Delta \eta$",
lambda df: df.jet_0_eta - df.jet_1_eta)
self.assertEqual(repr(variable), r"<Variable '$\\Delta \\eta$'>")
[docs] def test_equal_same_values(self):
"""
Check the equal operator for variables which are created with the same values.
"""
variable1 = Variable("name", "branch", "unit")
variable2 = Variable("name", "branch", "unit")
self.assertTrue(variable1 == variable2)
self.assertTrue(variable2 == variable1)
[docs] def test_equal_different_name(self):
"""
Check that variables with a different name are not equal.
"""
variable1 = Variable("name1", "branch", "unit")
variable2 = Variable("name2", "branch", "unit")
self.assertFalse(variable1 == variable2)
self.assertFalse(variable2 == variable1)
[docs] def test_equal_different_definition(self):
"""
Check that variables with a different definition are not equal.
"""
variable1 = Variable("name", "branch1", "unit")
variable2 = Variable("name", "branch2", "unit")
self.assertFalse(variable1 == variable2)
self.assertFalse(variable2 == variable1)
[docs] def test_equal_different_unit(self):
"""
Check that variables with a different unit are not equal.
"""
variable1 = Variable("name", "branch", "unit1")
variable2 = Variable("name", "branch", "unit2")
self.assertFalse(variable1 == variable2)
self.assertFalse(variable2 == variable1)
# TODO implement test case for different blinding strategies
# def test_equal_different_blinding(self):
# pass
[docs] def generate_df(self):
"""
Generate a toy dataframe.
"""
return pd.DataFrame({
"x": np.arange(5),
"y": np.arange(5)**2
})
[docs] def test_call_column(self):
"""
Check that calling the variable extracts the given column name.
"""
df = self.generate_df()
variable = Variable("$y$", "y")
y_col = variable(df)
self.assertListEqual(list(y_col), [0, 1, 4, 9, 16])
[docs] def test_call_lambda(self):
"""
Check that calling the variable called the given lambda.
"""
df = self.generate_df()
variable = Variable("$x + y$", lambda d: d.x + d.y)
sum = variable(df)
self.assertListEqual(list(sum), [0, 2, 6, 12, 20])
[docs] def test_saving_and_loading(self):
"""
Test that saving and loading a variable doesn't change the variable.
"""
variable1 = Variable("MMC", "ditau_mmc_mlm_m", "GeV")
fd, path = tempfile.mkstemp()
try:
variable1.save_to_h5(path, "variable")
variable2 = Variable.load_from_h5(path, "variable")
finally:
# close file descriptor and delete file
os.close(fd)
os.remove(path)
self.assertTrue(variable1 == variable2)
[docs]class RangeBlindingTestCase(unittest.TestCase):
"""
Test the implementation of the RangeBlinding class.
"""
[docs] def generate_df(self):
"""
Returns a toy dataframe.
"""
return pd.DataFrame({
"ditau_mmc_mlm_m": np.linspace(0, 400, 400),
"x": np.linspace(0, 1, 400),
})
[docs] def test_init_store(self):
"""
Check that the constructor stores all arguments.
"""
blinding = RangeBlindingStrategy(100, 125)
self.assertEqual(blinding.start, 100)
self.assertEqual(blinding.end, 125)
[docs] def test_event_blinding(self):
"""
Check that events in the given region are removed.
"""
blinding_strategy = RangeBlindingStrategy(100, 125)
variable = Variable("MMC", "ditau_mmc_mlm_m")
df = self.generate_df()
blinding = blinding_strategy(variable, bins=30, range=(50, 200))
blinded_df = blinding(df)
# All events outside
self.assertTrue((
(blinded_df.ditau_mmc_mlm_m < 100) |
(blinded_df.ditau_mmc_mlm_m > 125)).all())
# No events inside
self.assertFalse((
(blinded_df.ditau_mmc_mlm_m > 100) &
(blinded_df.ditau_mmc_mlm_m < 125)).any())
# Boundary not enlarged
self.assertTrue((
(blinded_df.ditau_mmc_mlm_m > 100) &
(blinded_df.ditau_mmc_mlm_m < 130)).any())
[docs] def test_bin_border(self):
"""
Check that the blind range is extended to match the bin borders.
"""
blinding_strategy = RangeBlindingStrategy(100, 125)
variable = Variable("MMC", "ditau_mmc_mlm_m")
df = self.generate_df()
blinding = blinding_strategy(variable, bins=15, range=(50, 200))
blinded_df = blinding(df)
# All events outside
self.assertTrue((
(blinded_df.ditau_mmc_mlm_m < 100) |
(blinded_df.ditau_mmc_mlm_m > 130)).all())
# No events inside
self.assertFalse((
(blinded_df.ditau_mmc_mlm_m > 100) &
(blinded_df.ditau_mmc_mlm_m < 130)).any())
[docs] def test_bin_border_left(self):
"""
Check that the blinding does not break if the blinding is left of the
first bin.
"""
blinding_strategy = RangeBlindingStrategy(10, 125)
variable = Variable("MMC", "ditau_mmc_mlm_m")
df = self.generate_df()
blinding = blinding_strategy(variable, bins=15, range=(50, 200))
blinded_df = blinding(df)
# All events outside
self.assertTrue((
(blinded_df.ditau_mmc_mlm_m < 10) |
(blinded_df.ditau_mmc_mlm_m > 130)).all())
# No events inside
self.assertFalse((
(blinded_df.ditau_mmc_mlm_m > 10) &
(blinded_df.ditau_mmc_mlm_m < 130)).any())
[docs] def test_bin_border_right(self):
"""
Check that the blinding does not break if the blinding is left of the
first bin.
"""
blinding_strategy = RangeBlindingStrategy(100, 225)
variable = Variable("MMC", "ditau_mmc_mlm_m")
df = self.generate_df()
blinding = blinding_strategy(variable, bins=15, range=(50, 200))
blinded_df = blinding(df)
# All events outside
self.assertTrue((
(blinded_df.ditau_mmc_mlm_m < 100) |
(blinded_df.ditau_mmc_mlm_m > 225)).all())
# No events inside
self.assertFalse((
(blinded_df.ditau_mmc_mlm_m > 100) &
(blinded_df.ditau_mmc_mlm_m < 225)).any())