All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
LArAdaBoostDecisionTree.cc
Go to the documentation of this file.
1 /**
2  * @file larpandoracontent/LArObjects/LArAdaBoostDecisionTree.cc
3  *
4  * @brief Implementation of the lar adaptive boost decision tree class.
5  *
6  * $Log: $
7  */
8 
9 #include "Helpers/XmlHelper.h"
10 
12 
13 using namespace pandora;
14 
15 namespace lar_content
16 {
17 
18 AdaBoostDecisionTree::AdaBoostDecisionTree() : m_pStrongClassifier(nullptr)
19 {
20 }
21 
22 //------------------------------------------------------------------------------------------------------------------------------------------
23 
25 {
27 }
28 
29 //------------------------------------------------------------------------------------------------------------------------------------------
30 
32 {
33  if (this != &rhs)
35 
36  return *this;
37 }
38 
39 //------------------------------------------------------------------------------------------------------------------------------------------
40 
42 {
43  delete m_pStrongClassifier;
44 }
45 
46 //------------------------------------------------------------------------------------------------------------------------------------------
47 
48 StatusCode AdaBoostDecisionTree::Initialize(const std::string &bdtXmlFileName, const std::string &bdtName)
49 {
51  {
52  std::cout << "AdaBoostDecisionTree: AdaBoostDecisionTree was already initialized" << std::endl;
53  return STATUS_CODE_ALREADY_INITIALIZED;
54  }
55 
56  TiXmlDocument xmlDocument(bdtXmlFileName);
57 
58  if (!xmlDocument.LoadFile())
59  {
60  std::cout << "AdaBoostDecisionTree::Initialize - Invalid xml file." << std::endl;
61  return STATUS_CODE_INVALID_PARAMETER;
62  }
63 
64  const TiXmlHandle xmlDocumentHandle(&xmlDocument);
65  TiXmlNode *pContainerXmlNode(TiXmlHandle(xmlDocumentHandle).FirstChildElement().Element());
66 
67  while (pContainerXmlNode)
68  {
69  if (pContainerXmlNode->ValueStr() != "AdaBoostDecisionTree")
70  return STATUS_CODE_FAILURE;
71 
72  const TiXmlHandle currentHandle(pContainerXmlNode);
73 
74  std::string currentName;
75  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(currentHandle, "Name", currentName));
76 
77  if (currentName.empty() || (currentName.size() > 1000))
78  {
79  std::cout << "AdaBoostDecisionTree::Initialize - Implausible AdaBoostDecisionTree name extracted from xml." << std::endl;
80  return STATUS_CODE_INVALID_PARAMETER;
81  }
82 
83  if (currentName == bdtName)
84  break;
85 
86  pContainerXmlNode = pContainerXmlNode->NextSibling();
87  }
88 
89  if (!pContainerXmlNode)
90  {
91  std::cout << "AdaBoostDecisionTree: Could not find an AdaBoostDecisionTree of name " << bdtName << std::endl;
92  return STATUS_CODE_NOT_FOUND;
93  }
94 
95  const TiXmlHandle xmlHandle(pContainerXmlNode);
96 
97  try
98  {
99  m_pStrongClassifier = new StrongClassifier(&xmlHandle);
100  }
101  catch (StatusCodeException &statusCodeException)
102  {
103  delete m_pStrongClassifier;
104 
105  if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.GetStatusCode())
106  std::cout << "AdaBoostDecisionTree: Initialization failure, unknown component in xml file." << std::endl;
107 
108  if (STATUS_CODE_FAILURE == statusCodeException.GetStatusCode())
109  std::cout << "AdaBoostDecisionTree: Node definition does not contain expected leaf or branch variables." << std::endl;
110 
111  return statusCodeException.GetStatusCode();
112  }
113 
114  return STATUS_CODE_SUCCESS;
115 }
116 
117 //------------------------------------------------------------------------------------------------------------------------------------------
118 
120 {
121  return ((this->CalculateScore(features) > 0.) ? true : false);
122 }
123 
124 //------------------------------------------------------------------------------------------------------------------------------------------
125 
127 {
128  return this->CalculateScore(features);
129 }
130 
131 //------------------------------------------------------------------------------------------------------------------------------------------
132 
134 {
135  // ATTN: BDT score, once normalised by total weight, is confined to the range -1 to +1. This linear mapping places the score in the
136  // range 0 to 1 so that it may be interpreted as a probability.
137  return (this->CalculateScore(features) + 1.) * 0.5;
138 }
139 
140 //------------------------------------------------------------------------------------------------------------------------------------------
141 
143 {
144  if (!m_pStrongClassifier)
145  {
146  std::cout << "AdaBoostDecisionTree: Attempting to use an uninitialized bdt" << std::endl;
147  throw StatusCodeException(STATUS_CODE_NOT_INITIALIZED);
148  }
149 
150  try
151  {
152  // TODO: Add consistency check for number of features, bearing in mind not all features in a bdt may be used
153  return m_pStrongClassifier->Predict(features);
154  }
155  catch (StatusCodeException &statusCodeException)
156  {
157  if (STATUS_CODE_NOT_FOUND == statusCodeException.GetStatusCode())
158  {
159  std::cout << "AdaBoostDecisionTree: Caught exception thrown when trying to cut on an unknown variable." << std::endl;
160  }
161  else if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.GetStatusCode())
162  {
163  std::cout << "AdaBoostDecisionTree: Caught exception thrown when classifier weights sum to zero indicating defunct classifier."
164  << std::endl;
165  }
166  else if (STATUS_CODE_OUT_OF_RANGE == statusCodeException.GetStatusCode())
167  {
168  std::cout << "AdaBoostDecisionTree: Caught exception thrown when heirarchy in decision tree is incomplete." << std::endl;
169  }
170  else
171  {
172  std::cout << "AdaBoostDecisionTree: Unexpected exception thrown." << std::endl;
173  }
174 
175  throw statusCodeException;
176  }
177 }
178 
179 //------------------------------------------------------------------------------------------------------------------------------------------
180 //------------------------------------------------------------------------------------------------------------------------------------------
181 
182 AdaBoostDecisionTree::Node::Node(const TiXmlHandle *const pXmlHandle) :
183  m_nodeId(0),
184  m_parentNodeId(0),
185  m_leftChildNodeId(0),
186  m_rightChildNodeId(0),
187  m_isLeaf(false),
188  m_threshold(0.),
189  m_variableId(0),
190  m_outcome(false)
191 {
192  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "NodeId", m_nodeId));
193  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "ParentNodeId", m_parentNodeId));
194 
195  const StatusCode leftChildNodeIdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "LeftChildNodeId", m_leftChildNodeId));
196  const StatusCode rightChildNodeIdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "RightChildNodeId", m_rightChildNodeId));
197  const StatusCode thresholdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "Threshold", m_threshold));
198  const StatusCode variableIdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "VariableId", m_variableId));
199  const StatusCode outcomeStatusCode(XmlHelper::ReadValue(*pXmlHandle, "Outcome", m_outcome));
200 
201  if (STATUS_CODE_SUCCESS == leftChildNodeIdStatusCode || STATUS_CODE_SUCCESS == rightChildNodeIdStatusCode ||
202  STATUS_CODE_SUCCESS == thresholdStatusCode || STATUS_CODE_SUCCESS == variableIdStatusCode)
203  {
204  m_isLeaf = false;
205  m_outcome = false;
206  }
207  else if (outcomeStatusCode == STATUS_CODE_SUCCESS)
208  {
209  m_isLeaf = true;
210  m_leftChildNodeId = std::numeric_limits<int>::max();
211  m_rightChildNodeId = std::numeric_limits<int>::max();
212  m_threshold = std::numeric_limits<double>::max();
213  m_variableId = std::numeric_limits<int>::max();
214  }
215  else
216  {
217  throw StatusCodeException(STATUS_CODE_FAILURE);
218  }
219 }
220 
221 //------------------------------------------------------------------------------------------------------------------------------------------
222 
224  m_nodeId(rhs.m_nodeId),
225  m_parentNodeId(rhs.m_parentNodeId),
226  m_leftChildNodeId(rhs.m_leftChildNodeId),
227  m_rightChildNodeId(rhs.m_rightChildNodeId),
228  m_isLeaf(rhs.m_isLeaf),
229  m_threshold(rhs.m_threshold),
230  m_variableId(rhs.m_variableId),
231  m_outcome(rhs.m_outcome)
232 {
233 }
234 
235 //------------------------------------------------------------------------------------------------------------------------------------------
236 
238 {
239  if (this != &rhs)
240  {
241  m_nodeId = rhs.m_nodeId;
242  m_parentNodeId = rhs.m_parentNodeId;
243  m_leftChildNodeId = rhs.m_leftChildNodeId;
244  m_rightChildNodeId = rhs.m_rightChildNodeId;
245  m_isLeaf = rhs.m_isLeaf;
246  m_threshold = rhs.m_threshold;
247  m_variableId = rhs.m_variableId;
248  m_outcome = rhs.m_outcome;
249  }
250 
251  return *this;
252 }
253 
254 //------------------------------------------------------------------------------------------------------------------------------------------
255 
257 {
258 }
259 
260 //------------------------------------------------------------------------------------------------------------------------------------------
261 //------------------------------------------------------------------------------------------------------------------------------------------
262 
263 AdaBoostDecisionTree::WeakClassifier::WeakClassifier(const TiXmlHandle *const pXmlHandle) : m_weight(0.), m_treeId(0)
264 {
265  for (TiXmlElement *pHeadTiXmlElement = pXmlHandle->FirstChildElement().ToElement(); pHeadTiXmlElement != NULL;
266  pHeadTiXmlElement = pHeadTiXmlElement->NextSiblingElement())
267  {
268  if ("TreeIndex" == pHeadTiXmlElement->ValueStr())
269  {
270  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "TreeIndex", m_treeId));
271  }
272  else if ("TreeWeight" == pHeadTiXmlElement->ValueStr())
273  {
274  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "TreeWeight", m_weight));
275  }
276  else if ("Node" == pHeadTiXmlElement->ValueStr())
277  {
278  const TiXmlHandle nodeHandle(pHeadTiXmlElement);
279  const Node *pNode = new Node(&nodeHandle);
280  m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->GetNodeId(), pNode));
281  }
282  }
283 }
284 
285 //------------------------------------------------------------------------------------------------------------------------------------------
286 
287 AdaBoostDecisionTree::WeakClassifier::WeakClassifier(const WeakClassifier &rhs) : m_weight(rhs.m_weight), m_treeId(rhs.m_treeId)
288 {
289  for (const auto &mapEntry : rhs.m_idToNodeMap)
290  {
291  const Node *pNode = new Node(*(mapEntry.second));
292  m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->GetNodeId(), pNode));
293  }
294 }
295 
296 //------------------------------------------------------------------------------------------------------------------------------------------
297 
299 {
300  if (this != &rhs)
301  {
302  for (const auto &mapEntry : rhs.m_idToNodeMap)
303  {
304  const Node *pNode = new Node(*(mapEntry.second));
305  m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->GetNodeId(), pNode));
306  }
307 
308  m_weight = rhs.m_weight;
309  m_treeId = rhs.m_treeId;
310  }
311 
312  return *this;
313 }
314 
315 //------------------------------------------------------------------------------------------------------------------------------------------
316 
318 {
319  for (const auto &mapEntry : m_idToNodeMap)
320  delete mapEntry.second;
321 }
322 
323 //------------------------------------------------------------------------------------------------------------------------------------------
324 
326 {
327  return this->EvaluateNode(0, features);
328 }
329 
330 //------------------------------------------------------------------------------------------------------------------------------------------
331 
333 {
334  const Node *pActiveNode(nullptr);
335 
336  if (m_idToNodeMap.find(nodeId) != m_idToNodeMap.end())
337  {
338  pActiveNode = m_idToNodeMap.at(nodeId);
339  }
340  else
341  {
342  throw StatusCodeException(STATUS_CODE_OUT_OF_RANGE);
343  }
344 
345  if (pActiveNode->IsLeaf())
346  return pActiveNode->GetOutcome();
347 
348  if (static_cast<int>(features.size()) <= pActiveNode->GetVariableId())
349  throw StatusCodeException(STATUS_CODE_NOT_FOUND);
350 
351  if (features.at(pActiveNode->GetVariableId()).Get() <= pActiveNode->GetThreshold())
352  {
353  return this->EvaluateNode(pActiveNode->GetLeftChildNodeId(), features);
354  }
355  else
356  {
357  return this->EvaluateNode(pActiveNode->GetRightChildNodeId(), features);
358  }
359 }
360 
361 //------------------------------------------------------------------------------------------------------------------------------------------
362 //------------------------------------------------------------------------------------------------------------------------------------------
363 
364 AdaBoostDecisionTree::StrongClassifier::StrongClassifier(const TiXmlHandle *const pXmlHandle)
365 {
366  TiXmlElement *pCurrentXmlElement = pXmlHandle->FirstChild().Element();
367 
368  while (pCurrentXmlElement)
369  {
370  if (STATUS_CODE_SUCCESS != this->ReadComponent(pCurrentXmlElement))
371  throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
372 
373  pCurrentXmlElement = pCurrentXmlElement->NextSiblingElement();
374  }
375 }
376 
377 //------------------------------------------------------------------------------------------------------------------------------------------
378 
380 {
381  for (const WeakClassifier *const pWeakClassifier : rhs.m_weakClassifiers)
382  m_weakClassifiers.emplace_back(new WeakClassifier(*pWeakClassifier));
383 }
384 
385 //------------------------------------------------------------------------------------------------------------------------------------------
386 
388 {
389  if (this != &rhs)
390  {
391  for (const WeakClassifier *const pWeakClassifier : rhs.m_weakClassifiers)
392  m_weakClassifiers.emplace_back(new WeakClassifier(*pWeakClassifier));
393  }
394 
395  return *this;
396 }
397 
398 //------------------------------------------------------------------------------------------------------------------------------------------
399 
401 {
402  for (const WeakClassifier *const pWeakClassifier : m_weakClassifiers)
403  delete pWeakClassifier;
404 }
405 
406 //------------------------------------------------------------------------------------------------------------------------------------------
407 
409 {
410  double score(0.), weights(0.);
411 
412  for (const WeakClassifier *const pWeakClassifier : m_weakClassifiers)
413  {
414  weights += pWeakClassifier->GetWeight();
415 
416  if (pWeakClassifier->Predict(features))
417  {
418  score += pWeakClassifier->GetWeight();
419  }
420  else
421  {
422  score -= pWeakClassifier->GetWeight();
423  }
424  }
425 
426  if (weights > std::numeric_limits<double>::epsilon())
427  {
428  score /= weights;
429  }
430  else
431  {
432  throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
433  }
434 
435  return score;
436 }
437 
438 //------------------------------------------------------------------------------------------------------------------------------------------
439 
440 StatusCode AdaBoostDecisionTree::StrongClassifier::ReadComponent(TiXmlElement *pCurrentXmlElement)
441 {
442  const std::string componentName(pCurrentXmlElement->ValueStr());
443  TiXmlHandle currentHandle(pCurrentXmlElement);
444 
445  if ((std::string("Name") == componentName) || (std::string("Timestamp") == componentName))
446  return STATUS_CODE_SUCCESS;
447 
448  if (std::string("DecisionTree") == componentName)
449  {
450  m_weakClassifiers.emplace_back(new WeakClassifier(&currentHandle));
451  return STATUS_CODE_SUCCESS;
452  }
453 
454  return STATUS_CODE_INVALID_PARAMETER;
455 }
456 
457 } // namespace lar_content
WeakClassifiers m_weakClassifiers
Vector of weak classifers.
WeakClassifier & operator=(const WeakClassifier &rhs)
Assignment operator.
int GetLeftChildNodeId() const
Return left child node id.
int GetVariableId() const
Return cut variable.
MvaTypes::MvaFeatureVector MvaFeatureVector
Definition: LArMvaHelper.h:72
WeakClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
BEGIN_PROLOG or score(default)}sbnd_crttrackmatchingalg_crID
bool Classify(const LArMvaHelper::MvaFeatureVector &features) const
Classify the set of input features based on the trained model.
WeakClassifier class containing a decision tree and a weight.
double m_threshold
Threshold used for decision if decision node.
double Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
double GetThreshold() const
Return node threshold.
StrongClassifier class used in application of adaptive boost decision tree.
Header file for the lar adaptive boosted decision tree class.
Node & operator=(const Node &rhs)
Assignment operator.
bool EvaluateNode(const int nodeId, const LArMvaHelper::MvaFeatureVector &features) const
Evalute node and return outcome.
Node class used for representing a decision tree.
Node(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
double CalculateProbability(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification probability for a set of input features, based on the trained model...
pandora::StatusCode ReadComponent(pandora::TiXmlElement *pCurrentXmlElement)
Read xml element and if weak classifier add to member variables.
StrongClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
int GetRightChildNodeId() const
Return right child node id.
required by fuzzyCluster table::sbnd_g4_services gaushitTruthMatch pandora
Definition: reco_sbnd.fcl:182
double CalculateClassificationScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification score for a set of input features, based on the trained model...
pandora::StatusCode Initialize(const std::string &parameterLocation, const std::string &bdtName)
Initialize the bdt model.
AdaBoostDecisionTree & operator=(const AdaBoostDecisionTree &rhs)
Assignment operator.
double CalculateScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate score for input features using strong classifier.
StrongClassifier & operator=(const StrongClassifier &rhs)
Assignment operator.
StrongClassifier * m_pStrongClassifier
Strong adaptive boost tree classifier.
bool Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
BEGIN_PROLOG could also be cout
int m_variableId
Variable cut on for decision if decision node.
bool IsLeaf() const
Return is the node a leaf.