Codebase list rabit / 62854a4
New upstream version 0.0~git20200628.74bf00a Mo Zhou 3 years ago
10 changed file(s) with 74 addition(s) and 78 deletion(s). Raw diff Collapse all Expand all
6464 /*! \brief error message buffer length */
6565 const int kPrintBuffer = 1 << 12;
6666
67 /*! \brief we may want to keep the process alive when there are multiple workers
68 * co-locate in the same process */
69 extern bool STOP_PROCESS_ON_ERROR;
70
6771 /* \brief Case-insensitive string comparison */
6872 inline int CompareStringsCaseInsensitive(const char* s1, const char* s2) {
6973 #ifdef _MSC_VER
8488 * \param msg error message
8589 */
8690 inline void HandleAssertError(const char *msg) {
87 fprintf(stderr,
88 "AssertError:%s, rabit is configured to keep process running\n", msg);
89 throw dmlc::Error(msg);
91 if (STOP_PROCESS_ON_ERROR) {
92 fprintf(stderr, "AssertError:%s, shutting down process\n", msg);
93 exit(-1);
94 } else {
95 fprintf(stderr, "AssertError:%s, rabit is configured to keep process running\n", msg);
96 throw dmlc::Error(msg);
97 }
9098 }
9199 /*!
92100 * \brief handling of Check error, caused by inappropriate input
93101 * \param msg error message
94102 */
95103 inline void HandleCheckError(const char *msg) {
96 fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
97 throw dmlc::Error(msg);
104 if (STOP_PROCESS_ON_ERROR) {
105 fprintf(stderr, "%s, shutting down process\n", msg);
106 exit(-1);
107 } else {
108 fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
109 throw dmlc::Error(msg);
110 }
98111 }
99112 inline void HandlePrint(const char *msg) {
100113 printf("%s", msg);
55 * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
66 */
77 #define NOMINMAX
8 #include "allreduce_base.h"
98 #include <rabit/base.h>
109 #include <netinet/tcp.h>
1110 #include <cstring>
1211 #include <map>
12 #include "allreduce_base.h"
1313
1414 namespace rabit {
15
16 namespace utils {
17 bool STOP_PROCESS_ON_ERROR = true;
18 }
19
1520 namespace engine {
1621 // constructor
1722 AllreduceBase::AllreduceBase(void) {
2732 version_number = 0;
2833 // 32 K items
2934 reduce_ring_mincount = 32 << 10;
30 // 1M reducer size each time
31 tree_reduce_minsize = 1 << 20;
3235 // tracker URL
3336 task_id = "NULL";
3437 err_link = NULL;
4245 env_vars.push_back("DMLC_TRACKER_URI");
4346 env_vars.push_back("DMLC_TRACKER_PORT");
4447 env_vars.push_back("DMLC_WORKER_CONNECT_RETRY");
48 env_vars.push_back("DMLC_WORKER_STOP_PROCESS_ON_ERROR");
4549 }
4650
4751 // initialization function
182186 if (!strcmp(name, "DMLC_ROLE")) dmlc_role = val;
183187 if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
184188 if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = utils::StringToBool(val);
185 if (!strcmp(name, "rabit_tree_reduce_minsize")) tree_reduce_minsize = atoi(val);
186189 if (!strcmp(name, "rabit_reduce_ring_mincount")) {
187190 reduce_ring_mincount = atoi(val);
188191 utils::Assert(reduce_ring_mincount > 0, "rabit_reduce_ring_mincount should be greater than 0");
192195 }
193196 if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) {
194197 connect_retry = atoi(val);
198 }
199 if (!strcmp(name, "DMLC_WORKER_STOP_PROCESS_ON_ERROR")) {
200 if (!strcmp(val, "true")) {
201 rabit::utils::STOP_PROCESS_ON_ERROR = true;
202 } else if (!strcmp(val, "false")) {
203 rabit::utils::STOP_PROCESS_ON_ERROR = false;
204 } else {
205 throw std::runtime_error("invalid value of DMLC_WORKER_STOP_PROCESS_ON_ERROR");
206 }
195207 }
196208 if (!strcmp(name, "rabit_bootstrap_cache")) {
197209 rabit_bootstrap_cache = utils::StringToBool(val);
491503 size_t size_up_out = 0;
492504 // size of message we received, and send in the down pass
493505 size_t size_down_in = 0;
494 // minimal size of each reducer
495 const size_t eachreduce = (tree_reduce_minsize / type_nbytes * type_nbytes);
496
497506 // initialize the link ring-buffer and pointer
498507 for (int i = 0; i < nlink; ++i) {
499508 if (i != parent_index) {
550559 // read data from childs
551560 for (int i = 0; i < nlink; ++i) {
552561 if (i != parent_index && watcher.CheckRead(links[i].sock)) {
553 // make sure to receive minimal reducer size
554 // since each child reduce and sends the minimal reducer size
555 while (links[i].size_read < total_size
556 && links[i].size_read - size_up_reduce < eachreduce) {
557 ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
558 if (ret != kSuccess) {
559 return ReportError(&links[i], ret);
560 }
562 ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
563 if (ret != kSuccess) {
564 return ReportError(&links[i], ret);
561565 }
562566 }
563567 }
577581 utils::Assert(buffer_size != 0, "must assign buffer_size");
578582 // round to type_n4bytes
579583 max_reduce = (max_reduce / type_nbytes * type_nbytes);
580
581 // if max reduce is less than total size, we reduce multiple times of
582 // eachreduce size
583 if (max_reduce < total_size)
584 max_reduce = max_reduce - max_reduce % eachreduce;
585
586584 // peform reduce, can be at most two rounds
587585 while (size_up_reduce < max_reduce) {
588586 // start position
606604 // pass message up to parent, can pass data that are already been reduced
607605 if (size_up_out < size_up_reduce) {
608606 ssize_t len = links[parent_index].sock.
609 Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out);
607 Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out);
610608 if (len != -1) {
611609 size_up_out += static_cast<size_t>(len);
612610 } else {
619617 // read data from parent
620618 if (watcher.CheckRead(links[parent_index].sock) &&
621619 total_size > size_down_in) {
622 size_t left_size = total_size-size_down_in;
623 size_t reduce_size_min = std::min(left_size, eachreduce);
624 size_t recved = 0;
625 while (recved < reduce_size_min) {
626 ssize_t len = links[parent_index].sock.
627 Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
628
629 if (len == 0) {
630 links[parent_index].sock.Close();
631 return ReportError(&links[parent_index], kRecvZeroLen);
632 }
633 if (len != -1) {
634 size_down_in += static_cast<size_t>(len);
635 utils::Assert(size_down_in <= size_up_out,
636 "Allreduce: boundary error");
637 recved+=len;
638
639 // if it receives more data than each reduce, it means the next block is sent.
640 // we double the reduce_size_min or add to left_size
641 while (recved > reduce_size_min) {
642 reduce_size_min += std::min(left_size-reduce_size_min, eachreduce);
643 }
644 } else {
645 ReturnType ret = Errno2Return();
646 if (ret != kSuccess) {
647 return ReportError(&links[parent_index], ret);
648 }
620 ssize_t len = links[parent_index].sock.
621 Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
622 if (len == 0) {
623 links[parent_index].sock.Close();
624 return ReportError(&links[parent_index], kRecvZeroLen);
625 }
626 if (len != -1) {
627 size_down_in += static_cast<size_t>(len);
628 utils::Assert(size_down_in <= size_up_out,
629 "Allreduce: boundary error");
630 } else {
631 ReturnType ret = Errno2Return();
632 if (ret != kSuccess) {
633 return ReportError(&links[parent_index], ret);
649634 }
650635 }
651636 }
564564 int reduce_method;
565565 // mininum count of cells to use ring based method
566566 size_t reduce_ring_mincount;
567 // minimul block size per tree reduce
568 size_t tree_reduce_minsize;
569567 // current rank
570568 int rank;
571569 // world size
166166 * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
167167 * \param _file caller file name used to generate unique cache key
168168 * \param _line caller line number used to generate unique cache key
169 * \param _caller caller function name used to generate unique cache key
170 */
169 * \param _caller caller function name used to generate unique cache key
170 */
171171 void AllreduceRobust::Allgather(void *sendrecvbuf,
172172 size_t total_size,
173173 size_t slice_begin,
517517 }
518518 // execute checkpoint, note: when checkpoint existing, load will not happen
519519 _assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint,
520 ActionSummary::kSpecialOp, cur_cache_seq),
521 "check point must return true");
520 ActionSummary::kSpecialOp, cur_cache_seq),
521 "check point must return true");
522522 // this is the critical region where we will change all the stored models
523523 // increase version number
524524 version_number += 1;
549549 delta = utils::GetTime() - start;
550550 // log checkpoint ack latency
551551 if (rabit_debug) {
552 utils::HandleLogInfo(
553 "[%d] checkpoint ack finished version %d, take %f seconds\n", rank,
554 version_number, delta);
552 utils::HandleLogInfo("[%d] checkpoint ack finished version %d, take %f seconds\n",
553 rank, version_number, delta);
555554 }
556555 }
557556 /*!
1111 #include "rabit/internal/engine.h"
1212
1313 namespace rabit {
14
15 namespace utils {
16 bool STOP_PROCESS_ON_ERROR = true;
17 }
18
1419 namespace engine {
1520 /*! \brief EmptyEngine */
1621 class EmptyEngine : public IEngine {
66 * \author Tianqi Chen
77 */
88 #define NOMINMAX
9 #include <rabit/base.h>
910 #include <mpi.h>
10 #include <rabit/base.h>
1111 #include <cstdio>
12 #include <string>
1312 #include "rabit/internal/engine.h"
1413 #include "rabit/internal/utils.h"
1514
1615 namespace rabit {
16
17 namespace utils {
18 bool STOP_PROCESS_ON_ERROR = true;
19 }
20
1721 namespace engine {
1822 /*! \brief implementation of engine using MPI */
1923 class MPIEngine : public IEngine {
22 add_executable(
33 unit_tests
44 test_io.cc
5 test_utils.cc
65 allreduce_robust_test.cc
76 allreduce_base_test.cc
87 allreduce_mock_test.cc
1616 char* argv[] = {cmd};
1717 m.Init(1, argv);
1818 m.rank = 0;
19 EXPECT_THROW(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), dmlc::Error);
19 EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), "");
2020 }
2121
2222 TEST(allreduce_mock, mock_broadcast)
3131 m.rank = 0;
3232 m.version_number=1;
3333 m.seq_counter=2;
34 EXPECT_THROW(m.Broadcast(nullptr,0,0), dmlc::Error);
34 EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), "");
3535 }
22
33 #include <string>
44 #include <iostream>
5 #include <dmlc/logging.h>
65 #include "../../src/allreduce_mock.h"
76
87 TEST(allreduce_mock, mock_allreduce)
1716 char* argv[] = {cmd};
1817 m.Init(1, argv);
1918 m.rank = 0;
20 EXPECT_THROW({m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr);}, dmlc::Error);
19 EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), "");
2120 }
2221
2322 TEST(allreduce_mock, mock_broadcast)
3231 m.rank = 0;
3332 m.version_number=1;
3433 m.seq_counter=2;
35 EXPECT_THROW({m.Broadcast(nullptr,0,0);}, dmlc::Error);
34 EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), "");
3635 }
3736
3837 TEST(allreduce_mock, mock_gather)
4746 m.rank = 3;
4847 m.version_number=13;
4948 m.seq_counter=22;
50 EXPECT_THROW({m.Allgather(nullptr,0,0,0,0);}, dmlc::Error);
49 EXPECT_EXIT(m.Allgather(nullptr,0,0,0,0), ::testing::ExitedWithCode(255), "");
5150 }
+0
-6
test/cpp/test_utils.cc less more
0 #include <gtest/gtest.h>
1 #include <rabit/internal/utils.h>
2
3 TEST(Utils, Assert) {
4 EXPECT_THROW({rabit::utils::Assert(false, "foo");}, dmlc::Error);
5 }