Codebase list votca-xtp / debian/1.6.2-2 include / votca / xtp / checkpointreader.h
debian/1.6.2-2

Tree @debian/1.6.2-2 (Download .tar.gz)

checkpointreader.h @debian/1.6.2-2raw · history · blame

/*
 * Copyright 2009-2019 The VOTCA Development Team (http://www.votca.org)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */
#pragma once
#ifndef VOTCA_XTP_CHECKPOINT_READER_H
#define VOTCA_XTP_CHECKPOINT_READER_H

#include <H5Cpp.h>
#include <string>
#include <type_traits>
#include <typeinfo>
#include <vector>
#include <votca/tools/linalg.h>
#include <votca/xtp/checkpoint_utils.h>
#include <votca/xtp/checkpointtable.h>
#include <votca/xtp/eigen.h>

namespace votca {
namespace xtp {
using namespace checkpoint_utils;

class CheckpointReader {
 public:
  CheckpointReader(const CptLoc& loc) : CheckpointReader(loc, "/"){};

  CheckpointReader(const CptLoc& loc, const std::string path)
      : _loc(loc), _path(path){};

  template <typename T>
  typename std::enable_if<!std::is_fundamental<T>::value>::type operator()(
      T& var, const std::string& name) const {
    try {
      ReadData(_loc, var, name);
    } catch (H5::Exception&) {
      std::stringstream message;

      message << "Could not read " << name << " from " << _loc.getFileName()
              << ":" << _path << std::endl;

      throw std::runtime_error(message.str());
    }
  }

  template <typename T>
  typename std::enable_if<std::is_fundamental<T>::value &&
                          !std::is_same<T, bool>::value>::type
      operator()(T& var, const std::string& name) const {
    try {
      ReadScalar(_loc, var, name);
    } catch (H5::Exception&) {
      std::stringstream message;
      message << "Could not read " << name << " from " << _loc.getFileName()
              << ":" << _path << "/" << std::endl;

      throw std::runtime_error(message.str());
    }
  }

  void operator()(bool& v, const std::string& name) const {
    Index temp = Index(v);
    try {
      ReadScalar(_loc, temp, name);
    } catch (H5::Exception&) {
      std::stringstream message;
      message << "Could not read " << name << " from " << _loc.getFileName()
              << ":" << _path << std::endl;

      throw std::runtime_error(message.str());
    }
    v = static_cast<bool>(temp);
  }

  void operator()(std::string& var, const std::string& name) const {
    try {
      ReadScalar(_loc, var, name);
    } catch (H5::Exception&) {
      std::stringstream message;
      message << "Could not read " << name << " from " << _loc.getFileName()
              << ":" << _path << std::endl;

      throw std::runtime_error(message.str());
    }
  }

  CheckpointReader openChild(const std::string& childName) const {
    try {
      return CheckpointReader(_loc.openGroup(childName),
                              _path + "/" + childName);
    } catch (H5::Exception&) {
      std::stringstream message;
      message << "Could not open " << _loc.getFileName() << ":/" << _path << "/"
              << childName << std::endl;

      throw std::runtime_error(message.str());
    }
  }

  Index getNumDataSets() const { return _loc.getNumObjs(); }

  CptLoc getLoc() { return _loc; }

  template <typename T>
  CptTable openTable(const std::string& name, const T& obj) {
    try {
      CptTable table = CptTable(name, sizeof(typename T::data), _loc);
      obj.SetupCptTable(table);
      return table;
    } catch (H5::Exception&) {
      std::stringstream message;
      message << "Could not open table " << name << " in " << _loc.getFileName()
              << ":" << _path << std::endl;
      throw std::runtime_error(message.str());
    }
  }

 private:
  const CptLoc _loc;
  const std::string _path;
  void ReadScalar(const CptLoc& loc, std::string& var,
                  const std::string& name) const {
    const H5::DataType* strType = InferDataType<std::string>::get();

    H5::Attribute attr = loc.openAttribute(name);

    H5std_string readbuf("");

    attr.read(*strType, readbuf);

    var = readbuf;
  }

  template <typename T>
  void ReadScalar(const CptLoc& loc, T& value, const std::string& name) const {

    H5::Attribute attr = loc.openAttribute(name);
    const H5::DataType* dataType = InferDataType<T>::get();

    attr.read(*dataType, &value);
  }

