Codebase list votca-xtp / debian/1.5-1 include / votca / xtp / checkpointreader.h
debian/1.5-1

Tree @debian/1.5-1 (Download .tar.gz)

checkpointreader.h @debian/1.5-1raw · history · blame

/*
 * Copyright 2009-2018 The VOTCA Development Team (http://www.votca.org)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */
#ifndef VOTCA_XTP_CHECKPOINT_READER_H
#define VOTCA_XTP_CHECKPOINT_READER_H

#include <votca/xtp/eigen.h>
#include <H5Cpp.h>
#include <string>
#include <typeinfo>
#include <vector>
#include <type_traits>
#include <votca/tools/vec.h>

#include <votca/xtp/checkpoint_utils.h>

namespace votca {
    namespace xtp {
        using namespace checkpoint_utils;

class CheckpointReader{
public:
CheckpointReader(const CptLoc& loc): CheckpointReader(loc, "/"){};

CheckpointReader(const CptLoc& loc, const std::string path):
    _loc(loc), _path(path){};

    template<typename T>
    typename std::enable_if<!std::is_fundamental<T>::value>::type
    operator()(T& var, const std::string& name)const{
        try{
            ReadData(_loc, var, name);
        } catch (H5::Exception& error){
            std::stringstream message;

            message << "Could not read " << name << " from "
                    << _loc.getFileName() << ":" << _path << std::endl;

            throw std::runtime_error(message.str());
        }
    }

    template<typename T>
    typename std::enable_if<std::is_fundamental<T>::value && !std::is_same<T, bool>::value>::type
    operator()(T& var, const std::string& name)const{
        try{
            ReadScalar(_loc, var, name);
        } catch (H5::Exception& error){
            std::stringstream message;
            message << "Could not read " << name << " from "
                    << _loc.getFileName() << ":" << _path << "/" << std::endl;

            throw std::runtime_error(message.str());
        }
    }

    void operator()(bool& v, const std::string& name)const{
        int temp = int(v);
        try{
            ReadScalar(_loc, temp, name);
        } catch (H5::Exception& error){
            std::stringstream message;
            message << "Could not read " << name << " from "
                    << _loc.getFileName() << ":" << _path << std::endl;

            throw std::runtime_error(message.str());
        }
        v = static_cast<bool>(temp);
    }

    void operator()(std::string& var, const std::string& name)const{
        try{
            ReadScalar(_loc, var, name);
        } catch (H5::Exception& error){
            std::stringstream message;
            message << "Could not read " << name << " from "
                    << _loc.getFileName() << ":" << _path << std::endl;

            throw std::runtime_error(message.str());
        }
    }

    CheckpointReader openChild(const std::string& childName)const{
        try{
            return CheckpointReader(_loc.openGroup(childName), _path+"/"+childName);
        } catch (H5::Exception& e){
            std::stringstream message;
            message << "Could not open " << _loc.getFileName() << ":/"
                    << _path << "/" << childName << std::endl;

            throw std::runtime_error(message.str());
        }
    }

    int getNumDataSets()const{
        return _loc.getNumObjs();
    }

private:
    const CptLoc _loc;
    const std::string _path;
    void ReadScalar(const CptLoc& loc, std::string& var, const std::string& name)const{
        const H5::DataType* strType = InferDataType<std::string>::get();

        H5::Attribute attr = loc.openAttribute(name);

        H5std_string readbuf("");

        attr.read(*strType, readbuf);

        var = readbuf;
    }

    template <typename T>
        void ReadScalar(const CptLoc& loc, T& value,
                        const std::string& name) const{

        H5::Attribute attr = loc.openAttribute(name);
        const H5::DataType* dataType = InferDataType<T>::get();

        attr.read(*dataType, &value);
    }

    template <typename T>
        void ReadData(const CptLoc& loc, Eigen::MatrixBase<T>& matrix,
                      const std::string& name) const{

        const H5::DataType* dataType = InferDataType<typename T::Scalar>::get();

        H5::DataSet dataset = loc.openDataSet(name);

        H5::DataSpace dp = dataset.getSpace();

        hsize_t dims[2];
        dp.getSimpleExtentDims(dims, NULL); // ndims is always 2 for us

        hsize_t matRows = dims[0];
        hsize_t matCols = dims[1];

        matrix.derived().resize(matRows, matCols);

        hsize_t matColSize = matrix.derived().outerStride();

        hsize_t fileRows = matCols;

        hsize_t fStride[2] = {1, fileRows};
        hsize_t fCount[2] = {1,1};
        hsize_t fBlock[2] = {1, fileRows};

        hsize_t mStride[2] = {matColSize, 1};
        hsize_t mCount[2] = {1,1};
        hsize_t mBlock[2] = {matCols, 1};

        hsize_t mDim[2] = {matCols, matColSize};
        H5::DataSpace mspace(2, mDim);

        for (hsize_t i = 0; i < matRows; i++){
            hsize_t fStart[2] = {i, 0};
            hsize_t mStart[2] = {0, i};
            dp.selectHyperslab(H5S_SELECT_SET, fCount, fStart, fStride, fBlock);
            mspace.selectHyperslab(H5S_SELECT_SET, mCount, mStart, mStride, mBlock);
            dataset.read(matrix.derived().data(), *dataType, mspace, dp);
        }

    }

    template <typename T>
        typename std::enable_if<std::is_fundamental<T>::value>::type
        ReadData(const CptLoc& loc, std::vector<T>& v,
                 const std::string& name) const{

        H5::DataSet dataset = loc.openDataSet(name);
        H5::DataSpace dp = dataset.getSpace();

        const H5::DataType* dataType = InferDataType<T>::get();

        hsize_t dims[2];
        dp.getSimpleExtentDims(dims, NULL);

        v.resize(dims[0]);

        dataset.read(&(v[0]), *dataType);
    }

    void ReadData(const CptLoc& loc, std::vector<Eigen::Vector3d>& v,
                  const std::string& name)const{

        CptLoc parent = loc.openGroup(name);
        size_t count = parent.getNumObjs();

        v.resize(count);

        size_t c = 0;
        for (auto &vec: v){
            ReadData(parent, vec, "ind"+std::to_string(c));
            ++c;
        }
    }
};

}  // namespace xtp
}  // namespace votca

#endif  // VOTCA_XTP_CHECKPOINT_READER_H