From c3497de21263c7a4788318a2d8dcb7436821147c Mon Sep 17 00:00:00 2001 From: Giovanni Marchiori <39376142+giovannimarchiori@users.noreply.github.com> Date: Tue, 2 Jul 2024 13:58:44 +0200 Subject: [PATCH] Implement an algorithm for photon/pi0 discrimination based on input features and a trained model in ONNX format (#85) * run MVA photon ID in gaudi * save photon score instead of pi0 score * fix compilation warning * use [[maybe_unused]] for EventContext * fix some typos and indentation * parse json files with nlohmann * handle malformed JSON files --- RecFCCeeCalorimeter/CMakeLists.txt | 3 +- .../src/components/PhotonIDTool.cpp | 419 ++++++++++++++++++ .../src/components/PhotonIDTool.h | 115 +++++ 3 files changed, 536 insertions(+), 1 deletion(-) create mode 100644 RecFCCeeCalorimeter/src/components/PhotonIDTool.cpp create mode 100644 RecFCCeeCalorimeter/src/components/PhotonIDTool.h diff --git a/RecFCCeeCalorimeter/CMakeLists.txt b/RecFCCeeCalorimeter/CMakeLists.txt index c6d0beeb..d7a497a8 100644 --- a/RecFCCeeCalorimeter/CMakeLists.txt +++ b/RecFCCeeCalorimeter/CMakeLists.txt @@ -27,7 +27,8 @@ gaudi_add_module(k4RecFCCeeCalorimeterPlugins DD4hep::DDG4 ROOT::Core ROOT::Hist - ${ONNXRUNTIME_LIBRARY} + ${ONNXRUNTIME_LIBRARY} + nlohmann_json::nlohmann_json ) install(TARGETS k4RecFCCeeCalorimeterPlugins EXPORT k4RecCalorimeterTargets diff --git a/RecFCCeeCalorimeter/src/components/PhotonIDTool.cpp b/RecFCCeeCalorimeter/src/components/PhotonIDTool.cpp new file mode 100644 index 00000000..1e07dbd8 --- /dev/null +++ b/RecFCCeeCalorimeter/src/components/PhotonIDTool.cpp @@ -0,0 +1,419 @@ +#include "PhotonIDTool.h" + +// our EDM +#include "edm4hep/Cluster.h" +#include "edm4hep/ClusterCollection.h" + +#include + +#include "nlohmann/json.hpp" + +using json = nlohmann::json; + + +DECLARE_COMPONENT(PhotonIDTool) + +// convert vector data with given shape into ONNX runtime tensor +template +Ort::Value vec_to_tensor(std::vector &data, const std::vector &shape) +{ + Ort::MemoryInfo mem_info = + Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); + auto tensor = Ort::Value::CreateTensor(mem_info, data.data(), data.size(), shape.data(), shape.size()); + return tensor; +} + +PhotonIDTool::PhotonIDTool(const std::string &name, + ISvcLocator *svcLoc) + : Gaudi::Algorithm(name, svcLoc) +{ + declareProperty("inClusters", m_inClusters, "Input cluster collection"); + declareProperty("outClusters", m_outClusters, "Output cluster collection"); +} + +StatusCode PhotonIDTool::initialize() +{ + // Initialize base class + { + StatusCode sc = Gaudi::Algorithm::initialize(); + if (sc.isFailure()) + { + return sc; + } + } + + // read the files defining the model + StatusCode sc = readMVAFiles(m_mvaInputsFile, m_mvaModelFile); + if (sc.isFailure()) + { + error() << "Initialization of photon ID tool config files not successful!" << endmsg; + return sc; + } + + // read from the metadata the names of the shape parameters in the input clusters + std::vector shapeParameters = m_inShapeParameterHandle.get({}); + debug() << "Variables in shapeParameters of input clusters:" << endmsg; + for (const auto &str : shapeParameters) { + debug() << str << endmsg; + } + + // check if the shape parameters contain the inputs needed for the inference + m_inputPositionsInShapeParameters.clear(); + for (const auto &feature : m_internal_input_names) { + + if (feature == "ecl") { + // for the cluster energy, check if we have rawE in decorations + // this is for cluster that have been passed through the MVA calibration + // otherwise, we will use the energy of the cluster object + auto it = std::find(shapeParameters.begin(), shapeParameters.end(), "rawE"); + if (it != shapeParameters.end()) + { + int position = std::distance(shapeParameters.begin(), it); + m_inputPositionsInShapeParameters.push_back(position); + info() << "Feature " << feature << " found in position " << position << " of shapeParameters" << endmsg; + } + else { + m_inputPositionsInShapeParameters.push_back(-1); + } + } + else { + // for the other features, check if they are in the shape parameters + auto it = std::find(shapeParameters.begin(), shapeParameters.end(), feature); + if (it != shapeParameters.end()) + { + int position = std::distance(shapeParameters.begin(), it); + m_inputPositionsInShapeParameters.push_back(position); + info() << "Feature " << feature << " found in position " << position << " of shapeParameters" << endmsg; + } + else + { + // at least one of the inputs of the MVA was not found in the shapeParameters + // so we can stop checking the others + m_inputPositionsInShapeParameters.clear(); + error() << "Feature " << feature << " not found, aborting..." << endmsg; + return StatusCode::FAILURE; + } + } + } + + // append the MVA score to the output shape parameters + shapeParameters.push_back("photonIDscore"); + m_outShapeParameterHandle.put(shapeParameters); + + info() << "Initialized the photonID MVA tool" << endmsg; + return StatusCode::SUCCESS; +} + +StatusCode PhotonIDTool::execute([[maybe_unused]] const EventContext &evtCtx) const +{ + verbose() << "-------------------------------------------" << endmsg; + + // Get the input collection with clusters + const edm4hep::ClusterCollection *inClusters = m_inClusters.get(); + + // Initialize output clusters + edm4hep::ClusterCollection *outClusters = initializeOutputClusters(inClusters); + if (!outClusters) + { + error() << "Something went wrong in initialization of the output cluster collection, exiting!" << endmsg; + return StatusCode::FAILURE; + } + if (inClusters->size() != outClusters->size()) + { + error() << "Sizes of input and output cluster collections does not match, exiting!" << endmsg; + return StatusCode::FAILURE; + } + + // Run inference + { + StatusCode sc = applyMVAtoClusters(inClusters, outClusters); + if (sc.isFailure()) + { + return sc; + } + } + + return StatusCode::SUCCESS; +} + +StatusCode PhotonIDTool::finalize() +{ + if (m_ortSession) + delete m_ortSession; + if (m_ortEnv) + delete m_ortEnv; + + return Gaudi::Algorithm::finalize(); +} + +edm4hep::ClusterCollection *PhotonIDTool::initializeOutputClusters( + const edm4hep::ClusterCollection *inClusters) const +{ + edm4hep::ClusterCollection *outClusters = m_outClusters.createAndPut(); + + for (auto const &inCluster : *inClusters) + { + auto outCluster = inCluster.clone(); + outClusters->push_back(outCluster); + } + + return outClusters; +} + +StatusCode PhotonIDTool::readMVAFiles(const std::string& mvaInputsFileName, + const std::string& mvaModelFileName) +{ + // 1. read the file with the list of input features + // Open the JSON file + std::ifstream file(mvaInputsFileName); + if (!file.is_open()) { + error() << "Error opening file: " << mvaInputsFileName << endmsg; + return StatusCode::FAILURE; + } + + // Parse the JSON file + json j; + try { + file >> j; + } catch (const nlohmann::json::exception& e) { + error() << "Error parsing JSON: " << e.what() << endmsg; + return StatusCode::FAILURE; + } + file.close(); + + // Access the data and print to screen + std::string timeStamp; + if (!j.contains("timeStamp")) { + error() << "Error: timeStamp key not found in JSON" << endmsg; + return StatusCode::FAILURE; + } + else { + timeStamp = j["timeStamp"]; + } + + std::string clusterCollection; + if (!j.contains("clusterCollection")) { + error() << "Error: clusterCollection key not found in JSON" << endmsg; + return StatusCode::FAILURE; + } + else { + clusterCollection = j["clusterCollection"]; + } + + std::string trainingTool; + if (!j.contains("trainingTool")) { + error() << "Error: trainingTool key not found in JSON" << endmsg; + return StatusCode::FAILURE; + } + else { + trainingTool = j["trainingTool"]; + } + + info() << "Using the following photon-ID training:" << endmsg; + info() << " Timestamp: " << timeStamp << endmsg; + info() << " Training tool used: " << trainingTool << endmsg; + info() << " Input cluster collection: " << clusterCollection << endmsg; + if (!j.contains("shapeParameters")) { + error() << "Error: shapeParameters key not found in JSON" << endmsg; + return StatusCode::FAILURE; + } + else { + try { + const auto& shape_params = j["shapeParameters"]; + if (!shape_params.is_array()) { + throw std::runtime_error("shapeParameters is not an array"); + } + for (const auto& param : shape_params) { + if (!param.is_string()) { + throw std::runtime_error("shapeParameters contains non-string values"); + } + m_internal_input_names.push_back(param.get()); + } + } catch (const std::exception& e) { + error() << "Error: " << e.what() << endmsg; + return StatusCode::FAILURE; + } + } + info() << " Input shape parameters:" << endmsg; + for (const auto &str : m_internal_input_names) { + info() << " " << str << endmsg; + } + if (!j.contains("trainingParameters")) { + error() << "Error: trainingParameters key not found in JSON" << endmsg; + return StatusCode::FAILURE; + } + else { + info() << " Training parameters:" << endmsg; + for (const auto ¶m : j["trainingParameters"].items()) { + std::string key = param.key(); + std::string value; + if (param.value().is_string()) { + value = param.value().get(); + } + else if (param.value().is_number()) { + value = std::to_string(param.value().get()); + } + else if (param.value().is_null()) { + value = "null"; + } + else { + value = "invalid"; + } + info() << " " << key << " : " << value << endmsg; + } + } + + + // 2. - read the file with the MVA model and setup the ONNX runtime + // set ONNX logging level based on output level of this alg + OrtLoggingLevel loggingLevel = ORT_LOGGING_LEVEL_WARNING; + MSG::Level outputLevel = this->msgStream().level(); + switch (outputLevel) + { + case MSG::Level::FATAL: // 6 + loggingLevel = ORT_LOGGING_LEVEL_FATAL; // 4 + break; + case MSG::Level::ERROR: // 5 + loggingLevel = ORT_LOGGING_LEVEL_ERROR; // 3 + break; + case MSG::Level::WARNING: // 4 + loggingLevel = ORT_LOGGING_LEVEL_WARNING; // 2 + break; + case MSG::Level::INFO: // 3 + loggingLevel = ORT_LOGGING_LEVEL_WARNING; // 2 (ORT_LOGGING_LEVEL_INFO too verbose..) + break; + case MSG::Level::DEBUG: // 2 + loggingLevel = ORT_LOGGING_LEVEL_INFO; // 1 + break; + case MSG::Level::VERBOSE: // 1 + loggingLevel = ORT_LOGGING_LEVEL_VERBOSE; // 0 + break; + default: + break; + } + try + { + m_ortEnv = new Ort::Env(loggingLevel, "ONNX runtime environment for photonID"); + Ort::SessionOptions session_options; + session_options.SetIntraOpNumThreads(1); + m_ortSession = new Ort::Experimental::Session(*m_ortEnv, const_cast(mvaModelFileName), session_options); + // m_ortSession = new Ort::Session(*m_ortEnv, const_cast(mvaModelFileName), session_options); + } + catch (const Ort::Exception &exception) + { + error() << "ERROR setting up ONNX runtime environment: " << exception.what() << endmsg; + return StatusCode::FAILURE; + } + + // print name/shape of inputs + // use default allocator (CPU) + Ort::AllocatorWithDefaultOptions allocator; + debug() << "Input Node Name/Shape (" << m_ortSession->GetInputCount() << "):" << endmsg; + for (std::size_t i = 0; i < m_ortSession->GetInputCount(); i++) + { + // for old ONNX runtime version + // m_input_names.emplace_back(m_ortSession->GetInputName(i, allocator)); + // for new runtime version + m_input_names.emplace_back(m_ortSession->GetInputNameAllocated(i, allocator).get()); + m_input_shapes = m_ortSession->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape(); + debug() << "\t" << m_input_names.at(i) << " : "; + for (std::size_t k = 0; k < m_input_shapes.size() - 1; k++) + { + debug() << m_input_shapes[k] << "x"; + } + debug() << m_input_shapes[m_input_shapes.size() - 1] << endmsg; + } + // some models might have negative shape values to indicate dynamic shape, e.g., for variable batch size. + for (auto &s : m_input_shapes) + { + if (s < 0) + { + s = 1; + } + } + + // print name/shape of outputs + debug() << "Output Node Name/Shape (" << m_ortSession->GetOutputCount() << "):" << endmsg; + for (std::size_t i = 0; i < m_ortSession->GetOutputCount(); i++) + { + // for old ONNX runtime version + // m_output_names.emplace_back(m_ortSession->GetOutputName(i, allocator)); + // for new runtime version + m_output_names.emplace_back(m_ortSession->GetOutputNameAllocated(i, allocator).get()); + m_output_shapes = m_ortSession->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape(); + debug() << m_output_shapes.size() << endmsg; + debug() << "\t" << m_output_names.at(i) << " : "; + for (std::size_t k = 0; k < m_output_shapes.size() - 1; k++) + { + debug() << m_output_shapes[k] << "x"; + } + debug() << m_output_shapes[m_output_shapes.size() - 1] << endmsg; + } + + debug() << "PhotonID config files read out successfully" << endmsg; + + return StatusCode::SUCCESS; +} + +StatusCode PhotonIDTool::applyMVAtoClusters(const edm4hep::ClusterCollection *inClusters, + edm4hep::ClusterCollection *outClusters) const +{ + size_t numShapeVars = m_internal_input_names.size(); + std::vector mvaInputs(numShapeVars); + + // loop over the input clusters and perform the inference + for (unsigned int j = 0; j < inClusters->size(); ++j) + { + // read the values of the input features + for (unsigned int i = 0; i < m_inputPositionsInShapeParameters.size(); i++) { + int position = m_inputPositionsInShapeParameters[i]; + if (position == -1) + mvaInputs[i] = (inClusters->at(j)).getEnergy(); + else + mvaInputs[i] = (inClusters->at(j)).getShapeParameters(position); + } + + // print the values of the input features + verbose() << "MVA inputs:" << endmsg; + for (unsigned short int k = 0; k < numShapeVars; ++k) + { + verbose() << "var " << k << " : " << mvaInputs[k] << endmsg; + } + + // run the MVA and save the output score in output + float score= -1.0; + // Create a single Ort tensor + std::vector input_tensors; + input_tensors.emplace_back(vec_to_tensor(mvaInputs, m_input_shapes)); + + // pass data through model + try + { + std::vector output_tensors = m_ortSession->Run(m_input_names, + input_tensors, + m_output_names, + Ort::RunOptions{nullptr}); + + // double-check the dimensions of the output tensors + // NOTE: the number of output tensors is equal to the number of output nodes specified in the Run() call + // assert(output_tensors.size() == output_names.size() && output_tensors[0].IsTensor()); + // the probabilities are in the 2nd entry of the output + debug() << output_tensors.size() << endmsg; + debug() << output_tensors[1].GetTensorTypeAndShapeInfo().GetShape() << endmsg; + float *outputData = output_tensors[1].GetTensorMutableData(); + for (int i=0; i<2; i++) + debug() << i << " " << outputData[i] << endmsg; + score = outputData[1]; + } + catch (const Ort::Exception &exception) + { + error() << "ERROR running model inference: " << exception.what() << endmsg; + return StatusCode::FAILURE; + } + + verbose() << "Photon ID score: " << score << endmsg; + outClusters->at(j).addToShapeParameters(score); + } + + return StatusCode::SUCCESS; +} diff --git a/RecFCCeeCalorimeter/src/components/PhotonIDTool.h b/RecFCCeeCalorimeter/src/components/PhotonIDTool.h new file mode 100644 index 00000000..7edaab3c --- /dev/null +++ b/RecFCCeeCalorimeter/src/components/PhotonIDTool.h @@ -0,0 +1,115 @@ +#ifndef RECFCCEECALORIMETER_PHOTONIDTOOL_H +#define RECFCCEECALORIMETER_PHOTONIDTOOL_H + +// Key4HEP +#include "k4FWCore/DataHandle.h" +#include "k4FWCore/MetaDataHandle.h" + +// Gaudi +#include "GaudiKernel/Algorithm.h" +#include "GaudiKernel/ToolHandle.h" +#include "GaudiKernel/MsgStream.h" +class IGeoSvc; + +// EDM4HEP +namespace edm4hep { + class Cluster; + class ClusterCollection; +} + +// ONNX +#include "onnxruntime/core/session/experimental_onnxruntime_cxx_api.h" + +/** @class PhotonIDTool + * + * Apply a binary MVA classifier to discriminate between photons and pi0s. + * It takes a cluster collection in inputs, runs the inference using as inputs + * the variables in the shapeParameters of the input clusters, decorates the + * cluster with the photon probability (appended to the shapeParameters vector) + * and saves the cluster in a new output collection. + * + * @author Giovanni Marchiori + */ + +class PhotonIDTool : public Gaudi::Algorithm { + +public: + PhotonIDTool(const std::string& name, ISvcLocator* svcLoc); + + virtual StatusCode initialize(); + + virtual StatusCode execute(const EventContext& evtCtx) const; + + virtual StatusCode finalize(); + +private: + /** + * Initialize output calorimeter cluster collection. + * + * @param[in] inClusters Pointer to the input cluster collection. + * + * @return Pointer to the output cluster collection. + */ + edm4hep::ClusterCollection* initializeOutputClusters(const edm4hep::ClusterCollection* inClusters) const; + + /** + * Load file with MVA model into memory. + * + * @return Status code. + */ + StatusCode readMVAFiles(const std::string& mvaInputsFileName, + const std::string& mvaModelFileName); + + /** + * Calculate the MVA score for the input clusters and adds it + * as a new shapeParameter of the output clusters + * + * @param[in] inClusters Pointer to the input cluster collection. + * @param[out] outClusters Pointer to the output cluster collection. + * + * @return Status code. + */ + StatusCode applyMVAtoClusters(const edm4hep::ClusterCollection* inClusters, + edm4hep::ClusterCollection* outClusters) const; + + /// Handle for input calorimeter clusters collection + mutable DataHandle m_inClusters { + "inClusters", Gaudi::DataHandle::Reader, this + }; + + /// Handle for output calorimeter clusters collection + mutable DataHandle m_outClusters { + "outClusters", Gaudi::DataHandle::Writer, this + }; + + /// Handles for the cluster shower shape metadata to read and to write + MetaDataHandle> m_inShapeParameterHandle{ + m_inClusters, + edm4hep::labels::ShapeParameterNames, + Gaudi::DataHandle::Reader}; + MetaDataHandle> m_outShapeParameterHandle{ + m_outClusters, + edm4hep::labels::ShapeParameterNames, + Gaudi::DataHandle::Writer}; + + /// Files with the MVA model and list of inputs + Gaudi::Property m_mvaModelFile { + this, "mvaModelFile", {}, "ONNX file with the mva model"}; + Gaudi::Property m_mvaInputsFile { + this, "mvaInputsFile", {}, "JSON file with the mva inputs"}; + + // the ONNX runtime session for running the inference, + // the environment, and the input and output shapes and names + Ort::Experimental::Session* m_ortSession = nullptr; + Ort::Env* m_ortEnv = nullptr; + std::vector m_input_shapes; + std::vector m_output_shapes; + std::vector m_input_names; + std::vector m_output_names; + std::vector m_internal_input_names; + + // the indices of the shapeParameters containing the inputs to the model (-1 if not found) + std::vector m_inputPositionsInShapeParameters; +}; + +#endif /* RECFCCEECALORIMETER_PHOTONIDTOOL_H */