All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
DlHitTrackShowerIdAlgorithm.cc
Go to the documentation of this file.
1 /**
2  * @file larpandoradlcontent/LArTrackShowerId/DlHitTrackShowerIdAlgorithm.cc
3  *
4  * @brief Implementation of the deep learning track shower id algorithm.
5  *
6  * $Log: $
7  */
8 
9 #include "Pandora/AlgorithmHeaders.h"
10 
11 #include <torch/script.h>
12 #include <torch/torch.h>
13 
15 
21 
23 
24 #include <chrono>
25 
26 using namespace pandora;
27 using namespace lar_content;
28 
29 namespace lar_dl_content
30 {
31 
32 DlHitTrackShowerIdAlgorithm::DlHitTrackShowerIdAlgorithm() :
33  m_imageHeight(256),
34  m_imageWidth(256),
35  m_tileSize(128.f),
36  m_visualize(false),
37  m_useTrainingMode(false),
38  m_trainingOutputFile("")
39 {
40 }
41 
42 //------------------------------------------------------------------------------------------------------------------------------------------
43 
45 {
46 }
47 
48 //------------------------------------------------------------------------------------------------------------------------------------------
49 
51 {
53  return this->Train();
54  else
55  return this->Infer();
56 }
57 
58 //------------------------------------------------------------------------------------------------------------------------------------------
59 
61 {
62  const int SHOWER{1}, TRACK{2};
63  for (const std::string listName : m_caloHitListNames)
64  {
65  const CaloHitList *pCaloHitList(nullptr);
66  PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::GetList(*this, listName, pCaloHitList));
67  const MCParticleList *pMCParticleList(nullptr);
68  PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::GetCurrentList(*this, pMCParticleList));
69 
70  const HitType view{pCaloHitList->front()->GetHitType()};
71 
72  if (!(view == TPC_VIEW_U || view == TPC_VIEW_V || view == TPC_VIEW_W))
73  return STATUS_CODE_NOT_ALLOWED;
74 
75  std::string trainingOutputFileName(m_trainingOutputFile);
76 
77  if (view == TPC_VIEW_U)
78  trainingOutputFileName += "_CaloHitListU.csv";
79  else if (view == TPC_VIEW_V)
80  trainingOutputFileName += "_CaloHitListV.csv";
81  else if (view == TPC_VIEW_W)
82  trainingOutputFileName += "_CaloHitListW.csv";
83 
85  // Only care about reconstructability with respect to the current view, so skip good view check
86  parameters.m_minHitsForGoodView = 0;
87  // Turn off max photo propagation for now, only care about killing off daughters of neutrons
88  parameters.m_maxPhotonPropagation = std::numeric_limits<float>::max();
89  LArMCParticleHelper::MCContributionMap targetMCParticleToHitsMap;
90  LArMCParticleHelper::SelectReconstructableMCParticles(
91  pMCParticleList, pCaloHitList, parameters, LArMCParticleHelper::IsBeamNeutrinoFinalState, targetMCParticleToHitsMap);
92 
93  LArMvaHelper::MvaFeatureVector featureVector;
94  for (const CaloHit *pCaloHit : *pCaloHitList)
95  {
96  int tag{TRACK};
97  float inputEnergy{0.f};
98 
99  try
100  {
101  const MCParticle *const pMCParticle(MCParticleHelper::GetMainMCParticle(pCaloHit));
102  // Throw away non-reconstructable hits
103  if (targetMCParticleToHitsMap.find(pMCParticle) == targetMCParticleToHitsMap.end())
104  continue;
105  if (LArMCParticleHelper::IsDescendentOf(pMCParticle, 2112))
106  continue;
107  inputEnergy = pCaloHit->GetInputEnergy();
108  if (inputEnergy < 0.f)
109  continue;
110 
111  const int pdg{std::abs(pMCParticle->GetParticleId())};
112  if (pdg == 11 || pdg == 22)
113  tag = SHOWER;
114  else
115  tag = TRACK;
116  }
117  catch (const StatusCodeException &)
118  {
119  continue;
120  }
121 
122  featureVector.push_back(static_cast<double>(pCaloHit->GetPositionVector().GetX()));
123  featureVector.push_back(static_cast<double>(pCaloHit->GetPositionVector().GetZ()));
124  featureVector.push_back(static_cast<double>(tag));
125  featureVector.push_back(static_cast<double>(inputEnergy));
126  }
127  // Add number of hits to end of vector than rotate (more efficient than direct insert at front)
128  featureVector.push_back(static_cast<double>(featureVector.size() / 4));
129  std::rotate(featureVector.rbegin(), featureVector.rbegin() + 1, featureVector.rend());
130 
131  PANDORA_RETURN_RESULT_IF(pandora::STATUS_CODE_SUCCESS, !=, LArMvaHelper::ProduceTrainingExample(trainingOutputFileName, true, featureVector));
132  }
133 
134  return STATUS_CODE_SUCCESS;
135 }
136 
137 //------------------------------------------------------------------------------------------------------------------------------------------
138 
140 {
141  const float eps{1.1920929e-7}; // Python float epsilon, used in image padding
142 
143  if (m_visualize)
144  {
145  PANDORA_MONITORING_API(SetEveDisplayParameters(this->GetPandora(), true, DETECTOR_VIEW_XZ, -1.f, 1.f, 1.f));
146  }
147 
148  for (const std::string listName : m_caloHitListNames)
149  {
150  const CaloHitList *pCaloHitList(nullptr);
151  PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::GetList(*this, listName, pCaloHitList));
152 
153  const HitType view{pCaloHitList->front()->GetHitType()};
154 
155  if (!(view == TPC_VIEW_U || view == TPC_VIEW_V || view == TPC_VIEW_W))
156  return STATUS_CODE_NOT_ALLOWED;
157 
158  LArDLHelper::TorchModel &model{view == TPC_VIEW_U ? m_modelU : (view == TPC_VIEW_V ? m_modelV : m_modelW)};
159 
160  // Get bounds of hit region
161  float xMin{};
162  float xMax{};
163  float zMin{};
164  float zMax{};
165  this->GetHitRegion(*pCaloHitList, xMin, xMax, zMin, zMax);
166  const float xRange = (xMax + eps) - (xMin - eps);
167  int nTilesX = static_cast<int>(std::ceil(xRange / m_tileSize));
168 
169  PixelToTileMap sparseMap;
170  this->GetSparseTileMap(*pCaloHitList, xMin, zMin, nTilesX, sparseMap);
171  const int nTiles = sparseMap.size();
172 
173  CaloHitList trackHits, showerHits, otherHits;
174  // Process tile
175  // ATTN: Be sure to reset all values to zero after each tile has been processed
176  float **weights = new float *[m_imageHeight];
177  for (int r = 0; r < m_imageHeight; ++r)
178  weights[r] = new float[m_imageWidth]();
179  for (int i = 0; i < nTiles; ++i)
180  {
181  for (const CaloHit *pCaloHit : *pCaloHitList)
182  {
183  const float x(pCaloHit->GetPositionVector().GetX());
184  const float z(pCaloHit->GetPositionVector().GetZ());
185  // Determine which tile the hit will be assigned to
186  const int tileX = static_cast<int>(std::floor((x - xMin) / m_tileSize));
187  const int tileZ = static_cast<int>(std::floor((z - zMin) / m_tileSize));
188  const int tile = sparseMap.at(tileZ * nTilesX + tileX);
189  if (tile == i)
190  {
191  // Determine hit position within the tile
192  const float localX = std::fmod(x - xMin, m_tileSize);
193  const float localZ = std::fmod(z - zMin, m_tileSize);
194  // Determine hit pixel within the tile
195  const int pixelX = static_cast<int>(std::floor(localX * m_imageWidth / m_tileSize));
196  const int pixelZ = (m_imageHeight - 1) - static_cast<int>(std::floor(localZ * m_imageHeight / m_tileSize));
197  weights[pixelZ][pixelX] += pCaloHit->GetInputEnergy();
198  }
199  }
200 
201  // Find min and max charge to allow normalisation
202  float chargeMin{std::numeric_limits<float>::max()}, chargeMax{-std::numeric_limits<float>::max()};
203  for (int r = 0; r < m_imageHeight; ++r)
204  {
205  for (int c = 0; c < m_imageWidth; ++c)
206  {
207  if (weights[r][c] > chargeMax)
208  chargeMax = weights[r][c];
209  if (weights[r][c] < chargeMin)
210  chargeMin = weights[r][c];
211  }
212  }
213  float chargeRange{chargeMax - chargeMin};
214  if (chargeRange <= 0.f)
215  chargeRange = 1.f;
216 
217  // Populate accessor based on normalised weights
218  CaloHitToPixelMap caloHitToPixelMap;
221  auto accessor = input.accessor<float, 4>();
222  for (const CaloHit *pCaloHit : *pCaloHitList)
223  {
224  const float x(pCaloHit->GetPositionVector().GetX());
225  const float z(pCaloHit->GetPositionVector().GetZ());
226  // Determine which tile the hit will be assigned to
227  const int tileX = static_cast<int>(std::floor((x - xMin) / m_tileSize));
228  const int tileZ = static_cast<int>(std::floor((z - zMin) / m_tileSize));
229  const int tile = sparseMap.at(tileZ * nTilesX + tileX);
230  if (tile == i)
231  {
232  // Determine hit position within the tile
233  const float localX = std::fmod(x - xMin, m_tileSize);
234  const float localZ = std::fmod(z - zMin, m_tileSize);
235  // Determine hit pixel within the tile
236  const int pixelX = static_cast<int>(std::floor(localX * m_imageWidth / m_tileSize));
237  const int pixelZ = (m_imageHeight - 1) - static_cast<int>(std::floor(localZ * m_imageHeight / m_tileSize));
238  accessor[0][0][pixelZ][pixelX] = (weights[pixelZ][pixelX] - chargeMin) / chargeRange;
239  caloHitToPixelMap.insert(std::make_pair(pCaloHit, std::make_tuple(tileZ, tileX, pixelZ, pixelX)));
240  }
241  }
242  // Reset weights
243  for (int r = 0; r < this->m_imageHeight; ++r)
244  for (int c = 0; c < this->m_imageWidth; ++c)
245  weights[r][c] = 0.f;
246 
247  // Run the input through the trained model and get the output accessor
249  inputs.push_back(input);
251  LArDLHelper::Forward(model, inputs, output);
252  auto outputAccessor = output.accessor<float, 4>();
253 
254  for (const CaloHit *pCaloHit : *pCaloHitList)
255  {
256  auto found{caloHitToPixelMap.find(pCaloHit)};
257  if (found == caloHitToPixelMap.end())
258  continue;
259  auto pixelMap = found->second;
260  const int tileZ(std::get<0>(pixelMap));
261  const int tileX(std::get<1>(pixelMap));
262  const int tile = sparseMap.at(tileZ * nTilesX + tileX);
263  if (tile == i)
264  { // Make sure we're looking at a hit in the correct tile
265  const int pixelZ(std::get<2>(pixelMap));
266  const int pixelX(std::get<3>(pixelMap));
267 
268  // Apply softmax to loss to get actual probability
269  float probShower = exp(outputAccessor[0][1][pixelZ][pixelX]);
270  float probTrack = exp(outputAccessor[0][2][pixelZ][pixelX]);
271  float probNull = exp(outputAccessor[0][0][pixelZ][pixelX]);
272  if (probShower > probTrack && probShower > probNull)
273  showerHits.push_back(pCaloHit);
274  else if (probTrack > probShower && probTrack > probNull)
275  trackHits.push_back(pCaloHit);
276  else
277  otherHits.push_back(pCaloHit);
278  float recipSum = 1.f / (probShower + probTrack);
279  // Adjust probabilities to ignore null hits and update LArCaloHit
280  probShower *= recipSum;
281  probTrack *= recipSum;
282  LArCaloHit *pLArCaloHit{const_cast<LArCaloHit *>(dynamic_cast<const LArCaloHit *>(pCaloHit))};
283  pLArCaloHit->SetShowerProbability(probShower);
284  pLArCaloHit->SetTrackProbability(probTrack);
285  }
286  }
287  }
288  for (int r = 0; r < this->m_imageHeight; ++r)
289  delete[] weights[r];
290  delete[] weights;
291 
292  if (m_visualize)
293  {
294  const std::string trackListName("TrackHits_" + listName);
295  const std::string showerListName("ShowerHits_" + listName);
296  const std::string otherListName("OtherHits_" + listName);
297  PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &trackHits, trackListName, BLUE));
298  PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &showerHits, showerListName, RED));
299  PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &otherHits, otherListName, BLACK));
300  }
301  }
302 
303  if (m_visualize)
304  {
305  PANDORA_MONITORING_API(ViewEvent(this->GetPandora()));
306  }
307 
308  return STATUS_CODE_SUCCESS;
309 }
310 
311 //------------------------------------------------------------------------------------------------------------------------------------------
312 
313 void DlHitTrackShowerIdAlgorithm::GetHitRegion(const CaloHitList &caloHitList, float &xMin, float &xMax, float &zMin, float &zMax)
314 {
315  xMin = std::numeric_limits<float>::max();
316  xMax = -std::numeric_limits<float>::max();
317  zMin = std::numeric_limits<float>::max();
318  zMax = -std::numeric_limits<float>::max();
319  for (const CaloHit *pCaloHit : caloHitList)
320  {
321  const float x(pCaloHit->GetPositionVector().GetX());
322  const float z(pCaloHit->GetPositionVector().GetZ());
323  if (x < xMin)
324  xMin = x;
325  if (x > xMax)
326  xMax = x;
327  if (z < zMin)
328  zMin = z;
329  if (z > zMax)
330  zMax = z;
331  }
332 }
333 
334 //------------------------------------------------------------------------------------------------------------------------------------------
335 
337  const CaloHitList &caloHitList, const float xMin, const float zMin, const int nTilesX, PixelToTileMap &sparseMap)
338 {
339  // Identify the tiles that actually contain hits
340  std::map<int, bool> tilePopulationMap;
341  for (const CaloHit *pCaloHit : caloHitList)
342  {
343  const float x(pCaloHit->GetPositionVector().GetX());
344  const float z(pCaloHit->GetPositionVector().GetZ());
345  // Determine which tile the hit will be assigned to
346  const int tileX = static_cast<int>(std::floor((x - xMin) / m_tileSize));
347  const int tileZ = static_cast<int>(std::floor((z - zMin) / m_tileSize));
348  const int tile = tileZ * nTilesX + tileX;
349  tilePopulationMap.insert(std::make_pair(tile, true));
350  }
351 
352  int nextTile = 0;
353  for (auto element : tilePopulationMap)
354  {
355  if (element.second)
356  {
357  sparseMap.insert(std::make_pair(element.first, nextTile));
358  ++nextTile;
359  }
360  }
361 }
362 
363 //------------------------------------------------------------------------------------------------------------------------------------------
364 
365 StatusCode DlHitTrackShowerIdAlgorithm::ReadSettings(const TiXmlHandle xmlHandle)
366 {
367  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle, "UseTrainingMode", m_useTrainingMode));
368 
369  if (m_useTrainingMode)
370  {
371  PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle, "TrainingOutputFileName", m_trainingOutputFile));
372  }
373  else
374  {
375  bool modelLoaded{false};
376  PANDORA_RETURN_RESULT_IF_AND_IF(
377  STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle, "ModelFileNameU", m_modelFileNameU));
378  if (!m_modelFileNameU.empty())
379  {
380  m_modelFileNameU = LArFileHelper::FindFileInPath(m_modelFileNameU, "FW_SEARCH_PATH");
381  PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, LArDLHelper::LoadModel(m_modelFileNameU, m_modelU));
382  modelLoaded = true;
383  }
384  PANDORA_RETURN_RESULT_IF_AND_IF(
385  STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle, "ModelFileNameV", m_modelFileNameV));
386  if (!m_modelFileNameV.empty())
387  {
388  m_modelFileNameV = LArFileHelper::FindFileInPath(m_modelFileNameV, "FW_SEARCH_PATH");
389  PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, LArDLHelper::LoadModel(m_modelFileNameV, m_modelV));
390  modelLoaded = true;
391  }
392  PANDORA_RETURN_RESULT_IF_AND_IF(
393  STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle, "ModelFileNameW", m_modelFileNameW));
394  if (!m_modelFileNameW.empty())
395  {
396  m_modelFileNameW = LArFileHelper::FindFileInPath(m_modelFileNameW, "FW_SEARCH_PATH");
397  PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, LArDLHelper::LoadModel(m_modelFileNameW, m_modelW));
398  modelLoaded = true;
399  }
400  if (!modelLoaded)
401  {
402  std::cout << "Error: Inference requested, but no model files were successfully loaded" << std::endl;
403  return STATUS_CODE_INVALID_PARAMETER;
404  }
405  }
406 
407  PANDORA_RETURN_RESULT_IF_AND_IF(
408  STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadVectorOfValues(xmlHandle, "CaloHitListNames", m_caloHitListNames));
409  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle, "ImageHeight", m_imageHeight));
410  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle, "ImageWidth", m_imageWidth));
411  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle, "TileSize", m_tileSize));
412  if (m_imageHeight <= 0.f || m_imageWidth <= 0.f || m_tileSize <= 0.f)
413  {
414  std::cout << "Error: Invalid image size specification" << std::endl;
415  return STATUS_CODE_INVALID_PARAMETER;
416  }
417  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle, "Visualize", m_visualize));
418 
419  return STATUS_CODE_SUCCESS;
420 }
421 
422 } // namespace lar_dl_content
process_name opflash particleana ie ie ie z
Header file for the pfo helper class.
std::unordered_map< const pandora::MCParticle *, pandora::CaloHitList > MCContributionMap
LArDLHelper::TorchModel m_modelW
Model for the W view.
bool m_visualize
Whether to visualize the track shower ID scores.
MvaTypes::MvaFeatureVector MvaFeatureVector
Definition: LArMvaHelper.h:72
var pdg
Definition: selectors.fcl:14
process_name opflash particleana ie x
Header file for the lar calo hit class.
void SetShowerProbability(const float probability)
Set the probability that the hit is shower-like.
Definition: LArCaloHit.h:234
LArDLHelper::TorchModel m_modelU
Model for the U view.
torch::jit::script::Module TorchModel
Definition: LArDLHelper.h:25
std::string m_modelFileNameW
Model file name for W view.
Header file for the lar monitoring helper helper class.
LAr calo hit class.
Definition: LArCaloHit.h:39
unsigned int m_minHitsForGoodView
the minimum number of Hits for a good view
float m_maxPhotonPropagation
the maximum photon propagation length
std::string m_modelFileNameU
Model file name for U view.
T abs(T value)
pandora::StringVector m_caloHitListNames
Name of input calo hit list.
Header file for the lar monte carlo particle helper helper class.
pandora::StatusCode Infer()
Run network inference.
std::string m_trainingOutputFile
Output file name for training examples.
LArDLHelper::TorchModel m_modelV
Model for the V view.
void GetHitRegion(const pandora::CaloHitList &caloHitList, float &xMin, float &xMax, float &zMin, float &zMax)
Identify the XZ range containing the hits for an event.
Header file for the file helper class.
std::string m_modelFileNameV
Model file name for V view.
static void Forward(TorchModel &model, const TorchInputVector &input, TorchOutput &output)
Run a deep learning model.
Definition: LArDLHelper.cc:41
pandora::StatusCode ReadSettings(const pandora::TiXmlHandle xmlHandle)
required by fuzzyCluster table::sbnd_g4_services gaushitTruthMatch pandora
Definition: reco_sbnd.fcl:182
static pandora::StatusCode LoadModel(const std::string &filename, TorchModel &model)
Loads a deep learning model.
Definition: LArDLHelper.cc:16
Header file for the deep learning track shower id algorithm.
BEGIN_PROLOG sequence::SlidingWindowTriggerPatternsOppositeWindows END_PROLOG simSlidingORM6O6 effSlidingORW output
pandora::StatusCode Train()
Produce files that act as inputs to network training.
void GetSparseTileMap(const pandora::CaloHitList &caloHitList, const float xMin, const float zMin, const int nTilesX, PixelToTileMap &sparseMap)
Populate a map between pixels and tiles.
static void InitialiseInput(const at::IntArrayRef dimensions, TorchInput &tensor)
Create a torch input tensor.
Definition: LArDLHelper.cc:34
esac echo uname r
std::map< const pandora::CaloHit *, std::tuple< int, int, int, int > > CaloHitToPixelMap
BEGIN_PROLOG could also be cout
std::vector< torch::jit::IValue > TorchInputVector
Definition: LArDLHelper.h:27