Codebase list votca-xtp / debian/2021.1-1 src / libxtp / statetracker.cc
debian/2021.1-1

Tree @debian/2021.1-1 (Download .tar.gz)

statetracker.cc @debian/2021.1-1raw · history · blame

/*
 *            Copyright 2009-2020 The VOTCA Development Team
 *                       (http://www.votca.org)
 *
 *      Licensed under the Apache License, Version 2.0 (the "License")
 *
 * You may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *              http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

// Local VOTCA includes
#include "votca/xtp/statetracker.h"
#include "votca/xtp/filterfactory.h"

namespace votca {
namespace xtp {
using std::flush;

void StateTracker::Initialize(const tools::Property& options) {

  std::string filters = options.get("filters").as<std::string>();
  tools::Tokenizer tok(filters, " ,;\n");
  std::vector<std::string> list_filters = tok.ToVector();

  FilterFactory::RegisterAll();
  for (const std::string& filtername : list_filters) {
    _filters.push_back(Filter().Create(filtername));
  }

  for (auto& filter : _filters) {
    const tools::Property& filterop = options.get(filter->Identify());
    filter->Initialize(filterop);
  }
}

void StateTracker::PrintInfo() const {
  XTP_LOG(Log::error, *_log)
      << "Initial state: " << _statehist[0].ToString() << flush;
  if (_statehist.size() > 1) {
    XTP_LOG(Log::error, *_log)
        << "Last state: " << _statehist.back().ToString() << flush;
  }

  if (_filters.empty()) {
    XTP_LOG(Log::error, *_log) << "WARNING: No tracker is used " << flush;
  } else {
    for (const auto& filter : _filters) {
      filter->Info(*_log);
    }
  }
}

std::vector<Index> StateTracker::ComparePairofVectors(
    std::vector<Index>& vec1, std::vector<Index>& vec2) const {
  std::vector<Index> result(std::min(vec1, vec2));
  std::sort(vec1.begin(), vec1.end());
  std::sort(vec2.begin(), vec2.end());
  std::vector<Index>::iterator it = std::set_intersection(
      vec1.begin(), vec1.end(), vec2.begin(), vec2.end(), result.begin());
  result.resize(it - result.begin());
  return result;
}

std::vector<Index> StateTracker::CollapseResults(
    std::vector<std::vector<Index> >& results) const {
  if (results.empty()) {
    return std::vector<Index>(0);
  } else {
    std::vector<Index> result = results[0];
    for (Index i = 1; i < Index(results.size()); i++) {
      result = ComparePairofVectors(result, results[i]);
    }
    return result;
  }
}

QMState StateTracker::CalcState(const Orbitals& orbitals) const {

  if (_filters.empty()) {
    return _statehist[0];
  }

  std::vector<std::vector<Index> > results;
  for (const auto& filter : _filters) {
    if (_statehist.size() < 2 && filter->NeedsInitialState()) {
      XTP_LOG(Log::error, *_log)
          << "Filter " << filter->Identify()
          << " not used in first iteration as it needs a reference state"
          << flush;
      continue;
    }
    results.push_back(filter->CalcIndeces(orbitals, _statehist[0].Type()));
  }

  std::vector<Index> result = CollapseResults(results);
  QMState state;
  if (result.size() < 1) {
    state = _statehist.back();
    XTP_LOG(Log::error, *_log)
        << "No State found by tracker using last state: " << state.ToString()
        << flush;
  } else {
    state = QMState(_statehist.back().Type(), result[0], false);
    XTP_LOG(Log::error, *_log)
        << "Next State is: " << state.ToString() << flush;
  }
  return state;
}

QMState StateTracker::CalcStateAndUpdate(const Orbitals& orbitals) {
  QMState result = CalcState(orbitals);
  _statehist.push_back(result);
  for (auto& filter : _filters) {
    filter->UpdateHist(orbitals, result);
  }
  return result;
}

void StateTracker::WriteToCpt(CheckpointWriter& w) const {
  std::vector<std::string> statehiststring;
  statehiststring.reserve(_statehist.size());
  for (const QMState& s : _statehist) {
    statehiststring.push_back(s.ToString());
  }
  w(statehiststring, "statehist");

  for (const auto& filter : _filters) {
    CheckpointWriter ww = w.openChild(filter->Identify());
    filter->WriteToCpt(ww);
  }
}

void StateTracker::ReadFromCpt(CheckpointReader& r) {
  std::vector<std::string> statehiststring;
  r(statehiststring, "statehist");
  _statehist.clear();
  _statehist.reserve(statehiststring.size());
  for (const std::string& s : statehiststring) {
    _statehist.push_back(QMState(s));
  }
  for (auto& filter : _filters) {
    CheckpointReader rr = r.openChild(filter->Identify());
    filter->ReadFromCpt(rr);
  }
}

}  // namespace xtp
}  // namespace votca