Codebase list votca-xtp / debian/1.6.3-1 include / votca / xtp / checkpointwriter.h
debian/1.6.3-1

Tree @debian/1.6.3-1 (Download .tar.gz)

checkpointwriter.h @debian/1.6.3-1raw · history · blame

/*
 * Copyright 2009-2019 The VOTCA Development Team (http://www.votca.org)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

#pragma once
#ifndef VOTCA_XTP_CHECKPOINT_WRITER_H
#define VOTCA_XTP_CHECKPOINT_WRITER_H

#include <H5Cpp.h>
#include <map>
#include <string>
#include <type_traits>
#include <typeinfo>
#include <vector>
#include <votca/tools/linalg.h>
#include <votca/xtp/checkpoint_utils.h>
#include <votca/xtp/checkpointtable.h>
#include <votca/xtp/eigen.h>

namespace votca {
namespace xtp {

using namespace checkpoint_utils;

class CheckpointWriter {
 public:
  CheckpointWriter(const CptLoc& loc) : CheckpointWriter(loc, "/"){};

  CheckpointWriter(const CptLoc& loc, const std::string& path)
      : _loc(loc), _path(path){};

  // see the following links for details
  // https://stackoverflow.com/a/8671617/1186564
  template <typename T>
  typename std::enable_if<!std::is_fundamental<T>::value>::type operator()(
      const T& data, const std::string& name) const {
    try {
      WriteData(_loc, data, name);
    } catch (H5::Exception&) {
      std::stringstream message;
      message << "Could not write " << name << " to " << _loc.getFileName()
              << ":" << _path;

      throw std::runtime_error(message.str());
    }
  }

  // Use this overload if T is a fundamental type
  // int, double, unsigned, etc, but not bool
  template <typename T>
  typename std::enable_if<std::is_fundamental<T>::value &&
                          !std::is_same<T, bool>::value>::type
      operator()(const T& v, const std::string& name) const {
    try {
      WriteScalar(_loc, v, name);
    } catch (H5::Exception&) {
      std::stringstream message;
      message << "Could not write " << name << " to " << _loc.getFileName()
              << ":" << _path << std::endl;

      throw std::runtime_error(message.str());
    }
  }

  void operator()(const bool& v, const std::string& name) const {
    Index temp = static_cast<Index>(v);
    try {
      WriteScalar(_loc, temp, name);
    } catch (H5::Exception&) {
      std::stringstream message;
      message << "Could not write " << name << " to " << _loc.getFileName()
              << ":" << _path << std::endl;

      throw std::runtime_error(message.str());
    }
  }

  void operator()(const std::string& v, const std::string& name) const {
    try {
      WriteScalar(_loc, v, name);
    } catch (H5::Exception&) {
      std::stringstream message;
      message << "Could not write " << name << " to " << _loc.getFileName()
              << ":" << _path << std::endl;

      throw std::runtime_error(message.str());
    }
  }

  CheckpointWriter openChild(const std::string& childName) const {
    try {
      return CheckpointWriter(_loc.openGroup(childName),
                              _path + "/" + childName);
    } catch (H5::Exception&) {
      try {
        return CheckpointWriter(_loc.createGroup(childName),
                                _path + "/" + childName);
      } catch (H5::Exception&) {
        std::stringstream message;
        message << "Could not open or create" << _loc.getFileName() << ":/"
                << _path << "/" << childName << std::endl;

        throw std::runtime_error(message.str());
      }
    }
  }

  template <typename T>
  CptTable openTable(const std::string& name, const T& obj, std::size_t nRows,
                     bool compact = false) {
    CptTable table;
    try {
      table = CptTable(name, sizeof(typename T::data), _loc);
      obj.SetupCptTable(table);
    } catch (H5::Exception&) {
      try {
        table = CptTable(name, sizeof(typename T::data), nRows);
        obj.SetupCptTable(table);
        table.initialize(_loc, compact);
      } catch (H5::Exception&) {
        std::stringstream message;
        message << "Could not open table " << name << " in "
                << _loc.getFileName() << ":" << _path << std::endl;
        throw std::runtime_error(message.str());
      }
    }

    return table;
  }

 private:
  const CptLoc _loc;
  const std::string _path;
  template <typename T>
  void WriteScalar(const CptLoc& loc, const T& value,
                   const std::string& name) const {

    hsize_t dims[1] = {1};
    H5::DataSpace dp(1, dims);
    const H5::DataType* dataType = InferDataType<T>::get();
    H5::Attribute attr;
    try {
      attr = loc.createAttribute(name, *dataType, dp);
    } catch (H5::AttributeIException&) {
      attr = loc.openAttribute(name);
    }
    attr.write(*dataType, &value);
  }

  void WriteScalar(const CptLoc& loc, const std::string& value,
                   const std::string& name) const {

    hsize_t dims[1] = {1};
    H5::DataSpace dp(1, dims);
    const H5::DataType* strType = InferDataType<std::string>::get();

    H5::Attribute attr;

    try {
      attr = loc.createAttribute(name, *strType, dp);
    } catch (H5::AttributeIException&) {
      attr = loc.openAttribute(name);
    }
    attr.write(*strType, &value);
  }

  template <typename T>
  void WriteData(const CptLoc& loc, const Eigen::MatrixBase<T>& matrix,
                 const std::string& name) const {

    hsize_t matRows = hsize_t(matrix.rows());
    hsize_t matCols = hsize_t(matrix.cols());

    hsize_t dims[2] = {matRows, matCols};  // eigen vectors are n,1 matrices

    if (dims[1] == 0) {
      dims[1] = 1;
    }

    H5::DataSpace dp(2, dims);
    const H5::DataType* dataType = InferDataType<typename T::Scalar>::get();
    H5::DataSet dataset;
    try {
      dataset = loc.createDataSet(name.c_str(), *dataType, dp);
    } catch (H5::GroupIException&) {
      dataset = loc.openDataSet(name.c_str());
    }

    hsize_t matColSize = matrix.derived().outerStride();

    hsize_t fileRows = matCols;

    hsize_t fStride[2] = {1, fileRows};
    hsize_t fCount[2] = {1, 1};
    hsize_t fBlock[2] = {1, fileRows};

    hsize_t mStride[2] = {matColSize, 1};
    hsize_t mCount[2] = {1, 1};
    hsize_t mBlock[2] = {matCols, 1};

    hsize_t mDim[2] = {matCols, matColSize};
    H5::DataSpace mspace(2, mDim);

    for (hsize_t i = 0; i < matRows; i++) {
      hsize_t fStart[2] = {i, 0};
      hsize_t mStart[2] = {0, i};
      dp.selectHyperslab(H5S_SELECT_SET, fCount, fStart, fStride, fBlock);
      mspace.selectHyperslab(H5S_SELECT_SET, mCount, mStart, mStride, mBlock);
      dataset.write(matrix.derived().data(), *dataType, mspace, dp);
    }
  }

  template <typename T>
  typename std::enable_if<std::is_fundamental<T>::value>::type WriteData(
      const CptLoc& loc, const std::vector<T> v,
      const std::string& name) const {
    hsize_t dims[2] = {(hsize_t)v.size(), 1};

    const H5::DataType* dataType = InferDataType<T>::get();
    H5::DataSet dataset;
    H5::DataSpace dp(2, dims);
    try {
      dataset = loc.createDataSet(name.c_str(), *dataType, dp);
    } catch (H5::GroupIException&) {
      dataset = loc.openDataSet(name.c_str());
    }
    dataset.write(v.data(), *dataType);
  }

  void WriteData(const CptLoc& loc, const std::vector<std::string>& v,
                 const std::string& name) const {

    hsize_t dims[1] = {(hsize_t)v.size()};

    std::vector<const char*> c_str_copy;
    c_str_copy.reserve(v.size());
    for (const std::string& s : v) {
      c_str_copy.push_back(s.c_str());
    }
    const H5::DataType* dataType = InferDataType<std::string>::get();
    H5::DataSet dataset;
    H5::DataSpace dp(1, dims);
    try {
      dataset = loc.createDataSet(name.c_str(), *dataType, dp);
    } catch (H5::GroupIException&) {
      dataset = loc.openDataSet(name.c_str());
    }
    dataset.write(c_str_copy.data(), *dataType);
  }

  void WriteData(const CptLoc& loc, const std::vector<Eigen::Vector3d>& v,
                 const std::string& name) const {

    size_t c = 0;
    std::string r;
    CptLoc parent;
    try {
      parent = loc.createGroup(name);
    } catch (H5::GroupIException&) {
      parent = loc.openGroup(name);
    }
    for (auto const& x : v) {
      r = std::to_string(c);
      WriteData(parent, x, "ind" + r);
      ++c;
    }
  }

  void WriteData(const CptLoc& loc, const tools::EigenSystem& sys,
                 const std::string& name) const {

    CptLoc parent;
    try {
      parent = loc.createGroup(name);
    } catch (H5::GroupIException&) {
      parent = loc.openGroup(name);
    }

    WriteData(parent, sys.eigenvalues(), "eigenvalues");
    WriteData(parent, sys.eigenvectors(), "eigenvectors");
    WriteData(parent, sys.eigenvectors2(), "eigenvectors2");
    WriteScalar(parent, Index(sys.info()), "info");
  }

  template <typename T1, typename T2>
  void WriteData(const CptLoc& loc, const std::map<T1, std::vector<T2>> map,
                 const std::string& name) const {

    size_t c = 0;
    std::string r;
    // Iterate over the map and write map as a number of vectors with T1 as
    // index
    for (auto const& x : map) {
      r = std::to_string(c);
      CptLoc tempGr;
      try {
        tempGr = loc.createGroup(name);
      } catch (H5::GroupIException&) {
        tempGr = loc.openGroup(name);
      }
      WriteData(tempGr, x.second, "index" + r);
      ++c;
    }
  }
};
}  // namespace xtp
}  // namespace votca
#endif  // VOTCA_XTP_CHECKPOINT_WRITER_H