  template <typename T>
  void ReadData(const CptLoc& loc, Eigen::MatrixBase<T>& matrix,
                const std::string& name) const {

    const H5::DataType* dataType = InferDataType<typename T::Scalar>::get();

    H5::DataSet dataset = loc.openDataSet(name);

    H5::DataSpace dp = dataset.getSpace();

    hsize_t dims[2];
    dp.getSimpleExtentDims(dims, nullptr);  // ndims is always 2 for us

    hsize_t matRows = dims[0];
    hsize_t matCols = dims[1];

    matrix.derived().resize(matRows, matCols);
    if (matrix.size() == 0) {
      return;
    }

    hsize_t matColSize = matrix.derived().outerStride();

    hsize_t fileRows = matCols;

    hsize_t fStride[2] = {1, fileRows};
    hsize_t fCount[2] = {1, 1};
    hsize_t fBlock[2] = {1, fileRows};

    hsize_t mStride[2] = {matColSize, 1};
    hsize_t mCount[2] = {1, 1};
    hsize_t mBlock[2] = {matCols, 1};

    hsize_t mDim[2] = {matCols, matColSize};
    H5::DataSpace mspace(2, mDim);

    for (hsize_t i = 0; i < matRows; i++) {
      hsize_t fStart[2] = {i, 0};
      hsize_t mStart[2] = {0, i};
      dp.selectHyperslab(H5S_SELECT_SET, fCount, fStart, fStride, fBlock);
      mspace.selectHyperslab(H5S_SELECT_SET, mCount, mStart, mStride, mBlock);
      dataset.read(matrix.derived().data(), *dataType, mspace, dp);
    }
  }

  template <typename T>
  typename std::enable_if<std::is_fundamental<T>::value>::type ReadData(
      const CptLoc& loc, std::vector<T>& v, const std::string& name) const {

    H5::DataSet dataset = loc.openDataSet(name);
    H5::DataSpace dp = dataset.getSpace();

    const H5::DataType* dataType = InferDataType<T>::get();

    hsize_t dims[2];
    dp.getSimpleExtentDims(dims, nullptr);

    v.resize(dims[0]);
    if (v.empty()) {
      return;
    }
    try {
      dataset.read(&(v[0]), *dataType);
    } catch (H5::Exception&) {
      std::stringstream message;
      message << "Could not read " << name << " from " << _loc.getFileName()
              << ":" << _path << std::endl;
      throw std::runtime_error(message.str());
    }
  }

  void ReadData(const CptLoc& loc, std::vector<std::string>& v,
                const std::string& name) const {

    H5::DataSet dataset = loc.openDataSet(name);
    H5::DataSpace dp = dataset.getSpace();

    const H5::DataType* dataType = InferDataType<std::string>::get();

    hsize_t dims[2];
    dp.getSimpleExtentDims(dims, nullptr);

    std::vector<char*> temp(dims[0]);
    if (temp.empty()) {
      return;
    }
    try {
      dataset.read(temp.data(), *dataType);
    } catch (H5::Exception&) {
      std::stringstream message;
      message << "Could not read " << name << " from " << _loc.getFileName()
              << ":" << _path << std::endl;
      throw std::runtime_error(message.str());
    }
    v.reserve(dims[0]);
    for (char* s : temp) {
      v.push_back(std::string(s));
      free(s);
    }
  }

  void ReadData(const CptLoc& loc, tools::EigenSystem& sys,
                const std::string& name) const {

    CptLoc parent = loc.openGroup(name);
    ReadData(parent, sys.eigenvalues(), "eigenvalues");
    ReadData(parent, sys.eigenvectors(), "eigenvectors");
    ReadData(parent, sys.eigenvectors2(), "eigenvectors2");
    Index info;
    ReadScalar(parent, info, "info");
    sys.info() = static_cast<Eigen::ComputationInfo>(info);
  }

  void ReadData(const CptLoc& loc, std::vector<Eigen::Vector3d>& v,
                const std::string& name) const {

    CptLoc parent = loc.openGroup(name);
    size_t count = parent.getNumObjs();

    v.resize(count);

    size_t c = 0;
    for (auto& vec : v) {
      ReadData(parent, vec, "ind" + std::to_string(c));
      ++c;
    }
  }
};

}  // namespace xtp
}  // namespace votca

#endif  // VOTCA_XTP_CHECKPOINT_READER_H