All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
LArAdaBoostDecisionTree.h
Go to the documentation of this file.
1 /**
2  * @file larpandoracontent/LArObjects/LArAdaBoostDecisionTree.h
3  *
4  * @brief Header file for the lar adaptive boosted decision tree class.
5  *
6  * $Log: $
7  */
8 #ifndef LAR_ADABOOST_DECISION_TREE_H
9 #define LAR_ADABOOST_DECISION_TREE_H 1
10 
12 
14 
15 #include "Pandora/StatusCodes.h"
16 
17 #include <functional>
18 #include <map>
19 #include <vector>
20 
21 namespace lar_content
22 {
23 
24 /**
25  * @brief AdaBoostDecisionTree class
26  */
28 {
29 public:
30  /**
31  * @brief Constructor.
32  */
34 
35  /**
36  * @brief Copy constructor
37  *
38  * @param rhs the AdaBoostDecisionTree to copy
39  */
41 
42  /**
43  * @brief Assignment operator
44  *
45  * @param rhs the AdaBoostDecisionTree to assign
46  */
48 
49  /**
50  * @brief Destructor
51  */
53 
54  /**
55  * @brief Initialize the bdt model
56  *
57  * @param parameterLocation the location of the model
58  * @param bdtName the name of the model
59  *
60  * @return success
61  */
62  pandora::StatusCode Initialize(const std::string &parameterLocation, const std::string &bdtName);
63 
64  /**
65  * @brief Classify the set of input features based on the trained model
66  *
67  * @param features the input features
68  *
69  * @return the classification
70  */
71  bool Classify(const LArMvaHelper::MvaFeatureVector &features) const;
72 
73  /**
74  * @brief Calculate the classification score for a set of input features, based on the trained model
75  *
76  * @param features the input features
77  *
78  * @return the classification score
79  */
81 
82  /**
83  * @brief Calculate the classification probability for a set of input features, based on the trained model
84  *
85  * @param features the input features
86  *
87  * @return the classification probability
88  */
89  double CalculateProbability(const LArMvaHelper::MvaFeatureVector &features) const;
90 
91 private:
92  /**
93  * @brief Node class used for representing a decision tree
94  */
95  class Node
96  {
97  public:
98  /**
99  * @brief Constructor using xml handle to set member variables
100  *
101  * @param pXmlHandle xml handle to use when setting member variables
102  */
103  Node(const pandora::TiXmlHandle *const pXmlHandle);
104 
105  /**
106  * @brief Copy constructor
107  *
108  * @param rhs the node to copy
109  */
110  Node(const Node &rhs);
111 
112  /**
113  * @brief Assignment operator
114  *
115  * @param rhs the node to assign
116  */
117  Node &operator=(const Node &rhs);
118 
119  /**
120  * @brief Destructor
121  */
122  ~Node();
123 
124  /**
125  * @brief Return node id
126  *
127  * @return node id
128  */
129  int GetNodeId() const;
130 
131  /**
132  * @brief Return parent node id
133  *
134  * @return parent node id
135  */
136  int GetParentNodeId() const;
137 
138  /**
139  * @brief Return left child node id
140  *
141  * @return left child node id
142  */
143  int GetLeftChildNodeId() const;
144 
145  /**
146  * @brief Return right child node id
147  *
148  * @return right child node id
149  */
150  int GetRightChildNodeId() const;
151 
152  /**
153  * @brief Return is the node a leaf
154  *
155  * @return is node a leaf
156  */
157  bool IsLeaf() const;
158 
159  /**
160  * @brief Return node threshold
161  *
162  * @return threshold cut
163  */
164  double GetThreshold() const;
165 
166  /**
167  * @brief Return cut variable
168  *
169  * @return variable cut on
170  */
171  int GetVariableId() const;
172 
173  /**
174  * @brief Return outcome
175  *
176  * @return outcome of cut
177  */
178  bool GetOutcome() const;
179 
180  private:
181  int m_nodeId; ///< Node id
182  int m_parentNodeId; ///< Parent node id
183  int m_leftChildNodeId; ///< Left child node id
184  int m_rightChildNodeId; ///< Right child node id
185  bool m_isLeaf; ///< Is node a leaf
186  double m_threshold; ///< Threshold used for decision if decision node
187  int m_variableId; ///< Variable cut on for decision if decision node
188  bool m_outcome; ///< Outcome if leaf node
189  };
190 
191  typedef std::map<int, const Node *> IdToNodeMap;
192 
193  /**
194  * @brief WeakClassifier class containing a decision tree and a weight
195  */
197  {
198  public:
199  /**
200  * @brief Constructor using xml handle to set member variables
201  *
202  * @param pXmlHandle xml handle to use when setting member variables
203  */
204  WeakClassifier(const pandora::TiXmlHandle *const pXmlHandle);
205 
206  /**
207  * @brief Copy constructor
208  *
209  * @param rhs the weak classifier to copy
210  */
211  WeakClassifier(const WeakClassifier &rhs);
212 
213  /**
214  * @brief Assignment operator
215  *
216  * @param rhs the weak classifier to assign
217  */
219 
220  /**
221  * @brief Destructor
222  */
223  ~WeakClassifier();
224 
225  /**
226  * @brief Predict signal or background based on trained data
227  *
228  * @param features the input features
229  *
230  * @return is signal or background
231  */
232  bool Predict(const LArMvaHelper::MvaFeatureVector &features) const;
233 
234  /**
235  * @brief Evalute node and return outcome
236  *
237  * @param nodeId current node id
238  * @param features the input features
239  *
240  * @return is signal or background node
241  */
242  bool EvaluateNode(const int nodeId, const LArMvaHelper::MvaFeatureVector &features) const;
243 
244  /**
245  * @brief Get boost weight for weak classifier
246  *
247  * @return weight for decision tree
248  */
249  double GetWeight() const;
250 
251  /**
252  * @brief Get tree id for weak classifier
253  *
254  * @return tree id
255  */
256  int GetTreeId() const;
257 
258  private:
259  IdToNodeMap m_idToNodeMap; ///< Decision tree nodes
260  double m_weight; ///< Boost weight
261  int m_treeId; ///< Decision tree id
262  };
263 
264  typedef std::vector<const WeakClassifier *> WeakClassifiers;
265 
266  /**
267  * @brief StrongClassifier class used in application of adaptive boost decision tree
268  */
270  {
271  public:
272  /**
273  * @brief Constructor using xml handle to set member variables
274  *
275  * @param pXmlHandle xml handle to use when setting member variables
276  */
277  StrongClassifier(const pandora::TiXmlHandle *const pXmlHandle);
278 
279  /**
280  * @brief Copy constructor
281  *
282  * @param rhs the strong classifier to copy
283  */
285 
286  /**
287  * @brief Assignment operator
288  *
289  * @param rhs the strong classifier to assign
290  */
292 
293  /**
294  * @brief Destructor
295  */
297 
298  /**
299  * @brief Predict signal or background based on trained data
300  *
301  * @param features the input features
302  *
303  * @return return score produced from trained model
304  */
305  double Predict(const LArMvaHelper::MvaFeatureVector &features) const;
306 
307  private:
308  /**
309  * @brief Read xml element and if weak classifier add to member variables
310  */
311  pandora::StatusCode ReadComponent(pandora::TiXmlElement *pCurrentXmlElement);
312 
313  WeakClassifiers m_weakClassifiers; ///< Vector of weak classifers
314  };
315 
316  /**
317  * @brief Calculate score for input features using strong classifier
318  *
319  * @param features the input features
320  *
321  * @return score
322  */
323  double CalculateScore(const LArMvaHelper::MvaFeatureVector &features) const;
324 
325  StrongClassifier *m_pStrongClassifier; ///< Strong adaptive boost tree classifier
326 };
327 
328 //------------------------------------------------------------------------------------------------------------------------------------------
329 
331 {
332  return m_nodeId;
333 }
334 
335 //------------------------------------------------------------------------------------------------------------------------------------------
336 
338 {
339  return m_parentNodeId;
340 }
341 
342 //------------------------------------------------------------------------------------------------------------------------------------------
343 
345 {
346  return m_leftChildNodeId;
347 }
348 
349 //------------------------------------------------------------------------------------------------------------------------------------------
350 
352 {
353  return m_rightChildNodeId;
354 }
355 
356 //------------------------------------------------------------------------------------------------------------------------------------------
357 
359 {
360  return m_isLeaf;
361 }
362 
363 //------------------------------------------------------------------------------------------------------------------------------------------
364 
366 {
367  return m_threshold;
368 }
369 
370 //------------------------------------------------------------------------------------------------------------------------------------------
371 
373 {
374  return m_variableId;
375 }
376 
377 //------------------------------------------------------------------------------------------------------------------------------------------
378 
380 {
381  return m_outcome;
382 }
383 
384 //------------------------------------------------------------------------------------------------------------------------------------------
385 
387 {
388  return m_weight;
389 }
390 
391 //------------------------------------------------------------------------------------------------------------------------------------------
392 
394 {
395  return m_treeId;
396 }
397 
398 } // namespace lar_content
399 
400 #endif // #ifndef LAR_ADABOOST_DECISION_TREE_H
WeakClassifiers m_weakClassifiers
Vector of weak classifers.
WeakClassifier & operator=(const WeakClassifier &rhs)
Assignment operator.
int GetLeftChildNodeId() const
Return left child node id.
int GetVariableId() const
Return cut variable.
MvaTypes::MvaFeatureVector MvaFeatureVector
Definition: LArMvaHelper.h:72
WeakClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
double GetWeight() const
Get boost weight for weak classifier.
bool Classify(const LArMvaHelper::MvaFeatureVector &features) const
Classify the set of input features based on the trained model.
MvaInterface class.
WeakClassifier class containing a decision tree and a weight.
double m_threshold
Threshold used for decision if decision node.
double Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
double GetThreshold() const
Return node threshold.
StrongClassifier class used in application of adaptive boost decision tree.
Node & operator=(const Node &rhs)
Assignment operator.
int GetParentNodeId() const
Return parent node id.
bool EvaluateNode(const int nodeId, const LArMvaHelper::MvaFeatureVector &features) const
Evalute node and return outcome.
std::vector< const WeakClassifier * > WeakClassifiers
Node class used for representing a decision tree.
std::map< int, const Node * > IdToNodeMap
Node(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
double CalculateProbability(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification probability for a set of input features, based on the trained model...
pandora::StatusCode ReadComponent(pandora::TiXmlElement *pCurrentXmlElement)
Read xml element and if weak classifier add to member variables.
StrongClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
int GetRightChildNodeId() const
Return right child node id.
int GetTreeId() const
Get tree id for weak classifier.
double CalculateClassificationScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification score for a set of input features, based on the trained model...
pandora::StatusCode Initialize(const std::string &parameterLocation, const std::string &bdtName)
Initialize the bdt model.
AdaBoostDecisionTree & operator=(const AdaBoostDecisionTree &rhs)
Assignment operator.
double CalculateScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate score for input features using strong classifier.
StrongClassifier & operator=(const StrongClassifier &rhs)
Assignment operator.
StrongClassifier * m_pStrongClassifier
Strong adaptive boost tree classifier.
bool Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
Header file for the lar multivariate analysis interface class.
int m_variableId
Variable cut on for decision if decision node.
bool IsLeaf() const
Return is the node a leaf.