Codebase list votca-xtp / e70a901 include / votca / xtp / checkpointwriter.h
e70a901

Tree @e70a901 (Download .tar.gz)

checkpointwriter.h @e70a901raw · history · blame

/*
 * Copyright 2009-2018 The VOTCA Development Team (http://www.votca.org)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

#ifndef VOTCA_XTP_CHECKPOINT_WRITER_H
#define VOTCA_XTP_CHECKPOINT_WRITER_H

#include <votca/xtp/eigen.h>
#include <map>
#include <H5Cpp.h>
#include <string>
#include <typeinfo>
#include <vector>
#include <type_traits>
#include <votca/tools/vec.h>

#include <votca/xtp/checkpoint_utils.h>

namespace votca { namespace xtp {

using namespace checkpoint_utils;

class CheckpointWriter {
public:
CheckpointWriter(const CptLoc& loc): CheckpointWriter(loc, "/"){};

CheckpointWriter(const CptLoc& loc, const std::string& path):
    _loc(loc), _path(path){};

    // see the following links for details
    // https://stackoverflow.com/a/8671617/1186564
    template <typename T>
        typename std::enable_if<!std::is_fundamental<T>::value>::type
        operator()(const T& data, const std::string& name)const{
        try {
            WriteData(_loc, data, name);
        } catch (H5::Exception& error){
            std::stringstream message;
            message << "Could not write " << name
                    << " to " << _loc.getFileName() << ":" << _path;

            throw std::runtime_error(message.str());
        }
    }

    // Use this overload iff T is a fundamental type
    // int, double, unsigned int, etc, but not bool
    template<typename T>
        typename std::enable_if<std::is_fundamental<T>::value && !std::is_same<T, bool>::value>::type
        operator()(const T& v, const std::string& name)const{
        try{
            WriteScalar(_loc, v, name);
        } catch (H5::Exception& error){
            std::stringstream message;
            message << "Could not write " << name << " to "
                    << _loc.getFileName() << ":" << _path << std::endl;

            throw std::runtime_error(message.str());
        }
    }

    void operator()(const bool& v, const std::string& name)const{
        int temp=static_cast<int>(v);
        try{
            WriteScalar(_loc, temp, name);
        } catch (H5::Exception& error){
            std::stringstream message;
            message << "Could not write " << name << " to "
                    << _loc.getFileName() << ":" << _path << std::endl;

            throw std::runtime_error(message.str());

        }
    }

    void operator()(const std::string& v, const std::string& name)const{
        try{
            WriteScalar(_loc, v, name);
        } catch (H5::Exception& error){
            std::stringstream message;
            message << "Could not write " << name << " to "
                    << _loc.getFileName() << ":" << _path << std::endl;

            throw std::runtime_error(message.str());

        }
    }

    CheckpointWriter openChild(const std::string& childName)const{
        try{
            return CheckpointWriter(_loc.openGroup(childName), _path+"/"+childName);
        } catch (H5::Exception& e){
            try{
                return CheckpointWriter(_loc.createGroup(childName), _path+"/"+childName);
            } catch (H5::Exception& e){
                std::stringstream message;
                message << "Could not open or create" << _loc.getFileName()
                        << ":/" << _path << "/" << childName << std::endl;

                throw std::runtime_error(message.str());
            }
        }
    }

private:
    const CptLoc _loc;
    const std::string _path;
    template <typename T>
        void WriteScalar(const CptLoc& loc, const T& value,
                         const std::string& name) const{

        hsize_t dims[1] = {1};
        H5::DataSpace dp(1, dims);
        const H5::DataType* dataType = InferDataType<T>::get();
        H5::Attribute attr;
        try{
            attr = loc.createAttribute(name, *dataType, dp);
        } catch (H5::AttributeIException& error){
            attr = loc.openAttribute(name);
        }
        attr.write(*dataType, &value);
    }

    void WriteScalar(const CptLoc& loc, const std::string& value,
                     const std::string& name) const{

        hsize_t dims[1] = {1};
        H5::DataSpace dp(1, dims);
        const H5::DataType* strType = InferDataType<std::string>::get();

        H5::Attribute attr;

        try{
            attr = loc.createAttribute(name, *strType, dp);
        } catch (H5::AttributeIException& error){
            attr = loc.openAttribute(name);
        }
        attr.write(*strType, &value);
    }

    template <typename T>
        void WriteData(const CptLoc& loc, const Eigen::MatrixBase<T>& matrix,
                       const std::string& name) const{

        hsize_t matRows = hsize_t(matrix.rows());
        hsize_t matCols = hsize_t(matrix.cols());

        hsize_t dims[2] = {matRows, matCols};  // eigen vectors are n,1 matrices

        if (dims[1] == 0) dims[1] = 1;

        H5::DataSpace dp(2, dims);
        const H5::DataType* dataType = InferDataType<typename T::Scalar>::get();
        H5::DataSet dataset;
        try{
            dataset = loc.createDataSet(name.c_str(), *dataType, dp);
        } catch (H5::GroupIException& error){
            dataset = loc.openDataSet(name.c_str());
        }

        hsize_t matColSize = matrix.derived().outerStride();

        hsize_t fileRows = matCols;

        hsize_t fStride[2] = {1, fileRows};
        hsize_t fCount[2] = {1,1};
        hsize_t fBlock[2] = {1, fileRows};

        hsize_t mStride[2] = {matColSize, 1};
        hsize_t mCount[2] = {1,1};
        hsize_t mBlock[2] = {matCols, 1};

        hsize_t mDim[2] = {matCols, matColSize};
        H5::DataSpace mspace(2, mDim);

        for (hsize_t i = 0; i < matRows; i++){
            hsize_t fStart[2] = {i, 0};
            hsize_t mStart[2] = {0, i};
            dp.selectHyperslab(H5S_SELECT_SET, fCount, fStart, fStride, fBlock);
            mspace.selectHyperslab(H5S_SELECT_SET, mCount, mStart, mStride, mBlock);
            dataset.write(matrix.derived().data(), *dataType, mspace, dp);
        }
    }

    template <typename T>
        typename std::enable_if<std::is_fundamental<T>::value>::type
        WriteData(const CptLoc& loc, const std::vector<T> v,
                  const std::string& name) const{
        hsize_t dims[2] = {(hsize_t)v.size(), 1};

        const H5::DataType* dataType = InferDataType<T>::get();
        H5::DataSet dataset;
        H5::DataSpace dp(2, dims);
        try{
            dataset = loc.createDataSet(name.c_str(), *dataType, dp);
        } catch (H5::GroupIException& error){
            dataset = loc.openDataSet(name.c_str());
        }
        dataset.write(&(v[0]), *dataType);
    }

    void WriteData(const CptLoc& loc, const std::vector<Eigen::Vector3d>& v,
                   const std::string& name) const{

        size_t c = 0;
        std::string r;
        CptLoc parent;
        try{
            parent = loc.createGroup(name);
        } catch (H5::GroupIException& error){
            parent = loc.openGroup(name);
        }
        for (auto const& x: v){
            r = std::to_string(c);
            WriteData(parent, x, "ind"+r);
            ++c;
        }
    }

    template <typename T1, typename T2>
        void WriteData(const CptLoc& loc, const std::map<T1, std::vector<T2>> map,
                       const std::string& name) const{

        size_t c = 0;
        std::string r;
        // Iterate over the map and write map as a number of vectors with T1 as index
        for (auto const& x : map) {
            r = std::to_string(c);
            CptLoc tempGr;
            try{
                tempGr = loc.createGroup(name);
            } catch (H5::GroupIException& error){
                tempGr = loc.openGroup(name);
            }
            WriteData(tempGr, x.second, "index" + r);
            ++c;
        }
    }

};
}  // namespace xtp
}  // namespace votca
#endif  // VOTCA_XTP_CHECKPOINT_WRITER_H