Codebase list rabit / 04ced897-8765-4867-ad0c-ad282edb46cf/main test / speed_test.cc
04ced897-8765-4867-ad0c-ad282edb46cf/main

Tree @04ced897-8765-4867-ad0c-ad282edb46cf/main (Download .tar.gz)

speed_test.cc @04ced897-8765-4867-ad0c-ad282edb46cf/mainraw · history · blame

// This program is used to test the speed of rabit API
#include <rabit/rabit.h>
#include <rabit/internal/timer.h>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <time.h>

using namespace rabit;

double max_tdiff, sum_tdiff, bcast_tdiff, tot_tdiff;

inline void TestMax(size_t n) {
  int rank = rabit::GetRank();
  std::vector<float> ndata(n);
  for (size_t i = 0; i < ndata.size(); ++i) {
    ndata[i] = (i * (rank+1)) % 111;
  }
  double tstart = utils::GetTime();
  rabit::Allreduce<op::Max>(&ndata[0], ndata.size());
  max_tdiff += utils::GetTime() - tstart;
}

inline void TestSum(size_t n) {
  int rank = rabit::GetRank();
  const int z = 131;
  std::vector<float> ndata(n);
  for (size_t i = 0; i < ndata.size(); ++i) {
    ndata[i] = (i * (rank+1)) % z;
  }
  double tstart = utils::GetTime();
  rabit::Allreduce<op::Sum>(&ndata[0], ndata.size());
  sum_tdiff += utils::GetTime() - tstart;
}

inline void TestBcast(size_t n, int root) {
  int rank = rabit::GetRank();
  std::string s; s.resize(n);
  for (size_t i = 0; i < n; ++i) {
    s[i] = char(i % 126 + 1);
  }
  std::string res;
  res.resize(n);
  if (root == rank) {
    res = s;
  }
  double tstart = utils::GetTime();
  rabit::Broadcast(&res[0], res.length(), root);
  bcast_tdiff += utils::GetTime() - tstart;
}

inline void PrintStats(const char *name, double tdiff, int n, int nrep, size_t size) {
  int nproc = rabit::GetWorldSize();
  double tsum = tdiff;
  rabit::Allreduce<op::Sum>(&tsum, 1);
  double tavg = tsum / nproc;
  double tsqr = tdiff - tavg;
  tsqr *= tsqr;
  rabit::Allreduce<op::Sum>(&tsqr, 1);
  double tstd = sqrt(tsqr / nproc);
  if (rabit::GetRank() == 0) {
    rabit::TrackerPrintf("%s: mean=%g, std=%g sec\n", name, tavg, tstd);
    double ndata = n;
    ndata *= nrep * size;
    if (n != 0) {
      rabit::TrackerPrintf("%s-speed: %g MB/sec\n", name, (ndata / tavg) / 1024 / 1024 );
    }
  }
}

int main(int argc, char *argv[]) {
  if (argc < 3) {
    printf("Usage: <ndata> <nrepeat>\n");
    return 0;
  }
  srand(0);
  int n = atoi(argv[1]);
  int nrep = atoi(argv[2]);
  utils::Check(nrep >= 1, "need to at least repeat running once");
  rabit::Init(argc, argv);
  //int rank = rabit::GetRank();
  int nproc = rabit::GetWorldSize();
  std::string name = rabit::GetProcessorName();
  max_tdiff = sum_tdiff = bcast_tdiff = 0;
  double tstart = utils::GetTime();
  for (int i = 0; i < nrep; ++i) {
    TestMax(n);
    TestSum(n);
    TestBcast(n, rand() % nproc);
  }
  tot_tdiff = utils::GetTime() - tstart;
  // use allreduce to get the sum and std of time
  PrintStats("max_tdiff", max_tdiff, n, nrep, sizeof(float));
  PrintStats("sum_tdiff", sum_tdiff, n, nrep, sizeof(float));
  PrintStats("bcast_tdiff", bcast_tdiff, n, nrep, sizeof(char));
  PrintStats("tot_tdiff", tot_tdiff, 0, nrep, sizeof(float));
  rabit::Finalize();
  return 0;
}