Codebase list rabit / fcbf9e0
New upstream version 0.0~git20200127.2f7fcff Mo Zhou 4 years ago
77 changed file(s) with 11160 addition(s) and 0 deletion(s). Raw diff Collapse all Expand all
0 # Compiled Object files
1 *.slo
2 *.lo
3 *.o
4 *.obj
5
6 # Precompiled Headers
7 *.gch
8 *.pch
9 *.lnk
10 # Compiled Dynamic libraries
11 *.so
12 *.dylib
13 *.dll
14
15 # Fortran module files
16 *.mod
17
18 # Compiled Static libraries
19 *.lai
20 *.la
21 *.a
22 *.lib
23
24 # Executables
25 *.miss
26 *.exe
27 *.out
28 *.app
29 *~
30 *.pyc
31 *.mpi
32 *.exe
33 *tmp*
34 *.rabit
35 *.mock
36 recommonmark
37 recom
38 _*
39
40 #mpi lib
41 mpich/
42 mpich-3.2/
43
44 # Jetbrain
45 .idea
46 cmake-build-debug/
47 .vscode/
48
49 # cmake
50 build/
51 compile_commands.json
0 sudo: true
1
2 os:
3 - linux
4 - osx
5
6 osx_image: xcode10.2
7
8 dist: xenial
9
10 language: cpp
11
12 # Use Build Matrix to do lint and build seperately
13 env:
14 matrix:
15 - TASK=lint LINT_LANG=cpp
16 - TASK=lint LINT_LANG=python
17 - TASK=doc
18 # - TASK=build
19 - TASK=mpi-build
20 - TASK=cmake-test
21
22 matrix:
23 exclude:
24 - os: osx
25 env: TASK=lint LINT_LANG=cpp
26 - os: osx
27 env: TASK=lint LINT_LANG=python
28 - os: osx
29 env: TASK=doc
30 - os: osx
31 env: TASK=build
32
33 # dependent apt packages
34 addons:
35 apt:
36 sources:
37 - llvm-toolchain-trusty-5.0
38 - ubuntu-toolchain-r-test
39 - george-edison55-precise-backports
40 packages:
41 - doxygen
42 - wget
43 - git
44 - libcurl4-openssl-dev
45 - unzip
46 - python-numpy
47 - gcc-4.8
48 - g++-4.8
49 - openssh-client
50 - openssh-server
51 - python3
52 - python3-setuptools
53 - python3-pip
54 - tree
55 homebrew:
56 packages:
57 - gcc49
58 - openssl
59 - libgit2
60 - python3
61 update: true
62
63 before_install:
64 - git clone https://github.com/dmlc/dmlc-core
65 - export TRAVIS=dmlc-core/scripts/travis/
66 - source ${TRAVIS}/travis_setup_env.sh
67 - ${TRAVIS}/travis_osx_install.sh
68 - source ./scripts/travis_setup.sh
69
70 script: scripts/travis_script.sh
71
72 cache:
73 directories:
74 - ${HOME}/.cache/usr
75 - ${HOME}/.cache/pip
76 - mpich
77
78 before_cache:
79 - ${TRAVIS}/travis_before_cache.sh
80
81 after_success:
82 - tree build
83 - bash <(curl -s https://codecov.io/bash) -a '-o src/ src/*.c'
84
85 notifications:
86 # Emails are sent to the committer's git-configured email address by default,
87 email:
88 on_success: change
89 on_failure: always
0 cmake_minimum_required(VERSION 3.3)
1
2 project(rabit VERSION 0.3.0 LANGUAGES CXX)
3
4 if ((${CMAKE_VERSION} VERSION_GREATER 3.13) OR (${CMAKE_VERSION} VERSION_EQUAL 3.13))
5 # This allows user to specify `RABIT_BUILD_DMLC` and others as CMake variable.
6 cmake_policy(SET CMP0077 NEW)
7 endif ((${CMAKE_VERSION} VERSION_GREATER 3.13) OR (${CMAKE_VERSION} VERSION_EQUAL 3.13))
8
9 option(RABIT_BUILD_TESTS "Build rabit tests" OFF)
10 option(RABIT_BUILD_MPI "Build MPI" OFF)
11 option(RABIT_BUILD_DMLC "Include DMLC_CORE in build" OFF)
12 option(RABIT_WITH_R_LIB "Fit the strict environment of R" OFF)
13
14 option(DMLC_ROOT "Specify root of external dmlc core.")
15 # by default point to xgboost/dmlc-core
16 set(DMLC_ROOT ${CMAKE_CURRENT_LIST_DIR}/../dmlc-core)
17
18 # moved from xgboost build
19 if(R_LIB OR MINGW OR WIN32)
20 add_library(rabit src/engine_empty.cc src/c_api.cc)
21 set(rabit_libs rabit)
22 set_target_properties(rabit
23 PROPERTIES CXX_STANDARD 11
24 CXX_STANDARD_REQUIRED ON
25 POSITION_INDEPENDENT_CODE ON)
26 else()
27 find_package(Threads REQUIRED)
28 add_library(rabit_empty src/engine_empty.cc src/c_api.cc)
29 add_library(rabit_base src/allreduce_base.cc src/engine_base.cc src/c_api.cc)
30
31 add_library(rabit src/allreduce_base.cc src/allreduce_robust.cc src/engine.cc src/c_api.cc)
32 add_library(rabit_mock_static src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
33 add_library(rabit_mock SHARED src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
34 target_link_libraries(rabit Threads::Threads)
35 target_link_libraries(rabit_mock_static Threads::Threads)
36 target_link_libraries(rabit_mock Threads::Threads)
37
38 set(rabit_libs rabit rabit_base rabit_empty rabit_mock rabit_mock_static)
39 set_target_properties(rabit rabit_base rabit_empty rabit_mock rabit_mock_static
40 PROPERTIES CXX_STANDARD 11
41 CXX_STANDARD_REQUIRED ON
42 POSITION_INDEPENDENT_CODE ON)
43 ENDIF(R_LIB OR MINGW OR WIN32)
44
45 if(RABIT_BUILD_MPI)
46 find_package(MPI REQUIRED)
47 if (NOT MPI_CXX_FOUND)
48 message(FATAL_ERROR "CXX Interface for MPI is required for building MPI backend.")
49 endif (NOT MPI_CXX_FOUND)
50 add_library(rabit_mpi src/engine_mpi.cc ${MPI_INCLUDE_PATH})
51 target_link_libraries(rabit_mpi ${MPI_CXX_LIBRARIES})
52 list(APPEND rabit_libs rabit_mpi)
53 endif()
54
55 # place binaries and libraries according to GNU standards
56 include(GNUInstallDirs)
57 set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/${CMAKE_INSTALL_LIBDIR})
58 set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/${CMAKE_INSTALL_LIBDIR})
59 set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/${CMAKE_INSTALL_BINDIR})
60
61 # we use this to get code coverage
62 if ((CMAKE_CONFIGURATION_TYPES STREQUAL "Debug") AND (CMAKE_CXX_COMPILER_ID MATCHES GNU))
63 foreach(lib ${rabit_libs})
64 target_compile_options(${lib}
65 -fprofile-arcs
66 -ftest-coverage)
67 endforeach()
68 endif((CMAKE_CONFIGURATION_TYPES STREQUAL "Debug") AND (CMAKE_CXX_COMPILER_ID MATCHES GNU))
69
70 if(RABIT_BUILD_DMLC)
71 set(DMLC_ROOT ${CMAKE_CURRENT_LIST_DIR}/dmlc-core)
72 endif()
73
74 if(DMLC_ROOT)
75 message("DMLC_ROOT point to " ${DMLC_ROOT})
76 endif(DMLC_ROOT)
77
78 foreach(lib ${rabit_libs})
79 target_include_directories(${lib} PUBLIC
80 "$<BUILD_INTERFACE:${rabit_SOURCE_DIR}/include>"
81 "$<BUILD_INTERFACE:${DMLC_ROOT}/include>")
82 endforeach()
83
84 if (RABIT_BUILD_TESTS)
85 enable_testing()
86 add_subdirectory(${rabit_SOURCE_DIR}/test/cpp)
87
88 # rabit mock based integration tests
89 list(REMOVE_ITEM rabit_libs "rabit_mock_static") # remove here to avoid installing it
90 set(tests lazy_recover local_recover model_recover)
91
92 foreach(test ${tests})
93 add_executable(${test} test/${test}.cc)
94 target_link_libraries(${test} rabit_mock_static)
95 set_target_properties(${test} PROPERTIES CXX_STANDARD 11 CXX_STANDARD_REQUIRED ON)
96 install(TARGETS ${test} DESTINATION test) # Why are we installing these??
97 endforeach()
98
99 if(RABIT_BUILD_MPI)
100 add_executable(speed_test_mpi test/speed_test.cc)
101 target_link_libraries(speed_test_mpi rabit_mpi)
102 install(TARGETS speed_test_mpi DESTINATION test)
103 endif()
104 endif (RABIT_BUILD_TESTS)
105
106 # Installation (https://github.com/forexample/package-example) {
107
108 # Layout. This works for all platforms:
109 # * <prefix>/lib/cmake/<PROJECT-NAME>
110 # * <prefix>/lib/
111 # * <prefix>/include/
112 set(CMAKE_INSTALL_PREFIX "${rabit_SOURCE_DIR}")
113 set(config_install_dir "lib/cmake/${PROJECT_NAME}")
114 set(include_install_dir "include")
115
116 set(generated_dir "${CMAKE_CURRENT_BINARY_DIR}/generated")
117
118 # Configuration
119 set(version_config "${generated_dir}/${PROJECT_NAME}ConfigVersion.cmake")
120 set(project_config "${generated_dir}/${PROJECT_NAME}Config.cmake")
121 set(TARGETS_EXPORT_NAME "${PROJECT_NAME}Targets")
122 set(namespace "${PROJECT_NAME}::")
123
124 # Include module with fuction 'write_basic_package_version_file'
125 include(CMakePackageConfigHelpers)
126
127 # Configure '<PROJECT-NAME>ConfigVersion.cmake'
128 # Use:
129 # * PROJECT_VERSION
130 write_basic_package_version_file(
131 "${version_config}" COMPATIBILITY SameMajorVersion
132 )
133
134 # Configure '<PROJECT-NAME>Config.cmake'
135 # Use variables:
136 # * TARGETS_EXPORT_NAME
137 # * PROJECT_NAME
138 configure_package_config_file(
139 "cmake/Config.cmake.in"
140 "${project_config}"
141 INSTALL_DESTINATION "${config_install_dir}"
142 )
143
144 # Targets:
145 # * <prefix>/lib/librabit.a
146 # * <prefix>/lib/librabit_base
147 # * <prefix>/lib/librabit_empty
148 # * header location after install: <prefix>/include/rabit/rabit.h
149 # * headers can be included by C++ code `#include <rabit/rabit.h>`
150 install(
151 TARGETS ${rabit_libs}
152 EXPORT "${TARGETS_EXPORT_NAME}"
153 LIBRARY DESTINATION "lib"
154 ARCHIVE DESTINATION "lib"
155 RUNTIME DESTINATION "bin"
156 INCLUDES DESTINATION "${include_install_dir}"
157 )
158
159 # Headers:
160 install(
161 DIRECTORY "include/"
162 DESTINATION "${include_install_dir}"
163 FILES_MATCHING PATTERN "*.h"
164 )
165
166 # Config
167 # * <prefix>/lib/cmake/rabit/rabitConfig.cmake
168 # * <prefix>/lib/cmake/rabit/rabitConfigVersion.cmake
169 install(
170 FILES "${project_config}" "${version_config}"
171 DESTINATION "${config_install_dir}"
172 )
173
174 # Config
175 # * <prefix>/lib/cmake/Foo/FooTargets.cmake
176 install(
177 EXPORT "${TARGETS_EXPORT_NAME}"
178 NAMESPACE "${namespace}"
179 DESTINATION "${config_install_dir}"
180 )
181 # }
0 Copyright (c) 2014 by Contributors
1 All rights reserved.
2
3 Redistribution and use in source and binary forms, with or without
4 modification, are permitted provided that the following conditions are met:
5
6 * Redistributions of source code must retain the above copyright notice, this
7 list of conditions and the following disclaimer.
8
9 * Redistributions in binary form must reproduce the above copyright notice,
10 this list of conditions and the following disclaimer in the documentation
11 and/or other materials provided with the distribution.
12
13 * Neither the name of rabit nor the names of its
14 contributors may be used to endorse or promote products derived from
15 this software without specific prior written permission.
16
17 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
21 FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
23 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
24 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
25 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27
0 OS := $(shell uname)
1
2 RABIT_BUILD_DMLC = 0
3
4 export WARNFLAGS= -Wall -Wextra -Wno-unused-parameter -Wno-unknown-pragmas -std=c++11
5 export CFLAGS = -O3 $(WARNFLAGS)
6 export LDFLAGS =-Llib
7
8 #download mpi
9 #echo $(shell scripts/mpi.sh)
10
11 MPICXX=./mpich/bin/mpicxx
12
13 export CXX = g++
14
15
16 #----------------------------
17 # Settings for power and arm arch
18 #----------------------------
19 ARCH := $(shell uname -a)
20 ifneq (,$(filter $(ARCH), armv6l armv7l powerpc64le ppc64le aarch64))
21 CFLAGS += -march=native
22 else
23 CFLAGS += -msse2
24 endif
25
26 ifndef WITH_FPIC
27 WITH_FPIC = 1
28 endif
29 ifeq ($(WITH_FPIC), 1)
30 CFLAGS += -fPIC
31 endif
32
33 ifndef LINT_LANG
34 LINT_LANG="all"
35 endif
36
37 ifeq ($(RABIT_BUILD_DMLC),1)
38 DMLC=dmlc-core
39 else
40 DMLC=../dmlc-core
41 endif
42
43 CFLAGS += -I $(DMLC)/include -I include/
44
45 # build path
46 BPATH=.
47 # objectives that makes up rabit library
48 MPIOBJ= $(BPATH)/engine_mpi.o
49 OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/engine_empty.o $(BPATH)/engine_mock.o\
50 $(BPATH)/c_api.o $(BPATH)/engine_base.o
51 SLIB= lib/librabit.so lib/librabit_mock.so lib/librabit_base.so
52 ALIB= lib/librabit.a lib/librabit_empty.a lib/librabit_mock.a lib/librabit_base.a
53 MPISLIB= lib/librabit_mpi.so
54 MPIALIB= lib/librabit_mpi.a
55 HEADERS=src/*.h include/rabit/*.h include/rabit/internal/*.h
56
57 .PHONY: clean all install mpi python lint doc doxygen
58
59 all: lib/librabit.a lib/librabit_mock.a lib/librabit.so lib/librabit_base.a lib/librabit_mock.so
60 mpi: lib/librabit_mpi.a lib/librabit_mpi.so
61
62 $(BPATH)/allreduce_base.o: src/allreduce_base.cc $(HEADERS)
63 $(BPATH)/engine.o: src/engine.cc $(HEADERS)
64 $(BPATH)/allreduce_robust.o: src/allreduce_robust.cc $(HEADERS)
65 $(BPATH)/engine_mpi.o: src/engine_mpi.cc $(HEADERS)
66 $(BPATH)/engine_empty.o: src/engine_empty.cc $(HEADERS)
67 $(BPATH)/engine_mock.o: src/engine_mock.cc $(HEADERS)
68 $(BPATH)/engine_base.o: src/engine_base.cc $(HEADERS)
69 $(BPATH)/c_api.o: src/c_api.cc $(HEADERS)
70
71 lib/librabit.a lib/librabit.so: $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/c_api.o
72 lib/librabit_base.a lib/librabit_base.so: $(BPATH)/allreduce_base.o $(BPATH)/engine_base.o $(BPATH)/c_api.o
73 lib/librabit_mock.a lib/librabit_mock.so: $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine_mock.o $(BPATH)/c_api.o
74 lib/librabit_empty.a: $(BPATH)/engine_empty.o $(BPATH)/c_api.o
75 lib/librabit_mpi.a lib/librabit_mpi.so: $(MPIOBJ)
76
77 $(OBJ) :
78 $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
79
80 $(ALIB):
81 ar cr $@ $+
82
83 $(SLIB) :
84 $(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS)
85
86 $(MPIOBJ) :
87 $(MPICXX) -c $(CFLAGS) -I./mpich/include -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
88
89 $(MPIALIB):
90 ar cr $@ $+
91
92 $(MPISLIB) :
93 $(MPICXX) $(CFLAGS) -I./mpich/include -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) \
94 $(LDFLAGS) -L./mpich/lib -Wl,-rpath,./mpich/lib -lmpi
95
96 lint:
97 $(DMLC)/scripts/lint.py rabit $(LINT_LANG) src include
98
99 doc doxygen:
100 cd include; doxygen ../doc/Doxyfile; cd -
101
102 clean:
103 $(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) $(SLIB) *~ src/*~ include/*~ include/*/*~
0 # Rabit: Reliable Allreduce and Broadcast Interface
1 [![Build Status](https://travis-ci.org/dmlc/rabit.svg?branch=master)](https://travis-ci.org/dmlc/rabit)
2 [![Documentation Status](https://readthedocs.org/projects/rabit/badge/?version=latest)](http://rabit.readthedocs.org/)
3
4 rabit is a light weight library that provides a fault tolerant interface of Allreduce and Broadcast. It is designed to support easy implementations of distributed machine learning programs, many of which fall naturally under the Allreduce abstraction. The goal of rabit is to support ***portable*** , ***scalable*** and ***reliable*** distributed machine learning programs.
5
6 * [Tutorial](guide)
7 * [API Documentation](http://homes.cs.washington.edu/~tqchen/rabit/doc)
8 * You can also directly read the [interface header](include/rabit.h)
9 * [XGBoost](https://github.com/dmlc/xgboost)
10 - Rabit is one of the backbone library to support distributed XGBoost
11
12 ## Features
13 All these features comes from the facts about small rabbit:)
14 * Portable: rabit is light weight and runs everywhere
15 - Rabit is a library instead of a framework, a program only needs to link the library to run
16 - Rabit only replies on a mechanism to start program, which was provided by most framework
17 - You can run rabit programs on many platforms, including Yarn(Hadoop), MPI using the same code
18 * Scalable and Flexible: rabit runs fast
19 * Rabit program use Allreduce to communicate, and do not suffer the cost between iterations of MapReduce abstraction.
20 - Programs can call rabit functions in any order, as opposed to frameworks where callbacks are offered and called by the framework, i.e. inversion of control principle.
21 - Programs persist over all the iterations, unless they fail and recover.
22 * Reliable: rabit dig burrows to avoid disasters
23 - Rabit programs can recover the model and results using synchronous function calls.
24 - Rabit programs can set rabit_boostrap_cache=1 to support allreduce/broadcast operations before loadcheckpoint
25 `
26 rabit::Init(); -> rabit::AllReduce(); -> rabit::loadCheckpoint(); -> for () { rabit::AllReduce(); rabit::Checkpoint();} -> rabit::Shutdown();
27 `
28
29 ## Use Rabit
30 * Type make in the root folder will compile the rabit library in lib folder
31 * Add lib to the library path and include to the include path of compiler
32 * Languages: You can use rabit in C++ and python
33 - It is also possible to port the library to other languages
34
35 ## Contributing
36 Rabit is an open-source library, contributions are welcomed, including:
37 * The rabit core library.
38 * Customized tracker script for new platforms and interface of new languages.
39 * Tutorial and examples about the library.
0 @PACKAGE_INIT@
1
2 include("${CMAKE_CURRENT_LIST_DIR}/@TARGETS_EXPORT_NAME@.cmake")
3 check_required_components("@PROJECT_NAME@")
0 # code copied from https://crascit.com/2015/07/25/cmake-gtest/
1 cmake_minimum_required(VERSION 3.5 FATAL_ERROR)
2
3 project(googletest-download NONE)
4
5 include(ExternalProject)
6
7 ExternalProject_Add(
8 googletest
9 SOURCE_DIR "@GOOGLETEST_DOWNLOAD_ROOT@/googletest-src"
10 BINARY_DIR "@GOOGLETEST_DOWNLOAD_ROOT@/googletest-build"
11 GIT_REPOSITORY
12 https://github.com/google/googletest.git
13 GIT_TAG
14 release-1.8.0
15 CONFIGURE_COMMAND ""
16 BUILD_COMMAND ""
17 INSTALL_COMMAND ""
18 TEST_COMMAND ""
19 )
0 # the following code to fetch googletest
1 # is inspired by and adapted after https://crascit.com/2015/07/25/cmake-gtest/
2 # download and unpack googletest at configure time
3
4 macro(fetch_googletest _download_module_path _download_root)
5 set(GOOGLETEST_DOWNLOAD_ROOT ${_download_root})
6 configure_file(
7 ${_download_module_path}/googletest-download.cmake
8 ${_download_root}/CMakeLists.txt
9 @ONLY
10 )
11 unset(GOOGLETEST_DOWNLOAD_ROOT)
12
13 execute_process(
14 COMMAND
15 "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" .
16 WORKING_DIRECTORY
17 ${_download_root}
18 )
19 execute_process(
20 COMMAND
21 "${CMAKE_COMMAND}" --build .
22 WORKING_DIRECTORY
23 ${_download_root}
24 )
25
26 # adds the targers: gtest, gtest_main, gmock, gmock_main
27 add_subdirectory(
28 ${_download_root}/googletest-src
29 ${_download_root}/googletest-build
30 )
31 endmacro()
0 html
1 latex
2 *.sh
3 _*
4 doxygen
0 # Doxyfile 1.7.6.1
1
2 #---------------------------------------------------------------------------
3 # Project related configuration options
4 #---------------------------------------------------------------------------
5 DOXYFILE_ENCODING = UTF-8
6 PROJECT_NAME = "rabit"
7 PROJECT_NUMBER =
8 PROJECT_BRIEF =
9 PROJECT_LOGO =
10 OUTPUT_DIRECTORY = ../doc/doxygen
11 CREATE_SUBDIRS = NO
12 OUTPUT_LANGUAGE = English
13 BRIEF_MEMBER_DESC = YES
14 REPEAT_BRIEF = YES
15 ABBREVIATE_BRIEF =
16 ALWAYS_DETAILED_SEC = NO
17 INLINE_INHERITED_MEMB = NO
18 FULL_PATH_NAMES = YES
19 STRIP_FROM_PATH =
20 STRIP_FROM_INC_PATH =
21 SHORT_NAMES = NO
22 JAVADOC_AUTOBRIEF = NO
23 QT_AUTOBRIEF = NO
24 MULTILINE_CPP_IS_BRIEF = NO
25 INHERIT_DOCS = YES
26 SEPARATE_MEMBER_PAGES = NO
27 TAB_SIZE = 8
28 ALIASES =
29 TCL_SUBST =
30 OPTIMIZE_OUTPUT_FOR_C = YES
31 OPTIMIZE_OUTPUT_JAVA = NO
32 OPTIMIZE_FOR_FORTRAN = NO
33 OPTIMIZE_OUTPUT_VHDL = NO
34 EXTENSION_MAPPING =
35 BUILTIN_STL_SUPPORT = NO
36 CPP_CLI_SUPPORT = NO
37 SIP_SUPPORT = NO
38 IDL_PROPERTY_SUPPORT = YES
39 DISTRIBUTE_GROUP_DOC = NO
40 SUBGROUPING = YES
41 INLINE_GROUPED_CLASSES = NO
42 INLINE_SIMPLE_STRUCTS = NO
43 TYPEDEF_HIDES_STRUCT = NO
44 LOOKUP_CACHE_SIZE = 0
45 #---------------------------------------------------------------------------
46 # Build related configuration options
47 #---------------------------------------------------------------------------
48 EXTRACT_ALL = NO
49 EXTRACT_PRIVATE = NO
50 EXTRACT_STATIC = NO
51 EXTRACT_LOCAL_CLASSES = YES
52 EXTRACT_LOCAL_METHODS = NO
53 EXTRACT_ANON_NSPACES = NO
54 HIDE_UNDOC_MEMBERS = NO
55 HIDE_UNDOC_CLASSES = YES
56 HIDE_FRIEND_COMPOUNDS = NO
57 HIDE_IN_BODY_DOCS = NO
58 INTERNAL_DOCS = NO
59 CASE_SENSE_NAMES = YES
60 HIDE_SCOPE_NAMES = NO
61 SHOW_INCLUDE_FILES = YES
62 FORCE_LOCAL_INCLUDES = NO
63 INLINE_INFO = YES
64 SORT_MEMBER_DOCS = YES
65 SORT_BRIEF_DOCS = NO
66 SORT_MEMBERS_CTORS_1ST = NO
67 SORT_GROUP_NAMES = NO
68 SORT_BY_SCOPE_NAME = NO
69 STRICT_PROTO_MATCHING = NO
70 GENERATE_TODOLIST = YES
71 GENERATE_TESTLIST = YES
72 GENERATE_BUGLIST = YES
73 GENERATE_DEPRECATEDLIST= YES
74 ENABLED_SECTIONS =
75 MAX_INITIALIZER_LINES = 30
76 SHOW_USED_FILES = YES
77 SHOW_FILES = YES
78 SHOW_NAMESPACES = YES
79 FILE_VERSION_FILTER =
80 LAYOUT_FILE =
81 CITE_BIB_FILES =
82 #---------------------------------------------------------------------------
83 # configuration options related to warning and progress messages
84 #---------------------------------------------------------------------------
85 QUIET = NO
86 WARNINGS = YES
87 WARN_IF_UNDOCUMENTED = YES
88 WARN_IF_DOC_ERROR = YES
89 WARN_NO_PARAMDOC = YES
90 WARN_FORMAT = "$file:$line: $text"
91 WARN_LOGFILE =
92 #---------------------------------------------------------------------------
93 # configuration options related to the input files
94 #---------------------------------------------------------------------------
95 INPUT = rabit
96 INPUT_ENCODING = UTF-8
97 FILE_PATTERNS =
98 RECURSIVE = NO
99 EXCLUDE =
100 EXCLUDE_SYMLINKS = NO
101 EXCLUDE_PATTERNS = *-inl.hpp
102 EXCLUDE_SYMBOLS =
103 EXAMPLE_PATH =
104 EXAMPLE_PATTERNS =
105 EXAMPLE_RECURSIVE = NO
106 IMAGE_PATH =
107 INPUT_FILTER =
108 FILTER_PATTERNS =
109 FILTER_SOURCE_FILES = NO
110 FILTER_SOURCE_PATTERNS =
111 #---------------------------------------------------------------------------
112 # configuration options related to source browsing
113 #---------------------------------------------------------------------------
114 SOURCE_BROWSER = NO
115 INLINE_SOURCES = NO
116 STRIP_CODE_COMMENTS = YES
117 REFERENCED_BY_RELATION = NO
118 REFERENCES_RELATION = NO
119 REFERENCES_LINK_SOURCE = YES
120 USE_HTAGS = NO
121 VERBATIM_HEADERS = YES
122 #---------------------------------------------------------------------------
123 # configuration options related to the alphabetical class index
124 #---------------------------------------------------------------------------
125 ALPHABETICAL_INDEX = YES
126 COLS_IN_ALPHA_INDEX = 5
127 IGNORE_PREFIX =
128 #---------------------------------------------------------------------------
129 # configuration options related to the HTML output
130 #---------------------------------------------------------------------------
131 GENERATE_HTML = YES
132 HTML_OUTPUT = html
133 HTML_FILE_EXTENSION = .html
134 HTML_HEADER =
135 HTML_FOOTER =
136 HTML_STYLESHEET =
137 HTML_EXTRA_FILES =
138 HTML_COLORSTYLE_HUE = 220
139 HTML_COLORSTYLE_SAT = 100
140 HTML_COLORSTYLE_GAMMA = 80
141 HTML_TIMESTAMP = YES
142 HTML_DYNAMIC_SECTIONS = NO
143 GENERATE_DOCSET = NO
144 DOCSET_FEEDNAME = "Doxygen generated docs"
145 DOCSET_BUNDLE_ID = org.doxygen.Project
146 DOCSET_PUBLISHER_ID = org.doxygen.Publisher
147 DOCSET_PUBLISHER_NAME = Publisher
148 GENERATE_HTMLHELP = NO
149 CHM_FILE =
150 HHC_LOCATION =
151 GENERATE_CHI = NO
152 CHM_INDEX_ENCODING =
153 BINARY_TOC = NO
154 TOC_EXPAND = NO
155 GENERATE_QHP = NO
156 QCH_FILE =
157 QHP_NAMESPACE = org.doxygen.Project
158 QHP_VIRTUAL_FOLDER = doc
159 QHP_CUST_FILTER_NAME =
160 QHP_CUST_FILTER_ATTRS =
161 QHP_SECT_FILTER_ATTRS =
162 QHG_LOCATION =
163 GENERATE_ECLIPSEHELP = NO
164 ECLIPSE_DOC_ID = org.doxygen.Project
165 DISABLE_INDEX = NO
166 GENERATE_TREEVIEW = NO
167 ENUM_VALUES_PER_LINE = 4
168 TREEVIEW_WIDTH = 250
169 EXT_LINKS_IN_WINDOW = NO
170 FORMULA_FONTSIZE = 10
171 FORMULA_TRANSPARENT = YES
172 USE_MATHJAX = NO
173 MATHJAX_RELPATH = http://www.mathjax.org/mathjax
174 MATHJAX_EXTENSIONS =
175 SEARCHENGINE = YES
176 SERVER_BASED_SEARCH = NO
177 #---------------------------------------------------------------------------
178 # configuration options related to the LaTeX output
179 #---------------------------------------------------------------------------
180 GENERATE_LATEX = YES
181 LATEX_OUTPUT = latex
182 LATEX_CMD_NAME = latex
183 MAKEINDEX_CMD_NAME = makeindex
184 COMPACT_LATEX = NO
185 PAPER_TYPE = a4
186 EXTRA_PACKAGES =
187 LATEX_HEADER =
188 LATEX_FOOTER =
189 PDF_HYPERLINKS = YES
190 USE_PDFLATEX = YES
191 LATEX_BATCHMODE = NO
192 LATEX_HIDE_INDICES = NO
193 LATEX_SOURCE_CODE = NO
194 LATEX_BIB_STYLE = plain
195 #---------------------------------------------------------------------------
196 # configuration options related to the RTF output
197 #---------------------------------------------------------------------------
198 GENERATE_RTF = NO
199 RTF_OUTPUT = rtf
200 COMPACT_RTF = NO
201 RTF_HYPERLINKS = NO
202 RTF_STYLESHEET_FILE =
203 RTF_EXTENSIONS_FILE =
204 #---------------------------------------------------------------------------
205 # configuration options related to the man page output
206 #---------------------------------------------------------------------------
207 GENERATE_MAN = NO
208 MAN_OUTPUT = man
209 MAN_EXTENSION = .3
210 MAN_LINKS = NO
211 #---------------------------------------------------------------------------
212 # configuration options related to the XML output
213 #---------------------------------------------------------------------------
214 GENERATE_XML = YES
215 XML_OUTPUT = xml
216 XML_PROGRAMLISTING = YES
217 #---------------------------------------------------------------------------
218 # configuration options for the AutoGen Definitions output
219 #---------------------------------------------------------------------------
220 GENERATE_AUTOGEN_DEF = NO
221 #---------------------------------------------------------------------------
222 # configuration options related to the Perl module output
223 #---------------------------------------------------------------------------
224 GENERATE_PERLMOD = NO
225 PERLMOD_LATEX = NO
226 PERLMOD_PRETTY = YES
227 PERLMOD_MAKEVAR_PREFIX =
228 #---------------------------------------------------------------------------
229 # Configuration options related to the preprocessor
230 #---------------------------------------------------------------------------
231 ENABLE_PREPROCESSING = NO
232 MACRO_EXPANSION = NO
233 EXPAND_ONLY_PREDEF = NO
234 SEARCH_INCLUDES = YES
235 INCLUDE_PATH =
236 INCLUDE_FILE_PATTERNS =
237 PREDEFINED =
238 EXPAND_AS_DEFINED =
239 SKIP_FUNCTION_MACROS = YES
240 #---------------------------------------------------------------------------
241 # Configuration::additions related to external references
242 #---------------------------------------------------------------------------
243 TAGFILES =
244 GENERATE_TAGFILE =
245 ALLEXTERNALS = NO
246 EXTERNAL_GROUPS = YES
247 PERL_PATH = /usr/bin/perl
248 #---------------------------------------------------------------------------
249 # Configuration options related to the dot tool
250 #---------------------------------------------------------------------------
251 CLASS_DIAGRAMS = YES
252 MSCGEN_PATH =
253 HIDE_UNDOC_RELATIONS = YES
254 HAVE_DOT = NO
255 DOT_NUM_THREADS = 0
256 DOT_FONTNAME = Helvetica
257 DOT_FONTSIZE = 10
258 DOT_FONTPATH =
259 CLASS_GRAPH = YES
260 COLLABORATION_GRAPH = YES
261 GROUP_GRAPHS = YES
262 UML_LOOK = NO
263 TEMPLATE_RELATIONS = NO
264 INCLUDE_GRAPH = YES
265 INCLUDED_BY_GRAPH = YES
266 CALL_GRAPH = NO
267 CALLER_GRAPH = NO
268 GRAPHICAL_HIERARCHY = YES
269 DIRECTORY_GRAPH = YES
270 DOT_IMAGE_FORMAT = png
271 INTERACTIVE_SVG = NO
272 DOT_PATH =
273 DOTFILE_DIRS =
274 MSCFILE_DIRS =
275 DOT_GRAPH_MAX_NODES = 50
276 MAX_DOT_GRAPH_DEPTH = 0
277 DOT_TRANSPARENT = NO
278 DOT_MULTI_TARGETS = YES
279 GENERATE_LEGEND = YES
280 DOT_CLEANUP = YES
0 # Makefile for Sphinx documentation
1 #
2
3 # You can set these variables from the command line.
4 SPHINXOPTS =
5 SPHINXBUILD = sphinx-build
6 PAPER =
7 BUILDDIR = _build
8
9 # User-friendly check for sphinx-build
10 ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
11 $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/)
12 endif
13
14 # Internal variables.
15 PAPEROPT_a4 = -D latex_paper_size=a4
16 PAPEROPT_letter = -D latex_paper_size=letter
17 ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
18 # the i18n builder cannot share the environment and doctrees with the others
19 I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
20
21 .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest coverage gettext
22
23 help:
24 @echo "Please use \`make <target>' where <target> is one of"
25 @echo " html to make standalone HTML files"
26 @echo " dirhtml to make HTML files named index.html in directories"
27 @echo " singlehtml to make a single large HTML file"
28 @echo " pickle to make pickle files"
29 @echo " json to make JSON files"
30 @echo " htmlhelp to make HTML files and a HTML help project"
31 @echo " qthelp to make HTML files and a qthelp project"
32 @echo " applehelp to make an Apple Help Book"
33 @echo " devhelp to make HTML files and a Devhelp project"
34 @echo " epub to make an epub"
35 @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
36 @echo " latexpdf to make LaTeX files and run them through pdflatex"
37 @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx"
38 @echo " text to make text files"
39 @echo " man to make manual pages"
40 @echo " texinfo to make Texinfo files"
41 @echo " info to make Texinfo files and run them through makeinfo"
42 @echo " gettext to make PO message catalogs"
43 @echo " changes to make an overview of all changed/added/deprecated items"
44 @echo " xml to make Docutils-native XML files"
45 @echo " pseudoxml to make pseudoxml-XML files for display purposes"
46 @echo " linkcheck to check all external links for integrity"
47 @echo " doctest to run all doctests embedded in the documentation (if enabled)"
48 @echo " coverage to run coverage check of the documentation (if enabled)"
49
50 clean:
51 rm -rf $(BUILDDIR)/*
52
53 html:
54 $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
55 @echo
56 @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
57
58 dirhtml:
59 $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
60 @echo
61 @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
62
63 singlehtml:
64 $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
65 @echo
66 @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
67
68 pickle:
69 $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
70 @echo
71 @echo "Build finished; now you can process the pickle files."
72
73 json:
74 $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
75 @echo
76 @echo "Build finished; now you can process the JSON files."
77
78 htmlhelp:
79 $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
80 @echo
81 @echo "Build finished; now you can run HTML Help Workshop with the" \
82 ".hhp project file in $(BUILDDIR)/htmlhelp."
83
84 qthelp:
85 $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
86 @echo
87 @echo "Build finished; now you can run "qcollectiongenerator" with the" \
88 ".qhcp project file in $(BUILDDIR)/qthelp, like this:"
89 @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/rabit.qhcp"
90 @echo "To view the help file:"
91 @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/rabit.qhc"
92
93 applehelp:
94 $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp
95 @echo
96 @echo "Build finished. The help book is in $(BUILDDIR)/applehelp."
97 @echo "N.B. You won't be able to view it unless you put it in" \
98 "~/Library/Documentation/Help or install it in your application" \
99 "bundle."
100
101 devhelp:
102 $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
103 @echo
104 @echo "Build finished."
105 @echo "To view the help file:"
106 @echo "# mkdir -p $$HOME/.local/share/devhelp/rabit"
107 @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/rabit"
108 @echo "# devhelp"
109
110 epub:
111 $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
112 @echo
113 @echo "Build finished. The epub file is in $(BUILDDIR)/epub."
114
115 latex:
116 $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
117 @echo
118 @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex."
119 @echo "Run \`make' in that directory to run these through (pdf)latex" \
120 "(use \`make latexpdf' here to do that automatically)."
121
122 latexpdf:
123 $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
124 @echo "Running LaTeX files through pdflatex..."
125 $(MAKE) -C $(BUILDDIR)/latex all-pdf
126 @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
127
128 latexpdfja:
129 $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
130 @echo "Running LaTeX files through platex and dvipdfmx..."
131 $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja
132 @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
133
134 text:
135 $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
136 @echo
137 @echo "Build finished. The text files are in $(BUILDDIR)/text."
138
139 man:
140 $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
141 @echo
142 @echo "Build finished. The manual pages are in $(BUILDDIR)/man."
143
144 texinfo:
145 $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
146 @echo
147 @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo."
148 @echo "Run \`make' in that directory to run these through makeinfo" \
149 "(use \`make info' here to do that automatically)."
150
151 info:
152 $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
153 @echo "Running Texinfo files through makeinfo..."
154 make -C $(BUILDDIR)/texinfo info
155 @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo."
156
157 gettext:
158 $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale
159 @echo
160 @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale."
161
162 changes:
163 $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
164 @echo
165 @echo "The overview file is in $(BUILDDIR)/changes."
166
167 linkcheck:
168 $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
169 @echo
170 @echo "Link check complete; look for any errors in the above output " \
171 "or in $(BUILDDIR)/linkcheck/output.txt."
172
173 doctest:
174 $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
175 @echo "Testing of doctests in the sources finished, look at the " \
176 "results in $(BUILDDIR)/doctest/output.txt."
177
178 coverage:
179 $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage
180 @echo "Testing of coverage in the sources finished, look at the " \
181 "results in $(BUILDDIR)/coverage/python.txt."
182
183 xml:
184 $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml
185 @echo
186 @echo "Build finished. The XML files are in $(BUILDDIR)/xml."
187
188 pseudoxml:
189 $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml
190 @echo
191 @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml."
0 # -*- coding: utf-8 -*-
1 #
2 # documentation build configuration file, created by
3 # sphinx-quickstart on Thu Jul 23 19:40:08 2015.
4 #
5 # This file is execfile()d with the current directory set to its
6 # containing dir.
7 #
8 # Note that not all possible configuration values are present in this
9 # autogenerated file.
10 #
11 # All configuration values have a default; values that are commented out
12 # serve to show the default.
13 import sys
14 import os, subprocess
15 import shlex
16 # If extensions (or modules to document with autodoc) are in another directory,
17 # add these directories to sys.path here. If the directory is relative to the
18 # documentation root, use os.path.abspath to make it absolute, like shown here.
19 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
20 libpath = os.path.join(curr_path, '../wrapper/')
21 sys.path.insert(0, os.path.join(curr_path, '../wrapper/'))
22 sys.path.insert(0, curr_path)
23 from sphinx_util import MarkdownParser, AutoStructify
24
25 # -- General configuration ------------------------------------------------
26
27 # General information about the project.
28 project = u'rabit'
29 copyright = u'2015, rabit developers'
30 author = u'rabit developers'
31 github_doc_root = 'https://github.com/dmlc/rabit/tree/master/doc/'
32
33 # add markdown parser
34 MarkdownParser.github_doc_root = github_doc_root
35 source_parsers = {
36 '.md': MarkdownParser,
37 }
38 # Version information.
39 import rabit
40
41 version = rabit.__version__
42 release = rabit.__version__
43
44 # Add any Sphinx extension module names here, as strings. They can be
45 # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones
46 extensions = [
47 'sphinx.ext.autodoc',
48 'sphinx.ext.napoleon',
49 'sphinx.ext.mathjax',
50 'breathe',
51 ]
52
53 # Use breathe to include doxygen documents
54 breathe_projects = {'rabit' : 'doxygen/xml/'}
55 breathe_default_project = 'rabit'
56
57 # Add any paths that contain templates here, relative to this directory.
58 templates_path = ['_templates']
59
60 # The suffix(es) of source filenames.
61 # You can specify multiple suffix as a list of string:
62 # source_suffix = ['.rst', '.md']
63 source_suffix = ['.rst', '.md']
64
65 # The encoding of source files.
66 #source_encoding = 'utf-8-sig'
67
68 # The master toctree document.
69 master_doc = 'index'
70
71 # The language for content autogenerated by Sphinx. Refer to documentation
72 # for a list of supported languages.
73 #
74 # This is also used if you do content translation via gettext catalogs.
75 # Usually you set "language" from the command line for these cases.
76 language = None
77
78 # There are two options for replacing |today|: either, you set today to some
79 # non-false value, then it is used:
80 #today = ''
81 # Else, today_fmt is used as the format for a strftime call.
82 #today_fmt = '%B %d, %Y'
83
84 # List of patterns, relative to source directory, that match files and
85 # directories to ignore when looking for source files.
86 exclude_patterns = ['_build']
87
88 # The reST default role (used for this markup: `text`) to use for all
89 # documents.
90 #default_role = None
91
92 # If true, '()' will be appended to :func: etc. cross-reference text.
93 #add_function_parentheses = True
94
95 # If true, the current module name will be prepended to all description
96 # unit titles (such as .. function::).
97 #add_module_names = True
98
99 # If true, sectionauthor and moduleauthor directives will be shown in the
100 # output. They are ignored by default.
101 #show_authors = False
102
103 # The name of the Pygments (syntax highlighting) style to use.
104 pygments_style = 'sphinx'
105
106 # A list of ignored prefixes for module index sorting.
107 #modindex_common_prefix = []
108
109 # If true, keep warnings as "system message" paragraphs in the built documents.
110 #keep_warnings = False
111
112 # If true, `todo` and `todoList` produce output, else they produce nothing.
113 todo_include_todos = False
114
115 # -- Options for HTML output ----------------------------------------------
116
117 # The theme to use for HTML and HTML Help pages. See the documentation for
118 # a list of builtin themes.
119 # html_theme = 'alabaster'
120
121 # Add any paths that contain custom static files (such as style sheets) here,
122 # relative to this directory. They are copied after the builtin static files,
123 # so a file named "default.css" will overwrite the builtin "default.css".
124 html_static_path = ['_static']
125
126 # Output file base name for HTML help builder.
127 htmlhelp_basename = project + 'doc'
128
129 # -- Options for LaTeX output ---------------------------------------------
130 latex_elements = {
131 }
132
133 # Grouping the document tree into LaTeX files. List of tuples
134 # (source start file, target name, title,
135 # author, documentclass [howto, manual, or own class]).
136 latex_documents = [
137 (master_doc, 'rabit.tex', project,
138 author, 'manual'),
139 ]
140
141 # hook for doxygen
142 def run_doxygen(folder):
143 """Run the doxygen make command in the designated folder."""
144 try:
145 retcode = subprocess.call("cd %s; make doxygen" % folder, shell=True)
146 if retcode < 0:
147 sys.stderr.write("doxygen terminated by signal %s" % (-retcode))
148 except OSError as e:
149 sys.stderr.write("doxygen execution failed: %s" % e)
150
151
152 def run_build_lib(folder):
153 """Run the doxygen make command in the designated folder."""
154 try:
155 retcode = subprocess.call("cd %s; make" % folder, shell=True)
156 retcode = subprocess.call("rm -rf _build/html/doxygen", shell=True)
157 retcode = subprocess.call("mkdir _build", shell=True)
158 retcode = subprocess.call("mkdir _build/html", shell=True)
159 retcode = subprocess.call("cp -rf doxygen/html _build/html/doxygen", shell=True)
160 if retcode < 0:
161 sys.stderr.write("build terminated by signal %s" % (-retcode))
162 except OSError as e:
163 sys.stderr.write("build execution failed: %s" % e)
164
165
166 def generate_doxygen_xml(app):
167 """Run the doxygen make commands if we're on the ReadTheDocs server"""
168 read_the_docs_build = os.environ.get('READTHEDOCS', None) == 'True'
169 if read_the_docs_build:
170 run_doxygen('..')
171 sys.stderr.write('Check if shared lib exists\n')
172 run_build_lib('..')
173 sys.stderr.write('The wrapper path: %s\n' % str(os.listdir('../wrapper')))
174 rabit._loadlib()
175
176
177 def setup(app):
178 # Add hook for building doxygen xml when needed
179 app.connect("builder-inited", generate_doxygen_xml)
180 app.add_config_value('recommonmark_config', {
181 'url_resolver': lambda url: github_doc_root + url,
182 }, True)
183 app.add_transform(AutoStructify)
0 C++ Library API of Rabit
1 ========================
2 This page contains document of Library API of rabit.
3
4 ```eval_rst
5 .. toctree::
6
7 .. doxygennamespace:: rabit
8 ```
0 Tutorial
1 ========
2 This is rabit's tutorial, a ***Reliable Allreduce and Broadcast Interface***.
3 All the example codes are in the [guide](https://github.com/dmlc/rabit/blob/master/guide/) folder of the project.
4 To run the examples locally, you will need to build them with ```make```.
5
6 **List of Topics**
7 * [What is Allreduce](#what-is-allreduce)
8 * [Common Use Case](#common-use-case)
9 * [Use Rabit API](#use-rabit-api)
10 - [Structure of a Rabit Program](#structure-of-a-rabit-program)
11 - [Allreduce and Lazy Preparation](#allreduce-and-lazy-preparation)
12 - [Checkpoint and LazyCheckpoint](#checkpoint-and-lazycheckpoint)
13 * [Compile Programs with Rabit](#compile-programs-with-rabit)
14 * [Running Rabit Jobs](#running-rabit-jobs)
15 * [Fault Tolerance](#fault-tolerance)
16
17 What is Allreduce
18 -----------------
19 The main methods provided by rabit are Allreduce and Broadcast. Allreduce performs reduction across different computation nodes,
20 and returns the result to every node. To understand the behavior of the function, consider the following example in [basic.cc](../guide/basic.cc) (there is a python example right after this if you are more familiar with python).
21 ```c++
22 #include <rabit.h>
23 using namespace rabit;
24 const int N = 3;
25 int main(int argc, char *argv[]) {
26 int a[N];
27 rabit::Init(argc, argv);
28 for (int i = 0; i < N; ++i) {
29 a[i] = rabit::GetRank() + i;
30 }
31 printf("@node[%d] before-allreduce: a={%d, %d, %d}\n",
32 rabit::GetRank(), a[0], a[1], a[2]);
33 // allreduce take max of each elements in all processes
34 Allreduce<op::Max>(&a[0], N);
35 printf("@node[%d] after-allreduce-max: a={%d, %d, %d}\n",
36 rabit::GetRank(), a[0], a[1], a[2]);
37 // second allreduce that sums everything up
38 Allreduce<op::Sum>(&a[0], N);
39 printf("@node[%d] after-allreduce-sum: a={%d, %d, %d}\n",
40 rabit::GetRank(), a[0], a[1], a[2]);
41 rabit::Finalize();
42 return 0;
43 }
44 ```
45 You can run the example using the rabit_demo.py script. The following command
46 starts the rabit program with two worker processes.
47 ```bash
48 ../tracker/rabit_demo.py -n 2 basic.rabit
49 ```
50 This will start two processes, one process with rank 0 and the other with rank 1, both processes run the same code.
51 The ```rabit::GetRank()``` function returns the rank of current process.
52
53 Before the call to Allreduce, process 0 contains the array ```a = {0, 1, 2}```, while process 1 has the array
54 ```a = {1, 2, 3}```. After the call to Allreduce, the array contents in all processes are replaced by the
55 reduction result (in this case, the maximum value in each position across all the processes). So, after the
56 Allreduce call, the result will become ```a = {1, 2, 3}```.
57 Rabit provides different reduction operators, for example, if you change ```op::Max``` to ```op::Sum```,
58 the reduction operation will be a summation, and the result will become ```a = {1, 3, 5}```.
59 You can also run the example with different processes by setting -n to different values.
60
61 If you are more familiar with python, you can also use rabit in python. The same example as before can be found in [basic.py](../guide/basic.py):
62
63 ```python
64 import numpy as np
65 import rabit
66
67 rabit.init()
68 n = 3
69 rank = rabit.get_rank()
70 a = np.zeros(n)
71 for i in xrange(n):
72 a[i] = rank + i
73
74 print '@node[%d] before-allreduce: a=%s' % (rank, str(a))
75 a = rabit.allreduce(a, rabit.MAX)
76 print '@node[%d] after-allreduce-max: a=%s' % (rank, str(a))
77 a = rabit.allreduce(a, rabit.SUM)
78 print '@node[%d] after-allreduce-sum: a=%s' % (rank, str(a))
79 rabit.finalize()
80 ```
81 You can run the program using the following command
82 ```bash
83 ../tracker/rabit_demo.py -n 2 basic.py
84 ```
85
86 Broadcast is another method provided by rabit besides Allreduce. This function allows one node to broadcast its
87 local data to all other nodes. The following code in [broadcast.cc](../guide/broadcast.cc) broadcasts a string from
88 node 0 to all other nodes.
89 ```c++
90 #include <rabit.h>
91 using namespace rabit;
92 const int N = 3;
93 int main(int argc, char *argv[]) {
94 rabit::Init(argc, argv);
95 std::string s;
96 if (rabit::GetRank() == 0) s = "hello world";
97 printf("@node[%d] before-broadcast: s=\"%s\"\n",
98 rabit::GetRank(), s.c_str());
99 // broadcast s from node 0 to all other nodes
100 rabit::Broadcast(&s, 0);
101 printf("@node[%d] after-broadcast: s=\"%s\"\n",
102 rabit::GetRank(), s.c_str());
103 rabit::Finalize();
104 return 0;
105 }
106 ```
107 The following command starts the program with three worker processes.
108 ```bash
109 ../tracker/rabit_demo.py -n 3 broadcast.rabit
110 ```
111 Besides strings, rabit also allows to broadcast constant size array and vectors.
112
113 The counterpart in python can be found in [broadcast.py](../guide/broadcast.py). Here is a snippet so that you can get a better sense of how simple is to use the python library:
114
115 ```python
116 import rabit
117 rabit.init()
118 n = 3
119 rank = rabit.get_rank()
120 s = None
121 if rank == 0:
122 s = {'hello world':100, 2:3}
123 print '@node[%d] before-broadcast: s=\"%s\"' % (rank, str(s))
124 s = rabit.broadcast(s, 0)
125 print '@node[%d] after-broadcast: s=\"%s\"' % (rank, str(s))
126 rabit.finalize()
127 ```
128
129 Common Use Case
130 ---------------
131 Many distributed machine learning algorithms involve splitting the data into different nodes,
132 computing statistics locally, and finally aggregating them. Such workflow is usually done repetitively through many iterations before the algorithm converges. Allreduce naturally meets the structure of such programs,
133 common use cases include:
134
135 * Aggregation of gradient values, which can be used in optimization methods such as L-BFGS.
136 * Aggregation of other statistics, which can be used in KMeans and Gaussian Mixture Models.
137 * Find the best split candidate and aggregation of split statistics, used for tree based models.
138
139 Rabit is a reliable and portable library for distributed machine learning programs, that allow programs to run reliably on different platforms.
140
141 Use Rabit API
142 -------------
143 This section introduces topics about how to use rabit API.
144 You can always refer to [API Documentation](http://homes.cs.washington.edu/~tqchen/rabit/doc) for definition of each functions.
145 This section trys to gives examples of different aspectes of rabit API.
146
147 #### Structure of a Rabit Program
148 The following code illustrates the common structure of a rabit program. This is an abstract example,
149 you can also refer to [wormhole](https://github.com/dmlc/wormhole/blob/master/learn/kmeans/kmeans.cc) for an example implementation of kmeans algorithm.
150
151 ```c++
152 #include <rabit.h>
153 int main(int argc, char *argv[]) {
154 ...
155 rabit::Init(argc, argv);
156 // sync on expected model size before load checkpoint, if we pass rabit_bootstrap_cache=true
157 rabit::Allreduce<rabit::op::Max>(&model.size(), 1);
158 // load the latest checked model
159 int version = rabit::LoadCheckPoint(&model);
160 // initialize the model if it is the first version
161 if (version == 0) model.InitModel();
162 // the version number marks the iteration to resume
163 for (int iter = version; iter < max_iter; ++iter) {
164 // at this point, the model object should allow us to recover the program state
165 ...
166 // each iteration can contain multiple calls of allreduce/broadcast
167 rabit::Allreduce<rabit::op::Max>(&data[0], n);
168 ...
169 // checkpoint model after one iteration finishes
170 rabit::CheckPoint(&model);
171 }
172 rabit::Finalize();
173 return 0;
174 }
175 ```
176
177 Besides the common Allreduce and Broadcast functions, there are two additional functions: ```LoadCheckPoint```
178 and ```CheckPoint```. These two functions are used for fault-tolerance purposes.
179 As mentioned before, traditional machine learning programs involve several iterations. In each iteration, we start with a model, make some calls
180 to Allreduce or Broadcast and update the model. The calling sequence in each iteration does not need to be the same.
181
182 * When the nodes start from the beginning (i.e. iteration 0), ```LoadCheckPoint``` returns 0, so we can initialize the model.
183 * ```CheckPoint``` saves the model after each iteration.
184 - Efficiency Note: the model is only kept in local memory and no save to disk is performed when calling Checkpoint
185 * When a node goes down and restarts, ```LoadCheckPoint``` will recover the latest saved model, and
186 * When a node goes down, the rest of the nodes will block in the call of Allreduce/Broadcast and wait for
187 the recovery of the failed node until it catches up.
188
189 Please see the [Fault Tolerance](#fault-tolerance) section to understand the recovery procedure executed by rabit.
190
191 #### Allreduce and Lazy Preparation
192 Allreduce is one of the most important function provided by rabit. You can call allreduce by specifying the
193 reduction operator, pointer to the data and size of the buffer, as follows
194 ```c++
195 Allreduce<operator>(pointer_of_data, size_of_data);
196 ```
197 This is the basic use case of Allreduce function. It is common that user writes the code to prepare the data needed
198 into the data buffer, pass the data to Allreduce function, and get the reduced result. However, when a node restarts
199 from failure, we can directly recover the result from other nodes(see also [Fault Tolerance](#fault-tolerance)) and
200 the data preparation procedure no longer necessary. Rabit Allreduce add an optional parameter preparation function
201 to support such scenario. User can pass in a function that corresponds to the data preparation procedure to Allreduce
202 calls, and the data preparation function will only be called when necessary. We use [lazy_allreduce.cc](../guide/lazy_allreduce.cc)
203 as an example to demonstrate this feature. It is modified from [basic.cc](../guide/basic.cc), and you can compare the two codes.
204 ```c++
205 #include <rabit.h>
206 using namespace rabit;
207 const int N = 3;
208 int main(int argc, char *argv[]) {
209 int a[N] = {0};
210 rabit::Init(argc, argv);
211 // lazy preparation function
212 auto prepare = [&]() {
213 printf("@node[%d] run prepare function\n", rabit::GetRank());
214 for (int i = 0; i < N; ++i) {
215 a[i] = rabit::GetRank() + i;
216 }
217 };
218 printf("@node[%d] before-allreduce: a={%d, %d, %d}\n",
219 rabit::GetRank(), a[0], a[1], a[2]);
220 // allreduce take max of each elements in all processes
221 Allreduce<op::Max>(&a[0], N, prepare);
222 printf("@node[%d] after-allreduce-sum: a={%d, %d, %d}\n",
223 rabit::GetRank(), a[0], a[1], a[2]);
224 // rum second allreduce
225 Allreduce<op::Sum>(&a[0], N);
226 printf("@node[%d] after-allreduce-max: a={%d, %d, %d}\n",
227 rabit::GetRank(), a[0], a[1], a[2]);
228 rabit::Finalize();
229 return 0;
230 }
231 ```
232 Here we use features of C++11 because the lambda function makes things much shorter.
233 There is also C++ compatible callback interface provided in the [API](http://homes.cs.washington.edu/~tqchen/rabit/doc).
234 You can compile the program by typing ```make lazy_allreduce.mock```. We link against the mock library so that we can see
235 the effect when a process goes down. You can run the program using the following command
236 ```bash
237 ../tracker/rabit_demo.py -n 2 lazy_allreduce.mock mock=0,0,1,0
238 ```
239 The additional arguments ```mock=0,0,1,0``` will cause node 0 to kill itself before second call of Allreduce (see also [mock test](#link-against-mock-test-rabit-library)).
240 You will find that the prepare function's print is only executed once and node 0 will no longer execute the preparation function when it restarts from failure.
241
242 You can also find python version of the example in [lazy_allreduce.py](../guide/lazy_allreduce.py), and run it using the followin command
243 ```bash
244 ../tracker/rabit_demo.py -n 2 lazy_allreduce.py mock=0,0,1,0
245
246 ```
247
248 Since lazy preparation function may not be called during execution. User should be careful when using this feature. For example, a possible mistake
249 could be putting some memory allocation code in the lazy preparation function, and the computing memory was not allocated when lazy preparation function is not called.
250 The example in [lazy_allreduce.cc](../guide/lazy_allreduce.cc) provides a simple way to migrate normal prepration code([basic.cc](../guide/basic.cc)) to lazy version: wrap the preparation
251 code with a lambda function, and pass it to allreduce.
252
253 #### Checkpoint and LazyCheckpoint
254 Common machine learning algorithms usually involves iterative computation. As mentioned in the section ([Structure of a Rabit Program](#structure-of-a-rabit-program)),
255 user can and should use Checkpoint to ```save``` the progress so far, so that when a node fails, the latest checkpointed model can be loaded.
256
257 There are two model arguments you can pass to Checkpoint and LoadCheckpoint: ```global_model``` and ```local_model```:
258 * ```global_model``` refers to the model that is commonly shared across all the nodes
259 - For example, the centriods of clusters in kmeans is shared across all nodes
260 * ```local_model``` refers to the model that is specifically tied to the current node
261 - For example, in topic modeling, the topic assignments of subset of documents in current node is local model
262
263 Because the different nature of the two types of models, different strategy will be used for them.
264 ```global_model``` is simply saved in local memory of each node, while ```local_model``` will replicated to some other
265 nodes (selected using a ring replication strategy). The checkpoint is only saved in the memory without touching the disk which makes rabit programs more efficient.
266 User is encouraged to use ```global_model``` only when is sufficient for better efficiency.
267
268 To enable a model class to be checked pointed, user can implement a [serialization interface](../include/rabit_serialization.h). The serialization interface already
269 provide serialization functions of STL vector and string. For python API, user can checkpoint any python object that can be pickled.
270
271 There is a special Checkpoint function called [LazyCheckpoint](http://homes.cs.washington.edu/~tqchen/rabit/doc/namespacerabit.html#a99f74c357afa5fba2c80cc0363e4e459),
272 which can be used for ```global_model``` only cases under certain condition.
273 When LazyCheckpoint is called, no action is taken and the rabit engine only remembers the pointer to the model.
274 The serialization will only happen when another node fails and the recovery starts. So user basically pays no extra cost calling LazyCheckpoint.
275 To use this function, the user need to ensure the model remain unchanged until the last call of Allreduce/Broadcast in the current version finishes.
276 So that when recovery procedure happens in these function calls, the serialized model will be the same.
277
278 For example, consider the following calling sequence
279 ```
280 LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
281 ```
282 The user must only change the model in code3. Such condition can usually be satiesfied in many scenarios, and user can use LazyCheckpoint to further
283 improve the efficiency of the program.
284
285
286 Compile Programs with Rabit
287 ---------------------------
288 Rabit is a portable library, to use it, you only need to include the rabit header file.
289 * You will need to add the path to [../include](../include) to the header search path of the compiler
290 - Solution 1: add ```-I/path/to/rabit/include``` to the compiler flag in gcc or clang
291 - Solution 2: add the path to the environment variable CPLUS_INCLUDE_PATH
292 * You will need to add the path to [../lib](../lib) to the library search path of the compiler
293 - Solution 1: add ```-L/path/to/rabit/lib``` to the linker flag
294 - Solution 2: add the path to environment variable LIBRARY_PATH AND LD_LIBRARY_PATH
295 * Link against lib/rabit.a
296 - Add ```-lrabit``` to the linker flag
297
298 The procedure above allows you to compile a program with rabit. The following two sections contain additional
299 options you can use to link against different backends other than the normal one.
300
301 #### Link against MPI Allreduce
302 You can link against ```rabit_mpi.a``` instead of using MPI Allreduce, however, the resulting program is backed by MPI and
303 is not fault tolerant anymore.
304 * Simply change the linker flag from ```-lrabit``` to ```-lrabit_mpi```
305 * The final linking needs to be done by mpi wrapper compiler ```mpicxx```
306
307 #### Link against Mock Test Rabit Library
308 If you want to use a mock to test the program in order to see the behavior of the code when some nodes go down, you can link against ```rabit_mock.a``` .
309 * Simply change the linker flag from ```-lrabit``` to ```-lrabit_mock```
310
311 The resulting rabit mock program can take in additional arguments in the following format
312 ```
313 mock=rank,version,seq,ndeath
314 ```
315
316 The four integers specify an event that will cause the program to ```commit suicide```(exit with -2)
317 * rank specifies the rank of the node to kill
318 * version specifies the version (iteration) of the model where you want the process to die
319 * seq specifies the sequence number of the Allreduce/Broadcast call since last checkpoint, where the process will be killed
320 * ndeath specifies how many times this node died already
321
322 For example, consider the following script in the test case
323 ```bash
324 ../tracker/rabit_demo.py -n 10 test_model_recover 10000\
325 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1
326 ```
327 * The first mock will cause node 0 to exit when calling the second Allreduce/Broadcast (seq = 1) in iteration 0
328 * The second mock will cause node 1 to exit when calling the second Allreduce/Broadcast (seq = 1) in iteration 1
329 * The third mock will cause node 1 to exit again when calling second Allreduce/Broadcast (seq = 1) in iteration 1
330 - Note that ndeath = 1 means this will happen only if node 1 died once, which is our case
331
332 Running Rabit Jobs
333 ------------------
334 Rabit is a portable library that can run on multiple platforms.
335 All the rabit jobs can be submitted using [dmlc-tracker](https://github.com/dmlc/dmlc-core/tree/master/tracker)
336
337 Fault Tolerance
338 ---------------
339 This section introduces how fault tolerance works in rabit.
340 The following figure shows how rabit deals with failures.
341
342 ![](http://homes.cs.washington.edu/~tqchen/rabit/fig/fault-tol.png)
343
344 The scenario is as follows:
345 * Node 1 fails between the first and second call of Allreduce after the second checkpoint
346 * The other nodes wait in the call of the second Allreduce in order to help node 1 to recover.
347 * When node 1 restarts, it will call ```LoadCheckPoint```, and get the latest checkpoint from one of the existing nodes.
348 * Then node 1 can start from the latest checkpoint and continue running.
349 * When node 1 calls the first Allreduce again, as the other nodes already know the result, node 1 can get it from one of them.
350 * When node 1 reaches the second Allreduce, the other nodes find out that node 1 has catched up and they can continue the program normally.
351
352 This fault tolerance model is based on a key property of Allreduce and
353 Broadcast: All the nodes get the same result after calling Allreduce/Broadcast.
354 Because of this property, any node can record the results of history
355 Allreduce/Broadcast calls. When a node is recovered, it can fetch the lost
356 results from some alive nodes and rebuild its model.
357
358 The checkpoint is introduced so that we can discard the history results of
359 Allreduce/Broadcast calls before the latest checkpoint. This saves memory
360 consumption used for backup. The checkpoint of each node is a model defined by
361 users and can be split into 2 parts: a global model and a local model. The
362 global model is shared by all nodes and can be backed up by any nodes. The
363 local model of a node is replicated to some other nodes (selected using a ring
364 replication strategy). The checkpoint is only saved in the memory without
365 touching the disk which makes rabit programs more efficient. The strategy of
366 rabit is different from the fail-restart strategy where all the nodes restart
367 from the same checkpoint when any of them fail. In rabit, all the alive nodes
368 will block in the Allreduce call and help the recovery. To catch up, the
369 recovered node fetches its latest checkpoint and the results of
370 Allreduce/Broadcast calls after the checkpoint from some alive nodes.
371
372 This is just a conceptual introduction to rabit's fault tolerance model. The actual implementation is more sophisticated,
373 and can deal with more complicated cases such as multiple nodes failure and node failure during recovery phase.
374
375 Rabit Timeout
376 ---------------
377
378 In certain cases, rabit cluster may suffer lack of resources to retry failed workers.
379 Thanks to fault tolerant assumption with infinite retry, it might cause entire cluster hang infinitely.
380 We introduce sidecar thread which runs when rabit fault tolerant runtime observed allreduce/broadcast errors.
381 By default, it will wait for 30 mins before all workers program exit.
382 User can opt-in this feature and change treshold by passing rabit_timeout=true and rabit_timeout_sec=x (in seconds).
0 Rabit Documentation
1 =====================
2 rabit is a light weight library that provides a fault tolerant interface of Allreduce and Broadcast. It is designed to support easy implementations of distributed machine learning programs, many of which fall naturally under the Allreduce abstraction. The goal of rabit is to support **portable** , **scalable** and **reliable** distributed machine learning programs.
3
4 API Documents
5 -------------
6 ```eval_rst
7
8 .. toctree::
9 :maxdepth: 2
10
11 python_api.md
12 cpp_api.md
13 parameters.md
14 guide.md
15 ```
16 Indices and tables
17 ------------------
18
19 ```eval_rst
20 * :ref:`genindex`
21 * :ref:`modindex`
22 * :ref:`search`
23 ```
0 Parameters
1 ==========
2 This section list all the parameters that can be passed to rabit::Init function as argv.
3 All the parameters are passed in as string in format of ``parameter-name=parameter-value``.
4 In most setting these parameters have default value or will be automatically detected,
5 and do not need to be manually configured.
6
7 * rabit_tracker_uri [passed in automatically by tracker]
8 - The uri/ip of rabit tracker
9 * rabit_tracker_port [passed in automatically by tracker]
10 - The port of rabit tracker
11 * rabit_task_id [automatically detected]
12 - The unique identifier of computing process
13 - When running on hadoop, this is automatically extracted from enviroment variable
14 * rabit_reduce_buffer [default = 256MB]
15 - The memory buffer used to store intermediate result of reduction
16 - Format "digits + unit", can be 128M, 1G
17 * rabit_global_replica [default = 5]
18 - Number of replication copies of result kept for each Allreduce/Broadcast call
19 * rabit_local_replica [default = 2]
20 - Number of replication of local model in check point
0 numpy
1 breathe
2 commonmark
3
0 Python API of Rabit
1 ===================
2 This page contains document of python API of rabit.
3
4 ```eval_rst
5 .. toctree::
6
7 .. automodule:: rabit
8 :members:
9 :show-inheritance:
10 ```
0 # -*- coding: utf-8 -*-
1 """Helper utilty function for customization."""
2 import sys
3 import os
4 import docutils
5 import subprocess
6
7 if os.environ.get('READTHEDOCS', None) == 'True':
8 subprocess.call('cd ..; rm -rf recommonmark;' +
9 'git clone https://github.com/tqchen/recommonmark', shell=True)
10
11 sys.path.insert(0, os.path.abspath('../recommonmark/'))
12 from recommonmark import parser, transform
13
14 MarkdownParser = parser.CommonMarkParser
15 AutoStructify = transform.AutoStructify
0 export CC = gcc
1 export CXX = g++
2 export MPICXX = mpicxx
3 export LDFLAGS= -pthread -lm -L../lib
4 export CFLAGS = -Wall -O3 -msse2 -std=c++11 -Wno-unknown-pragmas -fPIC -fopenmp -I../include
5
6 .PHONY: clean all lib libmpi
7 BIN = basic.rabit broadcast.rabit
8 MOCKBIN= lazy_allreduce.mock
9
10 all: $(BIN)
11 basic.rabit: basic.cc lib ../lib/librabit.a
12 broadcast.rabit: broadcast.cc lib ../lib/librabit.a
13 lazy_allreduce.mock: lazy_allreduce.cc lib ../lib/librabit.a
14
15 $(BIN) :
16 $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS)
17
18 $(MOCKBIN) :
19 $(CXX) $(CFLAGS) -std=c++11 -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock
20
21 $(OBJ) :
22 $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
23
24 clean:
25 $(RM) $(OBJ) $(BIN) $(MOCKBIN) *~ ../src/*~
0 See tutorial at ../doc/guide.md
0 /*!
1 * Copyright (c) 2014 by Contributors
2 * \file basic.cc
3 * \brief This is an example demonstrating what is Allreduce
4 *
5 * \author Tianqi Chen
6 */
7 #define _CRT_SECURE_NO_WARNINGS
8 #define _CRT_SECURE_NO_DEPRECATE
9 #include <vector>
10 #include <rabit/rabit.h>
11 using namespace rabit;
12 int main(int argc, char *argv[]) {
13 int N = 3;
14 if (argc > 1) {
15 N = atoi(argv[1]);
16 }
17 std::vector<int> a(N);
18 rabit::Init(argc, argv);
19 for (int i = 0; i < N; ++i) {
20 a[i] = rabit::GetRank() + i;
21 }
22 printf("@node[%d] before-allreduce: a={%d, %d, %d}\n",
23 rabit::GetRank(), a[0], a[1], a[2]);
24 // allreduce take max of each elements in all processes
25 Allreduce<op::Max>(&a[0], N);
26 printf("@node[%d] after-allreduce-max: a={%d, %d, %d}\n",
27 rabit::GetRank(), a[0], a[1], a[2]);
28 // second allreduce that sums everything up
29 Allreduce<op::Sum>(&a[0], N);
30 printf("@node[%d] after-allreduce-sum: a={%d, %d, %d}\n",
31 rabit::GetRank(), a[0], a[1], a[2]);
32 rabit::Finalize();
33 return 0;
34 }
0 #!/usr/bin/python
1 """
2 demo python script of rabit
3 """
4 from __future__ import print_function
5 from builtins import range
6 import os
7 import sys
8 import numpy as np
9 # import rabit, the tracker script will setup the lib path correctly
10 # for normal run without tracker script, add following line
11 # sys.path.append(os.path.dirname(__file__) + '/../python')
12 import rabit
13
14 rabit.init()
15 n = 3
16 rank = rabit.get_rank()
17 a = np.zeros(n)
18 for i in range(n):
19 a[i] = rank + i
20
21 print('@node[%d] before-allreduce: a=%s' % (rank, str(a)))
22 a = rabit.allreduce(a, rabit.MAX)
23 print('@node[%d] after-allreduce-max: a=%s' % (rank, str(a)))
24 a = rabit.allreduce(a, rabit.SUM)
25 print('@node[%d] after-allreduce-sum: a=%s' % (rank, str(a)))
26 rabit.finalize()
0 #include <rabit/rabit.h>
1 using namespace rabit;
2 const int N = 3;
3 int main(int argc, char *argv[]) {
4 rabit::Init(argc, argv);
5 std::string s;
6 if (rabit::GetRank() == 0) s = "hello world";
7 printf("@node[%d] before-broadcast: s=\"%s\"\n",
8 rabit::GetRank(), s.c_str());
9 // broadcast s from node 0 to all other nodes
10 rabit::Broadcast(&s, 0);
11 printf("@node[%d] after-broadcast: s=\"%s\"\n",
12 rabit::GetRank(), s.c_str());
13 rabit::Finalize();
14 return 0;
15 }
0 #!/usr/bin/python
1 """
2 demo python script of rabit
3 """
4 from __future__ import print_function
5 import os
6 import sys
7 # add path to wrapper
8 # for normal run without tracker script, add following line
9 # sys.path.append(os.path.dirname(__file__) + '/../wrapper')
10 import rabit
11
12 rabit.init()
13 n = 3
14 rank = rabit.get_rank()
15 s = None
16 if rank == 0:
17 s = {'hello world':100, 2:3}
18 print('@node[%d] before-broadcast: s=\"%s\"' % (rank, str(s)))
19 s = rabit.broadcast(s, 0)
20
21 print('@node[%d] after-broadcast: s=\"%s\"' % (rank, str(s)))
22 rabit.finalize()
0 /*!
1 * Copyright (c) 2014 by Contributors
2 * \file basic.cc
3 * \brief This is an example demonstrating what is Allreduce
4 *
5 * \author Tianqi Chen
6 */
7 #include <rabit/rabit.h>
8
9 using namespace rabit;
10 const int N = 3;
11 int main(int argc, char *argv[]) {
12 int a[N] = {0};
13 rabit::Init(argc, argv);
14 // lazy preparation function
15 auto prepare = [&]() {
16 printf("@node[%d] run prepare function\n", rabit::GetRank());
17 for (int i = 0; i < N; ++i) {
18 a[i] = rabit::GetRank() + i;
19 }
20 };
21 printf("@node[%d] before-allreduce: a={%d, %d, %d}\n",
22 rabit::GetRank(), a[0], a[1], a[2]);
23 // allreduce take max of each elements in all processes
24 Allreduce<op::Max>(&a[0], N, prepare);
25 printf("@node[%d] after-allreduce-sum: a={%d, %d, %d}\n",
26 rabit::GetRank(), a[0], a[1], a[2]);
27 // rum second allreduce
28 Allreduce<op::Sum>(&a[0], N);
29 printf("@node[%d] after-allreduce-max: a={%d, %d, %d}\n",
30 rabit::GetRank(), a[0], a[1], a[2]);
31 rabit::Finalize();
32 return 0;
33 }
0 #!/usr/bin/python
1 """
2 demo python script of rabit: Lazy preparation function
3 """
4 import os
5 import sys
6 import numpy as np
7 # import rabit, the tracker script will setup the lib path correctly
8 # for normal run without tracker script, add following line
9 # sys.path.append(os.path.dirname(__file__) + '/../wrapper')
10 import rabit
11
12
13 # use mock library so that we can run failure test
14 rabit.init(lib = 'mock')
15 n = 3
16 rank = rabit.get_rank()
17 a = np.zeros(n)
18
19 def prepare(a):
20 print('@node[%d] run prepare function' % rank)
21 # must take in reference and modify the reference
22 for i in xrange(n):
23 a[i] = rank + i
24
25 print('@node[%d] before-allreduce: a=%s' % (rank, str(a)))
26 a = rabit.allreduce(a, rabit.MAX, prepare_fun = prepare)
27 print('@node[%d] after-allreduce-max: a=%s' % (rank, str(a)))
28 a = rabit.allreduce(a, rabit.SUM)
29 print('@node[%d] after-allreduce-sum: a=%s' % (rank, str(a)))
30 rabit.finalize()
0 /*!
1 * Copyright by Contributors
2 * \file c_api.h
3 * \author Tianqi Chen
4 * \brief a C style API of rabit.
5 */
6 #ifndef RABIT_C_API_H_
7 #define RABIT_C_API_H_
8
9 #ifdef __cplusplus
10 #define RABIT_EXTERN_C extern "C"
11 #include <cstdio>
12 #else
13 #define RABIT_EXTERN_C
14 #include <stdio.h>
15 #endif // __cplusplus
16
17 #if defined(_MSC_VER) || defined(_WIN32)
18 #define RABIT_DLL RABIT_EXTERN_C __declspec(dllexport)
19 #else
20 #define RABIT_DLL RABIT_EXTERN_C
21 #endif // defined(_MSC_VER) || defined(_WIN32)
22
23 /*! \brief rabit unsigned long type */
24 typedef unsigned long rbt_ulong; // NOLINT(*)
25
26 /*!
27 * \brief intialize the rabit module,
28 * call this once before using anything
29 * The additional arguments is not necessary.
30 * Usually rabit will detect settings
31 * from environment variables.
32 * \param argc number of arguments in argv
33 * \param argv the array of input arguments
34 * \return true if rabit is initialized successfully otherwise false
35 */
36 RABIT_DLL bool RabitInit(int argc, char *argv[]);
37
38 /*!
39 * \brief finalize the rabit engine,
40 * call this function after you finished all jobs.
41 * \return true if rabit is initialized successfully otherwise false
42 */
43 RABIT_DLL bool RabitFinalize(void);
44
45 /*!
46 * \brief get rank of previous process in ring topology
47 * \return rank number of worker
48 * */
49 RABIT_DLL int RabitGetRingPrevRank(void);
50
51 /*!
52 * \brief get rank of current process
53 * \return rank number of worker
54 * */
55 RABIT_DLL int RabitGetRank(void);
56
57 /*!
58 * \brief get total number of process
59 * \return total world size
60 * */
61 RABIT_DLL int RabitGetWorldSize(void);
62
63 /*!
64 * \brief get rank of current process
65 * \return if rabit is distributed
66 * */
67 RABIT_DLL int RabitIsDistributed(void);
68
69 /*!
70 * \brief print the msg to the tracker,
71 * this function can be used to communicate the information of the progress to
72 * the user who monitors the tracker
73 * \param msg the message to be printed
74 */
75 RABIT_DLL void RabitTrackerPrint(const char *msg);
76 /*!
77 * \brief get name of processor
78 * \param out_name hold output string
79 * \param out_len hold length of output string
80 * \param max_len maximum buffer length of input
81 */
82 RABIT_DLL void RabitGetProcessorName(char *out_name,
83 rbt_ulong *out_len,
84 rbt_ulong max_len);
85 /*!
86 * \brief broadcast an memory region to all others from root
87 *
88 * Example: int a = 1; Broadcast(&a, sizeof(a), root);
89 * \param sendrecv_data the pointer to send or recive buffer,
90 * \param size the size of the data
91 * \param root the root of process
92 */
93 RABIT_DLL void RabitBroadcast(void *sendrecv_data,
94 rbt_ulong size, int root);
95
96 /*!
97 * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
98 * the data provided by current node k is [slice_begin, slice_end),
99 * the next node's segment must start with slice_end
100 * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
101 * use a ring based algorithm
102 *
103 * \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
104 * \param total_size total size of data to be gathered
105 * \param beginIndex beginning of the current slice in sendrecvbuf of type enum_dtype
106 * \param size_node_slice size of the current node slice
107 * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
108 * \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include
109 * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
110 * \sa ReturnType
111 */
112 RABIT_DLL void RabitAllgather(void *sendrecvbuf,
113 size_t total_size,
114 size_t beginIndex,
115 size_t size_node_slice,
116 size_t size_prev_slice,
117 int enum_dtype);
118
119 /*!
120 * \brief perform in-place allreduce, on sendrecvbuf
121 * this function is NOT thread-safe
122 *
123 * Example Usage: the following code gives sum of the result
124 * vector<int> data(10);
125 * ...
126 * Allreduce<op::Sum>(&data[0], data.size());
127 * ...
128 * \param sendrecvbuf buffer for both sending and recving data
129 * \param count number of elements to be reduced
130 * \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include
131 * \param enum_op the enumeration of operation type, see rabit::engine::mpi::OpType in engine.h of rabit
132 * \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
133 * will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
134 * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
135 * \param prepare_arg argument used to passed into the lazy preprocessing function
136 */
137 RABIT_DLL void RabitAllreduce(void *sendrecvbuf,
138 size_t count,
139 int enum_dtype,
140 int enum_op,
141 void (*prepare_fun)(void *arg),
142 void *prepare_arg);
143
144 /*!
145 * \brief load latest check point
146 * \param out_global_model hold output of serialized global_model
147 * \param out_global_len the output length of serialized global model
148 * \param out_local_model hold output of serialized local_model, can be NULL
149 * \param out_local_len the output length of serialized local model, can be NULL
150 *
151 * \return the version number of check point loaded
152 * if returned version == 0, this means no model has been CheckPointed
153 * nothing will be touched
154 */
155 RABIT_DLL int RabitLoadCheckPoint(char **out_global_model,
156 rbt_ulong *out_global_len,
157 char **out_local_model,
158 rbt_ulong *out_local_len);
159 /*!
160 * \brief checkpoint the model, meaning we finished a stage of execution
161 * every time we call check point, there is a version number which will increase by one
162 *
163 * \param global_model hold content of serialized global_model
164 * \param global_len the content length of serialized global model
165 * \param local_model hold content of serialized local_model, can be NULL
166 * \param local_len the content length of serialized local model, can be NULL
167 *
168 * NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
169 * bring replication cost in CheckPoint function. global_model do not need explicit replication.
170 * So only CheckPoint with global_model if possible
171 */
172 RABIT_DLL void RabitCheckPoint(const char *global_model,
173 rbt_ulong global_len,
174 const char *local_model,
175 rbt_ulong local_len);
176 /*!
177 * \return version number of current stored model,
178 * which means how many calls to CheckPoint we made so far
179 * \return rabit version number
180 */
181 RABIT_DLL int RabitVersionNumber(void);
182
183
184 /*!
185 * \brief a Dummy function,
186 * used to cause force link of C API into the DLL.
187 * \code
188 * \/\/force link rabit C API library.
189 * static int must_link_rabit_ = RabitLinkTag();
190 * \endcode
191 * \return a dummy integer.
192 */
193 RABIT_DLL int RabitLinkTag(void);
194
195 #endif // RABIT_C_API_H_
0 /*!
1 * Copyright (c) 2014 by Contributors
2 * \file engine.h
3 * \brief This file defines the core interface of rabit library
4 * \author Tianqi Chen, Nacho, Tianyi
5 */
6 #ifndef RABIT_INTERNAL_ENGINE_H_
7 #define RABIT_INTERNAL_ENGINE_H_
8 #include <string>
9 #include "rabit/serializable.h"
10
11 #if (defined(__GNUC__) && !defined(__clang__))
12 #define _FILE __builtin_FILE()
13 #define _LINE __builtin_LINE()
14 #define _CALLER __builtin_FUNCTION()
15 #else
16 #define _FILE "N/A"
17 #define _LINE -1
18 #define _CALLER "N/A"
19 #endif // (defined(__GNUC__) && !defined(__clang__))
20
21 namespace MPI {
22 /*! \brief MPI data type just to be compatible with MPI reduce function*/
23 class Datatype;
24 }
25
26 /*! \brief namespace of rabit */
27 namespace rabit {
28 /*! \brief core interface of the engine */
29 namespace engine {
30 /*! \brief interface of core Allreduce engine */
31 class IEngine {
32 public:
33 /*!
34 * \brief Preprocessing function, that is called before AllReduce,
35 * used to prepare the data used by AllReduce
36 * \param arg additional possible argument used to invoke the preprocessor
37 */
38 typedef void (PreprocFunction) (void *arg);
39 /*!
40 * \brief reduce function, the same form of MPI reduce function is used,
41 * to be compatible with MPI interface
42 * In all the functions, the memory is ensured to aligned to 64-bit
43 * which means it is OK to cast src,dst to double* int* etc
44 * \param src pointer to source space
45 * \param dst pointer to destination reduction
46 * \param count total number of elements to be reduced (note this is total number of elements instead of bytes)
47 * the definition of the reduce function should be type aware
48 * \param dtype the data type object, to be compatible with MPI reduce
49 */
50 typedef void (ReduceFunction) (const void *src,
51 void *dst, int count,
52 const MPI::Datatype &dtype);
53 /*! \brief virtual destructor */
54 virtual ~IEngine() {}
55 /*!
56 * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
57 * the data provided by current node k is [slice_begin, slice_end),
58 * the next node's segment must start with slice_end
59 * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
60 * use a ring based algorithm
61 *
62 * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
63 * \param total_size total size of data to be gathered
64 * \param slice_begin beginning of the current slice
65 * \param slice_end end of the current slice
66 * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
67 * \param _file caller file name used to generate unique cache key
68 * \param _line caller line number used to generate unique cache key
69 * \param _caller caller function name used to generate unique cache key
70 */
71 virtual void Allgather(void *sendrecvbuf,
72 size_t total_size,
73 size_t slice_begin,
74 size_t slice_end,
75 size_t size_prev_slice,
76 const char* _file = _FILE,
77 const int _line = _LINE,
78 const char* _caller = _CALLER) = 0;
79 /*!
80 * \brief performs in-place Allreduce, on sendrecvbuf
81 * this function is NOT thread-safe
82 * \param sendrecvbuf_ buffer for both sending and receiving data
83 * \param type_nbytes the number of bytes the type has
84 * \param count number of elements to be reduced
85 * \param reducer reduce function
86 * \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
87 * will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
88 * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
89 * \param prepare_arg argument used to pass into the lazy preprocessing function
90 * \param _file caller file name used to generate unique cache key
91 * \param _line caller line number used to generate unique cache key
92 * \param _caller caller function name used to generate unique cache key
93 */
94 virtual void Allreduce(void *sendrecvbuf_,
95 size_t type_nbytes,
96 size_t count,
97 ReduceFunction reducer,
98 PreprocFunction prepare_fun = NULL,
99 void *prepare_arg = NULL,
100 const char* _file = _FILE,
101 const int _line = _LINE,
102 const char* _caller = _CALLER) = 0;
103 /*!
104 * \brief broadcasts data from root to every other node
105 * \param sendrecvbuf_ buffer for both sending and receiving data
106 * \param size the size of the data to be broadcasted
107 * \param root the root worker id to broadcast the data
108 * \param _file caller file name used to generate unique cache key
109 * \param _line caller line number used to generate unique cache key
110 * \param _caller caller function name used to generate unique cache key
111 */
112 virtual void Broadcast(void *sendrecvbuf_, size_t size, int root,
113 const char* _file = _FILE,
114 const int _line = _LINE,
115 const char* _caller = _CALLER) = 0;
116 /*!
117 * \brief explicitly re-initialize everything before calling LoadCheckPoint
118 * call this function when IEngine throws an exception,
119 * this function should only be used for test purposes
120 */
121 virtual void InitAfterException(void) = 0;
122 /*!
123 * \brief loads the latest check point
124 * \param global_model pointer to the globally shared model/state
125 * when calling this function, the caller needs to guarantee that the global_model
126 * is the same in all nodes
127 * \param local_model pointer to the local model that is specific to current node/rank
128 * this can be NULL when no local model is needed
129 *
130 * \return the version number of the model loaded
131 * if returned version == 0, this means no model has been CheckPointed
132 * the p_model is not touched, users should do necessary initialization by themselves
133 *
134 * Common usage example:
135 * int iter = rabit::LoadCheckPoint(&model);
136 * if (iter == 0) model.InitParameters();
137 * for (i = iter; i < max_iter; ++i) {
138 * do many things, include allreduce
139 * rabit::CheckPoint(model);
140 * }
141 *
142 * \sa CheckPoint, VersionNumber
143 */
144 virtual int LoadCheckPoint(Serializable *global_model,
145 Serializable *local_model = NULL) = 0;
146 /*!
147 * \brief checkpoints the model, meaning a stage of execution was finished
148 * every time we call check point, a version number increases by ones
149 *
150 * \param global_model pointer to the globally shared model/state
151 * when calling this function, the caller needs to guarantee that the global_model
152 * is the same in every node
153 * \param local_model pointer to the local model that is specific to current node/rank
154 * this can be NULL when no local state is needed
155 *
156 * NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
157 * bring replication cost in CheckPoint function. global_model does not need explicit replication.
158 * So, only CheckPoint with global_model if possible
159 *
160 * \sa LoadCheckPoint, VersionNumber
161 */
162 virtual void CheckPoint(const Serializable *global_model,
163 const Serializable *local_model = NULL) = 0;
164 /*!
165 * \brief This function can be used to replace CheckPoint for global_model only,
166 * when certain condition is met (see detailed explanation).
167 *
168 * This is a "lazy" checkpoint such that only the pointer to global_model is
169 * remembered and no memory copy is taken. To use this function, the user MUST ensure that:
170 * The global_model must remain unchanged until the last call of Allreduce/Broadcast in the current version finishes.
171 * In other words, global_model can be changed only between the last call of
172 * Allreduce/Broadcast and LazyCheckPoint in the current version
173 *
174 * For example, suppose the calling sequence is:
175 * LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
176 *
177 * If the user can only change global_model in code3, then LazyCheckPoint can be used to
178 * improve the efficiency of the program.
179 * \param global_model pointer to the globally shared model/state
180 * when calling this function, the caller needs to guarantee that global_model
181 * is the same in every node
182 * \sa LoadCheckPoint, CheckPoint, VersionNumber
183 */
184 virtual void LazyCheckPoint(const Serializable *global_model) = 0;
185 /*!
186 * \return version number of the current stored model,
187 * which means how many calls to CheckPoint we made so far
188 * \sa LoadCheckPoint, CheckPoint
189 */
190 virtual int VersionNumber(void) const = 0;
191 /*! \brief gets rank of previous node in ring topology */
192 virtual int GetRingPrevRank(void) const = 0;
193 /*! \brief gets rank of current node */
194 virtual int GetRank(void) const = 0;
195 /*! \brief gets total number of nodes */
196 virtual int GetWorldSize(void) const = 0;
197 /*! \brief whether we run in distribted mode */
198 virtual bool IsDistributed(void) const = 0;
199 /*! \brief gets the host name of the current node */
200 virtual std::string GetHost(void) const = 0;
201 /*!
202 * \brief prints the msg in the tracker,
203 * this function can be used to communicate progress information to
204 * the user who monitors the tracker
205 * \param msg message to be printed in the tracker
206 */
207 virtual void TrackerPrint(const std::string &msg) = 0;
208 };
209
210 /*! \brief initializes the engine module */
211 bool Init(int argc, char *argv[]);
212 /*! \brief finalizes the engine module */
213 bool Finalize(void);
214 /*! \brief singleton method to get engine */
215 IEngine *GetEngine(void);
216
217 /*! \brief namespace that contains stubs to be compatible with MPI */
218 namespace mpi {
219 /*!\brief enum of all operators */
220 enum OpType {
221 kMax = 0,
222 kMin = 1,
223 kSum = 2,
224 kBitwiseOR = 3
225 };
226 /*!\brief enum of supported data types */
227 enum DataType {
228 kChar = 0,
229 kUChar = 1,
230 kInt = 2,
231 kUInt = 3,
232 kLong = 4,
233 kULong = 5,
234 kFloat = 6,
235 kDouble = 7,
236 kLongLong = 8,
237 kULongLong = 9
238 };
239 } // namespace mpi
240 /*!
241 * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
242 * the data provided by current node k is [slice_begin, slice_end),
243 * the next node's segment must start with slice_end
244 * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
245 * use a ring based algorithm
246 *
247 * \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
248 * \param total_size total size of data to be gathered
249 * \param slice_begin beginning of the current slice
250 * \param slice_end end of the current slice
251 * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
252 * \param _file caller file name used to generate unique cache key
253 * \param _line caller line number used to generate unique cache key
254 * \param _caller caller function name used to generate unique cache key
255 */
256 void Allgather(void* sendrecvbuf,
257 size_t total_size,
258 size_t slice_begin,
259 size_t slice_end,
260 size_t size_prev_slice,
261 const char* _file = _FILE,
262 const int _line = _LINE,
263 const char* _caller = _CALLER);
264 /*!
265 * \brief perform in-place Allreduce, on sendrecvbuf
266 * this is an internal function used by rabit to be able to compile with MPI
267 * do not use this function directly
268 * \param sendrecvbuf buffer for both sending and receiving data
269 * \param type_nbytes the number of bytes the type has
270 * \param count number of elements to be reduced
271 * \param reducer reduce function
272 * \param dtype the data type
273 * \param op the reduce operator type
274 * \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
275 * will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_.
276 * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
277 * \param prepare_arg argument used to pass into the lazy preprocessing function.
278 * \param _file caller file name used to generate unique cache key
279 * \param _line caller line number used to generate unique cache key
280 * \param _caller caller function name used to generate unique cache key
281 */
282 void Allreduce_(void *sendrecvbuf,
283 size_t type_nbytes,
284 size_t count,
285 IEngine::ReduceFunction red,
286 mpi::DataType dtype,
287 mpi::OpType op,
288 IEngine::PreprocFunction prepare_fun = NULL,
289 void *prepare_arg = NULL,
290 const char* _file = _FILE,
291 const int _line = _LINE,
292 const char* _caller = _CALLER);
293 /*!
294 * \brief handle for customized reducer, used to handle customized reduce
295 * this class is mainly created for compatiblity issues with MPI's customized reduce
296 */
297 class ReduceHandle {
298 public:
299 // constructor
300 ReduceHandle(void);
301 // destructor
302 ~ReduceHandle(void);
303 /*!
304 * \brief initialize the reduce function,
305 * with the type the reduce function needs to deal with
306 * the reduce function MUST be communicative
307 */
308 void Init(IEngine::ReduceFunction redfunc, size_t type_nbytes);
309 /*!
310 * \brief customized in-place all reduce operation
311 * \param sendrecvbuf the in place send-recv buffer
312 * \param type_n4bytes size of the type, in terms of 4bytes
313 * \param count number of elements to send
314 * \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
315 * will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf_.
316 * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
317 * \param prepare_arg argument used to pass into the lazy preprocessing function
318 * \param _file caller file name used to generate unique cache key
319 * \param _line caller line number used to generate unique cache key
320 * \param _caller caller function name used to generate unique cache key
321 */
322 void Allreduce(void *sendrecvbuf,
323 size_t type_nbytes,
324 size_t count,
325 IEngine::PreprocFunction prepare_fun = NULL,
326 void *prepare_arg = NULL,
327 const char* _file = _FILE,
328 const int _line = _LINE,
329 const char* _caller = _CALLER);
330 /*! \return the number of bytes occupied by the type */
331 static int TypeSize(const MPI::Datatype &dtype);
332
333 protected:
334 // handle function field
335 void *handle_;
336 // reduce function of the reducer
337 IEngine::ReduceFunction *redfunc_;
338 // handle to the type field
339 void *htype_;
340 // the created type in 4 bytes
341 size_t created_type_nbytes_;
342 };
343 } // namespace engine
344 } // namespace rabit
345 #endif // RABIT_INTERNAL_ENGINE_H_
0 /*!
1 * Copyright (c) 2014-2019 by Contributors
2 * \file io.h
3 * \brief utilities with different serializable implementations
4 * \author Tianqi Chen
5 */
6 #ifndef RABIT_INTERNAL_IO_H_
7 #define RABIT_INTERNAL_IO_H_
8 #include <cstdio>
9 #include <vector>
10 #include <cstring>
11 #include <string>
12 #include <algorithm>
13 #include <numeric>
14 #include <limits>
15 #include "rabit/internal/utils.h"
16 #include "rabit/serializable.h"
17
18 namespace rabit {
19 namespace utils {
20 /*! \brief re-use definition of dmlc::SeekStream */
21 typedef dmlc::SeekStream SeekStream;
22 /*! \brief fixed size memory buffer */
23 struct MemoryFixSizeBuffer : public SeekStream {
24 public:
25 // similar to SEEK_END in libc
26 static size_t constexpr SeekEnd = std::numeric_limits<size_t>::max();
27
28 public:
29 MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
30 : p_buffer_(reinterpret_cast<char*>(p_buffer)),
31 buffer_size_(buffer_size) {
32 curr_ptr_ = 0;
33 }
34 virtual ~MemoryFixSizeBuffer(void) {}
35 virtual size_t Read(void *ptr, size_t size) {
36 size_t nread = std::min(buffer_size_ - curr_ptr_, size);
37 if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
38 curr_ptr_ += nread;
39 return nread;
40 }
41 virtual void Write(const void *ptr, size_t size) {
42 if (size == 0) return;
43 utils::Assert(curr_ptr_ + size <= buffer_size_,
44 "write position exceed fixed buffer size");
45 std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
46 curr_ptr_ += size;
47 }
48 virtual void Seek(size_t pos) {
49 if (pos == SeekEnd) {
50 curr_ptr_ = buffer_size_;
51 } else {
52 curr_ptr_ = static_cast<size_t>(pos);
53 }
54 }
55 virtual size_t Tell(void) {
56 return curr_ptr_;
57 }
58 virtual bool AtEnd(void) const {
59 return curr_ptr_ == buffer_size_;
60 }
61
62 private:
63 /*! \brief in memory buffer */
64 char *p_buffer_;
65 /*! \brief current pointer */
66 size_t buffer_size_;
67 /*! \brief current pointer */
68 size_t curr_ptr_;
69 }; // class MemoryFixSizeBuffer
70
71 /*! \brief a in memory buffer that can be read and write as stream interface */
72 struct MemoryBufferStream : public SeekStream {
73 public:
74 explicit MemoryBufferStream(std::string *p_buffer)
75 : p_buffer_(p_buffer) {
76 curr_ptr_ = 0;
77 }
78 virtual ~MemoryBufferStream(void) {}
79 virtual size_t Read(void *ptr, size_t size) {
80 utils::Assert(curr_ptr_ <= p_buffer_->length(),
81 "read can not have position excceed buffer length");
82 size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
83 if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
84 curr_ptr_ += nread;
85 return nread;
86 }
87 virtual void Write(const void *ptr, size_t size) {
88 if (size == 0) return;
89 if (curr_ptr_ + size > p_buffer_->length()) {
90 p_buffer_->resize(curr_ptr_+size);
91 }
92 std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
93 curr_ptr_ += size;
94 }
95 virtual void Seek(size_t pos) {
96 curr_ptr_ = static_cast<size_t>(pos);
97 }
98 virtual size_t Tell(void) {
99 return curr_ptr_;
100 }
101 virtual bool AtEnd(void) const {
102 return curr_ptr_ == p_buffer_->length();
103 }
104
105 private:
106 /*! \brief in memory buffer */
107 std::string *p_buffer_;
108 /*! \brief current pointer */
109 size_t curr_ptr_;
110 }; // class MemoryBufferStream
111 } // namespace utils
112 } // namespace rabit
113 #endif // RABIT_INTERNAL_IO_H_
0 /*!
1 * Copyright (c) 2014-2019 by Contributors
2 * \file rabit-inl.h
3 * \brief implementation of inline template function for rabit interface
4 *
5 * \author Tianqi Chen
6 */
7 #ifndef RABIT_INTERNAL_RABIT_INL_H_
8 #define RABIT_INTERNAL_RABIT_INL_H_
9 // use engine for implementation
10 #include <vector>
11 #include <string>
12 #include "rabit/internal/io.h"
13 #include "rabit/internal/utils.h"
14 #include "rabit/rabit.h"
15
16 namespace rabit {
17 namespace engine {
18 namespace mpi {
19 // template function to translate type to enum indicator
20 template<typename DType>
21 inline DataType GetType(void);
22 template<>
23 inline DataType GetType<char>(void) {
24 return kChar;
25 }
26 template<>
27 inline DataType GetType<unsigned char>(void) {
28 return kUChar;
29 }
30 template<>
31 inline DataType GetType<int>(void) {
32 return kInt;
33 }
34 template<>
35 inline DataType GetType<unsigned int>(void) { // NOLINT(*)
36 return kUInt;
37 }
38 template<>
39 inline DataType GetType<long>(void) { // NOLINT(*)
40 return kLong;
41 }
42 template<>
43 inline DataType GetType<unsigned long>(void) { // NOLINT(*)
44 return kULong;
45 }
46 template<>
47 inline DataType GetType<float>(void) {
48 return kFloat;
49 }
50 template<>
51 inline DataType GetType<double>(void) {
52 return kDouble;
53 }
54 template<>
55 inline DataType GetType<long long>(void) { // NOLINT(*)
56 return kLongLong;
57 }
58 template<>
59 inline DataType GetType<unsigned long long>(void) { // NOLINT(*)
60 return kULongLong;
61 }
62 } // namespace mpi
63 } // namespace engine
64
65 namespace op {
66 struct Max {
67 static const engine::mpi::OpType kType = engine::mpi::kMax;
68 template<typename DType>
69 inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
70 if (dst < src) dst = src;
71 }
72 };
73 struct Min {
74 static const engine::mpi::OpType kType = engine::mpi::kMin;
75 template<typename DType>
76 inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
77 if (dst > src) dst = src;
78 }
79 };
80 struct Sum {
81 static const engine::mpi::OpType kType = engine::mpi::kSum;
82 template<typename DType>
83 inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
84 dst += src;
85 }
86 };
87 struct BitOR {
88 static const engine::mpi::OpType kType = engine::mpi::kBitwiseOR;
89 template<typename DType>
90 inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
91 dst |= src;
92 }
93 };
94 template<typename OP, typename DType>
95 inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
96 const DType* src = (const DType*)src_;
97 DType* dst = (DType*)dst_; // NOLINT(*)
98 for (int i = 0; i < len; i++) {
99 OP::Reduce(dst[i], src[i]);
100 }
101 }
102 } // namespace op
103
104 // intialize the rabit engine
105 inline bool Init(int argc, char *argv[]) {
106 return engine::Init(argc, argv);
107 }
108 // finalize the rabit engine
109 inline bool Finalize(void) {
110 return engine::Finalize();
111 }
112 // get the rank of the previous worker in ring topology
113 inline int GetRingPrevRank(void) {
114 return engine::GetEngine()->GetRingPrevRank();
115 }
116 // get the rank of current process
117 inline int GetRank(void) {
118 return engine::GetEngine()->GetRank();
119 }
120 // the the size of the world
121 inline int GetWorldSize(void) {
122 return engine::GetEngine()->GetWorldSize();
123 }
124 // whether rabit is distributed
125 inline bool IsDistributed(void) {
126 return engine::GetEngine()->IsDistributed();
127 }
128 // get the name of current processor
129 inline std::string GetProcessorName(void) {
130 return engine::GetEngine()->GetHost();
131 }
132 // broadcast data to all other nodes from root
133 inline void Broadcast(void *sendrecv_data, size_t size, int root,
134 const char* _file,
135 const int _line,
136 const char* _caller) {
137 engine::GetEngine()->Broadcast(sendrecv_data, size, root,
138 _file, _line, _caller);
139 }
140 template<typename DType>
141 inline void Broadcast(std::vector<DType> *sendrecv_data, int root,
142 const char* _file,
143 const int _line,
144 const char* _caller) {
145 size_t size = sendrecv_data->size();
146 Broadcast(&size, sizeof(size), root, _file, _line, _caller);
147 if (sendrecv_data->size() != size) {
148 sendrecv_data->resize(size);
149 }
150 if (size != 0) {
151 Broadcast(&(*sendrecv_data)[0], size * sizeof(DType), root,
152 _file, _line, _caller);
153 }
154 }
155 inline void Broadcast(std::string *sendrecv_data, int root,
156 const char* _file,
157 const int _line,
158 const char* _caller) {
159 size_t size = sendrecv_data->length();
160 Broadcast(&size, sizeof(size), root, _file, _line, _caller);
161 if (sendrecv_data->length() != size) {
162 sendrecv_data->resize(size);
163 }
164 if (size != 0) {
165 Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root,
166 _file, _line, _caller);
167 }
168 }
169
170 // perform inplace Allreduce
171 template<typename OP, typename DType>
172 inline void Allreduce(DType *sendrecvbuf, size_t count,
173 void (*prepare_fun)(void *arg),
174 void *prepare_arg,
175 const char* _file,
176 const int _line,
177 const char* _caller) {
178 engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
179 engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg,
180 _file, _line, _caller);
181 }
182
183 // C++11 support for lambda prepare function
184 #if DMLC_USE_CXX11
185 inline void InvokeLambda_(void *fun) {
186 (*static_cast<std::function<void()>*>(fun))();
187 }
188 template<typename OP, typename DType>
189 inline void Allreduce(DType *sendrecvbuf, size_t count,
190 std::function<void()> prepare_fun,
191 const char* _file,
192 const int _line,
193 const char* _caller) {
194 engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
195 engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun,
196 _file, _line, _caller);
197 }
198
199 // Performs inplace Allgather
200 template<typename DType>
201 inline void Allgather(DType *sendrecvbuf,
202 size_t totalSize,
203 size_t beginIndex,
204 size_t sizeNodeSlice,
205 size_t sizePrevSlice,
206 const char* _file,
207 const int _line,
208 const char* _caller) {
209 engine::GetEngine()->Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType),
210 (beginIndex + sizeNodeSlice) * sizeof(DType),
211 sizePrevSlice * sizeof(DType), _file, _line, _caller);
212 }
213 #endif // C++11
214
215 // print message to the tracker
216 inline void TrackerPrint(const std::string &msg) {
217 engine::GetEngine()->TrackerPrint(msg);
218 }
219 #ifndef RABIT_STRICT_CXX98_
220 inline void TrackerPrintf(const char *fmt, ...) {
221 const int kPrintBuffer = 1 << 10;
222 std::string msg(kPrintBuffer, '\0');
223 va_list args;
224 va_start(args, fmt);
225 vsnprintf(&msg[0], kPrintBuffer, fmt, args);
226 va_end(args);
227 msg.resize(strlen(msg.c_str()));
228 TrackerPrint(msg);
229 }
230
231 #endif // RABIT_STRICT_CXX98_
232 // load latest check point
233 inline int LoadCheckPoint(Serializable *global_model,
234 Serializable *local_model) {
235 return engine::GetEngine()->LoadCheckPoint(global_model, local_model);
236 }
237 // checkpoint the model, meaning we finished a stage of execution
238 inline void CheckPoint(const Serializable *global_model,
239 const Serializable *local_model) {
240 engine::GetEngine()->CheckPoint(global_model, local_model);
241 }
242 // lazy checkpoint the model, only remember the pointer to global_model
243 inline void LazyCheckPoint(const Serializable *global_model) {
244 engine::GetEngine()->LazyCheckPoint(global_model);
245 }
246 // return the version number of currently stored model
247 inline int VersionNumber(void) {
248 return engine::GetEngine()->VersionNumber();
249 }
250 // ---------------------------------
251 // Code to handle customized Reduce
252 // ---------------------------------
253 // function to perform reduction for Reducer
254 template<typename DType, void (*freduce)(DType &dst, const DType &src)>
255 inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
256 const size_t kUnit = sizeof(DType);
257 const char *psrc = reinterpret_cast<const char*>(src_);
258 char *pdst = reinterpret_cast<char*>(dst_);
259
260 for (int i = 0; i < len_; ++i) {
261 DType tdst, tsrc;
262 // use memcpy to avoid alignment issue
263 std::memcpy(&tdst, pdst + (i * kUnit), sizeof(tdst));
264 std::memcpy(&tsrc, psrc + (i * kUnit), sizeof(tsrc));
265 freduce(tdst, tsrc);
266 std::memcpy(pdst + i * kUnit, &tdst, sizeof(tdst));
267 }
268 }
269 // function to perform reduction for Reducer
270 template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
271 inline void ReducerAlign_(const void *src_, void *dst_,
272 int len_, const MPI::Datatype &dtype) {
273 const DType *psrc = reinterpret_cast<const DType*>(src_);
274 DType *pdst = reinterpret_cast<DType*>(dst_);
275 for (int i = 0; i < len_; ++i) {
276 freduce(pdst[i], psrc[i]);
277 }
278 }
279 template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
280 inline Reducer<DType, freduce>::Reducer(void) {
281 // it is safe to directly use handle for aligned data types
282 if (sizeof(DType) == 8 || sizeof(DType) == 4 || sizeof(DType) == 1) {
283 this->handle_.Init(ReducerAlign_<DType, freduce>, sizeof(DType));
284 } else {
285 this->handle_.Init(ReducerSafe_<DType, freduce>, sizeof(DType));
286 }
287 }
288 template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
289 inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
290 void (*prepare_fun)(void *arg),
291 void *prepare_arg,
292 const char* _file,
293 const int _line,
294 const char* _caller) {
295 handle_.Allreduce(sendrecvbuf, sizeof(DType), count, prepare_fun,
296 prepare_arg, _file, _line, _caller);
297 }
298 // function to perform reduction for SerializeReducer
299 template<typename DType>
300 inline void SerializeReducerFunc_(const void *src_, void *dst_,
301 int len_, const MPI::Datatype &dtype) {
302 int nbytes = engine::ReduceHandle::TypeSize(dtype);
303 // temp space
304 for (int i = 0; i < len_; ++i) {
305 DType tsrc, tdst;
306 utils::MemoryFixSizeBuffer fsrc((char*)(src_) + i * nbytes, nbytes); // NOLINT(*)
307 utils::MemoryFixSizeBuffer fdst((char*)(dst_) + i * nbytes, nbytes); // NOLINT(*)
308 tsrc.Load(fsrc);
309 tdst.Load(fdst);
310 // govern const check
311 tdst.Reduce(static_cast<const DType &>(tsrc), nbytes);
312 fdst.Seek(0);
313 tdst.Save(fdst);
314 }
315 }
316 template<typename DType>
317 inline SerializeReducer<DType>::SerializeReducer(void) {
318 handle_.Init(SerializeReducerFunc_<DType>, sizeof(DType));
319 }
320 // closure to call Allreduce
321 template<typename DType>
322 struct SerializeReduceClosure {
323 DType *sendrecvobj;
324 size_t max_nbyte, count;
325 void (*prepare_fun)(void *arg);
326 void *prepare_arg;
327 std::string *p_buffer;
328 // invoke the closure
329 inline void Run(void) {
330 if (prepare_fun != NULL) prepare_fun(prepare_arg);
331 for (size_t i = 0; i < count; ++i) {
332 utils::MemoryFixSizeBuffer fs(BeginPtr(*p_buffer) + i * max_nbyte, max_nbyte);
333 sendrecvobj[i].Save(fs);
334 }
335 }
336 inline static void Invoke(void *c) {
337 static_cast<SerializeReduceClosure<DType>*>(c)->Run();
338 }
339 };
340 template<typename DType>
341 inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
342 size_t max_nbyte, size_t count,
343 void (*prepare_fun)(void *arg),
344 void *prepare_arg,
345 const char* _file,
346 const int _line,
347 const char* _caller) {
348 buffer_.resize(max_nbyte * count);
349 // setup closure
350 SerializeReduceClosure<DType> c;
351 c.sendrecvobj = sendrecvobj; c.max_nbyte = max_nbyte; c.count = count;
352 c.prepare_fun = prepare_fun; c.prepare_arg = prepare_arg; c.p_buffer = &buffer_;
353 // invoke here
354 handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count,
355 SerializeReduceClosure<DType>::Invoke, &c,
356 _file, _line, _caller);
357 for (size_t i = 0; i < count; ++i) {
358 utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
359 sendrecvobj[i].Load(fs);
360 }
361 }
362
363 #if DMLC_USE_CXX11
364 template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)g
365 inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
366 std::function<void()> prepare_fun,
367 const char* _file,
368 const int _line,
369 const char* _caller) {
370 this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun,
371 _file, _line, _caller);
372 }
373 template<typename DType>
374 inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
375 size_t max_nbytes, size_t count,
376 std::function<void()> prepare_fun,
377 const char* _file,
378 const int _line,
379 const char* _caller) {
380 this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda_, &prepare_fun,
381 _file, _line, _caller);
382 }
383 #endif // DMLC_USE_CXX11
384 } // namespace rabit
385 #endif // RABIT_INTERNAL_RABIT_INL_H_
0 /*!
1 * Copyright (c) 2014-2019 by Contributors
2 * \file socket.h
3 * \brief this file aims to provide a wrapper of sockets
4 * \author Tianqi Chen
5 */
6 #ifndef RABIT_INTERNAL_SOCKET_H_
7 #define RABIT_INTERNAL_SOCKET_H_
8 #if defined(_WIN32)
9 #include <winsock2.h>
10 #include <ws2tcpip.h>
11 #ifdef _MSC_VER
12 #pragma comment(lib, "Ws2_32.lib")
13 #endif // _MSC_VER
14 #else
15 #include <fcntl.h>
16 #include <netdb.h>
17 #include <errno.h>
18 #include <unistd.h>
19 #include <arpa/inet.h>
20 #include <netinet/in.h>
21 #include <sys/socket.h>
22 #include <sys/ioctl.h>
23 #endif // defined(_WIN32)
24 #include <string>
25 #include <cstring>
26 #include <vector>
27 #include <unordered_map>
28 #include "utils.h"
29
30 #if defined(_WIN32) || defined(__MINGW32__)
31 typedef int ssize_t;
32 #endif // defined(_WIN32) || defined(__MINGW32__)
33
34 #if defined(_WIN32)
35 typedef int sock_size_t;
36
37 static inline int poll(struct pollfd *pfd, int nfds,
38 int timeout) { return WSAPoll ( pfd, nfds, timeout ); }
39 #else
40 #include <sys/poll.h>
41 typedef int SOCKET;
42 typedef size_t sock_size_t;
43 const int INVALID_SOCKET = -1;
44 #endif // defined(_WIN32)
45
46 namespace rabit {
47 namespace utils {
48 /*! \brief data structure for network address */
49 struct SockAddr {
50 sockaddr_in addr;
51 // constructor
52 SockAddr(void) {}
53 SockAddr(const char *url, int port) {
54 this->Set(url, port);
55 }
56 inline static std::string GetHostName(void) {
57 std::string buf; buf.resize(256);
58 utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
59 return std::string(buf.c_str());
60 }
61 /*!
62 * \brief set the address
63 * \param url the url of the address
64 * \param port the port of address
65 */
66 inline void Set(const char *host, int port) {
67 addrinfo hints;
68 memset(&hints, 0, sizeof(hints));
69 hints.ai_family = AF_INET;
70 hints.ai_protocol = SOCK_STREAM;
71 addrinfo *res = NULL;
72 int sig = getaddrinfo(host, NULL, &hints, &res);
73 Check(sig == 0 && res != NULL, "cannot obtain address of %s", host);
74 Check(res->ai_family == AF_INET, "Does not support IPv6");
75 memcpy(&addr, res->ai_addr, res->ai_addrlen);
76 addr.sin_port = htons(port);
77 freeaddrinfo(res);
78 }
79 /*! \brief return port of the address*/
80 inline int port(void) const {
81 return ntohs(addr.sin_port);
82 }
83 /*! \return a string representation of the address */
84 inline std::string AddrStr(void) const {
85 std::string buf; buf.resize(256);
86 #ifdef _WIN32
87 const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
88 &buf[0], buf.length());
89 #else
90 const char *s = inet_ntop(AF_INET, &addr.sin_addr,
91 &buf[0], buf.length());
92 #endif // _WIN32
93 Assert(s != NULL, "cannot decode address");
94 return std::string(s);
95 }
96 };
97
98 /*!
99 * \brief base class containing common operations of TCP and UDP sockets
100 */
101 class Socket {
102 public:
103 /*! \brief the file descriptor of socket */
104 SOCKET sockfd;
105 // default conversion to int
106 inline operator SOCKET() const {
107 return sockfd;
108 }
109 /*!
110 * \return last error of socket operation
111 */
112 inline static int GetLastError(void) {
113 #ifdef _WIN32
114 return WSAGetLastError();
115 #else
116 return errno;
117 #endif // _WIN32
118 }
119 /*! \return whether last error was would block */
120 inline static bool LastErrorWouldBlock(void) {
121 int errsv = GetLastError();
122 #ifdef _WIN32
123 return errsv == WSAEWOULDBLOCK;
124 #else
125 return errsv == EAGAIN || errsv == EWOULDBLOCK;
126 #endif // _WIN32
127 }
128 /*!
129 * \brief start up the socket module
130 * call this before using the sockets
131 */
132 inline static void Startup(void) {
133 #ifdef _WIN32
134 WSADATA wsa_data;
135 if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
136 Socket::Error("Startup");
137 }
138 if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
139 WSACleanup();
140 utils::Error("Could not find a usable version of Winsock.dll\n");
141 }
142 #endif // _WIN32
143 }
144 /*!
145 * \brief shutdown the socket module after use, all sockets need to be closed
146 */
147 inline static void Finalize(void) {
148 #ifdef _WIN32
149 WSACleanup();
150 #endif // _WIN32
151 }
152 /*!
153 * \brief set this socket to use non-blocking mode
154 * \param non_block whether set it to be non-block, if it is false
155 * it will set it back to block mode
156 */
157 inline void SetNonBlock(bool non_block) {
158 #ifdef _WIN32
159 u_long mode = non_block ? 1 : 0;
160 if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
161 Socket::Error("SetNonBlock");
162 }
163 #else
164 int flag = fcntl(sockfd, F_GETFL, 0);
165 if (flag == -1) {
166 Socket::Error("SetNonBlock-1");
167 }
168 if (non_block) {
169 flag |= O_NONBLOCK;
170 } else {
171 flag &= ~O_NONBLOCK;
172 }
173 if (fcntl(sockfd, F_SETFL, flag) == -1) {
174 Socket::Error("SetNonBlock-2");
175 }
176 #endif // _WIN32
177 }
178 /*!
179 * \brief bind the socket to an address
180 * \param addr
181 */
182 inline void Bind(const SockAddr &addr) {
183 if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
184 sizeof(addr.addr)) == -1) {
185 Socket::Error("Bind");
186 }
187 }
188 /*!
189 * \brief try bind the socket to host, from start_port to end_port
190 * \param start_port starting port number to try
191 * \param end_port ending port number to try
192 * \return the port successfully bind to, return -1 if failed to bind any port
193 */
194 inline int TryBindHost(int start_port, int end_port) {
195 // TODO(tqchen) add prefix check
196 for (int port = start_port; port < end_port; ++port) {
197 SockAddr addr("0.0.0.0", port);
198 if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
199 sizeof(addr.addr)) == 0) {
200 return port;
201 }
202 #if defined(_WIN32)
203 if (WSAGetLastError() != WSAEADDRINUSE) {
204 Socket::Error("TryBindHost");
205 }
206 #else
207 if (errno != EADDRINUSE) {
208 Socket::Error("TryBindHost");
209 }
210 #endif // defined(_WIN32)
211 }
212
213 return -1;
214 }
215 /*! \brief get last error code if any */
216 inline int GetSockError(void) const {
217 int error = 0;
218 socklen_t len = sizeof(error);
219 if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
220 reinterpret_cast<char*>(&error), &len) != 0) {
221 Error("GetSockError");
222 }
223 return error;
224 }
225 /*! \brief check if anything bad happens */
226 inline bool BadSocket(void) const {
227 if (IsClosed()) return true;
228 int err = GetSockError();
229 if (err == EBADF || err == EINTR) return true;
230 return false;
231 }
232 /*! \brief check if socket is already closed */
233 inline bool IsClosed(void) const {
234 return sockfd == INVALID_SOCKET;
235 }
236 /*! \brief close the socket */
237 inline void Close(void) {
238 if (sockfd != INVALID_SOCKET) {
239 #ifdef _WIN32
240 closesocket(sockfd);
241 #else
242 close(sockfd);
243 #endif
244 sockfd = INVALID_SOCKET;
245 } else {
246 Error("Socket::Close double close the socket or close without create");
247 }
248 }
249 // report an socket error
250 inline static void Error(const char *msg) {
251 int errsv = GetLastError();
252 #ifdef _WIN32
253 utils::Error("Socket %s Error:WSAError-code=%d", msg, errsv);
254 #else
255 utils::Error("Socket %s Error:%s", msg, strerror(errsv));
256 #endif
257 }
258
259 protected:
260 explicit Socket(SOCKET sockfd) : sockfd(sockfd) {
261 }
262 };
263
264 /*!
265 * \brief a wrapper of TCP socket that hopefully be cross platform
266 */
267 class TCPSocket : public Socket{
268 public:
269 // constructor
270 TCPSocket(void) : Socket(INVALID_SOCKET) {
271 }
272 explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) {
273 }
274 /*!
275 * \brief enable/disable TCP keepalive
276 * \param keepalive whether to set the keep alive option on
277 */
278 void SetKeepAlive(bool keepalive) {
279 int opt = static_cast<int>(keepalive);
280 if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
281 reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
282 Socket::Error("SetKeepAlive");
283 }
284 }
285 inline void SetLinger(int timeout = 0) {
286 struct linger sl;
287 sl.l_onoff = 1; /* non-zero value enables linger option in kernel */
288 sl.l_linger = timeout; /* timeout interval in seconds */
289 if (setsockopt(sockfd, SOL_SOCKET, SO_LINGER, reinterpret_cast<char*>(&sl), sizeof(sl)) == -1) {
290 Socket::Error("SO_LINGER");
291 }
292 }
293 /*!
294 * \brief create the socket, call this before using socket
295 * \param af domain
296 */
297 inline void Create(int af = PF_INET) {
298 sockfd = socket(PF_INET, SOCK_STREAM, 0);
299 if (sockfd == INVALID_SOCKET) {
300 Socket::Error("Create");
301 }
302 }
303 /*!
304 * \brief perform listen of the socket
305 * \param backlog backlog parameter
306 */
307 inline void Listen(int backlog = 16) {
308 listen(sockfd, backlog);
309 }
310 /*! \brief get a new connection */
311 TCPSocket Accept(void) {
312 SOCKET newfd = accept(sockfd, NULL, NULL);
313 if (newfd == INVALID_SOCKET) {
314 Socket::Error("Accept");
315 }
316 return TCPSocket(newfd);
317 }
318 /*!
319 * \brief decide whether the socket is at OOB mark
320 * \return 1 if at mark, 0 if not, -1 if an error occured
321 */
322 inline int AtMark(void) const {
323 #ifdef _WIN32
324 unsigned long atmark; // NOLINT(*)
325 if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
326 #else
327 int atmark;
328 if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
329 #endif // _WIN32
330 return static_cast<int>(atmark);
331 }
332 /*!
333 * \brief connect to an address
334 * \param addr the address to connect to
335 * \return whether connect is successful
336 */
337 inline bool Connect(const SockAddr &addr) {
338 return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
339 sizeof(addr.addr)) == 0;
340 }
341 /*!
342 * \brief send data using the socket
343 * \param buf the pointer to the buffer
344 * \param len the size of the buffer
345 * \param flags extra flags
346 * \return size of data actually sent
347 * return -1 if error occurs
348 */
349 inline ssize_t Send(const void *buf_, size_t len, int flag = 0) {
350 const char *buf = reinterpret_cast<const char*>(buf_);
351 return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
352 }
353 /*!
354 * \brief receive data using the socket
355 * \param buf_ the pointer to the buffer
356 * \param len the size of the buffer
357 * \param flags extra flags
358 * \return size of data actually received
359 * return -1 if error occurs
360 */
361 inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
362 char *buf = reinterpret_cast<char*>(buf_);
363 return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
364 }
365 /*!
366 * \brief peform block write that will attempt to send all data out
367 * can still return smaller than request when error occurs
368 * \param buf the pointer to the buffer
369 * \param len the size of the buffer
370 * \return size of data actually sent
371 */
372 inline size_t SendAll(const void *buf_, size_t len) {
373 const char *buf = reinterpret_cast<const char*>(buf_);
374 size_t ndone = 0;
375 while (ndone < len) {
376 ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
377 if (ret == -1) {
378 if (LastErrorWouldBlock()) return ndone;
379 Socket::Error("SendAll");
380 }
381 buf += ret;
382 ndone += ret;
383 }
384 return ndone;
385 }
386 /*!
387 * \brief peforma block read that will attempt to read all data
388 * can still return smaller than request when error occurs
389 * \param buf_ the buffer pointer
390 * \param len length of data to recv
391 * \return size of data actually sent
392 */
393 inline size_t RecvAll(void *buf_, size_t len) {
394 char *buf = reinterpret_cast<char*>(buf_);
395 size_t ndone = 0;
396 while (ndone < len) {
397 ssize_t ret = recv(sockfd, buf,
398 static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
399 if (ret == -1) {
400 if (LastErrorWouldBlock()) return ndone;
401 Socket::Error("RecvAll");
402 }
403 if (ret == 0) return ndone;
404 buf += ret;
405 ndone += ret;
406 }
407 return ndone;
408 }
409 /*!
410 * \brief send a string over network
411 * \param str the string to be sent
412 */
413 inline void SendStr(const std::string &str) {
414 int len = static_cast<int>(str.length());
415 utils::Assert(this->SendAll(&len, sizeof(len)) == sizeof(len),
416 "error during send SendStr");
417 if (len != 0) {
418 utils::Assert(this->SendAll(str.c_str(), str.length()) == str.length(),
419 "error during send SendStr");
420 }
421 }
422 /*!
423 * \brief recv a string from network
424 * \param out_str the string to receive
425 */
426 inline void RecvStr(std::string *out_str) {
427 int len;
428 utils::Assert(this->RecvAll(&len, sizeof(len)) == sizeof(len),
429 "error during send RecvStr");
430 out_str->resize(len);
431 if (len != 0) {
432 utils::Assert(this->RecvAll(&(*out_str)[0], len) == out_str->length(),
433 "error during send SendStr");
434 }
435 }
436 };
437
438 /*! \brief helper data structure to perform poll */
439 struct PollHelper {
440 public:
441 /*!
442 * \brief add file descriptor to watch for read
443 * \param fd file descriptor to be watched
444 */
445 inline void WatchRead(SOCKET fd) {
446 auto& pfd = fds[fd];
447 pfd.fd = fd;
448 pfd.events |= POLLIN;
449 }
450 /*!
451 * \brief add file descriptor to watch for write
452 * \param fd file descriptor to be watched
453 */
454 inline void WatchWrite(SOCKET fd) {
455 auto& pfd = fds[fd];
456 pfd.fd = fd;
457 pfd.events |= POLLOUT;
458 }
459 /*!
460 * \brief add file descriptor to watch for exception
461 * \param fd file descriptor to be watched
462 */
463 inline void WatchException(SOCKET fd) {
464 auto& pfd = fds[fd];
465 pfd.fd = fd;
466 pfd.events |= POLLPRI;
467 }
468 /*!
469 * \brief Check if the descriptor is ready for read
470 * \param fd file descriptor to check status
471 */
472 inline bool CheckRead(SOCKET fd) const {
473 const auto& pfd = fds.find(fd);
474 return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
475 }
476 /*!
477 * \brief Check if the descriptor is ready for write
478 * \param fd file descriptor to check status
479 */
480 inline bool CheckWrite(SOCKET fd) const {
481 const auto& pfd = fds.find(fd);
482 return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
483 }
484 /*!
485 * \brief Check if the descriptor has any exception
486 * \param fd file descriptor to check status
487 */
488 inline bool CheckExcept(SOCKET fd) const {
489 const auto& pfd = fds.find(fd);
490 return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0);
491 }
492 /*!
493 * \brief wait for exception event on a single descriptor
494 * \param fd the file descriptor to wait the event for
495 * \param timeout the timeout counter, can be negative, which means wait until the event happen
496 * \return 1 if success, 0 if timeout, and -1 if error occurs
497 */
498 inline static int WaitExcept(SOCKET fd, long timeout = -1) { // NOLINT(*)
499 pollfd pfd;
500 pfd.fd = fd;
501 pfd.events = POLLPRI;
502 return poll(&pfd, 1, timeout);
503 }
504
505 /*!
506 * \brief peform poll on the set defined, read, write, exception
507 * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
508 * \return
509 */
510 inline void Poll(long timeout = -1) { // NOLINT(*)
511 std::vector<pollfd> fdset;
512 fdset.reserve(fds.size());
513 for (auto kv : fds) {
514 fdset.push_back(kv.second);
515 }
516 int ret = poll(fdset.data(), fdset.size(), timeout);
517 if (ret == -1) {
518 Socket::Error("Poll");
519 } else {
520 for (auto& pfd : fdset) {
521 auto revents = pfd.revents & pfd.events;
522 if (!revents) {
523 fds.erase(pfd.fd);
524 } else {
525 fds[pfd.fd].events = revents;
526 }
527 }
528 }
529 }
530
531 std::unordered_map<SOCKET, pollfd> fds;
532 };
533 } // namespace utils
534 } // namespace rabit
535 #endif // RABIT_INTERNAL_SOCKET_H_
0 /*!
1 * Copyright (c) 2015 by Contributors
2 * \file thread_local.h
3 * \brief Common utility for thread local storage.
4 */
5 #ifndef RABIT_INTERNAL_THREAD_LOCAL_H_
6 #define RABIT_INTERNAL_THREAD_LOCAL_H_
7
8 #include "../include/dmlc/base.h"
9
10 #if DMLC_ENABLE_STD_THREAD
11 #include <mutex>
12 #endif // DMLC_ENABLE_STD_THREAD
13
14 #include <memory>
15 #include <vector>
16
17 namespace rabit {
18
19 // macro hanlding for threadlocal variables
20 #ifdef __GNUC__
21 #define MX_TREAD_LOCAL __thread
22 #elif __STDC_VERSION__ >= 201112L
23 #define MX_TREAD_LOCAL _Thread_local
24 #elif defined(_MSC_VER)
25 #define MX_TREAD_LOCAL __declspec(thread)
26 #endif // __GNUC__
27
28 #ifndef MX_TREAD_LOCAL
29 #message("Warning: Threadlocal is not enabled");
30 #endif // MX_TREAD_LOCAL
31
32 /*!
33 * \brief A threadlocal store to store threadlocal variables.
34 * Will return a thread local singleton of type T
35 * \tparam T the type we like to store
36 */
37 template<typename T>
38 class ThreadLocalStore {
39 public:
40 /*! \return get a thread local singleton */
41 static T* Get() {
42 static MX_TREAD_LOCAL T* ptr = nullptr;
43 if (ptr == nullptr) {
44 ptr = new T();
45 Singleton()->RegisterDelete(ptr);
46 }
47 return ptr;
48 }
49
50 private:
51 /*! \brief constructor */
52 ThreadLocalStore() {}
53 /*! \brief destructor */
54 ~ThreadLocalStore() {
55 for (size_t i = 0; i < data_.size(); ++i) {
56 delete data_[i];
57 }
58 }
59 /*! \return singleton of the store */
60 static ThreadLocalStore<T> *Singleton() {
61 static ThreadLocalStore<T> inst;
62 return &inst;
63 }
64 /*!
65 * \brief register str for internal deletion
66 * \param str the string pointer
67 */
68 void RegisterDelete(T *str) {
69 #if DMLC_ENABLE_STD_THREAD
70 std::unique_lock<std::mutex> lock(mutex_);
71 data_.push_back(str);
72 lock.unlock();
73 #else
74 data_.push_back(str);
75 #endif // DMLC_ENABLE_STD_THREAD
76 }
77
78 #if DMLC_ENABLE_STD_THREAD
79 /*! \brief internal mutex */
80 std::mutex mutex_;
81 #endif // DMLC_ENABLE_STD_THREAD
82 /*!\brief internal data */
83 std::vector<T*> data_;
84 };
85 } // namespace rabit
86 #endif // RABIT_INTERNAL_THREAD_LOCAL_H_
0 /*!
1 * Copyright (c) 2014-2019 by Contributors
2 * \file timer.h
3 * \brief This file defines the utils for timing
4 * \author Tianqi Chen, Nacho, Tianyi
5 */
6 #ifndef RABIT_INTERNAL_TIMER_H_
7 #define RABIT_INTERNAL_TIMER_H_
8 #include <time.h>
9 #ifdef __MACH__
10 #include <mach/clock.h>
11 #include <mach/mach.h>
12 #endif // __MACH__
13 #include "./utils.h"
14
15 namespace rabit {
16 namespace utils {
17 /*!
18 * \brief return time in seconds, not cross platform, avoid to use this in most places
19 */
20 inline double GetTime(void) {
21 #ifdef __MACH__
22 clock_serv_t cclock;
23 mach_timespec_t mts;
24 host_get_clock_service(mach_host_self(), CALENDAR_CLOCK, &cclock);
25 utils::Check(clock_get_time(cclock, &mts) == 0, "failed to get time");
26 mach_port_deallocate(mach_task_self(), cclock);
27 return static_cast<double>(mts.tv_sec) + static_cast<double>(mts.tv_nsec) * 1e-9;
28 #else
29 #if defined(__unix__) || defined(__linux__)
30 timespec ts;
31 utils::Check(clock_gettime(CLOCK_REALTIME, &ts) == 0, "failed to get time");
32 return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9;
33 #else
34 return static_cast<double>(time(NULL));
35 #endif // defined(__unix__) || defined(__linux__)
36 #endif // __MACH__
37 }
38 } // namespace utils
39 } // namespace rabit
40 #endif // RABIT_INTERNAL_TIMER_H_
0 /*!
1 * Copyright (c) 2014 by Contributors
2 * \file utils.h
3 * \brief simple utils to support the code
4 * \author Tianqi Chen
5 */
6 #ifndef RABIT_INTERNAL_UTILS_H_
7 #define RABIT_INTERNAL_UTILS_H_
8 #define _CRT_SECURE_NO_WARNINGS
9 #include <string.h>
10 #include <cstdio>
11 #include <string>
12 #include <cstdlib>
13 #include <stdexcept>
14 #include <vector>
15 #include "dmlc/io.h"
16
17 #ifndef RABIT_STRICT_CXX98_
18 #include <cstdarg>
19 #endif // RABIT_STRICT_CXX98_
20
21 #if !defined(__GNUC__) || defined(__FreeBSD__)
22 #define fopen64 std::fopen
23 #endif // !defined(__GNUC__) || defined(__FreeBSD__)
24
25 #ifdef _MSC_VER
26 // NOTE: sprintf_s is not equivalent to snprintf,
27 // they are equivalent when success, which is sufficient for our case
28 #define snprintf sprintf_s
29 #define vsnprintf vsprintf_s
30
31 #else
32
33 #ifdef _FILE_OFFSET_BITS
34 #if _FILE_OFFSET_BITS == 32
35 #pragma message("Warning: FILE OFFSET BITS defined to be 32 bit")
36 #endif // _FILE_OFFSET_BITS == 32
37 #endif // _FILE_OFFSET_BITS
38
39 #ifdef __APPLE__
40 #define off64_t off_t
41 #define fopen64 std::fopen
42 #endif // __APPLE__
43
44 extern "C" {
45 #include <sys/types.h>
46 }
47 #endif // _MSC_VER
48
49 #ifdef _MSC_VER
50 typedef unsigned char uint8_t;
51 typedef unsigned __int16 uint16_t;
52 typedef unsigned __int32 uint32_t;
53 typedef unsigned __int64 uint64_t;
54 typedef __int64 int64_t;
55 #else
56 #include <inttypes.h>
57 #endif // _MSC_VER
58
59 namespace rabit {
60 /*! \brief namespace for helper utils of the project */
61 namespace utils {
62
63 /*! \brief error message buffer length */
64 const int kPrintBuffer = 1 << 12;
65
66 /*! \brief we may want to keep the process alive when there are multiple workers
67 * co-locate in the same process */
68 extern bool STOP_PROCESS_ON_ERROR;
69
70 /* \brief Case-insensitive string comparison */
71 inline int CompareStringsCaseInsensitive(const char* s1, const char* s2) {
72 #ifdef _MSC_VER
73 return _stricmp(s1, s2);
74 #else // _MSC_VER
75 return strcasecmp(s1, s2);
76 #endif // _MSC_VER
77 }
78
79 /* \brief parse config string too bool*/
80 inline bool StringToBool(const char* s) {
81 return CompareStringsCaseInsensitive(s, "true") == 0 || atoi(s) != 0;
82 }
83
84 #ifndef RABIT_CUSTOMIZE_MSG_
85 /*!
86 * \brief handling of Assert error, caused by inappropriate input
87 * \param msg error message
88 */
89 inline void HandleAssertError(const char *msg) {
90 if (STOP_PROCESS_ON_ERROR) {
91 fprintf(stderr, "AssertError:%s, shutting down process\n", msg);
92 exit(-1);
93 } else {
94 fprintf(stderr, "AssertError:%s, rabit is configured to keep process running\n", msg);
95 throw dmlc::Error(msg);
96 }
97 }
98 /*!
99 * \brief handling of Check error, caused by inappropriate input
100 * \param msg error message
101 */
102 inline void HandleCheckError(const char *msg) {
103 if (STOP_PROCESS_ON_ERROR) {
104 fprintf(stderr, "%s, shutting down process\n", msg);
105 exit(-1);
106 } else {
107 fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
108 throw dmlc::Error(msg);
109 }
110 }
111 inline void HandlePrint(const char *msg) {
112 printf("%s", msg);
113 }
114
115 inline void HandleLogInfo(const char *fmt, ...) {
116 std::string msg(kPrintBuffer, '\0');
117 va_list args;
118 va_start(args, fmt);
119 vsnprintf(&msg[0], kPrintBuffer, fmt, args);
120 va_end(args);
121 fprintf(stdout, "%s", msg.c_str());
122 fflush(stdout);
123 }
124 #else
125 #ifndef RABIT_STRICT_CXX98_
126 // include declarations, some one must implement this
127 void HandleAssertError(const char *msg);
128 void HandleCheckError(const char *msg);
129 void HandlePrint(const char *msg);
130 #endif // RABIT_STRICT_CXX98_
131 #endif // RABIT_CUSTOMIZE_MSG_
132 #ifdef RABIT_STRICT_CXX98_
133 // these function pointers are to be assigned
134 extern "C" void (*Printf)(const char *fmt, ...);
135 extern "C" int (*SPrintf)(char *buf, size_t size, const char *fmt, ...);
136 extern "C" void (*Assert)(int exp, const char *fmt, ...);
137 extern "C" void (*Check)(int exp, const char *fmt, ...);
138 extern "C" void (*Error)(const char *fmt, ...);
139 #else
140 /*! \brief printf, prints messages to the console */
141 inline void Printf(const char *fmt, ...) {
142 std::string msg(kPrintBuffer, '\0');
143 va_list args;
144 va_start(args, fmt);
145 vsnprintf(&msg[0], kPrintBuffer, fmt, args);
146 va_end(args);
147 HandlePrint(msg.c_str());
148 }
149 /*! \brief portable version of snprintf */
150 inline int SPrintf(char *buf, size_t size, const char *fmt, ...) {
151 va_list args;
152 va_start(args, fmt);
153 int ret = vsnprintf(buf, size, fmt, args);
154 va_end(args);
155 return ret;
156 }
157
158 /*! \brief assert a condition is true, use this to handle debug information */
159 inline void Assert(bool exp, const char *fmt, ...) {
160 if (!exp) {
161 std::string msg(kPrintBuffer, '\0');
162 va_list args;
163 va_start(args, fmt);
164 vsnprintf(&msg[0], kPrintBuffer, fmt, args);
165 va_end(args);
166 HandleAssertError(msg.c_str());
167 }
168 }
169
170 /*!\brief same as assert, but this is intended to be used as a message for users */
171 inline void Check(bool exp, const char *fmt, ...) {
172 if (!exp) {
173 std::string msg(kPrintBuffer, '\0');
174 va_list args;
175 va_start(args, fmt);
176 vsnprintf(&msg[0], kPrintBuffer, fmt, args);
177 va_end(args);
178 HandleCheckError(msg.c_str());
179 }
180 }
181
182 /*! \brief report error message, same as check */
183 inline void Error(const char *fmt, ...) {
184 {
185 std::string msg(kPrintBuffer, '\0');
186 va_list args;
187 va_start(args, fmt);
188 vsnprintf(&msg[0], kPrintBuffer, fmt, args);
189 va_end(args);
190 HandleCheckError(msg.c_str());
191 }
192 }
193 #endif // RABIT_STRICT_CXX98_
194
195 /*! \brief replace fopen, report error when the file open fails */
196 inline std::FILE *FopenCheck(const char *fname, const char *flag) {
197 std::FILE *fp = fopen64(fname, flag);
198 Check(fp != NULL, "can not open file \"%s\"\n", fname);
199 return fp;
200 }
201 } // namespace utils
202 // easy utils that can be directly accessed in xgboost
203 /*! \brief get the beginning address of a vector */
204 template<typename T>
205 inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
206 if (vec.size() == 0) {
207 return NULL;
208 } else {
209 return &vec[0];
210 }
211 }
212 /*! \brief get the beginning address of a vector */
213 template<typename T>
214 inline const T *BeginPtr(const std::vector<T> &vec) { // NOLINT(*)
215 if (vec.size() == 0) {
216 return NULL;
217 } else {
218 return &vec[0];
219 }
220 }
221 inline char* BeginPtr(std::string &str) { // NOLINT(*)
222 if (str.length() == 0) return NULL;
223 return &str[0];
224 }
225 inline const char* BeginPtr(const std::string &str) {
226 if (str.length() == 0) return NULL;
227 return &str[0];
228 }
229 } // namespace rabit
230 #endif // RABIT_INTERNAL_UTILS_H_
0 /*!
1 * Copyright (c) 2014 by Contributors
2 * \file rabit.h
3 * \brief This file defines rabit's Allreduce/Broadcast interface
4 * The rabit engine contains the actual implementation
5 * Code that only uses this header can also be compiled with MPI Allreduce (non fault-tolerant),
6 *
7 * rabit.h and serializable.h is all what the user needs to use the rabit interface
8 * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
9 */
10 #ifndef RABIT_RABIT_H_ // NOLINT(*)
11 #define RABIT_RABIT_H_ // NOLINT(*)
12 #include <string>
13 #include <vector>
14
15 // whether or not use c++11 support
16 #ifndef DMLC_USE_CXX11
17 #if defined(__GXX_EXPERIMENTAL_CXX0X__) || defined(_MSC_VER)
18 #define DMLC_USE_CXX11 1
19 #else
20 #define DMLC_USE_CXX11 (__cplusplus >= 201103L)
21 #endif // defined(__GXX_EXPERIMENTAL_CXX0X__) || defined(_MSC_VER)
22 #endif // DMLC_USE_CXX11
23
24 // keeps rabit api caller signature
25 #ifndef RABIT_API_CALLER_SIGNATURE
26 #define RABIT_API_CALLER_SIGNATURE
27
28 #if (defined(__GNUC__) && !defined(__clang__))
29 #define _FILE __builtin_FILE()
30 #define _LINE __builtin_LINE()
31 #define _CALLER __builtin_FUNCTION()
32 #else
33 #define _FILE "N/A"
34 #define _LINE -1
35 #define _CALLER "N/A"
36 #endif // (defined(__GNUC__) && !defined(__clang__))
37
38 #endif // RABIT_API_CALLER_SIGNATURE
39
40 // optionally support of lambda functions in C++11, if available
41 #if DMLC_USE_CXX11
42 #include <functional>
43 #endif // C++11
44 // engine definition of rabit, defines internal implementation
45 // to use rabit interface, there is no need to read engine.h
46 // rabit.h and serializable.h are enough to use the interface
47 #include "./internal/engine.h"
48
49 /*! \brief rabit namespace */
50 namespace rabit {
51 /*!
52 * \brief defines stream used in rabit
53 * see definition of Stream in dmlc/io.h
54 */
55 typedef dmlc::Stream Stream;
56 /*!
57 * \brief defines serializable objects used in rabit
58 * see definition of Serializable in dmlc/io.h
59 */
60 typedef dmlc::Serializable Serializable;
61
62 /*!
63 * \brief reduction operators namespace
64 */
65 namespace op {
66 /*!
67 * \class rabit::op::Max
68 * \brief maximum reduction operator
69 */
70 struct Max;
71 /*!
72 * \class rabit::op::Min
73 * \brief minimum reduction operator
74 */
75 struct Min;
76 /*!
77 * \class rabit::op::Sum
78 * \brief sum reduction operator
79 */
80 struct Sum;
81 /*!
82 * \class rabit::op::BitOR
83 * \brief bitwise OR reduction operator
84 */
85 struct BitOR;
86 } // namespace op
87 /*!
88 * \brief initializes rabit, call this once at the beginning of your program
89 * \param argc number of arguments in argv
90 * \param argv the array of input arguments
91 * \return true if initialized successfully, otherwise false
92 */
93 inline bool Init(int argc, char *argv[]);
94 /*!
95 * \brief finalizes the rabit engine, call this function after you finished with all the jobs
96 * \return true if finalized successfully, otherwise false
97 */
98 inline bool Finalize();
99 /*! \brief gets rank of the current process
100 * \return rank number of worker*/
101 inline int GetRank();
102 /*! \brief gets total number of processes
103 * \return total world size*/
104 inline int GetWorldSize();
105 /*! \brief whether rabit env is in distributed mode
106 * \return is distributed*/
107 inline bool IsDistributed();
108
109 /*! \brief gets processor's name
110 * \return processor name*/
111 inline std::string GetProcessorName();
112 /*!
113 * \brief prints the msg to the tracker,
114 * this function can be used to communicate progress information to
115 * the user who monitors the tracker
116 * \param msg the message to be printed
117 */
118 inline void TrackerPrint(const std::string &msg);
119
120 #ifndef RABIT_STRICT_CXX98_
121 /*!
122 * \brief prints the msg to the tracker, this function may not be available
123 * in very strict c++98 compilers, though it usually is.
124 * this function can be used to communicate progress information to
125 * the user who monitors the tracker
126 * \param fmt the format string
127 */
128 inline void TrackerPrintf(const char *fmt, ...);
129 #endif // RABIT_STRICT_CXX98_
130 /*!
131 * \brief broadcasts a memory region to every node from the root
132 *
133 * Example: int a = 1; Broadcast(&a, sizeof(a), root);
134 * \param sendrecv_data the pointer to the send/receive buffer,
135 * \param size the data size
136 * \param root the process root
137 * \param _file caller file name used to generate unique cache key
138 * \param _line caller line number used to generate unique cache key
139 * \param _caller caller function name used to generate unique cache key
140 */
141 inline void Broadcast(void *sendrecv_data, size_t size, int root,
142 const char* _file = _FILE,
143 const int _line = _LINE,
144 const char* _caller = _CALLER);
145
146 /*!
147 * \brief broadcasts an std::vector<DType> to every node from root
148 * \param sendrecv_data the pointer to send/receive vector,
149 * for the receiver, the vector does not need to be pre-allocated
150 * \param root the process root
151 * \param _file caller file name used to generate unique cache key
152 * \param _line caller line number used to generate unique cache key
153 * \param _caller caller function name used to generate unique cache key
154 * \tparam DType the data type stored in the vector, has to be a simple data type
155 * that can be directly transmitted by sending the sizeof(DType)
156 */
157 template<typename DType>
158 inline void Broadcast(std::vector<DType> *sendrecv_data, int root,
159 const char* _file = _FILE,
160 const int _line = _LINE,
161 const char* _caller = _CALLER);
162 /*!
163 * \brief broadcasts a std::string to every node from the root
164 * \param sendrecv_data the pointer to the send/receive buffer,
165 * for the receiver, the vector does not need to be pre-allocated
166 * \param _file caller file name used to generate unique cache key
167 * \param _line caller line number used to generate unique cache key
168 * \param _caller caller function name used to generate unique cache key
169 * \param root the process root
170 */
171 inline void Broadcast(std::string *sendrecv_data, int root,
172 const char* _file = _FILE,
173 const int _line = _LINE,
174 const char* _caller = _CALLER);
175 /*!
176 * \brief performs in-place Allreduce on sendrecvbuf
177 * this function is NOT thread-safe
178 *
179 * Example Usage: the following code does an Allreduce and outputs the sum as the result
180 * \code{.cpp}
181 * vector<int> data(10);
182 * ...
183 * Allreduce<op::Sum>(&data[0], data.size());
184 * ...
185 * \endcode
186 *
187 * \param sendrecvbuf buffer for both sending and receiving data
188 * \param count number of elements to be reduced
189 * \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
190 * will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
191 * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
192 * \param prepare_arg argument used to pass into the lazy preprocessing function
193 * \param _file caller file name used to generate unique cache key
194 * \param _line caller line number used to generate unique cache key
195 * \param _caller caller function name used to generate unique cache key
196 * \tparam OP see namespace op, reduce operator
197 * \tparam DType data type
198 */
199 template<typename OP, typename DType>
200 inline void Allreduce(DType *sendrecvbuf, size_t count,
201 void (*prepare_fun)(void *) = NULL,
202 void *prepare_arg = NULL,
203 const char* _file = _FILE,
204 const int _line = _LINE,
205 const char* _caller = _CALLER);
206
207 /*!
208 * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
209 * the data provided by current node k is [slice_begin, slice_end),
210 * the next node's segment must start with slice_end
211 * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
212 * use a ring based algorithm
213 *
214 * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
215 * \param total_size total size of data to be gathered
216 * \param slice_begin beginning of the current slice
217 * \param slice_end end of the current slice
218 * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
219 * \param _file caller file name used to generate unique cache key
220 * \param _line caller line number used to generate unique cache key
221 * \param _caller caller function name used to generate unique cache key
222 */
223 template<typename DType>
224 inline void Allgather(DType *sendrecvbuf_,
225 size_t total_size,
226 size_t slice_begin,
227 size_t slice_end,
228 size_t size_prev_slice,
229 const char* _file = _FILE,
230 const int _line = _LINE,
231 const char* _caller = _CALLER);
232
233 // C++11 support for lambda prepare function
234 #if DMLC_USE_CXX11
235 /*!
236 * \brief performs in-place Allreduce, on sendrecvbuf
237 * with a prepare function specified by a lambda function
238 *
239 * Example Usage:
240 * \code{.cpp}
241 * // the following code does an Allreduce and outputs the sum as the result
242 * vector<int> data(10);
243 * ...
244 * Allreduce<op::Sum>(&data[0], data.size(), [&]() {
245 * for (int i = 0; i < 10; ++i) {
246 * data[i] = i;
247 * }
248 * });
249 * ...
250 * \endcode
251 * \param sendrecvbuf buffer for both sending and receiving data
252 * \param count number of elements to be reduced
253 * \param prepare_fun Lazy lambda preprocessing function, prepare_fun() will be invoked
254 * by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
255 * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
256 * \param _file caller file name used to generate unique cache key
257 * \param _line caller line number used to generate unique cache key
258 * \param _caller caller function name used to generate unique cache key
259 * \tparam OP see namespace op, reduce operator
260 * \tparam DType data type
261 */
262 template<typename OP, typename DType>
263 inline void Allreduce(DType *sendrecvbuf, size_t count,
264 std::function<void()> prepare_fun,
265 const char* _file = _FILE,
266 const int _line = _LINE,
267 const char* _caller = _CALLER);
268 #endif // C++11
269 /*!
270 * \brief loads the latest check point
271 * \param global_model pointer to the globally shared model/state
272 * when calling this function, the caller needs to guarantee that the global_model
273 * is the same in every node
274 * \param local_model pointer to the local model that is specific to the current node/rank
275 * this can be NULL when no local model is needed
276 *
277 * \return the version number of the check point loaded
278 * if returned version == 0, this means no model has been CheckPointed
279 * the p_model is not touched, users should do the necessary initialization by themselves
280 *
281 * \code{.cpp}
282 * // Example usage code of LoadCheckPoint
283 * int iter = rabit::LoadCheckPoint(&model);
284 * if (iter == 0) model.InitParameters();
285 * for (i = iter; i < max_iter; ++i) {
286 * // do many things, include allreduce
287 * rabit::CheckPoint(model);
288 * }
289 * \endcode
290 * \sa CheckPoint, VersionNumber
291 */
292 inline int LoadCheckPoint(Serializable *global_model,
293 Serializable *local_model = NULL);
294 /*!
295 * \brief checkpoints the model, meaning a stage of execution has finished.
296 * every time we call check point, a version number will be increased by one
297 *
298 * \param global_model pointer to the globally shared model/state
299 * when calling this function, the caller needs to guarantee that the global_model
300 * is the same in every node
301 * \param local_model pointer to the local model that is specific to the current node/rank
302 * this can be NULL when no local state is needed
303 * NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
304 * bring replication cost in the CheckPoint function. global_model does not need explicit replication.
305 * So, only CheckPoint with the global_model if possible
306 * \sa LoadCheckPoint, VersionNumber
307 */
308 inline void CheckPoint(const Serializable *global_model,
309 const Serializable *local_model = NULL);
310 /*!
311 * \brief This function can be used to replace CheckPoint for global_model only,
312 * when certain condition is met (see detailed explanation).
313 *
314 * This is a "lazy" checkpoint such that only the pointer to the global_model is
315 * remembered and no memory copy is taken. To use this function, the user MUST ensure that:
316 * The global_model must remain unchanged until the last call of Allreduce/Broadcast in the current version finishes.
317 * In other words, the global_model model can be changed only between the last call of
318 * Allreduce/Broadcast and LazyCheckPoint, both in the same version
319 *
320 * For example, suppose the calling sequence is:
321 * LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint/(or can be CheckPoint)
322 *
323 * Then the user MUST only change the global_model in code3.
324 *
325 * The use of LazyCheckPoint instead of CheckPoint will improve the efficiency of the program.
326 * \param global_model pointer to the globally shared model/state
327 * when calling this function, the caller needs to guarantee that the global_model
328 * is the same in every node
329 * \sa LoadCheckPoint, CheckPoint, VersionNumber
330 */
331 inline void LazyCheckPoint(const Serializable *global_model);
332 /*!
333 * \return version number of the current stored model,
334 * which means how many calls to CheckPoint we made so far
335 * \sa LoadCheckPoint, CheckPoint
336 */
337 inline int VersionNumber();
338 // ----- extensions that allow customized reducer ------
339 // helper class to do customized reduce, user do not need to know the type
340 namespace engine {
341 class ReduceHandle;
342 } // namespace engine
343 /*!
344 * \brief template class to make customized reduce and all reduce easy
345 * Do not use reducer directly in the function you call Finalize,
346 * because the destructor can execute after Finalize
347 * \tparam DType data type that to be reduced
348 * \tparam freduce the customized reduction function
349 * DType must be a struct, with no pointer
350 */
351 template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
352 class Reducer {
353 public:
354 Reducer();
355 /*!
356 * \brief customized in-place all reduce operation
357 * \param sendrecvbuf the in place send-recv buffer
358 * \param count number of elements to be reduced
359 * \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
360 * will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf.
361 * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
362 * \param prepare_arg argument used to pass into the lazy preprocessing function
363 * \param _file caller file name used to generate unique cache key
364 * \param _line caller line number used to generate unique cache key
365 * \param _caller caller function name used to generate unique cache key
366 */
367 inline void Allreduce(DType *sendrecvbuf, size_t count,
368 void (*prepare_fun)(void *) = NULL,
369 void *prepare_arg = NULL,
370 const char* _file = _FILE,
371 const int _line = _LINE,
372 const char* _caller = _CALLER);
373 #if DMLC_USE_CXX11
374 /*!
375 * \brief customized in-place all reduce operation, with lambda function as preprocessor
376 * \param sendrecvbuf pointer to the array of objects to be reduced
377 * \param count number of elements to be reduced
378 * \param prepare_fun lambda function executed to prepare the data, if necessary
379 * \param _file caller file name used to generate unique cache key
380 * \param _line caller line number used to generate unique cache key
381 * \param _caller caller function name used to generate unique cache key
382 */
383 inline void Allreduce(DType *sendrecvbuf, size_t count,
384 std::function<void()> prepare_fun,
385 const char* _file = _FILE,
386 const int _line = _LINE,
387 const char* _caller = _CALLER);
388 #endif // DMLC_USE_CXX11
389
390 private:
391 /*! \brief function handle to do reduce */
392 engine::ReduceHandle handle_;
393 };
394 /*!
395 * \brief template class to make customized reduce,
396 * this class defines complex reducer handles all the data structure that can be
397 * serialized/deserialized into fixed size buffer
398 * Do not use reducer directly in the function you call Finalize, because the destructor can execute after Finalize
399 *
400 * \tparam DType data type that to be reduced, DType must contain the following functions:
401 * \tparam freduce the customized reduction function
402 * (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &src, size_t max_nbyte)
403 */
404 template<typename DType>
405 class SerializeReducer {
406 public:
407 SerializeReducer();
408 /*!
409 * \brief customized in-place all reduce operation
410 * \param sendrecvobj pointer to the array of objects to be reduced
411 * \param max_nbyte maximum amount of memory needed to serialize each object
412 * this includes budget limit for intermediate and final result
413 * \param count number of elements to be reduced
414 * \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
415 * will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf.
416 * If the result of Allreduce can be recovered directly, then the prepare_func will NOT be called
417 * \param prepare_arg argument used to pass into the lazy preprocessing function
418 * \param _file caller file name used to generate unique cache key
419 * \param _line caller line number used to generate unique cache key
420 * \param _caller caller function name used to generate unique cache key
421 */
422 inline void Allreduce(DType *sendrecvobj,
423 size_t max_nbyte, size_t count,
424 void (*prepare_fun)(void *) = NULL,
425 void *prepare_arg = NULL,
426 const char* _file = _FILE,
427 const int _line = _LINE,
428 const char* _caller = _CALLER);
429 // C++11 support for lambda prepare function
430 #if DMLC_USE_CXX11
431 /*!
432 * \brief customized in-place all reduce operation, with lambda function as preprocessor
433 * \param sendrecvobj pointer to the array of objects to be reduced
434 * \param max_nbyte maximum amount of memory needed to serialize each object
435 * this includes budget limit for intermediate and final result
436 * \param count number of elements to be reduced
437 * \param prepare_fun lambda function executed to prepare the data, if necessary
438 * \param _file caller file name used to generate unique cache key
439 * \param _line caller line number used to generate unique cache key
440 * \param _caller caller function name used to generate unique cache key
441 */
442 inline void Allreduce(DType *sendrecvobj,
443 size_t max_nbyte, size_t count,
444 std::function<void()> prepare_fun,
445 const char* _file = _FILE,
446 const int _line = _LINE,
447 const char* _caller = _CALLER);
448 #endif // DMLC_USE_CXX11
449
450 private:
451 /*! \brief function handle to do reduce */
452 engine::ReduceHandle handle_;
453 /*! \brief temporal buffer used to do reduce*/
454 std::string buffer_;
455 };
456 } // namespace rabit
457 // implementation of template functions
458 #include "./internal/rabit-inl.h"
459 #endif // RABIT_RABIT_H_ // NOLINT(*)
0 /*!
1 * Copyright (c) 2014 by Contributors
2 * \file serializable.h
3 * \brief defines serializable interface of rabit
4 * \author Tianqi Chen
5 */
6 #ifndef RABIT_SERIALIZABLE_H_
7 #define RABIT_SERIALIZABLE_H_
8 #include <vector>
9 #include <string>
10 #include "rabit/internal/utils.h"
11
12 namespace rabit {
13 /*!
14 * \brief defines stream used in rabit
15 * see definition of Stream in dmlc/io.h
16 */
17 typedef dmlc::Stream Stream;
18 /*!
19 * \brief defines serializable objects used in rabit
20 * see definition of Serializable in dmlc/io.h
21 */
22 typedef dmlc::Serializable Serializable;
23
24 } // namespace rabit
25 #endif // RABIT_SERIALIZABLE_H_
0 Rabit Library
1 =====
2 This folder holds the library file generated by the compiler. To generate the library file, type ```make``` in the project root folder. If you want mpi compatible library, type ```make mpi```
3
4 ***List of Files***
5 * rabit.a The rabit package library
6 - Normally you need to link with this one
7 * rabit_mock.a The rabit package library with mock test
8 - This library allows additional mock-test
9 * rabit_mpi.a The MPI backed library
10 - Link against this library makes the program use MPI Allreduce
11 - This library is not fault-tolerant
12 * rabit_empty.a Dummy package implementation
13 - This is an empty library that does not provide anything
14 - Only introduced to minimize code dependency for projects that only need single machine code
0 """
1 Reliable Allreduce and Broadcast Library.
2
3 Author: Tianqi Chen
4 """
5 # pylint: disable=unused-argument,invalid-name,global-statement,dangerous-default-value,
6 import pickle
7 import ctypes
8 import os
9 import platform
10 import sys
11 import warnings
12 import numpy as np
13
14 # version information about the doc
15 __version__ = '1.0'
16
17 _LIB = None
18
19 def _find_lib_path(dll_name):
20 """Find the rabit dynamic library files.
21
22 Returns
23 -------
24 lib_path: list(string)
25 List of all found library path to rabit
26 """
27 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
28 # make pythonpack hack: copy this directory one level upper for setup.py
29 dll_path = [curr_path,
30 os.path.join(curr_path, '../lib/'),
31 os.path.join(curr_path, './lib/')]
32 if os.name == 'nt':
33 dll_path = [os.path.join(p, dll_name) for p in dll_path]
34 else:
35 dll_path = [os.path.join(p, dll_name) for p in dll_path]
36 lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
37 #From github issues, most of installation errors come from machines w/o compilers
38 if len(lib_path) == 0 and not os.environ.get('XGBOOST_BUILD_DOC', False):
39 raise RuntimeError(
40 'Cannot find Rabit Libarary in the candicate path, ' +
41 'did you install compilers and run build.sh in root path?\n'
42 'List of candidates:\n' + ('\n'.join(dll_path)))
43 return lib_path
44
45 # load in xgboost library
46 def _loadlib(lib='standard', lib_dll=None):
47 """Load rabit library."""
48 global _LIB
49 if _LIB is not None:
50 warnings.warn('rabit.int call was ignored because it has'\
51 ' already been initialized', level=2)
52 return
53
54 if lib_dll is not None:
55 _LIB = lib_dll
56 return
57
58 if lib == 'standard':
59 dll_name = 'librabit'
60 else:
61 dll_name = 'librabit_' + lib
62
63 if os.name == 'nt':
64 dll_name += '.dll'
65 elif platform.system() == 'Darwin':
66 dll_name += '.dylib'
67 else:
68 dll_name += '.so'
69
70 _LIB = ctypes.cdll.LoadLibrary(_find_lib_path(dll_name)[0])
71 _LIB.RabitGetRank.restype = ctypes.c_int
72 _LIB.RabitGetWorldSize.restype = ctypes.c_int
73 _LIB.RabitVersionNumber.restype = ctypes.c_int
74
75 def _unloadlib():
76 """Unload rabit library."""
77 global _LIB
78 del _LIB
79 _LIB = None
80
81 # reduction operators
82 MAX = 0
83 MIN = 1
84 SUM = 2
85 BITOR = 3
86
87 def init(args=None, lib='standard', lib_dll=None):
88 """Intialize the rabit module, call this once before using anything.
89
90 Parameters
91 ----------
92 args: list of str, optional
93 The list of arguments used to initialized the rabit
94 usually you need to pass in sys.argv.
95 Defaults to sys.argv when it is None.
96 lib: {'standard', 'mock', 'mpi'}, optional
97 Type of library we want to load
98 When cdll is specified
99 lib_dll: ctypes.DLL, optional
100 The DLL object used as lib.
101 When this is presented argument lib will be ignored.
102 """
103 if args is None:
104 args = []
105 _loadlib(lib, lib_dll)
106 arr = (ctypes.c_char_p * len(args))()
107
108 arr[:] = args
109 _LIB.RabitInit(len(args), arr)
110
111 def finalize():
112 """Finalize the rabit engine.
113
114 Call this function after you finished all jobs.
115 """
116 _LIB.RabitFinalize()
117 _unloadlib()
118
119 def get_rank():
120 """Get rank of current process.
121
122 Returns
123 -------
124 rank : int
125 Rank of current process.
126 """
127 ret = _LIB.RabitGetRank()
128 return ret
129
130 def get_world_size():
131 """Get total number workers.
132
133 Returns
134 -------
135 n : int
136 Total number of process.
137 """
138 ret = _LIB.RabitGetWorldSize()
139 return ret
140
141 def tracker_print(msg):
142 """Print message to the tracker.
143
144 This function can be used to communicate the information of
145 the progress to the tracker
146
147 Parameters
148 ----------
149 msg : str
150 The message to be printed to tracker.
151 """
152 if not isinstance(msg, str):
153 msg = str(msg)
154 _LIB.RabitTrackerPrint(ctypes.c_char_p(msg).encode('utf-8'))
155
156 def get_processor_name():
157 """Get the processor name.
158
159 Returns
160 -------
161 name : str
162 the name of processor(host)
163 """
164 mxlen = 256
165 length = ctypes.c_ulong()
166 buf = ctypes.create_string_buffer(mxlen)
167 _LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
168 return buf.value
169
170 def broadcast(data, root):
171 """Broadcast object from one node to all other nodes.
172
173 Parameters
174 ----------
175 data : any type that can be pickled
176 Input data, if current rank does not equal root, this can be None
177 root : int
178 Rank of the node to broadcast data from.
179
180 Returns
181 -------
182 object : int
183 the result of broadcast.
184 """
185 rank = get_rank()
186 length = ctypes.c_ulong()
187 if root == rank:
188 assert data is not None, 'need to pass in data when broadcasting'
189 s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
190 length.value = len(s)
191 # run first broadcast
192 _LIB.RabitBroadcast(ctypes.byref(length),
193 ctypes.sizeof(ctypes.c_ulong), root)
194 if root != rank:
195 dptr = (ctypes.c_char * length.value)()
196 # run second
197 _LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
198 length.value, root)
199 data = pickle.loads(dptr.raw)
200 del dptr
201 else:
202 _LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
203 length.value, root)
204 del s
205 return data
206
207 # enumeration of dtypes
208 DTYPE_ENUM__ = {
209 np.dtype('int8') : 0,
210 np.dtype('uint8') : 1,
211 np.dtype('int32') : 2,
212 np.dtype('uint32') : 3,
213 np.dtype('int64') : 4,
214 np.dtype('uint64') : 5,
215 np.dtype('float32') : 6,
216 np.dtype('float64') : 7
217 }
218
219 def allreduce(data, op, prepare_fun=None):
220 """Perform allreduce, return the result.
221
222 Parameters
223 ----------
224 data: numpy array
225 Input data.
226 op: int
227 Reduction operators, can be MIN, MAX, SUM, BITOR
228 prepare_fun: function
229 Lazy preprocessing function, if it is not None, prepare_fun(data)
230 will be called by the function before performing allreduce, to intialize the data
231 If the result of Allreduce can be recovered directly,
232 then prepare_fun will NOT be called
233
234 Returns
235 -------
236 result : array_like
237 The result of allreduce, have same shape as data
238
239 Notes
240 -----
241 This function is not thread-safe.
242 """
243 if not isinstance(data, np.ndarray):
244 raise Exception('allreduce only takes in numpy.ndarray')
245 buf = data.ravel()
246 if buf.base is data.base:
247 buf = buf.copy()
248 if buf.dtype not in DTYPE_ENUM__:
249 raise Exception('data type %s not supported' % str(buf.dtype))
250 if prepare_fun is None:
251 _LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
252 buf.size, DTYPE_ENUM__[buf.dtype],
253 op, None, None)
254 else:
255 func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
256 def pfunc(args):
257 """prepare function."""
258 prepare_fun(data)
259 _LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
260 buf.size, DTYPE_ENUM__[buf.dtype],
261 op, func_ptr(pfunc), None)
262 return buf
263
264
265 def _load_model(ptr, length):
266 """
267 Internal function used by the module,
268 unpickle a model from a buffer specified by ptr, length
269 Arguments:
270 ptr: ctypes.POINTER(ctypes._char)
271 pointer to the memory region of buffer
272 length: int
273 the length of buffer
274 """
275 data = (ctypes.c_char * length).from_address(ctypes.addressof(ptr.contents))
276 return pickle.loads(data.raw)
277
278 def load_checkpoint(with_local=False):
279 """Load latest check point.
280
281 Parameters
282 ----------
283 with_local: bool, optional
284 whether the checkpoint contains local model
285
286 Returns
287 -------
288 tuple : tuple
289 if with_local: return (version, gobal_model, local_model)
290 else return (version, gobal_model)
291 if returned version == 0, this means no model has been CheckPointed
292 and global_model, local_model returned will be None
293 """
294 gptr = ctypes.POINTER(ctypes.c_char)()
295 global_len = ctypes.c_ulong()
296 if with_local:
297 lptr = ctypes.POINTER(ctypes.c_char)()
298 local_len = ctypes.c_ulong()
299 version = _LIB.RabitLoadCheckPoint(
300 ctypes.byref(gptr),
301 ctypes.byref(global_len),
302 ctypes.byref(lptr),
303 ctypes.byref(local_len))
304 if version == 0:
305 return (version, None, None)
306 return (version,
307 _load_model(gptr, global_len.value),
308 _load_model(lptr, local_len.value))
309 else:
310 version = _LIB.RabitLoadCheckPoint(
311 ctypes.byref(gptr),
312 ctypes.byref(global_len),
313 None, None)
314 if version == 0:
315 return (version, None)
316 return (version,
317 _load_model(gptr, global_len.value))
318
319 def checkpoint(global_model, local_model=None):
320 """Checkpoint the model.
321
322 This means we finished a stage of execution.
323 Every time we call check point, there is a version number which will increase by one.
324
325 Parameters
326 ----------
327 global_model: anytype that can be pickled
328 globally shared model/state when calling this function,
329 the caller need to gauranttees that global_model is the same in all nodes
330
331 local_model: anytype that can be pickled
332 Local model, that is specific to current node/rank.
333 This can be None when no local state is needed.
334
335 Notes
336 -----
337 local_model requires explicit replication of the model for fault-tolerance.
338 This will bring replication cost in checkpoint function.
339 while global_model do not need explicit replication.
340 It is recommended to use global_model if possible.
341 """
342 sglobal = pickle.dumps(global_model)
343 if local_model is None:
344 _LIB.RabitCheckPoint(sglobal, len(sglobal), None, 0)
345 del sglobal
346 else:
347 slocal = pickle.dumps(local_model)
348 _LIB.RabitCheckPoint(sglobal, len(sglobal), slocal, len(slocal))
349 del slocal
350 del sglobal
351
352 def version_number():
353 """Returns version number of current stored model.
354
355 This means how many calls to CheckPoint we made so far.
356
357 Returns
358 -------
359 version : int
360 Version number of currently stored model
361 """
362 ret = _LIB.RabitVersionNumber()
363 return ret
0 #!/usr/bin/env bash
1
2 if [ -f mpich/lib/libmpich.so ]; then
3 echo "libmpich.so found -- nothing to build."
4 else
5 echo "Downloading mpich source."
6 wget http://www.mpich.org/static/downloads/3.2/mpich-3.2.tar.gz
7 tar xfz mpich-3.2.tar.gz
8 rm mpich-3.2.tar.gz*
9 echo "configuring and building mpich."
10 cd mpich-3.2
11 #CC=gcc CXX=g++ CFLAGS=-m64 CXXFLAGS=-m64 FFLAGS=-m64
12 ./configure \
13 --prefix=`pwd`/../mpich \
14 --enable-static=false \
15 --enable-alloca=true \
16 --disable-long-double \
17 --enable-threads=single \
18 --enable-fortran=no \
19 --enable-fast=all \
20 --enable-g=none \
21 --enable-timing=none \
22 --enable-cxx
23 make -j4
24 make install
25 cd -
26 fi
0 #!/bin/bash
1
2 make -f test.mk RABIT_BUILD_DMLC=1 model_recover_10_10k || exit -1
3 make -f test.mk RABIT_BUILD_DMLC=1 model_recover_10_10k_die_same || exit -1
4 make -f test.mk RABIT_BUILD_DMLC=1 model_recover_10_10k_die_hard || exit -1
5 make -f test.mk RABIT_BUILD_DMLC=1 local_recover_10_10k || exit -1
6 make -f test.mk RABIT_BUILD_DMLC=1 lazy_recover_10_10k_die_hard || exit -1
7 make -f test.mk RABIT_BUILD_DMLC=1 lazy_recover_10_10k_die_same || exit -1
8 make -f test.mk RABIT_BUILD_DMLC=1 ringallreduce_10_10k || exit -1
9 make -f test.mk RABIT_BUILD_DMLC=1 pylocal_recover_10_10k || exit -1
0 #!/bin/bash
1
2 # main script of travis
3 if [ ${TASK} == "lint" ]; then
4 make lint RABIT_BUILD_DMLC=1 || exit -1
5 fi
6
7 if [ ${TASK} == "doc" ]; then
8 make doc 2>log.txt
9 (cat log.txt| grep -v ENABLE_PREPROCESSING |grep -v "unsupported tag" |grep warning) && exit -1
10 fi
11
12 # we should depreciate Makefile based build
13 if [ ${TASK} == "build" ]; then
14 make all RABIT_BUILD_DMLC=1 || exit -1
15 fi
16
17 if [ ${TASK} == "mpi-build" ]; then
18 ./scripts/mpi_build.sh
19 cd test
20 make mpi RABIT_BUILD_DMLC=1 && make speed_test.mpi RABIT_BUILD_DMLC=1 || exit -1
21 fi
22 #
23 if [ ${TASK} == "cmake-test" ]; then
24 mkdir build
25 cd build
26 cmake -DRABIT_BUILD_TESTS=ON -DRABIT_BUILD_DMLC=ON -DGTEST_ROOT=${HOME}/.local ..
27 # known osx gtest 1.8 issue
28 cp ${HOME}/.local/lib/*.dylib .
29 make -j$(nproc)
30 make test
31 make install || exit -1
32 cd ../test
33 ../scripts/travis_runtest.sh || exit -1
34 rm -rf ../build
35 fi
0 #!/bin/bash
1
2 echo "Testing on: ${TRAVIS_OS_NAME}, Home directory: ${HOME}"
3
4 pip3 install cpplint pylint urllib3 numpy cpplint
5 pip3 install websocket-client kubernetes
6
7
8 # Install googletest under home directory
9 GTEST_VERSION=1.8.1
10 GTEST_RELEASE=release-${GTEST_VERSION}.tar.gz
11 GTEST_TAR_BALL=googletest_${GTEST_RELEASE}
12
13 wget https://github.com/google/googletest/archive/${GTEST_RELEASE} -O ${GTEST_TAR_BALL}
14 echo "152b849610d91a9dfa1401293f43230c2e0c33f8 ${GTEST_TAR_BALL}" | sha1sum -c
15 tar -xf ${GTEST_TAR_BALL}
16 pushd .
17
18 cd googletest-release-${GTEST_VERSION}
19 mkdir build
20 cd build
21 echo "Installing to ${HOME}/.local"
22 cmake .. -DBUILD_SHARED_LIBS=ON -DCMAKE_INSTALL_PREFIX=${HOME}/.local
23 make -j$(nproc)
24 make install
25
26 popd
27
28 if [ ${TRAVIS_OS_NAME} == "linux" ]; then
29 sudo apt-get install python3-pip tree
30 fi
31
32 if [ ${TRAVIS_OS_NAME} == "osx" ]; then
33 brew install python3
34 fi
0 option(DMLC_ROOT "Specify root of external dmlc core.")
1
2 add_library(allreduce_base "")
3 add_library(allreduce_mock "")
4
5 target_sources(
6 allreduce_base
7 PRIVATE
8 allreduce_base.cc
9 PUBLIC
10 ${CMAKE_CURRENT_LIST_DIR}/allreduce_base.h
11 )
12 target_sources(
13 allreduce_mock
14 PRIVATE
15 allreduce_robust.cc
16 PUBLIC
17 ${CMAKE_CURRENT_LIST_DIR}/allreduce_mock.h
18 )
19
20 target_include_directories(
21 allreduce_base
22 PUBLIC
23 ${DMLC_ROOT}/include
24 ${CMAKE_CURRENT_LIST_DIR}/../../include)
25
26 target_include_directories(
27 allreduce_mock
28 PUBLIC
29 ${DMLC_ROOT}/include
30 ${CMAKE_CURRENT_LIST_DIR}/../../include)
0 Source Files of Rabit
1 ====
2 * This folder contains the source files of rabit library
3 * The library headers are in folder [include](../include)
4 * The .h files in this folder are internal header files that are only used by rabit and will not be seen by users
5
0 /*!
1 * Copyright (c) 2014 by Contributors
2 * \file allreduce_base.cc
3 * \brief Basic implementation of AllReduce
4 *
5 * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
6 */
7 #define _CRT_SECURE_NO_WARNINGS
8 #define _CRT_SECURE_NO_DEPRECATE
9 #define NOMINMAX
10 #include <netinet/tcp.h>
11 #include <cstring>
12 #include <map>
13 #include "allreduce_base.h"
14
15 namespace rabit {
16
17 namespace utils {
18 bool STOP_PROCESS_ON_ERROR = true;
19 }
20
21 namespace engine {
22 // constructor
23 AllreduceBase::AllreduceBase(void) {
24 tracker_uri = "NULL";
25 tracker_port = 9000;
26 host_uri = "";
27 slave_port = 9010;
28 nport_trial = 1000;
29 rank = 0;
30 world_size = -1;
31 connect_retry = 5;
32 hadoop_mode = 0;
33 version_number = 0;
34 // 32 K items
35 reduce_ring_mincount = 32 << 10;
36 // tracker URL
37 task_id = "NULL";
38 err_link = NULL;
39 dmlc_role = "worker";
40 this->SetParam("rabit_reduce_buffer", "256MB");
41 // setup possible enviroment variable of interest
42 // include dmlc support direct variables
43 env_vars.push_back("DMLC_TASK_ID");
44 env_vars.push_back("DMLC_ROLE");
45 env_vars.push_back("DMLC_NUM_ATTEMPT");
46 env_vars.push_back("DMLC_TRACKER_URI");
47 env_vars.push_back("DMLC_TRACKER_PORT");
48 env_vars.push_back("DMLC_WORKER_CONNECT_RETRY");
49 env_vars.push_back("DMLC_WORKER_STOP_PROCESS_ON_ERROR");
50 }
51
52 // initialization function
53 bool AllreduceBase::Init(int argc, char* argv[]) {
54 // setup from enviroment variables
55 // handler to get variables from env
56 for (size_t i = 0; i < env_vars.size(); ++i) {
57 const char *value = getenv(env_vars[i].c_str());
58 if (value != NULL) {
59 this->SetParam(env_vars[i].c_str(), value);
60 }
61 }
62 // pass in arguments override env variable.
63 for (int i = 0; i < argc; ++i) {
64 char name[256], val[256];
65 if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
66 this->SetParam(name, val);
67 }
68 }
69
70 {
71 // handling for hadoop
72 const char *task_id = getenv("mapred_tip_id");
73 if (task_id == NULL) {
74 task_id = getenv("mapreduce_task_id");
75 }
76 if (hadoop_mode) {
77 utils::Check(task_id != NULL,
78 "hadoop_mode is set but cannot find mapred_task_id");
79 }
80 if (task_id != NULL) {
81 this->SetParam("rabit_task_id", task_id);
82 this->SetParam("rabit_hadoop_mode", "1");
83 }
84 const char *attempt_id = getenv("mapred_task_id");
85 if (attempt_id != 0) {
86 const char *att = strrchr(attempt_id, '_');
87 int num_trial;
88 if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) {
89 this->SetParam("rabit_num_trial", att + 1);
90 }
91 }
92 // handling for hadoop
93 const char *num_task = getenv("mapred_map_tasks");
94 if (num_task == NULL) {
95 num_task = getenv("mapreduce_job_maps");
96 }
97 if (hadoop_mode) {
98 utils::Check(num_task != NULL,
99 "hadoop_mode is set but cannot find mapred_map_tasks");
100 }
101 if (num_task != NULL) {
102 this->SetParam("rabit_world_size", num_task);
103 }
104 }
105 if (dmlc_role != "worker") {
106 fprintf(stderr, "Rabit Module currently only work with dmlc worker"\
107 ", quit this program by exit 0\n");
108 exit(0);
109 }
110
111 // clear the setting before start reconnection
112 this->rank = -1;
113 //---------------------
114 // start socket
115 utils::Socket::Startup();
116 utils::Assert(all_links.size() == 0, "can only call Init once");
117 this->host_uri = utils::SockAddr::GetHostName();
118 // get information from tracker
119 return this->ReConnectLinks();
120 }
121
122 bool AllreduceBase::Shutdown(void) {
123 try {
124 for (size_t i = 0; i < all_links.size(); ++i) {
125 all_links[i].sock.Close();
126 }
127 all_links.clear();
128 tree_links.plinks.clear();
129
130 if (tracker_uri == "NULL") return true;
131 // notify tracker rank i have shutdown
132 utils::TCPSocket tracker = this->ConnectTracker();
133 tracker.SendStr(std::string("shutdown"));
134 tracker.Close();
135 utils::TCPSocket::Finalize();
136 return true;
137 } catch (const std::exception& e) {
138 fprintf(stderr, "failed to shutdown due to %s\n", e.what());
139 return false;
140 }
141 }
142
143 void AllreduceBase::TrackerPrint(const std::string &msg) {
144 if (tracker_uri == "NULL") {
145 utils::Printf("%s", msg.c_str()); return;
146 }
147 utils::TCPSocket tracker = this->ConnectTracker();
148 tracker.SendStr(std::string("print"));
149 tracker.SendStr(msg);
150 tracker.Close();
151 }
152
153 // util to parse data with unit suffix
154 inline size_t ParseUnit(const char *name, const char *val) {
155 char unit;
156 unsigned long amt; // NOLINT(*)
157 int n = sscanf(val, "%lu%c", &amt, &unit);
158 size_t amount = amt;
159 if (n == 2) {
160 switch (unit) {
161 case 'B': return amount;
162 case 'K': return amount << 10UL;
163 case 'M': return amount << 20UL;
164 case 'G': return amount << 30UL;
165 default: utils::Error("invalid format for %s", name); return 0;
166 }
167 } else if (n == 1) {
168 return amount;
169 } else {
170 utils::Error("invalid format for %s," \
171 "shhould be {integer}{unit}, unit can be {B, KB, MB, GB}", name);
172 return 0;
173 }
174 }
175 /*!
176 * \brief set parameters to the engine
177 * \param name parameter name
178 * \param val parameter value
179 */
180 void AllreduceBase::SetParam(const char *name, const char *val) {
181 if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val;
182 if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val);
183 if (!strcmp(name, "rabit_task_id")) task_id = val;
184 if (!strcmp(name, "DMLC_TRACKER_URI")) tracker_uri = val;
185 if (!strcmp(name, "DMLC_TRACKER_PORT")) tracker_port = atoi(val);
186 if (!strcmp(name, "DMLC_TASK_ID")) task_id = val;
187 if (!strcmp(name, "DMLC_ROLE")) dmlc_role = val;
188 if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
189 if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = utils::StringToBool(val);
190 if (!strcmp(name, "rabit_reduce_ring_mincount")) {
191 reduce_ring_mincount = atoi(val);
192 utils::Assert(reduce_ring_mincount > 0, "rabit_reduce_ring_mincount should be greater than 0");
193 }
194 if (!strcmp(name, "rabit_reduce_buffer")) {
195 reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3;
196 }
197 if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) {
198 connect_retry = atoi(val);
199 }
200 if (!strcmp(name, "DMLC_WORKER_STOP_PROCESS_ON_ERROR")) {
201 if (!strcmp(val, "true")) {
202 rabit::utils::STOP_PROCESS_ON_ERROR = true;
203 } else if (!strcmp(val, "false")) {
204 rabit::utils::STOP_PROCESS_ON_ERROR = false;
205 } else {
206 throw std::runtime_error("invalid value of DMLC_WORKER_STOP_PROCESS_ON_ERROR");
207 }
208 }
209 if (!strcmp(name, "rabit_bootstrap_cache")) {
210 rabit_bootstrap_cache = utils::StringToBool(val);
211 }
212 if (!strcmp(name, "rabit_debug")) {
213 rabit_debug = utils::StringToBool(val);
214 }
215 if (!strcmp(name, "rabit_timeout")) {
216 rabit_timeout = utils::StringToBool(val);
217 }
218 if (!strcmp(name, "rabit_timeout_sec")) {
219 timeout_sec = atoi(val);
220 utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second");
221 }
222 if (!strcmp(name, "rabit_enable_tcp_no_delay")) {
223 if (!strcmp(val, "true"))
224 rabit_enable_tcp_no_delay = true;
225 else
226 rabit_enable_tcp_no_delay = false;
227 }
228 }
229 /*!
230 * \brief initialize connection to the tracker
231 * \return a socket that initializes the connection
232 */
233 utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
234 int magic = kMagic;
235 // get information from tracker
236 utils::TCPSocket tracker;
237 tracker.Create();
238
239 int retry = 0;
240 do {
241 if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
242 if (++retry >= connect_retry) {
243 fprintf(stderr, "connect to (failed): [%s]\n", tracker_uri.c_str());
244 utils::Socket::Error("Connect");
245 } else {
246 fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str());
247 #if defined(_MSC_VER) || defined (__MINGW32__)
248 Sleep(retry << 1);
249 #else
250 sleep(retry << 1);
251 #endif
252 continue;
253 }
254 }
255 break;
256 } while (1);
257
258 using utils::Assert;
259 Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
260 "ReConnectLink failure 1");
261 Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic),
262 "ReConnectLink failure 2");
263 utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
264 Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank),
265 "ReConnectLink failure 3");
266 Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
267 "ReConnectLink failure 3");
268 tracker.SendStr(task_id);
269 return tracker;
270 }
271 /*!
272 * \brief connect to the tracker to fix the the missing links
273 * this function is also used when the engine start up
274 */
275 bool AllreduceBase::ReConnectLinks(const char *cmd) {
276 // single node mode
277 if (tracker_uri == "NULL") {
278 rank = 0; world_size = 1; return true;
279 }
280 try {
281 utils::TCPSocket tracker = this->ConnectTracker();
282 fprintf(stdout, "task %s connected to the tracker\n", task_id.c_str());
283 tracker.SendStr(std::string(cmd));
284
285 // the rank of previous link, next link in ring
286 int prev_rank, next_rank;
287 // the rank of neighbors
288 std::map<int, int> tree_neighbors;
289 using utils::Assert;
290 // get new ranks
291 int newrank, num_neighbors;
292 Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
293 "ReConnectLink failure 4");
294 Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == \
295 sizeof(parent_rank), "ReConnectLink failure 4");
296 Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
297 "ReConnectLink failure 4");
298 Assert(rank == -1 || newrank == rank,
299 "must keep rank to same if the node already have one");
300 rank = newrank;
301
302 // tracker got overwhelemed and not able to assign correct rank
303 if (rank == -1) exit(-1);
304
305 fprintf(stdout, "task %s got new rank %d\n", task_id.c_str(), rank);
306
307 Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
308 sizeof(num_neighbors), "ReConnectLink failure 4");
309 for (int i = 0; i < num_neighbors; ++i) {
310 int nrank;
311 Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
312 "ReConnectLink failure 4");
313 tree_neighbors[nrank] = 1;
314 }
315 Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
316 "ReConnectLink failure 4");
317 Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
318 "ReConnectLink failure 4");
319
320 utils::TCPSocket sock_listen;
321 if (!sock_listen.IsClosed()) {
322 sock_listen.Close();
323 }
324 // create listening socket
325 sock_listen.Create();
326 int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial);
327 utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
328 sock_listen.Listen();
329
330 // get number of to connect and number of to accept nodes from tracker
331 int num_conn, num_accept, num_error = 1;
332 do {
333 // send over good links
334 std::vector<int> good_link;
335 for (size_t i = 0; i < all_links.size(); ++i) {
336 if (!all_links[i].sock.BadSocket()) {
337 good_link.push_back(static_cast<int>(all_links[i].rank));
338 } else {
339 if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
340 }
341 }
342 int ngood = static_cast<int>(good_link.size());
343 Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
344 "ReConnectLink failure 5");
345 for (size_t i = 0; i < good_link.size(); ++i) {
346 Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
347 sizeof(good_link[i]), "ReConnectLink failure 6");
348 }
349 Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
350 "ReConnectLink failure 7");
351 Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
352 sizeof(num_accept), "ReConnectLink failure 8");
353 num_error = 0;
354 for (int i = 0; i < num_conn; ++i) {
355 LinkRecord r;
356 int hport, hrank;
357 std::string hname;
358 tracker.RecvStr(&hname);
359 Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
360 "ReConnectLink failure 9");
361 Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
362 "ReConnectLink failure 10");
363
364 r.sock.Create();
365 if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
366 num_error += 1;
367 r.sock.Close();
368 continue;
369 }
370 Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
371 "ReConnectLink failure 12");
372 Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
373 "ReConnectLink failure 13");
374 utils::Check(hrank == r.rank,
375 "ReConnectLink failure, link rank inconsistent");
376 bool match = false;
377 for (size_t i = 0; i < all_links.size(); ++i) {
378 if (all_links[i].rank == hrank) {
379 Assert(all_links[i].sock.IsClosed(),
380 "Override a link that is active");
381 all_links[i].sock = r.sock;
382 match = true;
383 break;
384 }
385 }
386 if (!match) all_links.push_back(r);
387 }
388 Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
389 "ReConnectLink failure 14");
390 } while (num_error != 0);
391 // send back socket listening port to tracker
392 Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
393 "ReConnectLink failure 14");
394 // close connection to tracker
395 tracker.Close();
396 // listen to incoming links
397 for (int i = 0; i < num_accept; ++i) {
398 LinkRecord r;
399 r.sock = sock_listen.Accept();
400 Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
401 "ReConnectLink failure 15");
402 Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
403 "ReConnectLink failure 15");
404 bool match = false;
405 for (size_t i = 0; i < all_links.size(); ++i) {
406 if (all_links[i].rank == r.rank) {
407 utils::Assert(all_links[i].sock.IsClosed(),
408 "Override a link that is active");
409 all_links[i].sock = r.sock;
410 match = true;
411 break;
412 }
413 }
414 if (!match) all_links.push_back(r);
415 }
416 sock_listen.Close();
417 this->parent_index = -1;
418 // setup tree links and ring structure
419 tree_links.plinks.clear();
420 int tcpNoDelay = 1;
421 for (size_t i = 0; i < all_links.size(); ++i) {
422 utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
423 // set the socket to non-blocking mode, enable TCP keepalive
424 all_links[i].sock.SetNonBlock(true);
425 all_links[i].sock.SetKeepAlive(true);
426 if (rabit_enable_tcp_no_delay) {
427 setsockopt(all_links[i].sock, IPPROTO_TCP,
428 TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
429 }
430 if (tree_neighbors.count(all_links[i].rank) != 0) {
431 if (all_links[i].rank == parent_rank) {
432 parent_index = static_cast<int>(tree_links.plinks.size());
433 }
434 tree_links.plinks.push_back(&all_links[i]);
435 }
436 if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
437 if (all_links[i].rank == next_rank) ring_next = &all_links[i];
438 }
439 Assert(parent_rank == -1 || parent_index != -1,
440 "cannot find parent in the link");
441 Assert(prev_rank == -1 || ring_prev != NULL,
442 "cannot find prev ring in the link");
443 Assert(next_rank == -1 || ring_next != NULL,
444 "cannot find next ring in the link");
445 return true;
446 } catch (const std::exception& e) {
447 fprintf(stderr, "failed in ReconnectLink %s\n", e.what());
448 return false;
449 }
450 }
451 /*!
452 * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
453 *
454 * NOTE on Allreduce:
455 * The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
456 * It only means the current node get the correct result of Allreduce.
457 * However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
458 *
459 * \param sendrecvbuf_ buffer for both sending and recving data
460 * \param type_nbytes the unit number of bytes the type have
461 * \param count number of elements to be reduced
462 * \param reducer reduce function
463 * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
464 * \sa ReturnType
465 */
466 AllreduceBase::ReturnType
467 AllreduceBase::TryAllreduce(void *sendrecvbuf_,
468 size_t type_nbytes,
469 size_t count,
470 ReduceFunction reducer) {
471 if (count > reduce_ring_mincount) {
472 return this->TryAllreduceRing(sendrecvbuf_, type_nbytes, count, reducer);
473 } else {
474 return this->TryAllreduceTree(sendrecvbuf_, type_nbytes, count, reducer);
475 }
476 }
477 /*!
478 * \brief perform in-place allreduce, on sendrecvbuf,
479 * this function implements tree-shape reduction
480 *
481 * \param sendrecvbuf_ buffer for both sending and recving data
482 * \param type_nbytes the unit number of bytes the type have
483 * \param count number of elements to be reduced
484 * \param reducer reduce function
485 * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
486 * \sa ReturnType
487 */
488 AllreduceBase::ReturnType
489 AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
490 size_t type_nbytes,
491 size_t count,
492 ReduceFunction reducer) {
493 RefLinkVector &links = tree_links;
494 if (links.size() == 0 || count == 0) return kSuccess;
495 // total size of message
496 const size_t total_size = type_nbytes * count;
497 // number of links
498 const int nlink = static_cast<int>(links.size());
499 // send recv buffer
500 char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
501 // size of space that we already performs reduce in up pass
502 size_t size_up_reduce = 0;
503 // size of space that we have already passed to parent
504 size_t size_up_out = 0;
505 // size of message we received, and send in the down pass
506 size_t size_down_in = 0;
507 // initialize the link ring-buffer and pointer
508 for (int i = 0; i < nlink; ++i) {
509 if (i != parent_index) {
510 links[i].InitBuffer(type_nbytes, count, reduce_buffer_size);
511 }
512 links[i].ResetSize();
513 }
514 // if no childs, no need to reduce
515 if (nlink == static_cast<int>(parent_index != -1)) {
516 size_up_reduce = total_size;
517 }
518 // while we have not passed the messages out
519 while (true) {
520 // select helper
521 bool finished = true;
522 utils::PollHelper watcher;
523 for (int i = 0; i < nlink; ++i) {
524 if (i == parent_index) {
525 if (size_down_in != total_size) {
526 watcher.WatchRead(links[i].sock);
527 // only watch for exception in live channels
528 watcher.WatchException(links[i].sock);
529 finished = false;
530 }
531 if (size_up_out != total_size && size_up_out < size_up_reduce) {
532 watcher.WatchWrite(links[i].sock);
533 }
534 } else {
535 if (links[i].size_read != total_size) {
536 watcher.WatchRead(links[i].sock);
537 }
538 // size_write <= size_read
539 if (links[i].size_write != total_size) {
540 if (links[i].size_write < size_down_in) {
541 watcher.WatchWrite(links[i].sock);
542 }
543 // only watch for exception in live channels
544 watcher.WatchException(links[i].sock);
545 finished = false;
546 }
547 }
548 }
549 // finish runing allreduce
550 if (finished) break;
551 // select must return
552 watcher.Poll();
553 // exception handling
554 for (int i = 0; i < nlink; ++i) {
555 // recive OOB message from some link
556 if (watcher.CheckExcept(links[i].sock)) {
557 return ReportError(&links[i], kGetExcept);
558 }
559 }
560 // read data from childs
561 for (int i = 0; i < nlink; ++i) {
562 if (i != parent_index && watcher.CheckRead(links[i].sock)) {
563 ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
564 if (ret != kSuccess) {
565 return ReportError(&links[i], ret);
566 }
567 }
568 }
569 // this node have childs, peform reduce
570 if (nlink > static_cast<int>(parent_index != -1)) {
571 size_t buffer_size = 0;
572 // do upstream reduce
573 size_t max_reduce = total_size;
574 for (int i = 0; i < nlink; ++i) {
575 if (i != parent_index) {
576 max_reduce = std::min(max_reduce, links[i].size_read);
577 utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size,
578 "buffer size inconsistent");
579 buffer_size = links[i].buffer_size;
580 }
581 }
582 utils::Assert(buffer_size != 0, "must assign buffer_size");
583 // round to type_n4bytes
584 max_reduce = (max_reduce / type_nbytes * type_nbytes);
585 // peform reduce, can be at most two rounds
586 while (size_up_reduce < max_reduce) {
587 // start position
588 size_t start = size_up_reduce % buffer_size;
589 // peform read till end of buffer
590 size_t nread = std::min(buffer_size - start,
591 max_reduce - size_up_reduce);
592 utils::Assert(nread % type_nbytes == 0, "Allreduce: size check");
593 for (int i = 0; i < nlink; ++i) {
594 if (i != parent_index) {
595 reducer(links[i].buffer_head + start,
596 sendrecvbuf + size_up_reduce,
597 static_cast<int>(nread / type_nbytes),
598 MPI::Datatype(type_nbytes));
599 }
600 }
601 size_up_reduce += nread;
602 }
603 }
604 if (parent_index != -1) {
605 // pass message up to parent, can pass data that are already been reduced
606 if (size_up_out < size_up_reduce) {
607 ssize_t len = links[parent_index].sock.
608 Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out);
609 if (len != -1) {
610 size_up_out += static_cast<size_t>(len);
611 } else {
612 ReturnType ret = Errno2Return();
613 if (ret != kSuccess) {
614 return ReportError(&links[parent_index], ret);
615 }
616 }
617 }
618 // read data from parent
619 if (watcher.CheckRead(links[parent_index].sock) &&
620 total_size > size_down_in) {
621 ssize_t len = links[parent_index].sock.
622 Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
623 if (len == 0) {
624 links[parent_index].sock.Close();
625 return ReportError(&links[parent_index], kRecvZeroLen);
626 }
627 if (len != -1) {
628 size_down_in += static_cast<size_t>(len);
629 utils::Assert(size_down_in <= size_up_out,
630 "Allreduce: boundary error");
631 } else {
632 ReturnType ret = Errno2Return();
633 if (ret != kSuccess) {
634 return ReportError(&links[parent_index], ret);
635 }
636 }
637 }
638 } else {
639 // this is root, can use reduce as most recent point
640 size_down_in = size_up_out = size_up_reduce;
641 }
642 // can pass message down to childs
643 for (int i = 0; i < nlink; ++i) {
644 if (i != parent_index && links[i].size_write < size_down_in) {
645 ReturnType ret = links[i].WriteFromArray(sendrecvbuf, size_down_in);
646 if (ret != kSuccess) {
647 return ReportError(&links[i], ret);
648 }
649 }
650 }
651 }
652 return kSuccess;
653 }
654 /*!
655 * \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
656 * \param sendrecvbuf_ buffer for both sending and recving data
657 * \param total_size the size of the data to be broadcasted
658 * \param root the root worker id to broadcast the data
659 * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
660 * \sa ReturnType
661 */
662 AllreduceBase::ReturnType
663 AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
664 RefLinkVector &links = tree_links;
665 if (links.size() == 0 || total_size == 0) return kSuccess;
666 utils::Check(root < world_size,
667 "Broadcast: root should be smaller than world size");
668 // number of links
669 const int nlink = static_cast<int>(links.size());
670 // size of space already read from data
671 size_t size_in = 0;
672 // input link, -2 means unknown yet, -1 means this is root
673 int in_link = -2;
674
675 // initialize the link statistics
676 for (int i = 0; i < nlink; ++i) {
677 links[i].ResetSize();
678 }
679 // root have all the data
680 if (this->rank == root) {
681 size_in = total_size;
682 in_link = -1;
683 }
684 // while we have not passed the messages out
685 while (true) {
686 bool finished = true;
687 // select helper
688 utils::PollHelper watcher;
689 for (int i = 0; i < nlink; ++i) {
690 if (in_link == -2) {
691 watcher.WatchRead(links[i].sock); finished = false;
692 }
693 if (i == in_link && links[i].size_read != total_size) {
694 watcher.WatchRead(links[i].sock); finished = false;
695 }
696 if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
697 if (links[i].size_write < size_in) {
698 watcher.WatchWrite(links[i].sock);
699 }
700 finished = false;
701 }
702 watcher.WatchException(links[i].sock);
703 }
704 // finish running
705 if (finished) break;
706 // select
707 watcher.Poll();
708 // exception handling
709 for (int i = 0; i < nlink; ++i) {
710 // recive OOB message from some link
711 if (watcher.CheckExcept(links[i].sock)) {
712 return ReportError(&links[i], kGetExcept);
713 }
714 }
715 if (in_link == -2) {
716 // probe in-link
717 for (int i = 0; i < nlink; ++i) {
718 if (watcher.CheckRead(links[i].sock)) {
719 ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size);
720 if (ret != kSuccess) {
721 return ReportError(&links[i], ret);
722 }
723 size_in = links[i].size_read;
724 if (size_in != 0) {
725 in_link = i; break;
726 }
727 }
728 }
729 } else {
730 // read from in link
731 if (in_link >= 0 && watcher.CheckRead(links[in_link].sock)) {
732 ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size);
733 if (ret != kSuccess) {
734 return ReportError(&links[in_link], ret);
735 }
736 size_in = links[in_link].size_read;
737 }
738 }
739 // send data to all out-link
740 for (int i = 0; i < nlink; ++i) {
741 if (i != in_link && links[i].size_write < size_in) {
742 ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, size_in);
743 if (ret != kSuccess) {
744 return ReportError(&links[i], ret);
745 }
746 }
747 }
748 }
749 return kSuccess;
750 }
751 /*!
752 * \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
753 * the data provided by current node k is [slice_begin, slice_end),
754 * the next node's segment must start with slice_end
755 * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
756 * use a ring based algorithm
757 *
758 * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
759 * \param total_size total size of data to be gathered
760 * \param slice_begin beginning of the current slice
761 * \param slice_end end of the current slice
762 * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
763 */
764 AllreduceBase::ReturnType
765 AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
766 size_t slice_begin,
767 size_t slice_end,
768 size_t size_prev_slice) {
769 // read from next link and send to prev one
770 LinkRecord &prev = *ring_prev, &next = *ring_next;
771 // need to reply on special rank structure
772 utils::Assert(next.rank == (rank + 1) % world_size &&
773 rank == (prev.rank + 1) % world_size,
774 "need to assume rank structure");
775 // send recv buffer
776 char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
777 const size_t stop_read = total_size + slice_begin;
778 const size_t stop_write = total_size + slice_begin - size_prev_slice;
779 size_t write_ptr = slice_begin;
780 size_t read_ptr = slice_end;
781
782 while (true) {
783 // select helper
784 bool finished = true;
785 utils::PollHelper watcher;
786 if (read_ptr != stop_read) {
787 watcher.WatchRead(next.sock);
788 finished = false;
789 }
790 if (write_ptr != stop_write) {
791 if (write_ptr < read_ptr) {
792 watcher.WatchWrite(prev.sock);
793 }
794 finished = false;
795 }
796 if (finished) break;
797 watcher.Poll();
798 if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
799 size_t size = stop_read - read_ptr;
800 size_t start = read_ptr % total_size;
801 if (start + size > total_size) {
802 size = total_size - start;
803 }
804 ssize_t len = next.sock.Recv(sendrecvbuf + start, size);
805 if (len != -1) {
806 read_ptr += static_cast<size_t>(len);
807 } else {
808 ReturnType ret = Errno2Return();
809 if (ret != kSuccess) return ReportError(&next, ret);
810 }
811 }
812 if (write_ptr < read_ptr && write_ptr != stop_write) {
813 size_t size = std::min(read_ptr, stop_write) - write_ptr;
814 size_t start = write_ptr % total_size;
815 if (start + size > total_size) {
816 size = total_size - start;
817 }
818 ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
819 if (len != -1) {
820 write_ptr += static_cast<size_t>(len);
821 } else {
822 ReturnType ret = Errno2Return();
823 if (ret != kSuccess) return ReportError(&prev, ret);
824 }
825 }
826 }
827 return kSuccess;
828 }
829 /*!
830 * \brief perform in-place allreduce, on sendrecvbuf, this function can fail,
831 * and will return the cause of failure
832 *
833 * Ring-based algorithm
834 *
835 * \param sendrecvbuf_ buffer for both sending and recving data
836 * \param type_nbytes the unit number of bytes the type have
837 * \param count number of elements to be reduced
838 * \param reducer reduce function
839 * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
840 * \sa ReturnType, TryAllreduce
841 */
842 AllreduceBase::ReturnType
843 AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
844 size_t type_nbytes,
845 size_t count,
846 ReduceFunction reducer) {
847 // read from next link and send to prev one
848 LinkRecord &prev = *ring_prev, &next = *ring_next;
849 // need to reply on special rank structure
850 utils::Assert(next.rank == (rank + 1) % world_size &&
851 rank == (prev.rank + 1) % world_size,
852 "need to assume rank structure");
853 // total size of message
854 const size_t total_size = type_nbytes * count;
855 size_t n = static_cast<size_t>(world_size);
856 size_t step = (count + n - 1) / n;
857 size_t r = static_cast<size_t>(next.rank);
858 size_t write_ptr = std::min(r * step, count) * type_nbytes;
859 size_t read_ptr = std::min((r + 1) * step, count) * type_nbytes;
860 size_t reduce_ptr = read_ptr;
861 // send recv buffer
862 char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
863 // position to stop reading
864 const size_t stop_read = total_size + write_ptr;
865 // position to stop writing
866 size_t stop_write = total_size + std::min(rank * step, count) * type_nbytes;
867 if (stop_write > stop_read) {
868 stop_write -= total_size;
869 utils::Assert(write_ptr <= stop_write, "write ptr boundary check");
870 }
871 // use ring buffer in next position
872 next.InitBuffer(type_nbytes, step, reduce_buffer_size);
873 // set size_read to read pointer for ring buffer to work properly
874 next.size_read = read_ptr;
875
876 while (true) {
877 // select helper
878 bool finished = true;
879 utils::PollHelper watcher;
880 if (read_ptr != stop_read) {
881 watcher.WatchRead(next.sock);
882 finished = false;
883 }
884 if (write_ptr != stop_write) {
885 if (write_ptr < reduce_ptr) {
886 watcher.WatchWrite(prev.sock);
887 }
888 finished = false;
889 }
890 if (finished) break;
891 watcher.Poll();
892 if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
893 ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
894 if (ret != kSuccess) {
895 return ReportError(&next, ret);
896 }
897 // sync the rate
898 read_ptr = next.size_read;
899 utils::Assert(read_ptr <= stop_read, "[%d] read_ptr boundary check", rank);
900 const size_t buffer_size = next.buffer_size;
901 size_t max_reduce = (read_ptr / type_nbytes) * type_nbytes;
902 while (reduce_ptr < max_reduce) {
903 size_t bstart = reduce_ptr % buffer_size;
904 size_t nread = std::min(buffer_size - bstart,
905 max_reduce - reduce_ptr);
906 size_t rstart = reduce_ptr % total_size;
907 nread = std::min(nread, total_size - rstart);
908 reducer(next.buffer_head + bstart,
909 sendrecvbuf + rstart,
910 static_cast<int>(nread / type_nbytes),
911 MPI::Datatype(type_nbytes));
912 reduce_ptr += nread;
913 }
914 }
915 if (write_ptr < reduce_ptr && write_ptr != stop_write) {
916 size_t size = std::min(reduce_ptr, stop_write) - write_ptr;
917 size_t start = write_ptr % total_size;
918 if (start + size > total_size) {
919 size = total_size - start;
920 }
921 ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
922 if (len != -1) {
923 write_ptr += static_cast<size_t>(len);
924 } else {
925 ReturnType ret = Errno2Return();
926 if (ret != kSuccess) return ReportError(&prev, ret);
927 }
928 }
929 }
930 return kSuccess;
931 }
932 /*!
933 * \brief perform in-place allreduce, on sendrecvbuf
934 * use a ring based algorithm
935 *
936 * \param sendrecvbuf_ buffer for both sending and recving data
937 * \param type_nbytes the unit number of bytes the type have
938 * \param count number of elements to be reduced
939 * \param reducer reduce function
940 * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
941 * \sa ReturnType
942 */
943 AllreduceBase::ReturnType
944 AllreduceBase::TryAllreduceRing(void *sendrecvbuf_,
945 size_t type_nbytes,
946 size_t count,
947 ReduceFunction reducer) {
948 ReturnType ret = TryReduceScatterRing(sendrecvbuf_, type_nbytes, count, reducer);
949 if (ret != kSuccess) return ret;
950 size_t n = static_cast<size_t>(world_size);
951 size_t step = (count + n - 1) / n;
952 size_t begin = std::min(rank * step, count) * type_nbytes;
953 size_t end = std::min((rank + 1) * step, count) * type_nbytes;
954 // previous rank
955 int prank = ring_prev->rank;
956 // get rank of previous
957 return TryAllgatherRing
958 (sendrecvbuf_, type_nbytes * count,
959 begin, end,
960 (std::min((prank + 1) * step, count) -
961 std::min(prank * step, count)) * type_nbytes);
962 }
963 } // namespace engine
964 } // namespace rabit
0 /*!
1 * Copyright (c) 2014 by Contributors
2 * \file allreduce_base.h
3 * \brief Basic implementation of AllReduce
4 * using TCP non-block socket and tree-shape reduction.
5 *
6 * This implementation provides basic utility of AllReduce and Broadcast
7 * without considering node failure
8 *
9 * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
10 */
11 #ifndef RABIT_ALLREDUCE_BASE_H_
12 #define RABIT_ALLREDUCE_BASE_H_
13
14 #include <vector>
15 #include <string>
16 #include <algorithm>
17 #include "rabit/internal/utils.h"
18 #include "rabit/internal/engine.h"
19 #include "rabit/internal/socket.h"
20
21 #ifdef RABIT_CXXTESTDEFS_H
22 #define private public
23 #define protected public
24 #endif // RABIT_CXXTESTDEFS_H
25
26
27 namespace MPI {
28 // MPI data type to be compatible with existing MPI interface
29 class Datatype {
30 public:
31 size_t type_size;
32 explicit Datatype(size_t type_size) : type_size(type_size) {}
33 };
34 }
35 namespace rabit {
36 namespace engine {
37 /*! \brief implementation of basic Allreduce engine */
38 class AllreduceBase : public IEngine {
39 public:
40 // magic number to verify server
41 static const int kMagic = 0xff99;
42 // constant one byte out of band message to indicate error happening
43 AllreduceBase(void);
44 virtual ~AllreduceBase(void) {}
45 // initialize the manager
46 virtual bool Init(int argc, char* argv[]);
47 // shutdown the engine
48 virtual bool Shutdown(void);
49 /*!
50 * \brief set parameters to the engine
51 * \param name parameter name
52 * \param val parameter value
53 */
54 virtual void SetParam(const char *name, const char *val);
55 /*!
56 * \brief print the msg in the tracker,
57 * this function can be used to communicate the information of the progress to
58 * the user who monitors the tracker
59 * \param msg message to be printed in the tracker
60 */
61 virtual void TrackerPrint(const std::string &msg);
62
63 /*! \brief get rank of previous node in ring topology*/
64 virtual int GetRingPrevRank(void) const {
65 return ring_prev->rank;
66 }
67 /*! \brief get rank */
68 virtual int GetRank(void) const {
69 return rank;
70 }
71 /*! \brief get rank */
72 virtual int GetWorldSize(void) const {
73 if (world_size == -1) return 1;
74 return world_size;
75 }
76 /*! \brief whether is distributed or not */
77 virtual bool IsDistributed(void) const {
78 return tracker_uri != "NULL";
79 }
80 /*! \brief get rank */
81 virtual std::string GetHost(void) const {
82 return host_uri;
83 }
84
85 /*!
86 * \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
87 * the data provided by current node k is [slice_begin, slice_end),
88 * the next node's segment must start with slice_end
89 * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
90 * use a ring based algorithm
91 *
92 * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
93 * \param total_size total size of data to be gathered
94 * \param slice_begin beginning of the current slice
95 * \param slice_end end of the current slice
96 * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
97 * \param _file caller file name used to generate unique cache key
98 * \param _line caller line number used to generate unique cache key
99 * \param _caller caller function name used to generate unique cache key
100 */
101 virtual void Allgather(void *sendrecvbuf_, size_t total_size,
102 size_t slice_begin,
103 size_t slice_end,
104 size_t size_prev_slice,
105 const char* _file = _FILE,
106 const int _line = _LINE,
107 const char* _caller = _CALLER) {
108 if (world_size == 1 || world_size == -1) return;
109 utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size,
110 slice_begin, slice_end, size_prev_slice) == kSuccess,
111 "AllgatherRing failed");
112 }
113 /*!
114 * \brief perform in-place allreduce, on sendrecvbuf
115 * this function is NOT thread-safe
116 * \param sendrecvbuf_ buffer for both sending and recving data
117 * \param type_nbytes the unit number of bytes the type have
118 * \param count number of elements to be reduced
119 * \param reducer reduce function
120 * \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
121 * will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
122 * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
123 * \param prepare_arg argument used to passed into the lazy preprocessing function
124 * \param _file caller file name used to generate unique cache key
125 * \param _line caller line number used to generate unique cache key
126 * \param _caller caller function name used to generate unique cache key
127 */
128 virtual void Allreduce(void *sendrecvbuf_,
129 size_t type_nbytes,
130 size_t count,
131 ReduceFunction reducer,
132 PreprocFunction prepare_fun = NULL,
133 void *prepare_arg = NULL,
134 const char* _file = _FILE,
135 const int _line = _LINE,
136 const char* _caller = _CALLER) {
137 if (prepare_fun != NULL) prepare_fun(prepare_arg);
138 if (world_size == 1 || world_size == -1) return;
139 utils::Assert(TryAllreduce(sendrecvbuf_,
140 type_nbytes, count, reducer) == kSuccess,
141 "Allreduce failed");
142 }
143 /*!
144 * \brief broadcast data from root to all nodes
145 * \param sendrecvbuf_ buffer for both sending and recving data
146 * \param size the size of the data to be broadcasted
147 * \param root the root worker id to broadcast the data
148 * \param _file caller file name used to generate unique cache key
149 * \param _line caller line number used to generate unique cache key
150 * \param _caller caller function name used to generate unique cache key
151 */
152 virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
153 const char* _file = _FILE, const int _line = _LINE, const char* _caller = _CALLER) {
154 if (world_size == 1 || world_size == -1) return;
155 utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
156 "Broadcast failed");
157 }
158 /*!
159 * \brief load latest check point
160 * \param global_model pointer to the globally shared model/state
161 * when calling this function, the caller need to gauranttees that global_model
162 * is the same in all nodes
163 * \param local_model pointer to local model, that is specific to current node/rank
164 * this can be NULL when no local model is needed
165 *
166 * \return the version number of check point loaded
167 * if returned version == 0, this means no model has been CheckPointed
168 * the p_model is not touched, user should do necessary initialization by themselves
169 *
170 * Common usage example:
171 * int iter = rabit::LoadCheckPoint(&model);
172 * if (iter == 0) model.InitParameters();
173 * for (i = iter; i < max_iter; ++i) {
174 * do many things, include allreduce
175 * rabit::CheckPoint(model);
176 * }
177 *
178 * \sa CheckPoint, VersionNumber
179 */
180 virtual int LoadCheckPoint(Serializable *global_model,
181 Serializable *local_model = NULL) {
182 return 0;
183 }
184 /*!
185 * \brief checkpoint the model, meaning we finished a stage of execution
186 * every time we call check point, there is a version number which will increase by one
187 *
188 * \param global_model pointer to the globally shared model/state
189 * when calling this function, the caller need to gauranttees that global_model
190 * is the same in all nodes
191 * \param local_model pointer to local model, that is specific to current node/rank
192 * this can be NULL when no local state is needed
193 *
194 * NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
195 * bring replication cost in CheckPoint function. global_model do not need explicit replication.
196 * So only CheckPoint with global_model if possible
197 *
198 * \sa LoadCheckPoint, VersionNumber
199 */
200 virtual void CheckPoint(const Serializable *global_model,
201 const Serializable *local_model = NULL) {
202 version_number += 1;
203 }
204 /*!
205 * \brief This function can be used to replace CheckPoint for global_model only,
206 * when certain condition is met(see detailed expplaination).
207 *
208 * This is a "lazy" checkpoint such that only the pointer to global_model is
209 * remembered and no memory copy is taken. To use this function, the user MUST ensure that:
210 * The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
211 * In another words, global_model model can be changed only between last call of
212 * Allreduce/Broadcast and LazyCheckPoint in current version
213 *
214 * For example, suppose the calling sequence is:
215 * LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
216 *
217 * If user can only changes global_model in code3, then LazyCheckPoint can be used to
218 * improve efficiency of the program.
219 * \param global_model pointer to the globally shared model/state
220 * when calling this function, the caller need to gauranttees that global_model
221 * is the same in all nodes
222 * \sa LoadCheckPoint, CheckPoint, VersionNumber
223 */
224 virtual void LazyCheckPoint(const Serializable *global_model) {
225 version_number += 1;
226 }
227 /*!
228 * \return version number of current stored model,
229 * which means how many calls to CheckPoint we made so far
230 * \sa LoadCheckPoint, CheckPoint
231 */
232 virtual int VersionNumber(void) const {
233 return version_number;
234 }
235 /*!
236 * \brief explicitly re-init everything before calling LoadCheckPoint
237 * call this function when IEngine throw an exception out,
238 * this function is only used for test purpose
239 */
240 virtual void InitAfterException(void) {
241 utils::Error("InitAfterException: not implemented");
242 }
243 /*!
244 * \brief report current status to the job tracker
245 * depending on the job tracker we are in
246 */
247 inline void ReportStatus(void) const {
248 if (hadoop_mode != 0) {
249 fprintf(stderr, "reporter:status:Rabit Phase[%03d] Operation %03d\n",
250 version_number, seq_counter);
251 }
252 }
253
254 protected:
255 /*! \brief enumeration of possible returning results from Try functions */
256 enum ReturnTypeEnum {
257 /*! \brief execution is successful */
258 kSuccess,
259 /*! \brief a link was reset by peer */
260 kConnReset,
261 /*! \brief received a zero length message */
262 kRecvZeroLen,
263 /*! \brief a neighbor node go down, the connection is dropped */
264 kSockError,
265 /*!
266 * \brief another node which is not my neighbor go down,
267 * get Out-of-Band exception notification from my neighbor
268 */
269 kGetExcept
270 };
271 /*! \brief struct return type to avoid implicit conversion to int/bool */
272 struct ReturnType {
273 /*! \brief internal return type */
274 ReturnTypeEnum value;
275 // constructor
276 ReturnType() {}
277 ReturnType(ReturnTypeEnum value) : value(value) {} // NOLINT(*)
278 inline bool operator==(const ReturnTypeEnum &v) const {
279 return value == v;
280 }
281 inline bool operator!=(const ReturnTypeEnum &v) const {
282 return value != v;
283 }
284 };
285 /*! \brief translate errno to return type */
286 inline static ReturnType Errno2Return() {
287 int errsv = utils::Socket::GetLastError();
288 if (errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == 0) return kSuccess;
289 #ifdef _WIN32
290 if (errsv == WSAEWOULDBLOCK) return kSuccess;
291 if (errsv == WSAECONNRESET) return kConnReset;
292 #endif // _WIN32
293 if (errsv == ECONNRESET) return kConnReset;
294 return kSockError;
295 }
296 // link record to a neighbor
297 struct LinkRecord {
298 public:
299 // socket to get data from/to link
300 utils::TCPSocket sock;
301 // rank of the node in this link
302 int rank;
303 // size of data readed from link
304 size_t size_read;
305 // size of data sent to the link
306 size_t size_write;
307 // pointer to buffer head
308 char *buffer_head;
309 // buffer size, in bytes
310 size_t buffer_size;
311 // constructor
312 LinkRecord(void)
313 : buffer_head(NULL), buffer_size(0) {
314 }
315 // initialize buffer
316 inline void InitBuffer(size_t type_nbytes, size_t count,
317 size_t reduce_buffer_size) {
318 size_t n = (type_nbytes * count + 7)/ 8;
319 buffer_.resize(std::min(reduce_buffer_size, n));
320 // make sure align to type_nbytes
321 buffer_size =
322 buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
323 utils::Assert(type_nbytes <= buffer_size,
324 "too large type_nbytes=%lu, buffer_size=%lu",
325 type_nbytes, buffer_size);
326 // set buffer head
327 buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
328 }
329 // reset the recv and sent size
330 inline void ResetSize(void) {
331 size_write = size_read = 0;
332 }
333 /*!
334 * \brief read data into ring-buffer, with care not to existing useful override data
335 * position after protect_start
336 * \param protect_start all data start from protect_start is still needed in buffer
337 * read shall not override this
338 * \param max_size_read maximum logical amount we can read, size_read cannot exceed this value
339 * \return the type of reading
340 */
341 inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) {
342 utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated");
343 utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check");
344 size_t ngap = size_read - protect_start;
345 utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
346 size_t offset = size_read % buffer_size;
347 size_t nmax = max_size_read - size_read;
348 nmax = std::min(nmax, buffer_size - ngap);
349 nmax = std::min(nmax, buffer_size - offset);
350 if (nmax == 0) return kSuccess;
351 ssize_t len = sock.Recv(buffer_head + offset, nmax);
352 // length equals 0, remote disconnected
353 if (len == 0) {
354 sock.Close(); return kRecvZeroLen;
355 }
356 if (len == -1) return Errno2Return();
357 size_read += static_cast<size_t>(len);
358 return kSuccess;
359 }
360 /*!
361 * \brief read data into array,
362 * this function can not be used together with ReadToRingBuffer
363 * a link can either read into the ring buffer, or existing array
364 * \param max_size maximum size of array
365 * \return true if it is an successful read, false if there is some error happens, check errno
366 */
367 inline ReturnType ReadToArray(void *recvbuf_, size_t max_size) {
368 if (max_size == size_read) return kSuccess;
369 char *p = static_cast<char*>(recvbuf_);
370 ssize_t len = sock.Recv(p + size_read, max_size - size_read);
371 // length equals 0, remote disconnected
372 if (len == 0) {
373 sock.Close(); return kRecvZeroLen;
374 }
375 if (len == -1) return Errno2Return();
376 size_read += static_cast<size_t>(len);
377 return kSuccess;
378 }
379 /*!
380 * \brief write data in array to sock
381 * \param sendbuf_ head of array
382 * \param max_size maximum size of array
383 * \return true if it is an successful write, false if there is some error happens, check errno
384 */
385 inline ReturnType WriteFromArray(const void *sendbuf_, size_t max_size) {
386 const char *p = static_cast<const char*>(sendbuf_);
387 ssize_t len = sock.Send(p + size_write, max_size - size_write);
388 if (len == -1) return Errno2Return();
389 size_write += static_cast<size_t>(len);
390 return kSuccess;
391 }
392
393 private:
394 // recv buffer to get data from child
395 // aligned with 64 bits, will be able to perform 64 bits operations freely
396 std::vector<uint64_t> buffer_;
397 };
398 /*!
399 * \brief simple data structure that works like a vector
400 * but takes reference instead of space
401 */
402 struct RefLinkVector {
403 std::vector<LinkRecord*> plinks;
404 inline LinkRecord &operator[](size_t i) {
405 return *plinks[i];
406 }
407 inline size_t size(void) const {
408 return plinks.size();
409 }
410 };
411 /*!
412 * \brief initialize connection to the tracker
413 * \return a socket that initializes the connection
414 */
415 utils::TCPSocket ConnectTracker(void) const;
416 /*!
417 * \brief connect to the tracker to fix the the missing links
418 * this function is also used when the engine start up
419 * \param cmd possible command to sent to tracker
420 */
421 bool ReConnectLinks(const char *cmd = "start");
422 /*!
423 * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
424 *
425 * NOTE on Allreduce:
426 * The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
427 * It only means the current node get the correct result of Allreduce.
428 * However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
429 *
430 * \param sendrecvbuf_ buffer for both sending and recving data
431 * \param type_nbytes the unit number of bytes the type have
432 * \param count number of elements to be reduced
433 * \param reducer reduce function
434 * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
435 * \sa ReturnType
436 */
437 ReturnType TryAllreduce(void *sendrecvbuf_,
438 size_t type_nbytes,
439 size_t count,
440 ReduceFunction reducer);
441 /*!
442 * \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
443 * \param sendrecvbuf_ buffer for both sending and receiving data
444 * \param size the size of the data to be broadcasted
445 * \param root the root worker id to broadcast the data
446 * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
447 * \sa ReturnType
448 */
449 ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
450 /*!
451 * \brief perform in-place allreduce, on sendrecvbuf,
452 * this function implements tree-shape reduction
453 *
454 * \param sendrecvbuf_ buffer for both sending and recving data
455 * \param type_nbytes the unit number of bytes the type have
456 * \param count number of elements to be reduced
457 * \param reducer reduce function
458 * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
459 * \sa ReturnType
460 */
461 ReturnType TryAllreduceTree(void *sendrecvbuf_,
462 size_t type_nbytes,
463 size_t count,
464 ReduceFunction reducer);
465 /*!
466 * \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
467 * the data provided by current node k is [slice_begin, slice_end),
468 * the next node's segment must start with slice_end
469 * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
470 * use a ring based algorithm
471 *
472 * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
473 * \param total_size total size of data to be gathered
474 * \param slice_begin beginning of the current slice
475 * \param slice_end end of the current slice
476 * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
477 * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
478 * \sa ReturnType
479 */
480 ReturnType TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
481 size_t slice_begin, size_t slice_end,
482 size_t size_prev_slice);
483 /*!
484 * \brief perform in-place allreduce, reduce on the sendrecvbuf,
485 *
486 * after the function, node k get k-th segment of the reduction result
487 * the k-th segment is defined by [k * step, min((k + 1) * step,count) )
488 * where step = ceil(count / world_size)
489 *
490 * \param sendrecvbuf_ buffer for both sending and recving data
491 * \param type_nbytes the unit number of bytes the type have
492 * \param count number of elements to be reduced
493 * \param reducer reduce function
494 * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
495 * \sa ReturnType, TryAllreduce
496 */
497 ReturnType TryReduceScatterRing(void *sendrecvbuf_,
498 size_t type_nbytes,
499 size_t count,
500 ReduceFunction reducer);
501 /*!
502 * \brief perform in-place allreduce, on sendrecvbuf
503 * use a ring based algorithm, reduce-scatter + allgather
504 *
505 * \param sendrecvbuf_ buffer for both sending and recving data
506 * \param type_nbytes the unit number of bytes the type have
507 * \param count number of elements to be reduced
508 * \param reducer reduce function
509 * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
510 * \sa ReturnType
511 */
512 ReturnType TryAllreduceRing(void *sendrecvbuf_,
513 size_t type_nbytes,
514 size_t count,
515 ReduceFunction reducer);
516 /*!
517 * \brief function used to report error when a link goes wrong
518 * \param link the pointer to the link who causes the error
519 * \param err the error type
520 */
521 inline ReturnType ReportError(LinkRecord *link, ReturnType err) {
522 err_link = link; return err;
523 }
524 //---- data structure related to model ----
525 // call sequence counter, records how many calls we made so far
526 // from last call to CheckPoint, LoadCheckPoint
527 int seq_counter;
528 // version number of model
529 int version_number;
530 // whether the job is running in hadoop
531 bool hadoop_mode;
532 //---- local data related to link ----
533 // index of parent link, can be -1, meaning this is root of the tree
534 int parent_index;
535 // rank of parent node, can be -1
536 int parent_rank;
537 // sockets of all links this connects to
538 std::vector<LinkRecord> all_links;
539 // used to record the link where things goes wrong
540 LinkRecord *err_link;
541 // all the links in the reduction tree connection
542 RefLinkVector tree_links;
543 // pointer to links in the ring
544 LinkRecord *ring_prev, *ring_next;
545 //----- meta information-----
546 // list of enviroment variables that are of possible interest
547 std::vector<std::string> env_vars;
548 // unique identifier of the possible job this process is doing
549 // used to assign ranks, optional, default to NULL
550 std::string task_id;
551 // uri of current host, to be set by Init
552 std::string host_uri;
553 // uri of tracker
554 std::string tracker_uri;
555 // role in dmlc jobs
556 std::string dmlc_role;
557 // port of tracker address
558 int tracker_port;
559 // port of slave process
560 int slave_port, nport_trial;
561 // reduce buffer size
562 size_t reduce_buffer_size;
563 // reduction method
564 int reduce_method;
565 // mininum count of cells to use ring based method
566 size_t reduce_ring_mincount;
567 // current rank
568 int rank;
569 // world size
570 int world_size;
571 // connect retry time
572 int connect_retry;
573 // enable bootstrap cache 0 false 1 true
574 bool rabit_bootstrap_cache = false;
575 // enable detailed logging
576 bool rabit_debug = false;
577 // by default, if rabit worker not recover in half an hour exit
578 int timeout_sec = 1800;
579 // flag to enable rabit_timeout
580 bool rabit_timeout = false;
581 // Enable TCP node delay
582 bool rabit_enable_tcp_no_delay = false;
583 };
584 } // namespace engine
585 } // namespace rabit
586 #endif // RABIT_ALLREDUCE_BASE_H_
0 /*!
1 * Copyright by Contributors
2 * \file allreduce_mock.h
3 * \brief Mock test module of AllReduce engine,
4 * insert failures in certain call point, to test if the engine is robust to failure
5 *
6 * \author Ignacio Cano, Tianqi Chen
7 */
8 #ifndef RABIT_ALLREDUCE_MOCK_H_
9 #define RABIT_ALLREDUCE_MOCK_H_
10 #include <vector>
11 #include <map>
12 #include <sstream>
13 #include "rabit/internal/engine.h"
14 #include "rabit/internal/timer.h"
15 #include "allreduce_robust.h"
16
17 namespace rabit {
18 namespace engine {
19 class AllreduceMock : public AllreduceRobust {
20 public:
21 // constructor
22 AllreduceMock(void) {
23 num_trial = 0;
24 force_local = 0;
25 report_stats = 0;
26 tsum_allreduce = 0.0;
27 tsum_allgather = 0.0;
28 }
29 // destructor
30 virtual ~AllreduceMock(void) {}
31 virtual void SetParam(const char *name, const char *val) {
32 AllreduceRobust::SetParam(name, val);
33 // additional parameters
34 if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val);
35 if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial = atoi(val);
36 if (!strcmp(name, "report_stats")) report_stats = atoi(val);
37 if (!strcmp(name, "force_local")) force_local = atoi(val);
38 if (!strcmp(name, "mock")) {
39 MockKey k;
40 utils::Check(sscanf(val, "%d,%d,%d,%d",
41 &k.rank, &k.version, &k.seqno, &k.ntrial) == 4,
42 "invalid mock parameter");
43 mock_map[k] = 1;
44 }
45 }
46 virtual void Allreduce(void *sendrecvbuf_,
47 size_t type_nbytes,
48 size_t count,
49 ReduceFunction reducer,
50 PreprocFunction prepare_fun,
51 void *prepare_arg,
52 const char* _file = _FILE,
53 const int _line = _LINE,
54 const char* _caller = _CALLER) {
55 this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
56 double tstart = utils::GetTime();
57 AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
58 count, reducer, prepare_fun, prepare_arg,
59 _file, _line, _caller);
60 tsum_allreduce += utils::GetTime() - tstart;
61 }
62 virtual void Allgather(void *sendrecvbuf,
63 size_t total_size,
64 size_t slice_begin,
65 size_t slice_end,
66 size_t size_prev_slice,
67 const char* _file = _FILE,
68 const int _line = _LINE,
69 const char* _caller = _CALLER) {
70 this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather");
71 double tstart = utils::GetTime();
72 AllreduceRobust::Allgather(sendrecvbuf, total_size,
73 slice_begin, slice_end,
74 size_prev_slice, _file, _line, _caller);
75 tsum_allgather += utils::GetTime() - tstart;
76 }
77 virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
78 const char* _file = _FILE,
79 const int _line = _LINE,
80 const char* _caller = _CALLER) {
81 this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
82 AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller);
83 }
84 virtual int LoadCheckPoint(Serializable *global_model,
85 Serializable *local_model) {
86 tsum_allreduce = 0.0;
87 tsum_allgather = 0.0;
88 time_checkpoint = utils::GetTime();
89 if (force_local == 0) {
90 return AllreduceRobust::LoadCheckPoint(global_model, local_model);
91 } else {
92 DummySerializer dum;
93 ComboSerializer com(global_model, local_model);
94 return AllreduceRobust::LoadCheckPoint(&dum, &com);
95 }
96 }
97 virtual void CheckPoint(const Serializable *global_model,
98 const Serializable *local_model) {
99 this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
100 double tstart = utils::GetTime();
101 double tbet_chkpt = tstart - time_checkpoint;
102 if (force_local == 0) {
103 AllreduceRobust::CheckPoint(global_model, local_model);
104 } else {
105 DummySerializer dum;
106 ComboSerializer com(global_model, local_model);
107 AllreduceRobust::CheckPoint(&dum, &com);
108 }
109 time_checkpoint = utils::GetTime();
110 double tcost = utils::GetTime() - tstart;
111 if (report_stats != 0 && rank == 0) {
112 std::stringstream ss;
113 ss << "[v" << version_number << "] global_size=" << global_checkpoint.length()
114 << ",local_size=" << (local_chkpt[0].length() + local_chkpt[1].length())
115 << ",check_tcost="<< tcost <<" sec"
116 << ",allreduce_tcost=" << tsum_allreduce << " sec"
117 << ",allgather_tcost=" << tsum_allgather << " sec"
118 << ",between_chpt=" << tbet_chkpt << "sec\n";
119 this->TrackerPrint(ss.str());
120 }
121 tsum_allreduce = 0.0;
122 tsum_allgather = 0.0;
123 }
124
125 virtual void LazyCheckPoint(const Serializable *global_model) {
126 this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
127 AllreduceRobust::LazyCheckPoint(global_model);
128 }
129
130 protected:
131 // force checkpoint to local
132 int force_local;
133 // whether report statistics
134 int report_stats;
135 // sum of allreduce
136 double tsum_allreduce;
137 // sum of allgather
138 double tsum_allgather;
139 double time_checkpoint;
140
141 private:
142 struct DummySerializer : public Serializable {
143 virtual void Load(Stream *fi) {
144 }
145 virtual void Save(Stream *fo) const {
146 }
147 };
148 struct ComboSerializer : public Serializable {
149 Serializable *lhs;
150 Serializable *rhs;
151 const Serializable *c_lhs;
152 const Serializable *c_rhs;
153 ComboSerializer(Serializable *lhs, Serializable *rhs)
154 : lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) {
155 }
156 ComboSerializer(const Serializable *lhs, const Serializable *rhs)
157 : lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) {
158 }
159 virtual void Load(Stream *fi) {
160 if (lhs != NULL) lhs->Load(fi);
161 if (rhs != NULL) rhs->Load(fi);
162 }
163 virtual void Save(Stream *fo) const {
164 if (c_lhs != NULL) c_lhs->Save(fo);
165 if (c_rhs != NULL) c_rhs->Save(fo);
166 }
167 };
168 // key to identify the mock stage
169 struct MockKey {
170 int rank;
171 int version;
172 int seqno;
173 int ntrial;
174 MockKey(void) {}
175 MockKey(int rank, int version, int seqno, int ntrial)
176 : rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
177 inline bool operator==(const MockKey &b) const {
178 return rank == b.rank &&
179 version == b.version &&
180 seqno == b.seqno &&
181 ntrial == b.ntrial;
182 }
183 inline bool operator<(const MockKey &b) const {
184 if (rank != b.rank) return rank < b.rank;
185 if (version != b.version) return version < b.version;
186 if (seqno != b.seqno) return seqno < b.seqno;
187 return ntrial < b.ntrial;
188 }
189 };
190 // number of failure trials
191 int num_trial;
192 // record all mock actions
193 std::map<MockKey, int> mock_map;
194 // used to generate all kinds of exceptions
195 inline void Verify(const MockKey &key, const char *name) {
196 if (mock_map.count(key) != 0) {
197 num_trial += 1;
198 // data processing frameworks runs on shared process
199 _error("[%d]@@@Hit Mock Error:%s ", rank, name);
200 }
201 }
202 };
203 } // namespace engine
204 } // namespace rabit
205 #endif // RABIT_ALLREDUCE_MOCK_H_
0 /*!
1 * Copyright (c) 2014 by Contributors
2 * \file allreduce_robust-inl.h
3 * \brief implementation of inline template function in AllreduceRobust
4 *
5 * \author Tianqi Chen
6 */
7 #ifndef RABIT_ALLREDUCE_ROBUST_INL_H_
8 #define RABIT_ALLREDUCE_ROBUST_INL_H_
9 #include <vector>
10
11 namespace rabit {
12 namespace engine {
13 /*!
14 * \brief run message passing algorithm on the allreduce tree
15 * the result is edge message stored in p_edge_in and p_edge_out
16 * \param node_value the value associated with current node
17 * \param p_edge_in used to store input message from each of the edge
18 * \param p_edge_out used to store output message from each of the edge
19 * \param func a function that defines the message passing rule
20 * Parameters of func:
21 * - node_value same as node_value in the main function
22 * - edge_in the array of input messages from each edge,
23 * this includes the output edge, which should be excluded
24 * - out_index array the index of output edge, the function should
25 * exclude the output edge when compute the message passing value
26 * Return of func:
27 * the function returns the output message based on the input message and node_value
28 *
29 * \tparam EdgeType type of edge message, must be simple struct
30 * \tparam NodeType type of node value
31 */
32 template<typename NodeType, typename EdgeType>
33 inline AllreduceRobust::ReturnType
34 AllreduceRobust::MsgPassing(const NodeType &node_value,
35 std::vector<EdgeType> *p_edge_in,
36 std::vector<EdgeType> *p_edge_out,
37 EdgeType(*func)
38 (const NodeType &node_value,
39 const std::vector<EdgeType> &edge_in,
40 size_t out_index)) {
41 RefLinkVector &links = tree_links;
42 if (links.size() == 0) return kSuccess;
43 // number of links
44 const int nlink = static_cast<int>(links.size());
45 // initialize the pointers
46 for (int i = 0; i < nlink; ++i) {
47 links[i].ResetSize();
48 }
49 std::vector<EdgeType> &edge_in = *p_edge_in;
50 std::vector<EdgeType> &edge_out = *p_edge_out;
51 edge_in.resize(nlink);
52 edge_out.resize(nlink);
53 // stages in the process
54 // 0: recv messages from childs
55 // 1: send message to parent
56 // 2: recv message from parent
57 // 3: send message to childs
58 int stage = 0;
59 // if no childs, no need to, directly start passing message
60 if (nlink == static_cast<int>(parent_index != -1)) {
61 utils::Assert(parent_index == 0, "parent must be 0");
62 edge_out[parent_index] = func(node_value, edge_in, parent_index);
63 stage = 1;
64 }
65 // while we have not passed the messages out
66 while (true) {
67 // for node with no parent, directly do stage 3
68 if (parent_index == -1) {
69 utils::Assert(stage != 2 && stage != 1, "invalie stage id");
70 }
71 // poll helper
72 utils::PollHelper watcher;
73 bool done = (stage == 3);
74 for (int i = 0; i < nlink; ++i) {
75 watcher.WatchException(links[i].sock);
76 switch (stage) {
77 case 0:
78 if (i != parent_index && links[i].size_read != sizeof(EdgeType)) {
79 watcher.WatchRead(links[i].sock);
80 }
81 break;
82 case 1:
83 if (i == parent_index) {
84 watcher.WatchWrite(links[i].sock);
85 }
86 break;
87 case 2:
88 if (i == parent_index) {
89 watcher.WatchRead(links[i].sock);
90 }
91 break;
92 case 3:
93 if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
94 watcher.WatchWrite(links[i].sock);
95 done = false;
96 }
97 break;
98 default: utils::Error("invalid stage");
99 }
100 }
101 // finish all the stages, and write out message
102 if (done) break;
103 watcher.Poll();
104 // exception handling
105 for (int i = 0; i < nlink; ++i) {
106 // recive OOB message from some link
107 if (watcher.CheckExcept(links[i].sock)) {
108 return ReportError(&links[i], kGetExcept);
109 }
110 }
111 if (stage == 0) {
112 bool finished = true;
113 // read data from childs
114 for (int i = 0; i < nlink; ++i) {
115 if (i != parent_index) {
116 if (watcher.CheckRead(links[i].sock)) {
117 ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType));
118 if (ret != kSuccess) return ReportError(&links[i], ret);
119 }
120 if (links[i].size_read != sizeof(EdgeType)) finished = false;
121 }
122 }
123 // if no parent, jump to stage 3, otherwise do stage 1
124 if (finished) {
125 if (parent_index != -1) {
126 edge_out[parent_index] = func(node_value, edge_in, parent_index);
127 stage = 1;
128 } else {
129 for (int i = 0; i < nlink; ++i) {
130 edge_out[i] = func(node_value, edge_in, i);
131 }
132 stage = 3;
133 }
134 }
135 }
136 if (stage == 1) {
137 const int pid = this->parent_index;
138 utils::Assert(pid != -1, "MsgPassing invalid stage");
139 ReturnType ret = links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType));
140 if (ret != kSuccess) return ReportError(&links[pid], ret);
141 if (links[pid].size_write == sizeof(EdgeType)) stage = 2;
142 }
143 if (stage == 2) {
144 const int pid = this->parent_index;
145 utils::Assert(pid != -1, "MsgPassing invalid stage");
146 ReturnType ret = links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType));
147 if (ret != kSuccess) return ReportError(&links[pid], ret);
148 if (links[pid].size_read == sizeof(EdgeType)) {
149 for (int i = 0; i < nlink; ++i) {
150 if (i != pid) edge_out[i] = func(node_value, edge_in, i);
151 }
152 stage = 3;
153 }
154 }
155 if (stage == 3) {
156 for (int i = 0; i < nlink; ++i) {
157 if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
158 ReturnType ret = links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType));
159 if (ret != kSuccess) return ReportError(&links[i], ret);
160 }
161 }
162 }
163 }
164 return kSuccess;
165 }
166 } // namespace engine
167 } // namespace rabit
168 #endif // RABIT_ALLREDUCE_ROBUST_INL_H_
0 /*!
1 * Copyright (c) 2014-2019 by Contributors
2 * \file allreduce_robust.cc
3 * \brief Robust implementation of Allreduce
4 *
5 * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
6 */
7 #define _CRT_SECURE_NO_WARNINGS
8 #define _CRT_SECURE_NO_DEPRECATE
9 #define NOMINMAX
10 #include <chrono>
11 #include <thread>
12 #include <limits>
13 #include <utility>
14 #include "rabit/internal/io.h"
15 #include "rabit/internal/timer.h"
16 #include "rabit/internal/utils.h"
17 #include "rabit/internal/engine.h"
18 #include "rabit/internal/rabit-inl.h"
19 #include "allreduce_robust.h"
20
21 #undef _assert
22
23 namespace rabit {
24 namespace engine {
25
26 AllreduceRobust::AllreduceRobust(void) {
27 num_local_replica = 0;
28 num_global_replica = 5;
29 default_local_replica = 2;
30 seq_counter = 0;
31 cur_cache_seq = 0;
32 local_chkpt_version = 0;
33 result_buffer_round = 1;
34 global_lazycheck = NULL;
35 use_local_model = -1;
36 recover_counter = 0;
37 checkpoint_loaded = false;
38 env_vars.push_back("rabit_global_replica");
39 env_vars.push_back("rabit_local_replica");
40 }
41 bool AllreduceRobust::Init(int argc, char* argv[]) {
42 if (AllreduceBase::Init(argc, argv)) {
43 // chenqin: alert user opted in experimental feature.
44 if (rabit_bootstrap_cache) utils::HandleLogInfo(
45 "[EXPERIMENTAL] bootstrap cache has been enabled\n");
46 checkpoint_loaded = false;
47 if (num_global_replica == 0) {
48 result_buffer_round = -1;
49 } else {
50 result_buffer_round = std::max(world_size / num_global_replica, 1);
51 }
52 return true;
53 } else {
54 return false;
55 }
56 }
57 /*! \brief shutdown the engine */
58 bool AllreduceRobust::Shutdown(void) {
59 try {
60 // need to sync the exec before we shutdown, do a pesudo check point
61 // execute checkpoint, note: when checkpoint existing, load will not happen
62 _assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp,
63 cur_cache_seq), "Shutdown: check point must return true");
64 // reset result buffer
65 resbuf.Clear(); seq_counter = 0;
66 cachebuf.Clear(); cur_cache_seq = 0;
67 lookupbuf.Clear();
68 // execute check ack step, load happens here
69 _assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
70 ActionSummary::kSpecialOp, cur_cache_seq), "Shutdown: check ack must return true");
71 // travis ci only osx test hang
72 #if defined (__APPLE__)
73 sleep(1);
74 #endif
75 shutdown_timeout = true;
76 if (rabit_timeout_task.valid()) {
77 rabit_timeout_task.wait();
78 _assert(rabit_timeout_task.get(), "expect timeout task return\n");
79 }
80 return AllreduceBase::Shutdown();
81 } catch (const std::exception& e) {
82 fprintf(stderr, "%s\n", e.what());
83 return false;
84 }
85 }
86
87 /*!
88 * \brief set parameters to the engine
89 * \param name parameter name
90 * \param val parameter value
91 */
92 void AllreduceRobust::SetParam(const char *name, const char *val) {
93 AllreduceBase::SetParam(name, val);
94 if (!strcmp(name, "rabit_global_replica")) num_global_replica = atoi(val);
95 if (!strcmp(name, "rabit_local_replica")) {
96 num_local_replica = atoi(val);
97 }
98 }
99
100 int AllreduceRobust::SetBootstrapCache(const std::string &key, const void *buf,
101 const size_t type_nbytes, const size_t count) {
102 int index = -1;
103 for (int i = 0 ; i < cur_cache_seq; i++) {
104 size_t nsize = 0;
105 void* name = lookupbuf.Query(i, &nsize);
106 if (nsize == key.length() + 1
107 && strcmp(static_cast<const char*>(name), key.c_str()) == 0) {
108 index = i;
109 break;
110 }
111 }
112 // we should consider way to support duplicated signatures
113 // https://github.com/dmlc/xgboost/issues/5012
114 // _assert(index == -1, "immutable cache key already exists");
115 _assert(type_nbytes*count > 0, "can't set empty cache");
116 void* temp = cachebuf.AllocTemp(type_nbytes, count);
117 cachebuf.PushTemp(cur_cache_seq, type_nbytes, count);
118 std::memcpy(temp, buf, type_nbytes*count);
119
120 std::string k(key);
121 void* name = lookupbuf.AllocTemp(strlen(k.c_str()) + 1, 1);
122 lookupbuf.PushTemp(cur_cache_seq, strlen(k.c_str()) + 1, 1);
123 std::memcpy(name, key.c_str(), strlen(k.c_str()) + 1);
124 cur_cache_seq += 1;
125 return 0;
126 }
127
128 int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
129 const size_t type_nbytes, const size_t count) {
130 // as requester sync with rest of nodes on latest cache content
131 if (!RecoverExec(NULL, 0, ActionSummary::kLoadBootstrapCache,
132 seq_counter, cur_cache_seq)) return -1;
133
134 int index = -1;
135 for (int i = 0 ; i < cur_cache_seq; i++) {
136 size_t nsize = 0;
137 void* name = lookupbuf.Query(i, &nsize);
138 if (nsize == strlen(key.c_str()) + 1
139 && strcmp(reinterpret_cast<char*>(name), key.c_str()) == 0) {
140 index = i;
141 break;
142 }
143 }
144 // cache doesn't exists
145 if (index == -1) return -1;
146
147 size_t siz = 0;
148 void* temp = cachebuf.Query(index, &siz);
149 utils::Assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index");
150 utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested");
151 utils::Assert(siz > 0, "cache size should be greater than 0");
152 std::memcpy(buf, temp, type_nbytes*count);
153 return 0;
154 }
155
156 /*!
157 * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
158 * the data provided by current node k is [slice_begin, slice_end),
159 * the next node's segment must start with slice_end
160 * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
161 * use a ring based algorithm
162 *
163 * \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
164 * \param total_size total size of data to be gathered
165 * \param slice_begin beginning of the current slice
166 * \param slice_end end of the current slice
167 * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
168 * \param _file caller file name used to generate unique cache key
169 * \param _line caller line number used to generate unique cache key
170 * \param _caller caller function name used to generate unique cache key
171 */
172 void AllreduceRobust::Allgather(void *sendrecvbuf,
173 size_t total_size,
174 size_t slice_begin,
175 size_t slice_end,
176 size_t size_prev_slice,
177 const char* _file,
178 const int _line,
179 const char* _caller) {
180 if (world_size == 1 || world_size == -1) return;
181 // genreate unique allgather signature
182 std::string key = std::string(_file) + "::" + std::to_string(_line) + "::"
183 + std::string(_caller) + "#" +std::to_string(total_size);
184
185 // try fetch bootstrap allgather results from cache
186 if (!checkpoint_loaded && rabit_bootstrap_cache &&
187 GetBootstrapCache(key, sendrecvbuf, total_size, 1) != -1) return;
188
189 double start = utils::GetTime();
190 bool recovered = RecoverExec(sendrecvbuf, total_size, 0, seq_counter, cur_cache_seq);
191
192 if (resbuf.LastSeqNo() != -1 &&
193 (result_buffer_round == -1 ||
194 resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
195 resbuf.DropLast();
196 }
197
198 void *temp = resbuf.AllocTemp(total_size, 1);
199 while (true) {
200 if (recovered) {
201 std::memcpy(temp, sendrecvbuf, total_size); break;
202 } else {
203 std::memcpy(temp, sendrecvbuf, total_size);
204 if (CheckAndRecover(TryAllgatherRing(temp, total_size,
205 slice_begin, slice_end, size_prev_slice))) {
206 std::memcpy(sendrecvbuf, temp, total_size); break;
207 } else {
208 recovered = RecoverExec(sendrecvbuf, total_size, 0, seq_counter, cur_cache_seq);
209 }
210 }
211 }
212 double delta = utils::GetTime() - start;
213 // log allgather latency
214 if (rabit_debug) {
215 utils::HandleLogInfo("[%d] allgather (%s) finished version %d, seq %d, take %f seconds\n",
216 rank, key.c_str(), version_number, seq_counter, delta);
217 }
218
219 // if bootstrap allgather, store and fetch through cache
220 if (checkpoint_loaded || !rabit_bootstrap_cache) {
221 resbuf.PushTemp(seq_counter, total_size, 1);
222 seq_counter += 1;
223 } else {
224 SetBootstrapCache(key, sendrecvbuf, total_size, 1);
225 }
226 }
227
228 /*!
229 * \brief perform in-place allreduce, on sendrecvbuf
230 * this function is NOT thread-safe
231 * \param sendrecvbuf_ buffer for both sending and recving data
232 * \param type_nbytes the unit number of bytes the type have
233 * \param count number of elements to be reduced
234 * \param reducer reduce function
235 * \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
236 * will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
237 * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
238 * \param prepare_arg argument used to passed into the lazy preprocessing function
239 * \param _file caller file name used to generate unique cache key
240 * \param _line caller line number used to generate unique cache key
241 * \param _caller caller function name used to generate unique cache key
242 */
243 void AllreduceRobust::Allreduce(void *sendrecvbuf_,
244 size_t type_nbytes,
245 size_t count,
246 ReduceFunction reducer,
247 PreprocFunction prepare_fun,
248 void *prepare_arg,
249 const char* _file,
250 const int _line,
251 const char* _caller) {
252 // skip action in single node
253 if (world_size == 1 || world_size == -1) {
254 if (prepare_fun != NULL) prepare_fun(prepare_arg);
255 return;
256 }
257
258 // genreate unique allreduce signature
259 std::string key = std::string(_file) + "::" + std::to_string(_line) + "::"
260 + std::string(_caller) + "#" +std::to_string(type_nbytes) + "x" + std::to_string(count);
261
262 // try fetch bootstrap allreduce results from cache
263 if (!checkpoint_loaded && rabit_bootstrap_cache &&
264 GetBootstrapCache(key, sendrecvbuf_, type_nbytes, count) != -1) return;
265
266 double start = utils::GetTime();
267 bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter, cur_cache_seq);
268
269 if (resbuf.LastSeqNo() != -1 &&
270 (result_buffer_round == -1 ||
271 resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
272 resbuf.DropLast();
273 }
274
275 if (!recovered && prepare_fun != NULL) prepare_fun(prepare_arg);
276 void *temp = resbuf.AllocTemp(type_nbytes, count);
277 while (true) {
278 if (recovered) {
279 std::memcpy(temp, sendrecvbuf_, type_nbytes * count); break;
280 } else {
281 std::memcpy(temp, sendrecvbuf_, type_nbytes * count);
282 if (CheckAndRecover(TryAllreduce(temp, type_nbytes, count, reducer))) {
283 std::memcpy(sendrecvbuf_, temp, type_nbytes * count); break;
284 } else {
285 recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter, cur_cache_seq);
286 }
287 }
288 }
289 double delta = utils::GetTime() - start;
290 // log allreduce latency
291 if (rabit_debug) {
292 utils::HandleLogInfo("[%d] allreduce (%s) finished version %d, seq %d, take %f seconds\n",
293 rank, key.c_str(), version_number, seq_counter, delta);
294 }
295
296 // if bootstrap allreduce, store and fetch through cache
297 if (checkpoint_loaded || !rabit_bootstrap_cache) {
298 resbuf.PushTemp(seq_counter, type_nbytes, count);
299 seq_counter += 1;
300 } else {
301 SetBootstrapCache(key, sendrecvbuf_, type_nbytes, count);
302 }
303 }
304 /*!
305 * \brief broadcast data from root to all nodes
306 * \param sendrecvbuf_ buffer for both sending and recving data
307 * \param size the size of the data to be broadcasted
308 * \param root the root worker id to broadcast the data
309 * \param _file caller file name used to generate unique cache key
310 * \param _line caller line number used to generate unique cache key
311 * \param _caller caller function name used to generate unique cache key
312 */
313 void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root,
314 const char* _file,
315 const int _line,
316 const char* _caller) {
317 // skip action in single node
318 if (world_size == 1 || world_size == -1) return;
319 // genreate unique cache signature
320 std::string key = std::string(_file) + "::" + std::to_string(_line) + "::"
321 + std::string(_caller) + "#" +std::to_string(total_size) + "@" + std::to_string(root);
322 // try fetch bootstrap allreduce results from cache
323 if (!checkpoint_loaded && rabit_bootstrap_cache &&
324 GetBootstrapCache(key, sendrecvbuf_, total_size, 1) != -1) return;
325 double start = utils::GetTime();
326 bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter, cur_cache_seq);
327 // now we are free to remove the last result, if any
328 if (resbuf.LastSeqNo() != -1 &&
329 (result_buffer_round == -1 ||
330 resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
331 resbuf.DropLast();
332 }
333 void *temp = resbuf.AllocTemp(1, total_size);
334 while (true) {
335 if (recovered) {
336 std::memcpy(temp, sendrecvbuf_, total_size); break;
337 } else {
338 if (CheckAndRecover(TryBroadcast(sendrecvbuf_, total_size, root))) {
339 std::memcpy(temp, sendrecvbuf_, total_size); break;
340 } else {
341 recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter, cur_cache_seq);
342 }
343 }
344 }
345
346 double delta = utils::GetTime() - start;
347 // log broadcast latency
348 if (rabit_debug) {
349 utils::HandleLogInfo(
350 "[%d] broadcast (%s) root %d finished version %d,seq %d, take %f seconds\n",
351 rank, key.c_str(), root, version_number, seq_counter, delta);
352 }
353 // if bootstrap broadcast, store and fetch through cache
354 if (checkpoint_loaded || !rabit_bootstrap_cache) {
355 resbuf.PushTemp(seq_counter, 1, total_size);
356 seq_counter += 1;
357 } else {
358 SetBootstrapCache(key, sendrecvbuf_, total_size, 1);
359 }
360 }
361 /*!
362 * \brief load latest check point
363 * \param global_model pointer to the globally shared model/state
364 * when calling this function, the caller need to gauranttees that global_model
365 * is the same in all nodes
366 * \param local_model pointer to local model, that is specific to current node/rank
367 * this can be NULL when no local model is needed
368 *
369 * \return the version number of check point loaded
370 * if returned version == 0, this means no model has been CheckPointed
371 * the p_model is not touched, user should do necessary initialization by themselves
372 *
373 * Common usage example:
374 * int iter = rabit::LoadCheckPoint(&model);
375 * if (iter == 0) model.InitParameters();
376 * for (i = iter; i < max_iter; ++i) {
377 * do many things, include allreduce
378 * rabit::CheckPoint(model);
379 * }
380 *
381 * \sa CheckPoint, VersionNumber
382 */
383 int AllreduceRobust::LoadCheckPoint(Serializable *global_model,
384 Serializable *local_model) {
385 checkpoint_loaded = true;
386 // skip action in single node
387 if (world_size == 1) return 0;
388 this->LocalModelCheck(local_model != NULL);
389 if (num_local_replica == 0) {
390 utils::Check(local_model == NULL,
391 "need to set rabit_local_replica larger than 1 to checkpoint local_model");
392 }
393 double start = utils::GetTime();
394 // check if we succeed
395 if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp, cur_cache_seq)) {
396 int nlocal = std::max(static_cast<int>(local_rptr[local_chkpt_version].size()) - 1, 0);
397 if (local_model != NULL) {
398 if (nlocal == num_local_replica + 1) {
399 // load in local model
400 utils::MemoryFixSizeBuffer fs(BeginPtr(local_chkpt[local_chkpt_version]),
401 local_rptr[local_chkpt_version][1]);
402 local_model->Load(&fs);
403 } else {
404 _assert(nlocal == 0, "[%d] local model inconsistent, nlocal=%d", rank, nlocal);
405 }
406 }
407 // reset result buffer
408 resbuf.Clear(); seq_counter = 0;
409 // load from buffer
410 utils::MemoryBufferStream fs(&global_checkpoint);
411 if (global_checkpoint.length() == 0) {
412 version_number = 0;
413 } else {
414 _assert(fs.Read(&version_number, sizeof(version_number)) != 0,
415 "read in version number");
416 global_model->Load(&fs);
417 _assert(local_model == NULL || nlocal == num_local_replica + 1,
418 "local model inconsistent, nlocal=%d", nlocal);
419 }
420 // run another phase of check ack, if recovered from data
421 _assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
422 ActionSummary::kSpecialOp, cur_cache_seq), "check ack must return true");
423
424 if (!RecoverExec(NULL, 0, ActionSummary::kLoadBootstrapCache, seq_counter, cur_cache_seq)) {
425 utils::Printf("no need to load cache\n");
426 }
427 double delta = utils::GetTime() - start;
428
429 // log broadcast latency
430 if (rabit_debug) {
431 utils::HandleLogInfo("[%d] loadcheckpoint size %ld finished version %d, "
432 "seq %d, take %f seconds\n",
433 rank, global_checkpoint.length(),
434 version_number, seq_counter, delta);
435 }
436 return version_number;
437 } else {
438 // log job fresh start
439 if (rabit_debug) utils::HandleLogInfo("[%d] loadcheckpoint reset\n", rank);
440
441 // reset result buffer
442 resbuf.Clear(); seq_counter = 0; version_number = 0;
443 // nothing loaded, a fresh start, everyone init model
444 return version_number;
445 }
446 }
447 /*!
448 * \brief internal consistency check function,
449 * use check to ensure user always call CheckPoint/LoadCheckPoint
450 * with or without local but not both, this function will set the approperiate settings
451 * in the first call of LoadCheckPoint/CheckPoint
452 *
453 * \param with_local whether the user calls CheckPoint with local model
454 */
455 void AllreduceRobust::LocalModelCheck(bool with_local) {
456 if (use_local_model == -1) {
457 if (with_local) {
458 use_local_model = 1;
459 if (num_local_replica == 0) {
460 num_local_replica = default_local_replica;
461 }
462 } else {
463 use_local_model = 0;
464 num_local_replica = 0;
465 }
466 } else {
467 utils::Check(use_local_model == static_cast<int>(with_local),
468 "Can only call Checkpoint/LoadCheckPoint always with"\
469 "or without local_model, but not mixed case");
470 }
471 }
472 /*!
473 * \brief internal implementation of checkpoint, support both lazy and normal way
474 *
475 * \param global_model pointer to the globally shared model/state
476 * when calling this function, the caller need to gauranttees that global_model
477 * is the same in all nodes
478 * \param local_model pointer to local model, that is specific to current node/rank
479 * this can be NULL when no local state is needed
480 * \param lazy_checkpt whether the action is lazy checkpoint
481 *
482 * \sa CheckPoint, LazyCheckPoint
483 */
484 void AllreduceRobust::CheckPoint_(const Serializable *global_model,
485 const Serializable *local_model,
486 bool lazy_checkpt) {
487 // never do check point in single machine mode
488 if (world_size == 1) {
489 version_number += 1; return;
490 }
491 double start = utils::GetTime();
492 this->LocalModelCheck(local_model != NULL);
493 if (num_local_replica == 0) {
494 utils::Check(local_model == NULL,
495 "need to set rabit_local_replica larger than 1 to checkpoint local_model");
496 }
497 if (num_local_replica != 0) {
498 while (true) {
499 if (RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckPoint)) break;
500 // save model to new version place
501 int new_version = !local_chkpt_version;
502
503 local_chkpt[new_version].clear();
504 utils::MemoryBufferStream fs(&local_chkpt[new_version]);
505 if (local_model != NULL) {
506 local_model->Save(&fs);
507 }
508 local_rptr[new_version].clear();
509 local_rptr[new_version].push_back(0);
510 local_rptr[new_version].push_back(local_chkpt[new_version].length());
511 if (CheckAndRecover(TryCheckinLocalState(&local_rptr[new_version],
512 &local_chkpt[new_version]))) break;
513 }
514 // run the ack phase, can be true or false
515 RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckAck);
516 // switch pointer to new version
517 local_chkpt_version = !local_chkpt_version;
518 }
519 // execute checkpoint, note: when checkpoint existing, load will not happen
520 _assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint,
521 ActionSummary::kSpecialOp, cur_cache_seq),
522 "check point must return true");
523 // this is the critical region where we will change all the stored models
524 // increase version number
525 version_number += 1;
526 // save model
527 if (lazy_checkpt) {
528 global_lazycheck = global_model;
529 } else {
530 global_checkpoint.resize(0);
531 utils::MemoryBufferStream fs(&global_checkpoint);
532 fs.Write(&version_number, sizeof(version_number));
533 global_model->Save(&fs);
534 global_lazycheck = NULL;
535 }
536 double delta = utils::GetTime() - start;
537 // log checkpoint latency
538 if (rabit_debug) {
539 utils::HandleLogInfo(
540 "[%d] checkpoint finished version %d,seq %d, take %f seconds\n",
541 rank, version_number, seq_counter, delta);
542 }
543 start = utils::GetTime();
544 // reset result buffer, mark boostrap phase complete
545 resbuf.Clear(); seq_counter = 0;
546 // execute check ack step, load happens here
547 _assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
548 ActionSummary::kSpecialOp, cur_cache_seq), "check ack must return true");
549
550 delta = utils::GetTime() - start;
551 // log checkpoint ack latency
552 if (rabit_debug) {
553 utils::HandleLogInfo("[%d] checkpoint ack finished version %d, take %f seconds\n",
554 rank, version_number, delta);
555 }
556 }
557 /*!
558 * \brief reset the all the existing links by sending Out-of-Band message marker
559 * after this function finishes, all the messages received and sent before in all live links are discarded,
560 * This allows us to get a fresh start after error has happened
561 *
562 * \return this function can return kSuccess or kSockError
563 * when kSockError is returned, it simply means there are bad sockets in the links,
564 * and some link recovery proceduer is needed
565 */
566 AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
567 // number of links
568 const int nlink = static_cast<int>(all_links.size());
569 for (int i = 0; i < nlink; ++i) {
570 all_links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size);
571 all_links[i].ResetSize();
572 }
573 // read and discard data from all channels until pass mark
574 while (true) {
575 for (int i = 0; i < nlink; ++i) {
576 if (all_links[i].sock.BadSocket()) continue;
577 if (all_links[i].size_write == 0) {
578 char sig = kOOBReset;
579 ssize_t len = all_links[i].sock.Send(&sig, sizeof(sig), MSG_OOB);
580 // error will be filtered in next loop
581 if (len == sizeof(sig)) all_links[i].size_write = 1;
582 }
583 if (all_links[i].size_write == 1) {
584 char sig = kResetMark;
585 ssize_t len = all_links[i].sock.Send(&sig, sizeof(sig));
586 if (len == sizeof(sig)) all_links[i].size_write = 2;
587 }
588 }
589 utils::PollHelper rsel;
590 bool finished = true;
591 for (int i = 0; i < nlink; ++i) {
592 if (all_links[i].size_write != 2 && !all_links[i].sock.BadSocket()) {
593 rsel.WatchWrite(all_links[i].sock); finished = false;
594 }
595 }
596 if (finished) break;
597 // wait to read from the channels to discard data
598 rsel.Poll();
599 }
600 for (int i = 0; i < nlink; ++i) {
601 if (!all_links[i].sock.BadSocket()) {
602 utils::PollHelper::WaitExcept(all_links[i].sock);
603 }
604 }
605 while (true) {
606 utils::PollHelper rsel;
607 bool finished = true;
608 for (int i = 0; i < nlink; ++i) {
609 if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) {
610 rsel.WatchRead(all_links[i].sock); finished = false;
611 }
612 }
613 if (finished) break;
614 rsel.Poll();
615 for (int i = 0; i < nlink; ++i) {
616 if (all_links[i].sock.BadSocket()) continue;
617 if (all_links[i].size_read == 0) {
618 int atmark = all_links[i].sock.AtMark();
619 if (atmark < 0) {
620 _assert(all_links[i].sock.BadSocket(), "must already gone bad");
621 } else if (atmark > 0) {
622 all_links[i].size_read = 1;
623 } else {
624 // no at mark, read and discard data
625 ssize_t len = all_links[i].sock.Recv(all_links[i].buffer_head, all_links[i].buffer_size);
626 if (all_links[i].sock.AtMark()) all_links[i].size_read = 1;
627 // zero length, remote closed the connection, close socket
628 if (len == 0) all_links[i].sock.Close();
629 }
630 }
631 }
632 }
633 // start synchronization, use blocking I/O to avoid select
634 for (int i = 0; i < nlink; ++i) {
635 if (!all_links[i].sock.BadSocket()) {
636 char oob_mark;
637 all_links[i].sock.SetNonBlock(false);
638 ssize_t len = all_links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL);
639 if (len == 0) {
640 all_links[i].sock.Close(); continue;
641 } else if (len > 0) {
642 _assert(oob_mark == kResetMark, "wrong oob msg");
643 _assert(all_links[i].sock.AtMark() != 1, "should already read past mark");
644 } else {
645 _assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
646 }
647 // send out ack
648 char ack = kResetAck;
649 while (true) {
650 len = all_links[i].sock.Send(&ack, sizeof(ack));
651 if (len == sizeof(ack)) break;
652 if (len == -1) {
653 if (errno != EAGAIN && errno != EWOULDBLOCK) break;
654 }
655 }
656 }
657 }
658 // wait all ack
659 for (int i = 0; i < nlink; ++i) {
660 if (!all_links[i].sock.BadSocket()) {
661 char ack;
662 ssize_t len = all_links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL);
663 if (len == 0) {
664 all_links[i].sock.Close(); continue;
665 } else if (len > 0) {
666 _assert(ack == kResetAck, "wrong Ack MSG");
667 } else {
668 _assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
669 }
670 // set back to nonblock mode
671 all_links[i].sock.SetNonBlock(true);
672 }
673 }
674 for (int i = 0; i < nlink; ++i) {
675 if (all_links[i].sock.BadSocket()) return kSockError;
676 }
677 return kSuccess;
678 }
679 /*!
680 * \brief if err_type indicates an error
681 * recover links according to the error type reported
682 * if there is no error, return true
683 * \param err_type the type of error happening in the system
684 * \return true if err_type is kSuccess, false otherwise
685 */
686 bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
687 shutdown_timeout = err_type == kSuccess;
688 if (err_type == kSuccess) return true;
689
690 _assert(err_link != NULL, "must know the error link");
691 recover_counter += 1;
692 // async launch timeout task if enable_rabit_timeout is set
693 if (rabit_timeout && !rabit_timeout_task.valid()) {
694 utils::Printf("[EXPERIMENTAL] timeout thread expires in %d second(s)\n", timeout_sec);
695 rabit_timeout_task = std::async(std::launch::async, [=]() {
696 if (rabit_debug) {
697 utils::Printf("[%d] timeout thread %ld starts\n", rank,
698 std::this_thread::get_id());
699 }
700 int time = 0;
701 // check if rabit recovered every 100ms
702 while (time++ < 10 * timeout_sec) {
703 std::this_thread::sleep_for(std::chrono::milliseconds(100));
704 if (shutdown_timeout.load()) {
705 if (rabit_debug) {
706 utils::Printf("[%d] timeout task thread %ld exits\n",
707 rank, std::this_thread::get_id());
708 }
709 return true;
710 }
711 }
712 _error("[%d] exit due to time out %d s\n", rank, timeout_sec);
713 return false;
714 });
715 }
716 // simple way, shutdown all links
717 for (size_t i = 0; i < all_links.size(); ++i) {
718 if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
719 }
720 // smooth out traffic to tracker
721 std::this_thread::sleep_for(std::chrono::milliseconds(10*rank));
722 ReConnectLinks("recover");
723 return false;
724 }
725 /*!
726 * \brief message passing function, used to decide the
727 * shortest distance to the possible source of data
728 * \param node_value a pair of have_data and size
729 * have_data whether current node have data
730 * size gives the size of data, if current node is kHaveData
731 * \param dist_in the shorest to any data source distance in each direction
732 * \param out_index the edge index of output link
733 * \return the shorest distance result of out edge specified by out_index
734 */
735 inline std::pair<int, size_t>
736 ShortestDist(const std::pair<bool, size_t> &node_value,
737 const std::vector< std::pair<int, size_t> > &dist_in,
738 size_t out_index) {
739 if (node_value.first) {
740 return std::make_pair(1, node_value.second);
741 }
742 size_t size = 0;
743 int res = std::numeric_limits<int>::max();
744 for (size_t i = 0; i < dist_in.size(); ++i) {
745 if (i == out_index) continue;
746 if (dist_in[i].first == std::numeric_limits<int>::max()) continue;
747 if (dist_in[i].first + 1 < res) {
748 res = dist_in[i].first + 1;
749 size = dist_in[i].second;
750 }
751 }
752 // add one hop
753
754 return std::make_pair(res, size);
755 }
756 /*!
757 * \brief message passing function, used to decide the
758 * data request from each edge, whether need to request data from certain edge
759 * \param node_value a pair of request_data and best_link
760 * request_data stores whether current node need to request data
761 * best_link gives the best edge index to fetch the data
762 * \param req_in the data request from incoming edges
763 * \param out_index the edge index of output link
764 * \return the request to the output edge
765 */
766 inline char DataRequest(const std::pair<bool, int> &node_value,
767 const std::vector<char> &req_in,
768 size_t out_index) {
769 // whether current node need to request data
770 bool request_data = node_value.first;
771 // which edge index is the best link to request data
772 // can be -1, which means current node contains data
773 const int best_link = node_value.second;
774 if (static_cast<int>(out_index) == best_link) {
775 if (request_data) return 1;
776 for (size_t i = 0; i < req_in.size(); ++i) {
777 if (i == out_index) continue;
778 if (req_in[i] != 0) return 1;
779 }
780 }
781 return 0;
782 }
783 /*!
784 * \brief try to decide the recovery message passing request
785 * \param role the current role of the node
786 * \param p_size used to store the size of the message, for node in state kHaveData,
787 * this size must be set correctly before calling the function
788 * for others, this surves as output parameter
789 *
790 * \param p_recvlink used to store the link current node should recv data from, if necessary
791 * this can be -1, which means current node have the data
792 * \param p_req_in used to store the resulting vector, indicating which link we should send the data to
793 *
794 * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
795 * \sa ReturnType
796 */
797 AllreduceRobust::ReturnType
798 AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
799 size_t *p_size,
800 int *p_recvlink,
801 std::vector<bool> *p_req_in) {
802 int best_link = -2;
803 {
804 // get the shortest distance to the request point
805 std::vector<std::pair<int, size_t> > dist_in, dist_out;
806
807 ReturnType succ = MsgPassing(std::make_pair(role == kHaveData, *p_size),
808 &dist_in, &dist_out, ShortestDist);
809 if (succ != kSuccess) return succ;
810 if (role != kHaveData) {
811 for (size_t i = 0; i < dist_in.size(); ++i) {
812 if (dist_in[i].first != std::numeric_limits<int>::max()) {
813 utils::Check(best_link == -2 || *p_size == dist_in[i].second,
814 "[%d] Allreduce size inconsistent, distin=%lu, size=%lu, reporting=%lu\n",
815 rank, dist_in[i].first, *p_size, dist_in[i].second);
816 if (best_link == -2 || dist_in[i].first < dist_in[best_link].first) {
817 best_link = static_cast<int>(i);
818 *p_size = dist_in[i].second;
819 }
820 }
821 }
822 utils::Check(best_link != -2, "Too many nodes went down and we cannot recover..");
823 } else {
824 best_link = -1;
825 }
826 }
827 // get the node request
828 std::vector<char> req_in, req_out;
829 ReturnType succ = MsgPassing(std::make_pair(role == kRequestData, best_link),
830 &req_in, &req_out, DataRequest);
831 if (succ != kSuccess) return succ;
832 // set p_req_in
833 p_req_in->resize(req_in.size());
834 for (size_t i = 0; i < req_in.size(); ++i) {
835 // set p_req_in
836 (*p_req_in)[i] = (req_in[i] != 0);
837 if (req_out[i] != 0) {
838 _assert(req_in[i] == 0, "cannot get and receive request");
839 _assert(static_cast<int>(i) == best_link, "request result inconsistent");
840 }
841 }
842 *p_recvlink = best_link;
843 return kSuccess;
844 }
845 /*!
846 * \brief try to finish the data recovery request,
847 * this function is used together with TryDecideRouting
848 * \param role the current role of the node
849 * \param sendrecvbuf_ the buffer to store the data to be sent/recived
850 * - if the role is kHaveData, this stores the data to be sent
851 * - if the role is kRequestData, this is the buffer to store the result
852 * - if the role is kPassData, this will not be used, and can be NULL
853 * \param size the size of the data, obtained from TryDecideRouting
854 * \param recv_link the link index to receive data, if necessary, obtained from TryDecideRouting
855 * \param req_in the request of each link to send data, obtained from TryDecideRouting
856 *
857 * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
858 * \sa ReturnType, TryDecideRouting
859 */
860 AllreduceRobust::ReturnType
861 AllreduceRobust::TryRecoverData(RecoverType role,
862 void *sendrecvbuf_,
863 size_t size,
864 int recv_link,
865 const std::vector<bool> &req_in) {
866 RefLinkVector &links = tree_links;
867 // no need to run recovery for zero size messages
868 if (links.size() == 0 || size == 0) return kSuccess;
869 _assert(req_in.size() == links.size(), "TryRecoverData");
870 const int nlink = static_cast<int>(links.size());
871 {
872 bool req_data = role == kRequestData;
873 for (int i = 0; i < nlink; ++i) {
874 if (req_in[i]) {
875 _assert(i != recv_link, "TryDecideRouting");
876 req_data = true;
877 }
878 }
879 // do not need to provide data or receive data, directly exit
880 if (!req_data) return kSuccess;
881 }
882 _assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
883 if (role == kPassData) {
884 links[recv_link].InitBuffer(1, size, reduce_buffer_size);
885 }
886 for (int i = 0; i < nlink; ++i) {
887 links[i].ResetSize();
888 }
889 while (true) {
890 bool finished = true;
891 utils::PollHelper watcher;
892 for (int i = 0; i < nlink; ++i) {
893 if (i == recv_link && links[i].size_read != size) {
894 watcher.WatchRead(links[i].sock);
895 finished = false;
896 }
897 if (req_in[i] && links[i].size_write != size) {
898 if (role == kHaveData ||
899 (links[recv_link].size_read != links[i].size_write)) {
900 watcher.WatchWrite(links[i].sock);
901 }
902 finished = false;
903 }
904 watcher.WatchException(links[i].sock);
905 }
906 if (finished) break;
907 watcher.Poll();
908 // exception handling
909 for (int i = 0; i < nlink; ++i) {
910 if (watcher.CheckExcept(links[i].sock)) {
911 return ReportError(&links[i], kGetExcept);
912 }
913 }
914 if (role == kRequestData) {
915 const int pid = recv_link;
916 if (watcher.CheckRead(links[pid].sock)) {
917 ReturnType ret = links[pid].ReadToArray(sendrecvbuf_, size);
918 if (ret != kSuccess) {
919 return ReportError(&links[pid], ret);
920 }
921 }
922 for (int i = 0; i < nlink; ++i) {
923 if (req_in[i] && links[i].size_write != links[pid].size_read) {
924 ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, links[pid].size_read);
925 if (ret != kSuccess) {
926 return ReportError(&links[i], ret);
927 }
928 }
929 }
930 }
931 if (role == kHaveData) {
932 for (int i = 0; i < nlink; ++i) {
933 if (req_in[i] && links[i].size_write != size) {
934 ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, size);
935 if (ret != kSuccess) {
936 return ReportError(&links[i], ret);
937 }
938 }
939 }
940 }
941 if (role == kPassData) {
942 const int pid = recv_link;
943 const size_t buffer_size = links[pid].buffer_size;
944 if (watcher.CheckRead(links[pid].sock)) {
945 size_t min_write = size;
946 for (int i = 0; i < nlink; ++i) {
947 if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
948 }
949 _assert(min_write <= links[pid].size_read, "boundary check");
950 ReturnType ret = links[pid].ReadToRingBuffer(min_write, size);
951 if (ret != kSuccess) {
952 return ReportError(&links[pid], ret);
953 }
954 }
955 for (int i = 0; i < nlink; ++i) {
956 if (req_in[i] && links[pid].size_read != links[i].size_write) {
957 size_t start = links[i].size_write % buffer_size;
958 // send out data from ring buffer
959 size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write);
960 ssize_t len = links[i].sock.Send(links[pid].buffer_head + start, nwrite);
961 if (len != -1) {
962 links[i].size_write += len;
963 } else {
964 ReturnType ret = Errno2Return();
965 if (ret != kSuccess) return ReportError(&links[i], ret);
966 }
967 }
968 }
969 }
970 }
971 return kSuccess;
972 }
973 /*!
974 * \brief try to fetch allreduce/broadcast results from rest of nodes
975 * as collaberative function called by all nodes, only requester node
976 * will pass seqno to rest of nodes and reconstruct/backfill sendrecvbuf_
977 * of specific seqno from other nodes.
978 */
979 AllreduceRobust::ReturnType AllreduceRobust::TryRestoreCache(bool requester,
980 const int min_seq, const int max_seq) {
981 // clear requester and rebuild from those with most cache entries
982 if (requester) {
983 _assert(cur_cache_seq <= max_seq, "requester is expected to have fewer cache entries");
984 cachebuf.Clear();
985 lookupbuf.Clear();
986 cur_cache_seq = 0;
987 }
988 RecoverType role = requester ? kRequestData : kHaveData;
989 size_t size = 1;
990 int recv_link;
991 std::vector<bool> req_in;
992 ReturnType ret = TryDecideRouting(role, &size, &recv_link, &req_in);
993 if (ret != kSuccess) return ret;
994 // only recover missing cache entries in requester
995 // as tryrecoverdata is collective call, need to go through entire cache
996 // and only work on those missing
997 for (int i = 0; i < max_seq; i++) {
998 // restore lookup map
999 size_t cache_size = 0;
1000 void* key = lookupbuf.Query(i, &cache_size);
1001 ret = TryRecoverData(role, &cache_size, sizeof(size_t), recv_link, req_in);
1002 if (ret != kSuccess) return ret;
1003 if (requester) {
1004 key = lookupbuf.AllocTemp(cache_size, 1);
1005 lookupbuf.PushTemp(i, cache_size, 1);
1006 }
1007 ret = TryRecoverData(role, key, cache_size, recv_link, req_in);
1008 if (ret != kSuccess) return ret;
1009 // restore cache content
1010 cache_size = 0;
1011 void* buf = cachebuf.Query(i, &cache_size);
1012 ret = TryRecoverData(role, &cache_size, sizeof(size_t), recv_link, req_in);
1013 if (requester) {
1014 buf = cachebuf.AllocTemp(cache_size, 1);
1015 cachebuf.PushTemp(i, cache_size, 1);
1016 cur_cache_seq +=1;
1017 }
1018 ret = TryRecoverData(role, buf, cache_size, recv_link, req_in);
1019 if (ret != kSuccess) return ret;
1020 }
1021
1022 return kSuccess;
1023 }
1024
1025 /*!
1026 * \brief try to load check point
1027 *
1028 * This is a collaborative function called by all nodes
1029 * only the nodes with requester set to true really needs to load the check point
1030 * other nodes acts as collaborative roles to complete this request
1031 *
1032 * \param requester whether current node is the requester
1033 * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
1034 * \sa ReturnType
1035 */
1036 AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
1037 // check in local data
1038 RecoverType role = requester ? kRequestData : kHaveData;
1039 ReturnType succ;
1040 if (num_local_replica != 0) {
1041 if (requester) {
1042 // clear existing history, if any, before load
1043 local_rptr[local_chkpt_version].clear();
1044 local_chkpt[local_chkpt_version].clear();
1045 }
1046 // recover local checkpoint
1047 succ = TryRecoverLocalState(&local_rptr[local_chkpt_version],
1048 &local_chkpt[local_chkpt_version]);
1049 if (succ != kSuccess) return succ;
1050 int nlocal = std::max(static_cast<int>(local_rptr[local_chkpt_version].size()) - 1, 0);
1051 // check if everyone is OK
1052 unsigned state = 0;
1053 if (nlocal == num_local_replica + 1) {
1054 // complete recovery
1055 state = 1;
1056 } else if (nlocal == 0) {
1057 // get nothing
1058 state = 2;
1059 } else {
1060 // partially complete state
1061 state = 4;
1062 }
1063 succ = TryAllreduce(&state, sizeof(state), 1, op::Reducer<op::BitOR, unsigned>);
1064 if (succ != kSuccess) return succ;
1065 utils::Check(state == 1 || state == 2,
1066 "LoadCheckPoint: too many nodes fails, cannot recover local state");
1067 }
1068 // do call save model if the checkpoint was lazy
1069 if (role == kHaveData && global_lazycheck != NULL) {
1070 global_checkpoint.resize(0);
1071 utils::MemoryBufferStream fs(&global_checkpoint);
1072 fs.Write(&version_number, sizeof(version_number));
1073 global_lazycheck->Save(&fs);
1074 global_lazycheck = NULL;
1075 }
1076 // recover global checkpoint
1077 size_t size = this->global_checkpoint.length();
1078 int recv_link;
1079 std::vector<bool> req_in;
1080 succ = TryDecideRouting(role, &size, &recv_link, &req_in);
1081 if (succ != kSuccess) return succ;
1082 if (role == kRequestData) {
1083 global_checkpoint.resize(size);
1084 }
1085 if (size == 0) return kSuccess;
1086 return TryRecoverData(role, BeginPtr(global_checkpoint), size, recv_link, req_in);
1087 }
1088 /*!
1089 * \brief try to get the result of operation specified by seqno
1090 *
1091 * This is a collaborative function called by all nodes
1092 * only the nodes with requester set to true really needs to get the result
1093 * other nodes acts as collaborative roles to complete this request
1094 *
1095 * \param buf the buffer to store the result, this parameter is only used when current node is requester
1096 * \param size the total size of the buffer, this parameter is only used when current node is requester
1097 * \param seqno sequence number of the operation, this is unique index of a operation in current iteration
1098 * \param requester whether current node is the requester
1099 * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
1100 * \sa ReturnType
1101 */
1102 AllreduceRobust::ReturnType
1103 AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) {
1104 // if minimum sequence requested is local check point ack,
1105 // this means all nodes have finished local check point, directly return
1106 if (seqno == ActionSummary::kLocalCheckAck) return kSuccess;
1107 if (seqno == ActionSummary::kLocalCheckPoint) {
1108 // new version of local model
1109 int new_version = !local_chkpt_version;
1110 int nlocal = std::max(static_cast<int>(local_rptr[new_version].size()) - 1, 0);
1111 // if we goes to this place, use must have already setup the state once
1112 _assert(nlocal == 1 || nlocal == num_local_replica + 1,
1113 "TryGetResult::Checkpoint");
1114 return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]);
1115 }
1116
1117 // handles normal data recovery
1118 RecoverType role;
1119 if (!requester) {
1120 sendrecvbuf = resbuf.Query(seqno, &size);
1121 role = sendrecvbuf != NULL ? kHaveData : kPassData;
1122 } else {
1123 role = kRequestData;
1124 }
1125 int recv_link;
1126 std::vector<bool> req_in;
1127 // size of data
1128 size_t data_size = size;
1129 ReturnType succ = TryDecideRouting(role, &data_size, &recv_link, &req_in);
1130 if (succ != kSuccess) return succ;
1131 utils::Check(data_size != 0, "zero size check point is not allowed");
1132 if (role == kRequestData || role == kHaveData) {
1133 utils::Check(data_size == size,
1134 "Allreduce Recovered data size do not match the specification of function call.\n"\
1135 "Please check if calling sequence of recovered program is the " \
1136 "same the original one in current VersionNumber");
1137 }
1138 return TryRecoverData(role, sendrecvbuf, data_size, recv_link, req_in);
1139 }
1140 /*!
1141 * \brief try to run recover execution for a request action described by flag and seqno,
1142 * the function will keep blocking to run possible recovery operations before the specified action,
1143 * until the requested result is received by a recovering procedure,
1144 * or the function discovers that the requested action is not yet executed, and return false
1145 *
1146 * \param buf the buffer to store the result
1147 * \param size the total size of the buffer
1148 * \param flag flag information about the action \sa ActionSummary
1149 * \param seqno sequence number of the action, if it is special action with flag set,
1150 * seqno needs to be set to ActionSummary::kSpecialOp
1151 *
1152 * \return if this function can return true or false
1153 * - true means buf already set to the
1154 * result by recovering procedure, the action is complete, no further action is needed
1155 * - false means this is the lastest action that has not yet been executed, need to execute the action
1156 */
1157 bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
1158 int cache_seqno, const char* caller) {
1159 // kLoadBootstrapCache should be treated similar as allreduce
1160 // when loadcheck/check/checkack runs in other nodes
1161 if (flag != 0 && flag != ActionSummary::kLoadBootstrapCache) {
1162 _assert(seqno == ActionSummary::kSpecialOp, "must only set seqno for normal operations");
1163 }
1164
1165 std::string msg = std::string(caller) + " pass negative seqno "
1166 + std::to_string(seqno) + " flag " + std::to_string(flag)
1167 + " version " + std::to_string(version_number);
1168 _assert(seqno >=0, msg.c_str());
1169
1170 ActionSummary req(flag, flag, seqno, cache_seqno);
1171
1172 while (true) {
1173 this->ReportStatus();
1174 // copy to action and send to allreduce with other nodes
1175 ActionSummary act = req;
1176 // get the reduced action
1177 if (!CheckAndRecover(TryAllreduce(&act, sizeof(act), 1, ActionSummary::Reducer))) continue;
1178
1179 if (act.check_ack()) {
1180 if (act.check_point()) {
1181 // if we also have check_point, do check point first
1182 _assert(!act.diff_seq(),
1183 "check ack & check pt cannot occur together with normal ops");
1184 // if we requested checkpoint, we are free to go
1185 if (req.check_point()) return true;
1186 } else if (act.load_check()) {
1187 // if there is only check_ack and load_check, do load_check
1188 if (!CheckAndRecover(TryLoadCheckPoint(req.load_check()))) continue;
1189 // if requested load check, then misson complete
1190 if (req.load_check()) return true;
1191 } else {
1192 // there is no check point and no load check, execute check ack
1193 if (req.check_ack()) return true;
1194 }
1195 // if execute to this point
1196 // this means the action requested has not been completed
1197 // try next round
1198 } else {
1199 if (act.check_point()) {
1200 if (act.diff_seq()) {
1201 _assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug");
1202 // print checkpoint consensus flag if user turn on debug
1203 if (rabit_debug) {
1204 req.print_flags(rank, "checkpoint req");
1205 act.print_flags(rank, "checkpoint act");
1206 }
1207 /*
1208 * Chen Qin
1209 * at least one hit checkpoint_ code & at least one not hitting
1210 * compare with version_number of req.check_point() set true with rest
1211 * expect to be equal, means rest fall behind in sequence
1212 * use resbuf resbuf to recover
1213 * worker-0 worker-1
1214 * checkpoint(n-1) checkpoint(n-1)
1215 * allreduce allreduce (requester) |
1216 * broadcast V
1217 * checkpoint(n req)
1218 * after catch up to checkpoint n, diff_seq will be false
1219 * */
1220 // assume requester is falling behind
1221 bool requester = req.seqno() == act.seqno();
1222 // if not load cache
1223 if (!act.load_cache()) {
1224 if (act.seqno() > 0) {
1225 if (!requester) {
1226 _assert(req.check_point(), "checkpoint node should be KHaveData role");
1227 buf = resbuf.Query(act.seqno(), &size);
1228 _assert(buf != NULL, "buf should have data from resbuf");
1229 _assert(size > 0, "buf size should be greater than 0");
1230 }
1231 if (!CheckAndRecover(TryGetResult(buf, size, act.seqno(), requester))) continue;
1232 }
1233 } else {
1234 // cache seq no should be smaller than kSpecialOp
1235 _assert(act.seqno(SeqType::kCache) != ActionSummary::kSpecialOp,
1236 "checkpoint with kSpecialOp");
1237 int max_cache_seq = cur_cache_seq;
1238 if (TryAllreduce(&max_cache_seq, sizeof(max_cache_seq), 1,
1239 op::Reducer<op::Max, unsigned>) != kSuccess) continue;
1240
1241 if (TryRestoreCache(req.load_cache(), act.seqno(), max_cache_seq)
1242 != kSuccess) continue;
1243 }
1244 if (requester) return true;
1245 } else {
1246 // no difference in seq no, means we are free to check point
1247 if (req.check_point()) return true;
1248 }
1249 } else {
1250 // no check point
1251 if (act.load_check()) {
1252 // all the nodes called load_check, this is an incomplete action
1253 if (!act.diff_seq()) return false;
1254 // load check have higher priority, do load_check
1255 if (!CheckAndRecover(TryLoadCheckPoint(req.load_check()))) continue;
1256 // if requested load check, then misson complete
1257 if (req.load_check()) return true;
1258 } else {
1259 // run all nodes in a isolated cache restore logic
1260 if (act.load_cache()) {
1261 // print checkpoint consensus flag if user turn on debug
1262 if (rabit_debug) {
1263 req.print_flags(rank, "loadcache req");
1264 act.print_flags(rank, "loadcache act");
1265 }
1266 // load cache should not running in parralel with other states
1267 _assert(!act.load_check(),
1268 "load cache state expect no nodes doing load checkpoint");
1269 _assert(!act.check_point() ,
1270 "load cache state expect no nodes doing checkpoint");
1271 _assert(!act.check_ack(),
1272 "load cache state expect no nodes doing checkpoint ack");
1273
1274 // if all nodes are requester in load cache, skip
1275 if (act.load_cache(SeqType::kCache)) return false;
1276
1277 // bootstrap cache always restore before loadcheckpoint
1278 // requester always have seq diff with non requester
1279 if (act.diff_seq()) {
1280 // restore cache failed, retry from what's left
1281 if (TryRestoreCache(req.load_cache(), act.seqno(), act.seqno(SeqType::kCache))
1282 != kSuccess) continue;
1283 }
1284 // if requested load cache, then mission complete
1285 if (req.load_cache()) return true;
1286 continue;
1287 }
1288
1289 // assert no req with load cache set goes into seq catch up
1290 _assert(!req.load_cache(), "load cache not interacte with rest states");
1291
1292 // no special flags, no checkpoint, check ack, load_check
1293 _assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug");
1294 if (act.diff_seq()) {
1295 bool requester = req.seqno() == act.seqno();
1296 if (!CheckAndRecover(TryGetResult(buf, size, act.seqno(), requester))) continue;
1297 if (requester) return true;
1298 } else {
1299 // all the request is same,
1300 // this is most recent command that is yet to be executed
1301 return false;
1302 }
1303 }
1304 }
1305 // something is still incomplete try next round
1306 }
1307 }
1308 _assert(false, "RecoverExec: should not reach here");
1309 return true;
1310 }
1311 /*!
1312 * \brief try to recover the local state, making each local state to be the result of itself
1313 * plus replication of states in previous num_local_replica hops in the ring
1314 *
1315 * The input parameters must contain the valid local states available in current nodes,
1316 * This function try ist best to "complete" the missing parts of local_rptr and local_chkpt
1317 * If there is sufficient information in the ring, when the function returns, local_chkpt will
1318 * contain num_local_replica + 1 checkpoints (including the chkpt of this node)
1319 * If there is no sufficient information in the ring, this function the number of checkpoints
1320 * will be less than the specified value
1321 *
1322 * \param p_local_rptr the pointer to the segment pointers in the states array
1323 * \param p_local_chkpt the pointer to the storage of local check points
1324 * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
1325 * \sa ReturnType
1326 */
1327 AllreduceRobust::ReturnType
1328 AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
1329 std::string *p_local_chkpt) {
1330 // if there is no local replica, we can do nothing
1331 if (num_local_replica == 0) return kSuccess;
1332 std::vector<size_t> &rptr = *p_local_rptr;
1333 std::string &chkpt = *p_local_chkpt;
1334 if (rptr.size() == 0) {
1335 rptr.push_back(0);
1336 _assert(chkpt.length() == 0, "local chkpt space inconsistent");
1337 }
1338 const int n = num_local_replica;
1339 {
1340 // backward passing, passing state in backward direction of the ring
1341 const int nlocal = static_cast<int>(rptr.size() - 1);
1342 _assert(nlocal <= n + 1, "invalid local replica");
1343 std::vector<int> msg_back(n + 1);
1344 msg_back[0] = nlocal;
1345 // backward passing one hop the request
1346 ReturnType succ;
1347 succ = RingPassing(BeginPtr(msg_back),
1348 1 * sizeof(int), (n+1) * sizeof(int),
1349 0 * sizeof(int), n * sizeof(int),
1350 ring_next, ring_prev);
1351 if (succ != kSuccess) return succ;
1352 int msg_forward[2];
1353 msg_forward[0] = nlocal;
1354 succ = RingPassing(msg_forward,
1355 1 * sizeof(int), 2 * sizeof(int),
1356 0 * sizeof(int), 1 * sizeof(int),
1357 ring_prev, ring_next);
1358 if (succ != kSuccess) return succ;
1359 // calculate the number of things we can read from next link
1360 int nread_end = nlocal;
1361 for (int i = 1; i <= n; ++i) {
1362 nread_end = std::max(nread_end, msg_back[i] - i);
1363 }
1364 // gives the size of forward
1365 int nwrite_start = std::min(msg_forward[1] + 1, nread_end);
1366 // get the size of each segments
1367 std::vector<size_t> sizes(nread_end);
1368 for (int i = 0; i < nlocal; ++i) {
1369 sizes[i] = rptr[i + 1] - rptr[i];
1370 }
1371 // pass size through the link
1372 succ = RingPassing(BeginPtr(sizes),
1373 nlocal * sizeof(size_t),
1374 nread_end * sizeof(size_t),
1375 nwrite_start * sizeof(size_t),
1376 nread_end * sizeof(size_t),
1377 ring_next, ring_prev);
1378 if (succ != kSuccess) return succ;
1379 // update rptr
1380 rptr.resize(nread_end + 1);
1381 for (int i = nlocal; i < nread_end; ++i) {
1382 rptr[i + 1] = rptr[i] + sizes[i];
1383 }
1384 chkpt.resize(rptr.back());
1385 // pass data through the link
1386 succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end],
1387 rptr[nwrite_start], rptr[nread_end],
1388 ring_next, ring_prev);
1389 if (succ != kSuccess) {
1390 rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ;
1391 }
1392 }
1393 {
1394 // forward passing, passing state in forward direction of the ring
1395 const int nlocal = static_cast<int>(rptr.size() - 1);
1396 _assert(nlocal <= n + 1, "invalid local replica");
1397 std::vector<int> msg_forward(n + 1);
1398 msg_forward[0] = nlocal;
1399 // backward passing one hop the request
1400 ReturnType succ;
1401 succ = RingPassing(BeginPtr(msg_forward),
1402 1 * sizeof(int), (n+1) * sizeof(int),
1403 0 * sizeof(int), n * sizeof(int),
1404 ring_prev, ring_next);
1405 if (succ != kSuccess) return succ;
1406 int msg_back[2];
1407 msg_back[0] = nlocal;
1408 succ = RingPassing(msg_back,
1409 1 * sizeof(int), 2 * sizeof(int),
1410 0 * sizeof(int), 1 * sizeof(int),
1411 ring_next, ring_prev);
1412 if (succ != kSuccess) return succ;
1413 // calculate the number of things we can read from next link
1414 int nread_end = nlocal, nwrite_end = 1;
1415 // have to have itself in order to get other data from prev link
1416 if (nlocal != 0) {
1417 for (int i = 1; i <= n; ++i) {
1418 if (msg_forward[i] == 0) break;
1419 nread_end = std::max(nread_end, i + 1);
1420 nwrite_end = i + 1;
1421 }
1422 if (nwrite_end > n) nwrite_end = n;
1423 } else {
1424 nread_end = 0; nwrite_end = 0;
1425 }
1426 // gives the size of forward
1427 int nwrite_start = std::min(msg_back[1] - 1, nwrite_end);
1428 // next node miss the state of itself, cannot recover
1429 if (nwrite_start < 0) nwrite_start = nwrite_end = 0;
1430 // get the size of each segments
1431 std::vector<size_t> sizes(nread_end);
1432 for (int i = 0; i < nlocal; ++i) {
1433 sizes[i] = rptr[i + 1] - rptr[i];
1434 }
1435 // pass size through the link, check consistency
1436 succ = RingPassing(BeginPtr(sizes),
1437 nlocal * sizeof(size_t),
1438 nread_end * sizeof(size_t),
1439 nwrite_start * sizeof(size_t),
1440 nwrite_end * sizeof(size_t),
1441 ring_prev, ring_next);
1442 if (succ != kSuccess) return succ;
1443 // update rptr
1444 rptr.resize(nread_end + 1);
1445 for (int i = nlocal; i < nread_end; ++i) {
1446 rptr[i + 1] = rptr[i] + sizes[i];
1447 }
1448 chkpt.resize(rptr.back());
1449 // pass data through the link
1450 succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end],
1451 rptr[nwrite_start], rptr[nwrite_end],
1452 ring_prev, ring_next);
1453 if (succ != kSuccess) {
1454 rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ;
1455 }
1456 }
1457 return kSuccess;
1458 }
1459 /*!
1460 * \brief try to checkpoint local state, this function is called in normal executation phase
1461 * of checkpoint that contains local state
1462 * the input state must exactly one saved state(local state of current node),
1463 * after complete, this function will get local state from previous num_local_replica nodes and put them
1464 * into local_chkpt and local_rptr
1465 *
1466 * It is also OK to call TryRecoverLocalState instead,
1467 * TryRecoverLocalState makes less assumption about the input, and requires more communications
1468 *
1469 * \param p_local_rptr the pointer to the segment pointers in the states array
1470 * \param p_local_chkpt the pointer to the storage of local check points
1471 * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
1472 * \sa ReturnType, TryRecoverLocalState
1473 */
1474 AllreduceRobust::ReturnType
1475 AllreduceRobust::TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
1476 std::string *p_local_chkpt) {
1477 // if there is no local replica, we can do nothing
1478 if (num_local_replica == 0) return kSuccess;
1479 std::vector<size_t> &rptr = *p_local_rptr;
1480 std::string &chkpt = *p_local_chkpt;
1481 _assert(rptr.size() == 2,
1482 "TryCheckinLocalState must have exactly 1 state");
1483 const int n = num_local_replica;
1484 std::vector<size_t> sizes(n + 1);
1485 sizes[0] = rptr[1] - rptr[0];
1486 ReturnType succ;
1487 // pass size through the link
1488 succ = RingPassing(BeginPtr(sizes),
1489 1 * sizeof(size_t),
1490 (n + 1) * sizeof(size_t),
1491 0 * sizeof(size_t),
1492 n * sizeof(size_t),
1493 ring_prev, ring_next);
1494 if (succ != kSuccess) return succ;
1495 // update rptr
1496 rptr.resize(n + 2);
1497 for (int i = 1; i <= n; ++i) {
1498 rptr[i + 1] = rptr[i] + sizes[i];
1499 }
1500 chkpt.resize(rptr.back());
1501 // pass data through the link
1502 succ = RingPassing(BeginPtr(chkpt),
1503 rptr[1], rptr[n + 1],
1504 rptr[0], rptr[n],
1505 ring_prev, ring_next);
1506 if (succ != kSuccess) {
1507 rptr.resize(2); chkpt.resize(rptr.back()); return succ;
1508 }
1509 return kSuccess;
1510 }
1511 /*!
1512 * \brief perform a ring passing to receive data from prev link, and sent data to next link
1513 * this allows data to stream over a ring structure
1514 * sendrecvbuf[0:read_ptr] are already provided by current node
1515 * current node will recv sendrecvbuf[read_ptr:read_end] from prev link
1516 * current node will send sendrecvbuf[write_ptr:write_end] to next link
1517 * write_ptr will wait till the data is readed before sending the data
1518 * this function requires read_end >= write_end
1519 *
1520 * \param sendrecvbuf_ the place to hold the incoming and outgoing data
1521 * \param read_ptr the initial read pointer
1522 * \param read_end the ending position to read
1523 * \param write_ptr the initial write pointer
1524 * \param write_end the ending position to write
1525 * \param read_link pointer to link to previous position in ring
1526 * \param write_link pointer to link of next position in ring
1527 */
1528 AllreduceRobust::ReturnType
1529 AllreduceRobust::RingPassing(void *sendrecvbuf_,
1530 size_t read_ptr,
1531 size_t read_end,
1532 size_t write_ptr,
1533 size_t write_end,
1534 LinkRecord *read_link,
1535 LinkRecord *write_link) {
1536 if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess;
1537 _assert(write_end <= read_end,
1538 "RingPassing: boundary check1");
1539 _assert(read_ptr <= read_end, "RingPassing: boundary check2");
1540 _assert(write_ptr <= write_end, "RingPassing: boundary check3");
1541 // take reference
1542 LinkRecord &prev = *read_link, &next = *write_link;
1543 // send recv buffer
1544 char *buf = reinterpret_cast<char*>(sendrecvbuf_);
1545 while (true) {
1546 bool finished = true;
1547 utils::PollHelper watcher;
1548 if (read_ptr != read_end) {
1549 watcher.WatchRead(prev.sock);
1550 finished = false;
1551 }
1552 if (write_ptr < read_ptr && write_ptr != write_end) {
1553 watcher.WatchWrite(next.sock);
1554 finished = false;
1555 }
1556 watcher.WatchException(prev.sock);
1557 watcher.WatchException(next.sock);
1558 if (finished) break;
1559 watcher.Poll();
1560 if (watcher.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept);
1561 if (watcher.CheckExcept(next.sock)) return ReportError(&next, kGetExcept);
1562 if (read_ptr != read_end && watcher.CheckRead(prev.sock)) {
1563 ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr);
1564 if (len == 0) {
1565 prev.sock.Close(); return ReportError(&prev, kRecvZeroLen);
1566 }
1567 if (len != -1) {
1568 read_ptr += static_cast<size_t>(len);
1569 } else {
1570 ReturnType ret = Errno2Return();
1571 if (ret != kSuccess) return ReportError(&prev, ret);
1572 }
1573 }
1574 if (write_ptr != write_end && write_ptr < read_ptr) {
1575 size_t nsend = std::min(write_end - write_ptr, read_ptr - write_ptr);
1576 ssize_t len = next.sock.Send(buf + write_ptr, nsend);
1577 if (len != -1) {
1578 write_ptr += static_cast<size_t>(len);
1579 } else {
1580 ReturnType ret = Errno2Return();
1581 if (ret != kSuccess) return ReportError(&prev, ret);
1582 }
1583 }
1584 }
1585 return kSuccess;
1586 }
1587 } // namespace engine
1588 } // namespace rabit
0 /*!
1 * Copyright (c) 2014 by Contributors
2 * \file allreduce_robust.h
3 * \brief Robust implementation of Allreduce
4 * using TCP non-block socket and tree-shape reduction.
5 *
6 * This implementation considers the failure of nodes
7 *
8 * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
9 */
10 #ifndef RABIT_ALLREDUCE_ROBUST_H_
11 #define RABIT_ALLREDUCE_ROBUST_H_
12 #include <future>
13 #include <vector>
14 #include <string>
15 #include <algorithm>
16 #include "rabit/internal/engine.h"
17 #include "allreduce_base.h"
18
19 namespace rabit {
20 namespace engine {
21 /*! \brief implementation of fault tolerant all reduce engine */
22 class AllreduceRobust : public AllreduceBase {
23 public:
24 AllreduceRobust(void);
25 virtual ~AllreduceRobust(void) {}
26 // initialize the manager
27 virtual bool Init(int argc, char* argv[]);
28 /*! \brief shutdown the engine */
29 virtual bool Shutdown(void);
30 /*!
31 * \brief set parameters to the engine
32 * \param name parameter name
33 * \param val parameter value
34 */
35 virtual void SetParam(const char *name, const char *val);
36 /*!
37 * \brief perform immutable local bootstrap cache insertion
38 * \param key unique cache key
39 * \param buf buffer of allreduce/robust payload to copy
40 * \param buflen total number of bytes
41 * \return -1 if no recovery cache fetched otherwise 0
42 */
43 int SetBootstrapCache(const std::string &key, const void *buf,
44 const size_t type_nbytes, const size_t count);
45 /*!
46 * \brief perform bootstrap cache lookup if nodes in fault recovery
47 * \param key unique cache key
48 * \param buf buffer for recv allreduce/robust payload
49 * \param buflen total number of bytes
50 */
51 int GetBootstrapCache(const std::string &key, void *buf, const size_t type_nbytes,
52 const size_t count);
53 /*!
54 * \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
55 * the data provided by current node k is [slice_begin, slice_end),
56 * the next node's segment must start with slice_end
57 * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
58 * use a ring based algorithm
59 *
60 * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
61 * \param total_size total size of data to be gathered
62 * \param slice_begin beginning of the current slice
63 * \param slice_end end of the current slice
64 * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
65 * \param _file caller file name used to generate unique cache key
66 * \param _line caller line number used to generate unique cache key
67 * \param _caller caller function name used to generate unique cache key
68 */
69 virtual void Allgather(void *sendrecvbuf_, size_t total_size,
70 size_t slice_begin,
71 size_t slice_end,
72 size_t size_prev_slice,
73 const char* _file = _FILE,
74 const int _line = _LINE,
75 const char* _caller = _CALLER);
76 /*!
77 * \brief perform in-place allreduce, on sendrecvbuf
78 * this function is NOT thread-safe
79 * \param sendrecvbuf_ buffer for both sending and recving data
80 * \param type_nbytes the unit number of bytes the type have
81 * \param count number of elements to be reduced
82 * \param reducer reduce function
83 * \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
84 * will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
85 * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
86 * \param prepare_arg argument used to passed into the lazy preprocessing function
87 * \param prepare_arg argument used to passed into the lazy preprocessing function
88 * \param _file caller file name used to generate unique cache key
89 * \param _line caller line number used to generate unique cache key
90 * \param _caller caller function name used to generate unique cache key
91 */
92 virtual void Allreduce(void *sendrecvbuf_,
93 size_t type_nbytes,
94 size_t count,
95 ReduceFunction reducer,
96 PreprocFunction prepare_fun = NULL,
97 void *prepare_arg = NULL,
98 const char* _file = _FILE,
99 const int _line = _LINE,
100 const char* _caller = _CALLER);
101 /*!
102 * \brief broadcast data from root to all nodes
103 * \param sendrecvbuf_ buffer for both sending and recving data
104 * \param size the size of the data to be broadcasted
105 * \param root the root worker id to broadcast the data
106 * \param _file caller file name used to generate unique cache key
107 * \param _line caller line number used to generate unique cache key
108 * \param _caller caller function name used to generate unique cache key
109 */
110 virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
111 const char* _file = _FILE,
112 const int _line = _LINE,
113 const char* _caller = _CALLER);
114 /*!
115 * \brief load latest check point
116 * \param global_model pointer to the globally shared model/state
117 * when calling this function, the caller need to gauranttees that global_model
118 * is the same in all nodes
119 * \param local_model pointer to local model, that is specific to current node/rank
120 * this can be NULL when no local model is needed
121 *
122 * \return the version number of check point loaded
123 * if returned version == 0, this means no model has been CheckPointed
124 * the p_model is not touched, user should do necessary initialization by themselves
125 *
126 * Common usage example:
127 * int iter = rabit::LoadCheckPoint(&model);
128 * if (iter == 0) model.InitParameters();
129 * for (i = iter; i < max_iter; ++i) {
130 * do many things, include allreduce
131 * rabit::CheckPoint(model);
132 * }
133 *
134 * \sa CheckPoint, VersionNumber
135 */
136 virtual int LoadCheckPoint(Serializable *global_model,
137 Serializable *local_model = NULL);
138 /*!
139 * \brief checkpoint the model, meaning we finished a stage of execution
140 * every time we call check point, there is a version number which will increase by one
141 *
142 * \param global_model pointer to the globally shared model/state
143 * when calling this function, the caller need to gauranttees that global_model
144 * is the same in all nodes
145 * \param local_model pointer to local model, that is specific to current node/rank
146 * this can be NULL when no local state is needed
147 *
148 * NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
149 * bring replication cost in CheckPoint function. global_model do not need explicit replication.
150 * So only CheckPoint with global_model if possible
151 *
152 * \sa LoadCheckPoint, VersionNumber
153 */
154 virtual void CheckPoint(const Serializable *global_model,
155 const Serializable *local_model = NULL) {
156 this->CheckPoint_(global_model, local_model, false);
157 }
158 /*!
159 * \brief This function can be used to replace CheckPoint for global_model only,
160 * when certain condition is met(see detailed expplaination).
161 *
162 * This is a "lazy" checkpoint such that only the pointer to global_model is
163 * remembered and no memory copy is taken. To use this function, the user MUST ensure that:
164 * The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
165 * In another words, global_model model can be changed only between last call of
166 * Allreduce/Broadcast and LazyCheckPoint in current version
167 *
168 * For example, suppose the calling sequence is:
169 * LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
170 *
171 * If user can only changes global_model in code3, then LazyCheckPoint can be used to
172 * improve efficiency of the program.
173 * \param global_model pointer to the globally shared model/state
174 * when calling this function, the caller need to gauranttees that global_model
175 * is the same in all nodes
176 * \sa LoadCheckPoint, CheckPoint, VersionNumber
177 */
178 virtual void LazyCheckPoint(const Serializable *global_model) {
179 this->CheckPoint_(global_model, NULL, true);
180 }
181 /*!
182 * \brief explicitly re-init everything before calling LoadCheckPoint
183 * call this function when IEngine throw an exception out,
184 * this function is only used for test purpose
185 */
186 virtual void InitAfterException(void) {
187 // simple way, shutdown all links
188 for (size_t i = 0; i < all_links.size(); ++i) {
189 if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
190 }
191 ReConnectLinks("recover");
192 }
193
194 protected:
195 // constant one byte out of band message to indicate error happening
196 // and mark for channel cleanup
197 static const char kOOBReset = 95;
198 // and mark for channel cleanup, after OOB signal
199 static const char kResetMark = 97;
200 // and mark for channel cleanup
201 static const char kResetAck = 97;
202 /*! \brief type of roles each node can play during recovery */
203 enum RecoverType {
204 /*! \brief current node have data */
205 kHaveData = 0,
206 /*! \brief current node request data */
207 kRequestData = 1,
208 /*! \brief current node only helps to pass data around */
209 kPassData = 2
210 };
211
212 enum SeqType {
213 /*! \brief apply to rabit seq code */
214 kSeq = 0,
215 /*! \brief apply to rabit cache seq code */
216 kCache = 1
217 };
218 /*!
219 * \brief summary of actions proposed in all nodes
220 * this data structure is used to make consensus decision
221 * about next action to take in the recovery mode
222 */
223 struct ActionSummary {
224 // maximumly allowed sequence id
225 static const u_int32_t kSpecialOp = (1 << 26);
226 // special sequence number for local state checkpoint
227 static const u_int32_t kLocalCheckPoint = (1 << 26) - 2;
228 // special sequnce number for local state checkpoint ack signal
229 static const u_int32_t kLocalCheckAck = (1 << 26) - 1;
230 //---------------------------------------------
231 // The following are bit mask of flag used in
232 //----------------------------------------------
233 // some node want to load check point
234 static const int kLoadCheck = 1;
235 // some node want to do check point
236 static const int kCheckPoint = 2;
237 // check point Ack, we use a two phase message in check point,
238 // this is the second phase of check pointing
239 static const int kCheckAck = 4;
240 // there are difference sequence number the nodes proposed
241 // this means we want to do recover execution of the lower sequence
242 // action instead of normal execution
243 static const int kDiffSeq = 8;
244 // there are nodes request load cache
245 static const int kLoadBootstrapCache = 16;
246 // constructor
247 ActionSummary(void) {}
248 // constructor of action
249 explicit ActionSummary(int seqno_flag, int cache_flag = 0,
250 u_int32_t minseqno = kSpecialOp, u_int32_t maxseqno = kSpecialOp) {
251 seqcode = (minseqno << 5) | seqno_flag;
252 maxseqcode = (maxseqno << 5) | cache_flag;
253 }
254 // minimum number of all operations by default
255 // maximum number of all cache operations otherwise
256 inline u_int32_t seqno(SeqType t = SeqType::kSeq) const {
257 int code = t == SeqType::kSeq ? seqcode : maxseqcode;
258 return code >> 5;
259 }
260 // whether the operation set contains a load_check
261 inline bool load_check(SeqType t = SeqType::kSeq) const {
262 int code = t == SeqType::kSeq ? seqcode : maxseqcode;
263 return (code & kLoadCheck) != 0;
264 }
265 // whether the operation set contains a load_cache
266 inline bool load_cache(SeqType t = SeqType::kSeq) const {
267 int code = t == SeqType::kSeq ? seqcode : maxseqcode;
268 return (code & kLoadBootstrapCache) != 0;
269 }
270 // whether the operation set contains a check point
271 inline bool check_point(SeqType t = SeqType::kSeq) const {
272 int code = t == SeqType::kSeq ? seqcode : maxseqcode;
273 return (code & kCheckPoint) != 0;
274 }
275 // whether the operation set contains a check ack
276 inline bool check_ack(SeqType t = SeqType::kSeq) const {
277 int code = t == SeqType::kSeq ? seqcode : maxseqcode;
278 return (code & kCheckAck) != 0;
279 }
280 // whether the operation set contains different sequence number
281 inline bool diff_seq() const {
282 return (seqcode & kDiffSeq) != 0;
283 }
284 // returns the operation flag of the result
285 inline int flag(SeqType t = SeqType::kSeq) const {
286 int code = t == SeqType::kSeq ? seqcode : maxseqcode;
287 return code & 31;
288 }
289 // print flags in user friendly way
290 inline void print_flags(int rank, std::string prefix ) {
291 utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n",
292 rank, prefix.c_str(),
293 seqno(), check_point(), check_ack(), load_cache(),
294 diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache));
295 }
296 // reducer for Allreduce, get the result ActionSummary from all nodes
297 inline static void Reducer(const void *src_, void *dst_,
298 int len, const MPI::Datatype &dtype) {
299 const ActionSummary *src = (const ActionSummary*)src_;
300 ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
301 for (int i = 0; i < len; ++i) {
302 u_int32_t min_seqno = std::min(src[i].seqno(), dst[i].seqno());
303 u_int32_t max_seqno = std::max(src[i].seqno(SeqType::kCache),
304 dst[i].seqno(SeqType::kCache));
305 int action_flag = src[i].flag() | dst[i].flag();
306 // if any node is not requester set to 0 otherwise 1
307 int role_flag = src[i].flag(SeqType::kCache) & dst[i].flag(SeqType::kCache);
308 // if seqno is different in src and destination
309 int seq_diff_flag = src[i].seqno() != dst[i].seqno() ? kDiffSeq : 0;
310 // apply or to both seq diff flag as well as cache seq diff flag
311 dst[i] = ActionSummary(action_flag | seq_diff_flag,
312 role_flag, min_seqno, max_seqno);
313 }
314 }
315
316 private:
317 // internel sequence code min of rabit seqno
318 u_int32_t seqcode;
319 // internal sequence code max of cache seqno
320 u_int32_t maxseqcode;
321 };
322 /*! \brief data structure to remember result of Bcast and Allreduce calls*/
323 class ResultBuffer{
324 public:
325 // constructor
326 ResultBuffer(void) {
327 this->Clear();
328 }
329 // clear the existing record
330 inline void Clear(void) {
331 seqno_.clear(); size_.clear();
332 rptr_.clear(); rptr_.push_back(0);
333 data_.clear();
334 }
335 // allocate temporal space
336 inline void *AllocTemp(size_t type_nbytes, size_t count) {
337 size_t size = type_nbytes * count;
338 size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);
339 utils::Assert(nhop != 0, "cannot allocate 0 size memory");
340 // allocate addational nhop buffer size
341 data_.resize(rptr_.back() + nhop);
342 return BeginPtr(data_) + rptr_.back();
343 }
344 // push the result in temp to the
345 inline void PushTemp(int seqid, size_t type_nbytes, size_t count) {
346 size_t size = type_nbytes * count;
347 size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);
348 if (seqno_.size() != 0) {
349 utils::Assert(seqno_.back() < seqid, "PushTemp seqid inconsistent");
350 }
351 seqno_.push_back(seqid);
352 rptr_.push_back(rptr_.back() + nhop);
353 size_.push_back(size);
354 utils::Assert(data_.size() == rptr_.back(), "PushTemp inconsistent");
355 }
356 // return the stored result of seqid, if any
357 inline void* Query(int seqid, size_t *p_size) {
358 size_t idx = std::lower_bound(seqno_.begin(),
359 seqno_.end(), seqid) - seqno_.begin();
360 if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL;
361 *p_size = size_[idx];
362 return BeginPtr(data_) + rptr_[idx];
363 }
364 // drop last stored result
365 inline void DropLast(void) {
366 utils::Assert(seqno_.size() != 0, "there is nothing to be dropped");
367 seqno_.pop_back();
368 rptr_.pop_back();
369 size_.pop_back();
370 data_.resize(rptr_.back());
371 }
372 // the sequence number of last stored result
373 inline int LastSeqNo(void) const {
374 if (seqno_.size() == 0) return -1;
375 return seqno_.back();
376 }
377
378 private:
379 // sequence number of each
380 std::vector<int> seqno_;
381 // pointer to the positions
382 std::vector<size_t> rptr_;
383 // actual size of each buffer
384 std::vector<size_t> size_;
385 // content of the buffer
386 std::vector<uint64_t> data_;
387 };
388 /*!
389 * \brief internal consistency check function,
390 * use check to ensure user always call CheckPoint/LoadCheckPoint
391 * with or without local but not both, this function will set the approperiate settings
392 * in the first call of LoadCheckPoint/CheckPoint
393 *
394 * \param with_local whether the user calls CheckPoint with local model
395 */
396 void LocalModelCheck(bool with_local);
397 /*!
398 * \brief internal implementation of checkpoint, support both lazy and normal way
399 *
400 * \param global_model pointer to the globally shared model/state
401 * when calling this function, the caller need to gauranttees that global_model
402 * is the same in all nodes
403 * \param local_model pointer to local model, that is specific to current node/rank
404 * this can be NULL when no local state is needed
405 * \param lazy_checkpt whether the action is lazy checkpoint
406 *
407 * \sa CheckPoint, LazyCheckPoint
408 */
409 void CheckPoint_(const Serializable *global_model,
410 const Serializable *local_model,
411 bool lazy_checkpt);
412 /*!
413 * \brief reset the all the existing links by sending Out-of-Band message marker
414 * after this function finishes, all the messages received and sent
415 * before in all live links are discarded,
416 * This allows us to get a fresh start after error has happened
417 *
418 * TODO(tqchen): this function is not yet functioning was not used by engine,
419 * simple resetlink and reconnect strategy is used
420 *
421 * \return this function can return kSuccess or kSockError
422 * when kSockError is returned, it simply means there are bad sockets in the links,
423 * and some link recovery proceduer is needed
424 */
425 ReturnType TryResetLinks(void);
426 /*!
427 * \brief if err_type indicates an error
428 * recover links according to the error type reported
429 * if there is no error, return true
430 * \param err_type the type of error happening in the system
431 * \return true if err_type is kSuccess, false otherwise
432 */
433 bool CheckAndRecover(ReturnType err_type);
434 /*!
435 * \brief try to run recover execution for a request action described by flag and seqno,
436 * the function will keep blocking to run possible recovery operations before the specified action,
437 * until the requested result is received by a recovering procedure,
438 * or the function discovers that the requested action is not yet executed, and return false
439 *
440 * \param buf the buffer to store the result
441 * \param size the total size of the buffer
442 * \param flag flag information about the action \sa ActionSummary
443 * \param seqno sequence number of the action, if it is special action with flag set,
444 * seqno needs to be set to ActionSummary::kSpecialOp
445 *
446 * \return if this function can return true or false
447 * - true means buf already set to the
448 * result by recovering procedure, the action is complete, no further action is needed
449 * - false means this is the lastest action that has not yet been executed, need to execute the action
450 */
451 bool RecoverExec(void *buf, size_t size, int flag,
452 int seqno = ActionSummary::kSpecialOp,
453 int cacheseqno = ActionSummary::kSpecialOp,
454 const char* caller = _CALLER);
455 /*!
456 * \brief try to load check point
457 *
458 * This is a collaborative function called by all nodes
459 * only the nodes with requester set to true really needs to load the check point
460 * other nodes acts as collaborative roles to complete this request
461 *
462 * \param requester whether current node is the requester
463 * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
464 * \sa ReturnType
465 */
466 ReturnType TryLoadCheckPoint(bool requester);
467
468 /*!
469 * \brief try to load cache
470 *
471 * This is a collaborative function called by all nodes
472 * only the nodes with requester set to true really needs to load the check point
473 * other nodes acts as collaborative roles to complete this request
474 * \param requester whether current node is the requester
475 * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
476 * \sa ReturnType
477 */
478 ReturnType TryRestoreCache(bool requester, const int min_seq = ActionSummary::kSpecialOp,
479 const int max_seq = ActionSummary::kSpecialOp);
480 /*!
481 * \brief try to get the result of operation specified by seqno
482 *
483 * This is a collaborative function called by all nodes
484 * only the nodes with requester set to true really needs to get the result
485 * other nodes acts as collaborative roles to complete this request
486 *
487 * \param buf the buffer to store the result, this parameter is only used when current node is requester
488 * \param size the total size of the buffer, this parameter is only used when current node is requester
489 * \param seqno sequence number of the operation, this is unique index of a operation in current iteration
490 * \param requester whether current node is the requester
491 * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
492 * \sa ReturnType
493 */
494 ReturnType TryGetResult(void *buf, size_t size, int seqno, bool requester);
495 /*!
496 * \brief try to decide the routing strategy for recovery
497 * \param role the current role of the node
498 * \param p_size used to store the size of the message, for node in state kHaveData,
499 * this size must be set correctly before calling the function
500 * for others, this surves as output parameter
501
502 * \param p_recvlink used to store the link current node should recv data from, if necessary
503 * this can be -1, which means current node have the data
504 * \param p_req_in used to store the resulting vector, indicating which link we should send the data to
505 *
506 * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
507 * \sa ReturnType, TryRecoverData
508 */
509 ReturnType TryDecideRouting(RecoverType role,
510 size_t *p_size,
511 int *p_recvlink,
512 std::vector<bool> *p_req_in);
513 /*!
514 * \brief try to finish the data recovery request,
515 * this function is used together with TryDecideRouting
516 * \param role the current role of the node
517 * \param sendrecvbuf_ the buffer to store the data to be sent/recived
518 * - if the role is kHaveData, this stores the data to be sent
519 * - if the role is kRequestData, this is the buffer to store the result
520 * - if the role is kPassData, this will not be used, and can be NULL
521 * \param size the size of the data, obtained from TryDecideRouting
522 * \param recv_link the link index to receive data, if necessary, obtained from TryDecideRouting
523 * \param req_in the request of each link to send data, obtained from TryDecideRouting
524 *
525 * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
526 * \sa ReturnType, TryDecideRouting
527 */
528 ReturnType TryRecoverData(RecoverType role,
529 void *sendrecvbuf_,
530 size_t size,
531 int recv_link,
532 const std::vector<bool> &req_in);
533 /*!
534 * \brief try to recover the local state, making each local state to be the result of itself
535 * plus replication of states in previous num_local_replica hops in the ring
536 *
537 * The input parameters must contain the valid local states available in current nodes,
538 * This function try ist best to "complete" the missing parts of local_rptr and local_chkpt
539 * If there is sufficient information in the ring, when the function returns, local_chkpt will
540 * contain num_local_replica + 1 checkpoints (including the chkpt of this node)
541 * If there is no sufficient information in the ring, this function the number of checkpoints
542 * will be less than the specified value
543 *
544 * \param p_local_rptr the pointer to the segment pointers in the states array
545 * \param p_local_chkpt the pointer to the storage of local check points
546 * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
547 * \sa ReturnType
548 */
549 ReturnType TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
550 std::string *p_local_chkpt);
551 /*!
552 * \brief try to checkpoint local state, this function is called in normal executation phase
553 * of checkpoint that contains local state
554 o * the input state must exactly one saved state(local state of current node),
555 * after complete, this function will get local state from previous num_local_replica nodes and put them
556 * into local_chkpt and local_rptr
557 *
558 * It is also OK to call TryRecoverLocalState instead,
559 * TryRecoverLocalState makes less assumption about the input, and requires more communications
560 *
561 * \param p_local_rptr the pointer to the segment pointers in the states array
562 * \param p_local_chkpt the pointer to the storage of local check points
563 * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
564 * \sa ReturnType, TryRecoverLocalState
565 */
566 ReturnType TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
567 std::string *p_local_chkpt);
568 /*!
569 * \brief perform a ring passing to receive data from prev link, and sent data to next link
570 * this allows data to stream over a ring structure
571 * sendrecvbuf[0:read_ptr] are already provided by current node
572 * current node will recv sendrecvbuf[read_ptr:read_end] from prev link
573 * current node will send sendrecvbuf[write_ptr:write_end] to next link
574 * write_ptr will wait till the data is readed before sending the data
575 * this function requires read_end >= write_end
576 *
577 * \param sendrecvbuf_ the place to hold the incoming and outgoing data
578 * \param read_ptr the initial read pointer
579 * \param read_end the ending position to read
580 * \param write_ptr the initial write pointer
581 * \param write_end the ending position to write
582 * \param read_link pointer to link to previous position in ring
583 * \param write_link pointer to link of next position in ring
584 */
585 ReturnType RingPassing(void *senrecvbuf_,
586 size_t read_ptr,
587 size_t read_end,
588 size_t write_ptr,
589 size_t write_end,
590 LinkRecord *read_link,
591 LinkRecord *write_link);
592 /*!
593 * \brief run message passing algorithm on the allreduce tree
594 * the result is edge message stored in p_edge_in and p_edge_out
595 * \param node_value the value associated with current node
596 * \param p_edge_in used to store input message from each of the edge
597 * \param p_edge_out used to store output message from each of the edge
598 * \param func a function that defines the message passing rule
599 * Parameters of func:
600 * - node_value same as node_value in the main function
601 * - edge_in the array of input messages from each edge,
602 * this includes the output edge, which should be excluded
603 * - out_index array the index of output edge, the function should
604 * exclude the output edge when compute the message passing value
605 * Return of func:
606 * the function returns the output message based on the input message and node_value
607 *
608 * \tparam EdgeType type of edge message, must be simple struct
609 * \tparam NodeType type of node value
610 */
611 template<typename NodeType, typename EdgeType>
612 inline ReturnType MsgPassing(const NodeType &node_value,
613 std::vector<EdgeType> *p_edge_in,
614 std::vector<EdgeType> *p_edge_out,
615 EdgeType(*func)
616 (const NodeType &node_value,
617 const std::vector<EdgeType> &edge_in,
618 size_t out_index));
619 //---- recovery data structure ----
620 // the round of result buffer, used to mode the result
621 int result_buffer_round;
622 // result buffer of all reduce
623 ResultBuffer resbuf;
624 // current cached allreduce/braodcast sequence number
625 int cur_cache_seq;
626 // result buffer of cached all reduce
627 ResultBuffer cachebuf;
628 // key of each cache entry
629 ResultBuffer lookupbuf;
630 // last check point global model
631 std::string global_checkpoint;
632 // lazy checkpoint of global model
633 const Serializable *global_lazycheck;
634 // number of replica for local state/model
635 int num_local_replica;
636 // number of default local replica
637 int default_local_replica;
638 // flag to decide whether local model is used, -1: unknown, 0: no, 1:yes
639 int use_local_model;
640 // number of replica for global state/model
641 int num_global_replica;
642 // number of times recovery happens
643 int recover_counter;
644 // --- recovery data structure for local checkpoint
645 // there is two version of the data structure,
646 // at one time one version is valid and another is used as temp memory
647 // pointer to memory position in the local model
648 // local model is stored in CSR format(like a sparse matrices)
649 // local_model[rptr[0]:rptr[1]] stores the model of current node
650 // local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops
651 std::vector<size_t> local_rptr[2];
652 // storage for local model replicas
653 std::string local_chkpt[2];
654 // version of local checkpoint can be 1 or 0
655 int local_chkpt_version;
656 // if checkpoint were loaded, used to distinguish results boostrap cache from seqno cache
657 bool checkpoint_loaded;
658 // sidecar executing timeout task
659 std::future<bool> rabit_timeout_task;
660 // flag to shutdown rabit_timeout_task before timeout
661 std::atomic<bool> shutdown_timeout{false};
662 // error handler
663 void (* _error)(const char *fmt, ...) = utils::Error;
664 // assert handler
665 void (* _assert)(bool exp, const char *fmt, ...) = utils::Assert;
666 };
667 } // namespace engine
668 } // namespace rabit
669 // implementation of inline template function
670 #include "./allreduce_robust-inl.h"
671 #endif // RABIT_ALLREDUCE_ROBUST_H_
0 // Copyright by Contributors
1 // implementations in ctypes
2 #define _CRT_SECURE_NO_WARNINGS
3 #define _CRT_SECURE_NO_DEPRECATE
4
5 #include <cstring>
6 #include <string>
7 #include "rabit/rabit.h"
8 #include "rabit/c_api.h"
9
10 namespace rabit {
11 namespace c_api {
12 // helper use to avoid BitOR operator
13 template<typename OP, typename DType>
14 struct FHelper {
15 static void
16 Allreduce(DType *senrecvbuf_,
17 size_t count,
18 void (*prepare_fun)(void *arg),
19 void *prepare_arg) {
20 rabit::Allreduce<OP>(senrecvbuf_, count,
21 prepare_fun, prepare_arg);
22 }
23 };
24
25 template<typename DType>
26 struct FHelper<op::BitOR, DType> {
27 static void
28 Allreduce(DType *senrecvbuf_,
29 size_t count,
30 void (*prepare_fun)(void *arg),
31 void *prepare_arg) {
32 utils::Error("DataType does not support bitwise or operation");
33 }
34 };
35
36 template<typename OP>
37 void Allreduce_(void *sendrecvbuf_,
38 size_t count,
39 engine::mpi::DataType enum_dtype,
40 void (*prepare_fun)(void *arg),
41 void *prepare_arg) {
42 using namespace engine::mpi;
43 switch (enum_dtype) {
44 case kChar:
45 rabit::Allreduce<OP>
46 (static_cast<char*>(sendrecvbuf_),
47 count, prepare_fun, prepare_arg);
48 return;
49 case kUChar:
50 rabit::Allreduce<OP>
51 (static_cast<unsigned char*>(sendrecvbuf_),
52 count, prepare_fun, prepare_arg);
53 return;
54 case kInt:
55 rabit::Allreduce<OP>
56 (static_cast<int*>(sendrecvbuf_),
57 count, prepare_fun, prepare_arg);
58 return;
59 case kUInt:
60 rabit::Allreduce<OP>
61 (static_cast<unsigned*>(sendrecvbuf_),
62 count, prepare_fun, prepare_arg);
63 return;
64 case kLong:
65 rabit::Allreduce<OP>
66 (static_cast<long*>(sendrecvbuf_), // NOLINT(*)
67 count, prepare_fun, prepare_arg);
68 return;
69 case kULong:
70 rabit::Allreduce<OP>
71 (static_cast<unsigned long*>(sendrecvbuf_), // NOLINT(*)
72 count, prepare_fun, prepare_arg);
73 return;
74 case kFloat:
75 FHelper<OP, float>::Allreduce
76 (static_cast<float*>(sendrecvbuf_),
77 count, prepare_fun, prepare_arg);
78 return;
79 case kDouble:
80 FHelper<OP, double>::Allreduce
81 (static_cast<double*>(sendrecvbuf_),
82 count, prepare_fun, prepare_arg);
83 return;
84 default: utils::Error("unknown data_type");
85 }
86 }
87 void Allreduce(void *sendrecvbuf,
88 size_t count,
89 engine::mpi::DataType enum_dtype,
90 engine::mpi::OpType enum_op,
91 void (*prepare_fun)(void *arg),
92 void *prepare_arg) {
93 using namespace engine::mpi;
94 switch (enum_op) {
95 case kMax:
96 Allreduce_<op::Max>
97 (sendrecvbuf,
98 count, enum_dtype,
99 prepare_fun, prepare_arg);
100 return;
101 case kMin:
102 Allreduce_<op::Min>
103 (sendrecvbuf,
104 count, enum_dtype,
105 prepare_fun, prepare_arg);
106 return;
107 case kSum:
108 Allreduce_<op::Sum>
109 (sendrecvbuf,
110 count, enum_dtype,
111 prepare_fun, prepare_arg);
112 return;
113 case kBitwiseOR:
114 Allreduce_<op::BitOR>
115 (sendrecvbuf,
116 count, enum_dtype,
117 prepare_fun, prepare_arg);
118 return;
119 default: utils::Error("unknown enum_op");
120 }
121 }
122 void Allgather(void *sendrecvbuf_,
123 size_t total_size,
124 size_t beginIndex,
125 size_t size_node_slice,
126 size_t size_prev_slice,
127 int enum_dtype) {
128 using namespace engine::mpi;
129 size_t type_size = 0;
130 switch (enum_dtype) {
131 case kChar:
132 type_size = sizeof(char);
133 rabit::Allgather(static_cast<char*>(sendrecvbuf_), total_size * type_size,
134 beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
135 size_prev_slice * type_size);
136 break;
137 case kUChar:
138 type_size = sizeof(unsigned char);
139 rabit::Allgather(static_cast<unsigned char*>(sendrecvbuf_), total_size * type_size,
140 beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
141 size_prev_slice * type_size);
142 break;
143 case kInt:
144 type_size = sizeof(int);
145 rabit::Allgather(static_cast<int*>(sendrecvbuf_), total_size * type_size,
146 beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
147 size_prev_slice * type_size);
148 break;
149 case kUInt:
150 type_size = sizeof(unsigned);
151 rabit::Allgather(static_cast<unsigned*>(sendrecvbuf_), total_size * type_size,
152 beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
153 size_prev_slice * type_size);
154 break;
155 case kLong:
156 type_size = sizeof(int64_t);
157 rabit::Allgather(static_cast<int64_t*>(sendrecvbuf_), total_size * type_size,
158 beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
159 size_prev_slice * type_size);
160 break;
161 case kULong:
162 type_size = sizeof(uint64_t);
163 rabit::Allgather(static_cast<uint64_t*>(sendrecvbuf_), total_size * type_size,
164 beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
165 size_prev_slice * type_size);
166 break;
167 case kFloat:
168 type_size = sizeof(float);
169 rabit::Allgather(static_cast<float*>(sendrecvbuf_), total_size * type_size,
170 beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
171 size_prev_slice * type_size);
172 break;
173 case kDouble:
174 type_size = sizeof(double);
175 rabit::Allgather(static_cast<double*>(sendrecvbuf_), total_size * type_size,
176 beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
177 size_prev_slice * type_size);
178 break;
179 default: utils::Error("unknown data_type");
180 }
181 }
182
183 // wrapper for serialization
184 struct ReadWrapper : public Serializable {
185 std::string *p_str;
186 explicit ReadWrapper(std::string *p_str)
187 : p_str(p_str) {}
188 virtual void Load(Stream *fi) {
189 uint64_t sz;
190 utils::Assert(fi->Read(&sz, sizeof(sz)) != 0,
191 "Read pickle string");
192 p_str->resize(sz);
193 if (sz != 0) {
194 utils::Assert(fi->Read(&(*p_str)[0], sizeof(char) * sz) != 0,
195 "Read pickle string");
196 }
197 }
198 virtual void Save(Stream *fo) const {
199 utils::Error("not implemented");
200 }
201 };
202
203 struct WriteWrapper : public Serializable {
204 const char *data;
205 size_t length;
206 explicit WriteWrapper(const char *data,
207 size_t length)
208 : data(data), length(length) {
209 }
210 virtual void Load(Stream *fi) {
211 utils::Error("not implemented");
212 }
213 virtual void Save(Stream *fo) const {
214 uint64_t sz = static_cast<uint16_t>(length);
215 fo->Write(&sz, sizeof(sz));
216 fo->Write(data, length * sizeof(char));
217 }
218 };
219 } // namespace c_api
220 } // namespace rabit
221
222 bool RabitInit(int argc, char *argv[]) {
223 return rabit::Init(argc, argv);
224 }
225
226 bool RabitFinalize() {
227 return rabit::Finalize();
228 }
229
230 int RabitGetRingPrevRank() {
231 return rabit::GetRingPrevRank();
232 }
233
234 int RabitGetRank() {
235 return rabit::GetRank();
236 }
237
238 int RabitGetWorldSize() {
239 return rabit::GetWorldSize();
240 }
241
242 int RabitIsDistributed() {
243 return rabit::IsDistributed();
244 }
245
246 void RabitTrackerPrint(const char *msg) {
247 std::string m(msg);
248 rabit::TrackerPrint(m);
249 }
250
251 void RabitGetProcessorName(char *out_name,
252 rbt_ulong *out_len,
253 rbt_ulong max_len) {
254 std::string s = rabit::GetProcessorName();
255 if (s.length() > max_len) {
256 s.resize(max_len - 1);
257 }
258 strcpy(out_name, s.c_str()); // NOLINT(*)
259 *out_len = static_cast<rbt_ulong>(s.length());
260 }
261
262 void RabitBroadcast(void *sendrecv_data,
263 rbt_ulong size, int root) {
264 rabit::Broadcast(sendrecv_data, size, root);
265 }
266
267 void RabitAllgather(void *sendrecvbuf_,
268 size_t total_size,
269 size_t beginIndex,
270 size_t size_node_slice,
271 size_t size_prev_slice,
272 int enum_dtype) {
273 rabit::c_api::Allgather(sendrecvbuf_,
274 total_size,
275 beginIndex,
276 size_node_slice,
277 size_prev_slice,
278 static_cast<rabit::engine::mpi::DataType>(enum_dtype));
279 }
280
281
282 void RabitAllreduce(void *sendrecvbuf,
283 size_t count,
284 int enum_dtype,
285 int enum_op,
286 void (*prepare_fun)(void *arg),
287 void *prepare_arg) {
288 rabit::c_api::Allreduce
289 (sendrecvbuf, count,
290 static_cast<rabit::engine::mpi::DataType>(enum_dtype),
291 static_cast<rabit::engine::mpi::OpType>(enum_op),
292 prepare_fun, prepare_arg);
293 }
294
295 int RabitLoadCheckPoint(char **out_global_model,
296 rbt_ulong *out_global_len,
297 char **out_local_model,
298 rbt_ulong *out_local_len) {
299 // NOTE: this function is not thread-safe
300 using rabit::BeginPtr;
301 using namespace rabit::c_api; // NOLINT(*)
302 static std::string global_buffer;
303 static std::string local_buffer;
304
305 ReadWrapper sg(&global_buffer);
306 ReadWrapper sl(&local_buffer);
307 int version;
308
309 if (out_local_model == NULL) {
310 version = rabit::LoadCheckPoint(&sg, NULL);
311 *out_global_model = BeginPtr(global_buffer);
312 *out_global_len = static_cast<rbt_ulong>(global_buffer.length());
313 } else {
314 version = rabit::LoadCheckPoint(&sg, &sl);
315 *out_global_model = BeginPtr(global_buffer);
316 *out_global_len = static_cast<rbt_ulong>(global_buffer.length());
317 *out_local_model = BeginPtr(local_buffer);
318 *out_local_len = static_cast<rbt_ulong>(local_buffer.length());
319 }
320 return version;
321 }
322
323 void RabitCheckPoint(const char *global_model,
324 rbt_ulong global_len,
325 const char *local_model,
326 rbt_ulong local_len) {
327 using namespace rabit::c_api; // NOLINT(*)
328 WriteWrapper sg(global_model, global_len);
329 WriteWrapper sl(local_model, local_len);
330 if (local_model == NULL) {
331 rabit::CheckPoint(&sg, NULL);
332 } else {
333 rabit::CheckPoint(&sg, &sl);
334 }
335 }
336
337 int RabitVersionNumber() {
338 return rabit::VersionNumber();
339 }
340
341 int RabitLinkTag() {
342 return 0;
343 }
0 /*!
1 * Copyright (c) 2014 by Contributors
2 * \file engine.cc
3 * \brief this file governs which implementation of engine we are actually using
4 * provides an singleton of engine interface
5 *
6 * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
7 */
8 #define _CRT_SECURE_NO_WARNINGS
9 #define _CRT_SECURE_NO_DEPRECATE
10 #define NOMINMAX
11
12 #include <memory>
13 #include "rabit/internal/engine.h"
14 #include "allreduce_base.h"
15 #include "allreduce_robust.h"
16 #include "rabit/internal/thread_local.h"
17
18 namespace rabit {
19 namespace engine {
20 // singleton sync manager
21 #ifndef RABIT_USE_BASE
22 #ifndef RABIT_USE_MOCK
23 typedef AllreduceRobust Manager;
24 #else
25 typedef AllreduceMock Manager;
26 #endif // RABIT_USE_MOCK
27 #else
28 typedef AllreduceBase Manager;
29 #endif // RABIT_USE_BASE
30
31 /*! \brief entry to to easily hold returning information */
32 struct ThreadLocalEntry {
33 /*! \brief stores the current engine */
34 std::unique_ptr<Manager> engine;
35 /*! \brief whether init has been called */
36 bool initialized;
37 /*! \brief constructor */
38 ThreadLocalEntry() : initialized(false) {}
39 };
40
41 // define the threadlocal store.
42 typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;
43
44 /*! \brief intiialize the synchronization module */
45 bool Init(int argc, char *argv[]) {
46 ThreadLocalEntry* e = EngineThreadLocal::Get();
47 if (e->engine.get() == nullptr) {
48 e->initialized = true;
49 e->engine.reset(new Manager());
50 return e->engine->Init(argc, argv);
51 } else {
52 return true;
53 }
54 }
55
56 /*! \brief finalize syncrhonization module */
57 bool Finalize() {
58 ThreadLocalEntry* e = EngineThreadLocal::Get();
59 if (e->engine.get() != nullptr) {
60 if (e->engine->Shutdown()) {
61 e->engine.reset(nullptr);
62 e->initialized = false;
63 return true;
64 } else {
65 return false;
66 }
67 } else {
68 return true;
69 }
70 }
71
72 /*! \brief singleton method to get engine */
73 IEngine *GetEngine() {
74 // un-initialized default manager.
75 static AllreduceBase default_manager;
76 ThreadLocalEntry* e = EngineThreadLocal::Get();
77 IEngine* ptr = e->engine.get();
78 if (ptr == nullptr) {
79 utils::Check(!e->initialized, "the rabit has not been initialized");
80 return &default_manager;
81 } else {
82 return ptr;
83 }
84 }
85
86 // perform in-place allgather, on sendrecvbuf
87 void Allgather(void *sendrecvbuf_, size_t total_size,
88 size_t slice_begin,
89 size_t slice_end,
90 size_t size_prev_slice,
91 const char* _file,
92 const int _line,
93 const char* _caller) {
94 GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin,
95 slice_end, size_prev_slice, _file, _line, _caller);
96 }
97
98
99 // perform in-place allreduce, on sendrecvbuf
100 void Allreduce_(void *sendrecvbuf,
101 size_t type_nbytes,
102 size_t count,
103 IEngine::ReduceFunction red,
104 mpi::DataType dtype,
105 mpi::OpType op,
106 IEngine::PreprocFunction prepare_fun,
107 void *prepare_arg,
108 const char* _file,
109 const int _line,
110 const char* _caller) {
111 GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun,
112 prepare_arg, _file, _line, _caller);
113 }
114
115 // code for reduce handle
116 ReduceHandle::ReduceHandle(void)
117 : handle_(NULL), redfunc_(NULL), htype_(NULL) {
118 }
119
120 ReduceHandle::~ReduceHandle(void) {}
121
122 int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
123 return static_cast<int>(dtype.type_size);
124 }
125
126 void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
127 utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice");
128 redfunc_ = redfunc;
129 }
130
131 void ReduceHandle::Allreduce(void *sendrecvbuf,
132 size_t type_nbytes, size_t count,
133 IEngine::PreprocFunction prepare_fun,
134 void *prepare_arg,
135 const char* _file,
136 const int _line,
137 const char* _caller) {
138 utils::Assert(redfunc_ != NULL, "must intialize handle to call AllReduce");
139 GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
140 redfunc_, prepare_fun, prepare_arg,
141 _file, _line, _caller);
142 }
143 } // namespace engine
144 } // namespace rabit
0 /*!
1 * Copyright (c) 2014 by Contributors
2 * \file engine_mock.cc
3 * \brief this is an engine implementation that will
4 * insert failures in certain call point, to test if the engine is robust to failure
5 * \author Tianqi Chen
6 */
7 // define use MOCK, os we will use mock Manager
8 #define _CRT_SECURE_NO_WARNINGS
9 #define _CRT_SECURE_NO_DEPRECATE
10 #define NOMINMAX
11 // switch engine to AllreduceMock
12 #define RABIT_USE_BASE
13 #include "engine.cc"
14
0 /*!
1 * Copyright (c) 2014 by Contributors
2 * \file engine_empty.cc
3 * \brief this file provides a dummy implementation of engine that does nothing
4 * this file provides a way to fall back to single node program without causing too many dependencies
5 * This is usually NOT needed, use engine_mpi or engine for real distributed version
6 * \author Tianqi Chen
7 */
8 #define _CRT_SECURE_NO_WARNINGS
9 #define _CRT_SECURE_NO_DEPRECATE
10 #define NOMINMAX
11
12 #include "rabit/internal/engine.h"
13
14 namespace rabit {
15
16 namespace utils {
17 bool STOP_PROCESS_ON_ERROR = true;
18 }
19
20 namespace engine {
21 /*! \brief EmptyEngine */
22 class EmptyEngine : public IEngine {
23 public:
24 EmptyEngine(void) {
25 version_number = 0;
26 }
27 virtual void Allgather(void *sendrecvbuf_,
28 size_t total_size,
29 size_t slice_begin,
30 size_t slice_end,
31 size_t size_prev_slice,
32 const char* _file,
33 const int _line,
34 const char* _caller) {
35 utils::Error("EmptyEngine:: Allgather is not supported");
36 }
37 virtual int GetRingPrevRank(void) const {
38 utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
39 return -1;
40 }
41 virtual void Allreduce(void *sendrecvbuf_,
42 size_t type_nbytes,
43 size_t count,
44 ReduceFunction reducer,
45 PreprocFunction prepare_fun,
46 void *prepare_arg,
47 const char* _file,
48 const int _line,
49 const char* _caller) {
50 utils::Error("EmptyEngine:: Allreduce is not supported,"\
51 "use Allreduce_ instead");
52 }
53 virtual void Broadcast(void *sendrecvbuf_, size_t size, int root,
54 const char* _file, const int _line, const char* _caller) {
55 }
56 virtual void InitAfterException(void) {
57 utils::Error("EmptyEngine is not fault tolerant");
58 }
59 virtual int LoadCheckPoint(Serializable *global_model,
60 Serializable *local_model = NULL) {
61 return 0;
62 }
63 virtual void CheckPoint(const Serializable *global_model,
64 const Serializable *local_model = NULL) {
65 version_number += 1;
66 }
67 virtual void LazyCheckPoint(const Serializable *global_model) {
68 version_number += 1;
69 }
70 virtual int VersionNumber(void) const {
71 return version_number;
72 }
73 /*! \brief get rank of current node */
74 virtual int GetRank(void) const {
75 return 0;
76 }
77 /*! \brief get total number of */
78 virtual int GetWorldSize(void) const {
79 return 1;
80 }
81 /*! \brief whether it is distributed */
82 virtual bool IsDistributed(void) const {
83 return false;
84 }
85 /*! \brief get the host name of current node */
86 virtual std::string GetHost(void) const {
87 return std::string("");
88 }
89 virtual void TrackerPrint(const std::string &msg) {
90 // simply print information into the tracker
91 utils::Printf("%s", msg.c_str());
92 }
93
94 private:
95 int version_number;
96 };
97
98 // singleton sync manager
99 EmptyEngine manager;
100
101 /*! \brief intiialize the synchronization module */
102 bool Init(int argc, char *argv[]) {
103 return true;
104 }
105 /*! \brief finalize syncrhonization module */
106 bool Finalize(void) {
107 return true;
108 }
109
110 /*! \brief singleton method to get engine */
111 IEngine *GetEngine(void) {
112 return &manager;
113 }
114 // perform in-place allreduce, on sendrecvbuf
115 void Allreduce_(void *sendrecvbuf,
116 size_t type_nbytes,
117 size_t count,
118 IEngine::ReduceFunction red,
119 mpi::DataType dtype,
120 mpi::OpType op,
121 IEngine::PreprocFunction prepare_fun,
122 void *prepare_arg,
123 const char* _file,
124 const int _line,
125 const char* _caller) {
126 if (prepare_fun != NULL) prepare_fun(prepare_arg);
127 }
128
129 // code for reduce handle
130 ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
131 }
132 ReduceHandle::~ReduceHandle(void) {}
133
134 int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
135 return 0;
136 }
137 void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {}
138 void ReduceHandle::Allreduce(void *sendrecvbuf,
139 size_t type_nbytes, size_t count,
140 IEngine::PreprocFunction prepare_fun,
141 void *prepare_arg,
142 const char* _file,
143 const int _line,
144 const char* _caller) {
145 if (prepare_fun != NULL) prepare_fun(prepare_arg);
146 }
147 } // namespace engine
148 } // namespace rabit
0 /*!
1 * Copyright (c) 2014 by Contributors
2 * \file engine_mock.cc
3 * \brief this is an engine implementation that will
4 * insert failures in certain call point, to test if the engine is robust to failure
5 * \author Tianqi Chen
6 */
7 // define use MOCK, os we will use mock Manager
8 #define _CRT_SECURE_NO_WARNINGS
9 #define _CRT_SECURE_NO_DEPRECATE
10 #define NOMINMAX
11 // switch engine to AllreduceMock
12 #define RABIT_USE_MOCK
13 #include "allreduce_mock.h"
14 #include "engine.cc"
15
0 /*!
1 * Copyright (c) 2014 by Contributors
2 * \file engine_mpi.cc
3 * \brief this file gives an implementation of engine interface using MPI,
4 * this will allow rabit program to run with MPI, but do not comes with fault tolerant
5 *
6 * \author Tianqi Chen
7 */
8 #define _CRT_SECURE_NO_WARNINGS
9 #define _CRT_SECURE_NO_DEPRECATE
10 #define NOMINMAX
11 #include <mpi.h>
12 #include <cstdio>
13 #include "rabit/internal/engine.h"
14 #include "rabit/internal/utils.h"
15
16 namespace rabit {
17
18 namespace utils {
19 bool STOP_PROCESS_ON_ERROR = true;
20 }
21
22 namespace engine {
23 /*! \brief implementation of engine using MPI */
24 class MPIEngine : public IEngine {
25 public:
26 MPIEngine(void) {
27 version_number = 0;
28 }
29 virtual void Allgather(void *sendrecvbuf_,
30 size_t total_size,
31 size_t slice_begin,
32 size_t slice_end,
33 size_t size_prev_slice,
34 const char* _file,
35 const int _line,
36 const char* _caller) {
37 utils::Error("MPIEngine:: Allgather is not supported");
38 }
39 virtual void Allreduce(void *sendrecvbuf_,
40 size_t type_nbytes,
41 size_t count,
42 ReduceFunction reducer,
43 PreprocFunction prepare_fun,
44 void *prepare_arg,
45 const char* _file,
46 const int _line,
47 const char* _caller) {
48 utils::Error("MPIEngine:: Allreduce is not supported,"\
49 "use Allreduce_ instead");
50 }
51 virtual int GetRingPrevRank(void) const {
52 utils::Error("MPIEngine:: GetRingPrevRank is not supported");
53 }
54 virtual void Broadcast(void *sendrecvbuf_, size_t size, int root,
55 const char* _file, const int _line,
56 const char* _caller) {
57 MPI::COMM_WORLD.Bcast(sendrecvbuf_, size, MPI::CHAR, root);
58 }
59 virtual void InitAfterException(void) {
60 utils::Error("MPI is not fault tolerant");
61 }
62 virtual int LoadCheckPoint(Serializable *global_model,
63 Serializable *local_model = NULL) {
64 return 0;
65 }
66 virtual void CheckPoint(const Serializable *global_model,
67 const Serializable *local_model = NULL) {
68 version_number += 1;
69 }
70 virtual void LazyCheckPoint(const Serializable *global_model) {
71 version_number += 1;
72 }
73 virtual int VersionNumber(void) const {
74 return version_number;
75 }
76 /*! \brief get rank of current node */
77 virtual int GetRank(void) const {
78 return MPI::COMM_WORLD.Get_rank();
79 }
80 /*! \brief get total number of */
81 virtual int GetWorldSize(void) const {
82 return MPI::COMM_WORLD.Get_size();
83 }
84 /*! \brief whether it is distributed */
85 virtual bool IsDistributed(void) const {
86 return true;
87 }
88 /*! \brief get the host name of current node */
89 virtual std::string GetHost(void) const {
90 int len;
91 char name[MPI_MAX_PROCESSOR_NAME];
92 MPI::Get_processor_name(name, len);
93 name[len] = '\0';
94 return std::string(name);
95 }
96 virtual void TrackerPrint(const std::string &msg) {
97 // simply print information into the tracker
98 if (GetRank() == 0) {
99 utils::Printf("%s", msg.c_str());
100 }
101 }
102
103 private:
104 int version_number;
105 };
106
107 // singleton sync manager
108 MPIEngine manager;
109
110 /*! \brief initialize the synchronization module */
111 bool Init(int argc, char *argv[]) {
112 try {
113 MPI::Init(argc, argv);
114 return true;
115 } catch (const std::exception& e) {
116 fprintf(stderr, " failed in MPI Init %s\n", e.what());
117 return false;
118 }
119 }
120 /*! \brief finalize syncrhonization module */
121 bool Finalize(void) {
122 try {
123 MPI::Finalize();
124 return true;
125 } catch (const std::exception& e) {
126 fprintf(stderr, "failed in MPI shutdown %s\n", e.what());
127 return false;
128 }
129 }
130
131 /*! \brief singleton method to get engine */
132 IEngine *GetEngine(void) {
133 return &manager;
134 }
135 // transform enum to MPI data type
136 inline MPI::Datatype GetType(mpi::DataType dtype) {
137 using namespace mpi;
138 switch (dtype) {
139 case kChar: return MPI::CHAR;
140 case kUChar: return MPI::BYTE;
141 case kInt: return MPI::INT;
142 case kUInt: return MPI::UNSIGNED;
143 case kLong: return MPI::LONG;
144 case kULong: return MPI::UNSIGNED_LONG;
145 case kFloat: return MPI::FLOAT;
146 case kDouble: return MPI::DOUBLE;
147 case kLongLong: return MPI::LONG_LONG;
148 case kULongLong: return MPI::UNSIGNED_LONG_LONG;
149 }
150 utils::Error("unknown mpi::DataType");
151 return MPI::CHAR;
152 }
153 // transform enum to MPI OP
154 inline MPI::Op GetOp(mpi::OpType otype) {
155 using namespace mpi;
156 switch (otype) {
157 case kMax: return MPI::MAX;
158 case kMin: return MPI::MIN;
159 case kSum: return MPI::SUM;
160 case kBitwiseOR: return MPI::BOR;
161 }
162 utils::Error("unknown mpi::OpType");
163 return MPI::MAX;
164 }
165 // perform in-place allreduce, on sendrecvbuf
166 void Allreduce_(void *sendrecvbuf,
167 size_t type_nbytes,
168 size_t count,
169 IEngine::ReduceFunction red,
170 mpi::DataType dtype,
171 mpi::OpType op,
172 IEngine::PreprocFunction prepare_fun,
173 void *prepare_arg,
174 const char* _file,
175 const int _line,
176 const char* _caller) {
177 if (prepare_fun != NULL) prepare_fun(prepare_arg);
178 MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf,
179 count, GetType(dtype), GetOp(op));
180 }
181
182 // code for reduce handle
183 ReduceHandle::ReduceHandle(void)
184 : handle_(NULL), redfunc_(NULL), htype_(NULL) {
185 }
186 ReduceHandle::~ReduceHandle(void) {
187 if (handle_ != NULL) {
188 MPI::Op *op = reinterpret_cast<MPI::Op*>(handle_);
189 op->Free();
190 delete op;
191 }
192 if (htype_ != NULL) {
193 MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype_);
194 dtype->Free();
195 delete dtype;
196 }
197 }
198 int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
199 return dtype.Get_size();
200 }
201 void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
202 utils::Assert(handle_ == NULL, "cannot initialize reduce handle twice");
203 if (type_nbytes != 0) {
204 MPI::Datatype *dtype = new MPI::Datatype();
205 if (type_nbytes % 8 == 0) {
206 *dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long)); // NOLINT(*)
207 } else if (type_nbytes % 4 == 0) {
208 *dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int));
209 } else {
210 *dtype = MPI::CHAR.Create_contiguous(type_nbytes);
211 }
212 dtype->Commit();
213 created_type_nbytes_ = type_nbytes;
214 htype_ = dtype;
215 }
216 MPI::Op *op = new MPI::Op();
217 MPI::User_function *pf = redfunc;
218 op->Init(pf, true);
219 handle_ = op;
220 }
221 void ReduceHandle::Allreduce(void *sendrecvbuf,
222 size_t type_nbytes, size_t count,
223 IEngine::PreprocFunction prepare_fun,
224 void *prepare_arg,
225 const char* _file,
226 const int _line,
227 const char* _caller) {
228 utils::Assert(handle_ != NULL, "must intialize handle to call AllReduce");
229 MPI::Op *op = reinterpret_cast<MPI::Op*>(handle_);
230 MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype_);
231 if (created_type_nbytes_ != type_nbytes || dtype == NULL) {
232 if (dtype == NULL) {
233 dtype = new MPI::Datatype();
234 } else {
235 dtype->Free();
236 }
237 if (type_nbytes % 8 == 0) {
238 *dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long)); // NOLINT(*)
239 } else if (type_nbytes % 4 == 0) {
240 *dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int));
241 } else {
242 *dtype = MPI::CHAR.Create_contiguous(type_nbytes);
243 }
244 dtype->Commit();
245 created_type_nbytes_ = type_nbytes;
246 }
247 if (prepare_fun != NULL) prepare_fun(prepare_arg);
248 MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, *dtype, *op);
249 }
250 } // namespace engine
251 } // namespace rabit
0 *.mpi
1 *_test
2 *_recover
0 RABIT_BUILD_DMLC = 0
1
2 ifeq ($(RABIT_BUILD_DMLC),1)
3 DMLC=../dmlc-core
4 else
5 DMLC=../../dmlc-core
6 endif
7
8 MPICXX=../mpich/bin/mpicxx
9 export LDFLAGS= -L../lib -pthread -lm
10 export CFLAGS = -Wall -O3 -Wno-unknown-pragmas
11
12 export CC = gcc
13 export CXX = g++
14
15
16 #----------------------------
17 # Settings for power and arm arch
18 #----------------------------
19 ARCH := $(shell uname -a)
20 ifneq (,$(filter $(ARCH), armv6l armv7l powerpc64le ppc64le aarch64))
21 CFLAGS += -march=native
22 else
23 CFLAGS += -msse2
24 endif
25
26 ifndef WITH_FPIC
27 WITH_FPIC = 1
28 endif
29 ifeq ($(WITH_FPIC), 1)
30 CFLAGS += -fPIC
31 endif
32
33 CFLAGS += -I../include -I $(DMLC)/include -std=c++11
34
35 # specify tensor path
36 BIN = speed_test model_recover local_recover lazy_recover
37 OBJ = $(RABIT_OBJ) speed_test.o model_recover.o local_recover.o lazy_recover.o
38 MPIBIN = speed_test.mpi
39 .PHONY: clean all lib mpi
40
41 .PHONY: lib all
42
43 all: $(BIN)
44
45 lib:
46 cd ..;make clean;make;cd -
47
48 .PHONY: mpi
49 mpi:
50 cd ..;make mpi;cd -
51
52 # programs
53 speed_test.o: speed_test.cc ../include/rabit/*.h lib mpi
54 model_recover.o: model_recover.cc ../include/rabit/*.h lib
55 local_recover.o: local_recover.cc ../include/rabit/*.h lib
56 lazy_recover.o: lazy_recover.cc ../include/rabit/*.h lib
57
58 # we can link against MPI version to get use MPI
59 speed_test: speed_test.o $(RABIT_OBJ)
60 speed_test.mpi: speed_test.o $(MPIOBJ)
61 model_recover: model_recover.o $(RABIT_OBJ)
62 local_recover: local_recover.o $(RABIT_OBJ)
63 lazy_recover: lazy_recover.o $(RABIT_OBJ)
64
65 $(BIN) :
66 $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) ../lib/librabit_mock.a $(LDFLAGS)
67
68 $(OBJ) :
69 $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
70
71 $(MPIBIN) :
72 $(MPICXX) $(CFLAGS) -I../mpich/include -shared -o $@ $(filter %.cpp %.o %.c %.cc, $^) \
73 ../lib/librabit_mpi.so $(LDFLAGS) -L../mpich/lib -Wl,-rpath,../mpich/lib -lmpi
74
75 clean:
76 $(RM) $(OBJ) $(BIN) $(MPIBIN) $(MPIOBJ) *~ ../src/*~
0 Testcases of Rabit
1 ====
2 This folder contains internal testcases to test correctness and efficiency of rabit API
3
4 The example running scripts for testcases are given by test.mk
5 * type ```make -f test.mk testcasename``` to run certain testcase
6
7
8 Helper Scripts
9 ====
10 * test.mk contains Makefile documentation of all testcases
11 * keepalive.sh helper bash to restart a program when it dies abnormally
12
13 List of Programs
14 ====
15 * speed_test: test the running speed of rabit API
16 * test_local_recover: test recovery of local state when error happens
17 * test_model_recover: test recovery of global state when error happens
0 find_package(GTest REQUIRED)
1
2 add_executable(
3 unit_tests
4 test_io.cc
5 allreduce_robust_test.cc
6 allreduce_base_test.cc
7 allreduce_mock_test.cc
8 test_main.cpp)
9
10 target_link_libraries(
11 unit_tests PRIVATE
12 GTest::GTest GTest::Main
13 rabit_base rabit_mock rabit)
14
15 target_include_directories(unit_tests PUBLIC
16 "$<BUILD_INTERFACE:${rabit_SOURCE_DIR}/include>"
17 "$<BUILD_INTERFACE:${DMLC_ROOT}/include>")
18
19 set_target_properties(unit_tests
20 PROPERTIES
21 CXX_STANDARD 11
22 CXX_STANDARD_REQUIRED ON
23 RUNTIME_OUTPUT_DIRECTORY ${rabit_BINARY_DIR}
24 RUNTIME_OUTPUT_DIRECTORY_DEBUG ${rabit_BINARY_DIR}
25 RUNTIME_OUTPUT_DIRECTORY_RELEASE ${rabit_BINARY_DIR})
26
27 add_test(
28 NAME TestRabitLib
29 COMMAND unit_tests
30 WORKING_DIRECTORY ${rabit_BINARY_DIR})
0 Unittests for Rabit
0 #define RABIT_CXXTESTDEFS_H
1 #include <gtest/gtest.h>
2
3 #include <string>
4 #include <iostream>
5 #include "../../src/allreduce_base.h"
6
7 TEST(allreduce_base, init_task)
8 {
9 rabit::engine::AllreduceBase base;
10
11 std::string rabit_task_id = "rabit_task_id=1";
12 char cmd[rabit_task_id.size()+1];
13 std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd);
14 cmd[rabit_task_id.size()] = '\0';
15
16 char* argv[] = {cmd};
17 base.Init(1, argv);
18 EXPECT_EQ(base.task_id, "1");
19 }
20
21 TEST(allreduce_base, init_with_cache_on)
22 {
23 rabit::engine::AllreduceBase base;
24
25 std::string rabit_task_id = "rabit_task_id=1";
26 char cmd[rabit_task_id.size()+1];
27 std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd);
28 cmd[rabit_task_id.size()] = '\0';
29
30 std::string rabit_bootstrap_cache = "rabit_bootstrap_cache=1";
31 char cmd2[rabit_bootstrap_cache.size()+1];
32 std::copy(rabit_bootstrap_cache.begin(), rabit_bootstrap_cache.end(), cmd2);
33 cmd2[rabit_bootstrap_cache.size()] = '\0';
34
35 std::string rabit_debug = "rabit_debug=1";
36 char cmd3[rabit_debug.size()+1];
37 std::copy(rabit_debug.begin(), rabit_debug.end(), cmd3);
38 cmd3[rabit_debug.size()] = '\0';
39
40 char* argv[] = {cmd, cmd2, cmd3};
41 base.Init(3, argv);
42 EXPECT_EQ(base.task_id, "1");
43 EXPECT_EQ(base.rabit_bootstrap_cache, 1);
44 EXPECT_EQ(base.rabit_debug, 1);
45 }
46
47 TEST(allreduce_base, init_with_ring_reduce)
48 {
49 rabit::engine::AllreduceBase base;
50
51 std::string rabit_task_id = "rabit_task_id=1";
52 char cmd[rabit_task_id.size()+1];
53 std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd);
54 cmd[rabit_task_id.size()] = '\0';
55
56 std::string rabit_reduce_ring_mincount = "rabit_reduce_ring_mincount=1";
57 char cmd2[rabit_reduce_ring_mincount.size()+1];
58 std::copy(rabit_reduce_ring_mincount.begin(), rabit_reduce_ring_mincount.end(), cmd2);
59 cmd2[rabit_reduce_ring_mincount.size()] = '\0';
60
61 char* argv[] = {cmd, cmd2};
62 base.Init(2, argv);
63 EXPECT_EQ(base.task_id, "1");
64 EXPECT_EQ(base.reduce_ring_mincount, 1);
65 }
0 #define RABIT_CXXTESTDEFS_H
1 #include <gtest/gtest.h>
2
3 #include <string>
4 #include <iostream>
5 #include "../../src/allreduce_base.h"
6
7 TEST(allreduce_base, init_task)
8 {
9 rabit::engine::AllreduceBase base;
10
11 std::string rabit_task_id = "rabit_task_id=1";
12 char cmd[rabit_task_id.size()+1];
13 std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd);
14 cmd[rabit_task_id.size()] = '\0';
15
16 char* argv[] = {cmd};
17 base.Init(1, argv);
18 EXPECT_EQ(base.task_id, "1");
19 }
20
21 TEST(allreduce_base, init_with_cache_on)
22 {
23 rabit::engine::AllreduceBase base;
24
25 std::string rabit_task_id = "rabit_task_id=1";
26 char cmd[rabit_task_id.size()+1];
27 std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd);
28 cmd[rabit_task_id.size()] = '\0';
29
30 std::string rabit_bootstrap_cache = "rabit_bootstrap_cache=1";
31 char cmd2[rabit_bootstrap_cache.size()+1];
32 std::copy(rabit_bootstrap_cache.begin(), rabit_bootstrap_cache.end(), cmd2);
33 cmd2[rabit_bootstrap_cache.size()] = '\0';
34
35 std::string rabit_debug = "rabit_debug=1";
36 char cmd3[rabit_debug.size()+1];
37 std::copy(rabit_debug.begin(), rabit_debug.end(), cmd3);
38 cmd3[rabit_debug.size()] = '\0';
39
40 char* argv[] = {cmd, cmd2, cmd3};
41 base.Init(3, argv);
42 EXPECT_EQ(base.task_id, "1");
43 EXPECT_EQ(base.rabit_bootstrap_cache, 1);
44 EXPECT_EQ(base.rabit_debug, 1);
45 }
46
47 TEST(allreduce_base, init_with_ring_reduce)
48 {
49 rabit::engine::AllreduceBase base;
50
51 std::string rabit_task_id = "rabit_task_id=1";
52 char cmd[rabit_task_id.size()+1];
53 std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd);
54 cmd[rabit_task_id.size()] = '\0';
55
56 std::string rabit_reduce_ring_mincount = "rabit_reduce_ring_mincount=1";
57 char cmd2[rabit_reduce_ring_mincount.size()+1];
58 std::copy(rabit_reduce_ring_mincount.begin(), rabit_reduce_ring_mincount.end(), cmd2);
59 cmd2[rabit_reduce_ring_mincount.size()] = '\0';
60
61 char* argv[] = {cmd, cmd2};
62 base.Init(2, argv);
63 EXPECT_EQ(base.task_id, "1");
64 EXPECT_EQ(base.reduce_ring_mincount, 1);
65 }
0 #define RABIT_CXXTESTDEFS_H
1 #include <gtest/gtest.h>
2
3 #include <string>
4 #include <iostream>
5 #include "../../src/allreduce_mock.h"
6
7 TEST(allreduce_mock, mock_allreduce)
8 {
9 rabit::engine::AllreduceMock m;
10
11 std::string mock_str = "mock=0,0,0,0";
12 char cmd[mock_str.size()+1];
13 std::copy(mock_str.begin(), mock_str.end(), cmd);
14 cmd[mock_str.size()] = '\0';
15
16 char* argv[] = {cmd};
17 m.Init(1, argv);
18 m.rank = 0;
19 EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), "");
20 }
21
22 TEST(allreduce_mock, mock_broadcast)
23 {
24 rabit::engine::AllreduceMock m;
25 std::string mock_str = "mock=0,1,2,0";
26 char cmd[mock_str.size()+1];
27 std::copy(mock_str.begin(), mock_str.end(), cmd);
28 cmd[mock_str.size()] = '\0';
29 char* argv[] = {cmd};
30 m.Init(1, argv);
31 m.rank = 0;
32 m.version_number=1;
33 m.seq_counter=2;
34 EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), "");
35 }
0 #define RABIT_CXXTESTDEFS_H
1 #include <gtest/gtest.h>
2
3 #include <string>
4 #include <iostream>
5 #include "../../src/allreduce_mock.h"
6
7 TEST(allreduce_mock, mock_allreduce)
8 {
9 rabit::engine::AllreduceMock m;
10
11 std::string mock_str = "mock=0,0,0,0";
12 char cmd[mock_str.size()+1];
13 std::copy(mock_str.begin(), mock_str.end(), cmd);
14 cmd[mock_str.size()] = '\0';
15
16 char* argv[] = {cmd};
17 m.Init(1, argv);
18 m.rank = 0;
19 EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), "");
20 }
21
22 TEST(allreduce_mock, mock_broadcast)
23 {
24 rabit::engine::AllreduceMock m;
25 std::string mock_str = "mock=0,1,2,0";
26 char cmd[mock_str.size()+1];
27 std::copy(mock_str.begin(), mock_str.end(), cmd);
28 cmd[mock_str.size()] = '\0';
29 char* argv[] = {cmd};
30 m.Init(1, argv);
31 m.rank = 0;
32 m.version_number=1;
33 m.seq_counter=2;
34 EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), "");
35 }
36
37 TEST(allreduce_mock, mock_gather)
38 {
39 rabit::engine::AllreduceMock m;
40 std::string mock_str = "mock=3,13,22,0";
41 char cmd[mock_str.size()+1];
42 std::copy(mock_str.begin(), mock_str.end(), cmd);
43 cmd[mock_str.size()] = '\0';
44 char* argv[] = {cmd};
45 m.Init(1, argv);
46 m.rank = 3;
47 m.version_number=13;
48 m.seq_counter=22;
49 EXPECT_EXIT(m.Allgather(nullptr,0,0,0,0), ::testing::ExitedWithCode(255), "");
50 }
0 #define RABIT_CXXTESTDEFS_H
1 #include <gtest/gtest.h>
2
3 #include <chrono>
4 #include <string>
5 #include <iostream>
6 #include "../../src/allreduce_robust.h"
7
8 inline void mockerr(const char *fmt, ...) {EXPECT_STRCASEEQ(fmt, "[%d] exit due to time out %d s\n");}
9 inline void mockassert(bool val, const char *fmt, ...) {}
10 rabit::engine::AllreduceRobust::ReturnType err_type(rabit::engine::AllreduceRobust::ReturnTypeEnum::kSockError);
11 rabit::engine::AllreduceRobust::ReturnType succ_type(rabit::engine::AllreduceRobust::ReturnTypeEnum::kSuccess);
12
13 TEST(allreduce_robust, sync_error_timeout)
14 {
15 rabit::engine::AllreduceRobust m;
16
17 std::string rabit_timeout = "rabit_timeout=1";
18 char cmd[rabit_timeout.size()+1];
19 std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
20 cmd[rabit_timeout.size()] = '\0';
21
22 std::string rabit_timeout_sec = "rabit_timeout_sec=1";
23 char cmd1[rabit_timeout_sec.size()+1];
24 std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
25 cmd1[rabit_timeout_sec.size()] = '\0';
26
27 char* argv[] = {cmd,cmd1};
28 m.Init(2, argv);
29 m.rank = 0;
30 m.rabit_bootstrap_cache = 1;
31 m._error = mockerr;
32 m._assert = mockassert;
33 EXPECT_EQ(m.CheckAndRecover(err_type), false);
34 std::this_thread::sleep_for(std::chrono::milliseconds(1500));
35 EXPECT_EQ(m.rabit_timeout_task.get(), false);
36 }
37
38 TEST(allreduce_robust, sync_error_reset)
39 {
40 rabit::engine::AllreduceRobust m;
41
42 std::string rabit_timeout = "rabit_timeout=1";
43 char cmd[rabit_timeout.size()+1];
44 std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
45 cmd[rabit_timeout.size()] = '\0';
46
47 std::string rabit_timeout_sec = "rabit_timeout_sec=1";
48 char cmd1[rabit_timeout_sec.size()+1];
49 std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
50 cmd1[rabit_timeout_sec.size()] = '\0';
51
52 std::string rabit_debug = "rabit_debug=1";
53 char cmd2[rabit_debug.size()+1];
54 std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
55 cmd2[rabit_debug.size()] = '\0';
56
57 char* argv[] = {cmd, cmd1,cmd2};
58 m.Init(3, argv);
59 m.rank = 0;
60 m._assert = mockassert;
61 EXPECT_EQ(m.CheckAndRecover(err_type), false);
62 std::this_thread::sleep_for(std::chrono::milliseconds(100));
63 EXPECT_EQ(m.CheckAndRecover(succ_type), true);
64 EXPECT_EQ(m.rabit_timeout_task.get(), true);
65 m.Shutdown();
66 }
67
68 TEST(allreduce_robust, sync_success_error_timeout)
69 {
70 rabit::engine::AllreduceRobust m;
71
72 std::string rabit_timeout = "rabit_timeout=1";
73 char cmd[rabit_timeout.size()+1];
74 std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
75 cmd[rabit_timeout.size()] = '\0';
76
77 std::string rabit_timeout_sec = "rabit_timeout_sec=1";
78 char cmd1[rabit_timeout_sec.size()+1];
79 std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
80 cmd1[rabit_timeout_sec.size()] = '\0';
81
82 std::string rabit_debug = "rabit_debug=1";
83 char cmd2[rabit_debug.size()+1];
84 std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
85 cmd2[rabit_debug.size()] = '\0';
86
87 char* argv[] = {cmd, cmd1,cmd2};
88 m.Init(3, argv);
89 m.rank = 0;
90 m.rabit_bootstrap_cache = 1;
91 m._assert = mockassert;
92 m._error = mockerr;
93 EXPECT_EQ(m.CheckAndRecover(succ_type), true);
94 std::this_thread::sleep_for(std::chrono::milliseconds(100));
95 EXPECT_EQ(m.CheckAndRecover(err_type), false);
96 std::this_thread::sleep_for(std::chrono::milliseconds(1500));
97 EXPECT_EQ(m.rabit_timeout_task.get(), false);
98 }
99
100 TEST(allreduce_robust, sync_success_error_success)
101 {
102 rabit::engine::AllreduceRobust m;
103
104 std::string rabit_timeout = "rabit_timeout=1";
105 char cmd[rabit_timeout.size()+1];
106 std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
107 cmd[rabit_timeout.size()] = '\0';
108
109 std::string rabit_timeout_sec = "rabit_timeout_sec=1";
110 char cmd1[rabit_timeout_sec.size()+1];
111 std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
112 cmd1[rabit_timeout_sec.size()] = '\0';
113
114 std::string rabit_debug = "rabit_debug=1";
115 char cmd2[rabit_debug.size()+1];
116 std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
117 cmd2[rabit_debug.size()] = '\0';
118
119 char* argv[] = {cmd, cmd1,cmd2};
120 m.Init(3, argv);
121 m.rank = 0;
122 m.rabit_bootstrap_cache = 1;
123 m._assert = mockassert;
124 EXPECT_EQ(m.CheckAndRecover(succ_type), true);
125 std::this_thread::sleep_for(std::chrono::milliseconds(10));
126
127 EXPECT_EQ(m.CheckAndRecover(err_type), false);
128 std::this_thread::sleep_for(std::chrono::milliseconds(10));
129 EXPECT_EQ(m.CheckAndRecover(succ_type), true);
130 std::this_thread::sleep_for(std::chrono::milliseconds(1100));
131 EXPECT_EQ(m.rabit_timeout_task.get(), true);
132 m.Shutdown();
133 }
134
135 TEST(allreduce_robust, sync_error_no_reset_timeout)
136 {
137 rabit::engine::AllreduceRobust m;
138
139 std::string rabit_timeout = "rabit_timeout=1";
140 char cmd[rabit_timeout.size()+1];
141 std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
142 cmd[rabit_timeout.size()] = '\0';
143
144 std::string rabit_timeout_sec = "rabit_timeout_sec=1";
145 char cmd1[rabit_timeout_sec.size()+1];
146 std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
147 cmd1[rabit_timeout_sec.size()] = '\0';
148
149 std::string rabit_debug = "rabit_debug=1";
150 char cmd2[rabit_debug.size()+1];
151 std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
152 cmd2[rabit_debug.size()] = '\0';
153
154 char* argv[] = {cmd, cmd1,cmd2};
155 m.Init(3, argv);
156 m.rank = 0;
157 m.rabit_bootstrap_cache = 1;
158 m._assert = mockassert;
159 m._error = mockerr;
160 auto start = std::chrono::system_clock::now();
161
162 EXPECT_EQ(m.CheckAndRecover(err_type), false);
163 std::this_thread::sleep_for(std::chrono::milliseconds(1100));
164
165 EXPECT_EQ(m.CheckAndRecover(err_type), false);
166
167 m.rabit_timeout_task.wait();
168 auto end = std::chrono::system_clock::now();
169 std::chrono::duration<double> diff = end-start;
170
171 EXPECT_EQ(m.rabit_timeout_task.get(), false);
172 // expect second error don't overwrite/reset timeout task
173 EXPECT_LT(diff.count(), 2);
174 }
175
176 TEST(allreduce_robust, no_timeout_shut_down)
177 {
178 rabit::engine::AllreduceRobust m;
179
180 std::string rabit_timeout = "rabit_timeout=1";
181 char cmd[rabit_timeout.size()+1];
182 std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
183 cmd[rabit_timeout.size()] = '\0';
184
185 std::string rabit_timeout_sec = "rabit_timeout_sec=1";
186 char cmd1[rabit_timeout_sec.size()+1];
187 std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
188 cmd1[rabit_timeout_sec.size()] = '\0';
189
190 std::string rabit_debug = "rabit_debug=1";
191 char cmd2[rabit_debug.size()+1];
192 std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
193 cmd2[rabit_debug.size()] = '\0';
194
195 char* argv[] = {cmd, cmd1,cmd2};
196 m.Init(3, argv);
197 m.rank = 0;
198
199 EXPECT_EQ(m.CheckAndRecover(succ_type), true);
200 std::this_thread::sleep_for(std::chrono::milliseconds(10));
201 m.Shutdown();
202 }
203
204 TEST(allreduce_robust, shut_down_before_timeout)
205 {
206 rabit::engine::AllreduceRobust m;
207
208 std::string rabit_timeout = "rabit_timeout=1";
209 char cmd[rabit_timeout.size()+1];
210 std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
211 cmd[rabit_timeout.size()] = '\0';
212
213 std::string rabit_timeout_sec = "rabit_timeout_sec=1";
214 char cmd1[rabit_timeout_sec.size()+1];
215 std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
216 cmd1[rabit_timeout_sec.size()] = '\0';
217
218 std::string rabit_debug = "rabit_debug=1";
219 char cmd2[rabit_debug.size()+1];
220 std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
221 cmd2[rabit_debug.size()] = '\0';
222
223 char* argv[] = {cmd, cmd1,cmd2};
224 m.Init(3, argv);
225 m.rank = 0;
226 rabit::engine::AllreduceRobust::LinkRecord a;
227 m.err_link = &a;
228
229 EXPECT_EQ(m.CheckAndRecover(err_type), false);
230 std::this_thread::sleep_for(std::chrono::milliseconds(10));
231 m.Shutdown();
232 }
0 /*!
1 * Copyright (c) 2019 by Contributors
2 */
3 #include <gtest/gtest.h>
4 #include <rabit/internal/io.h>
5
6 #include <vector>
7
8 namespace rabit {
9 TEST(MemoryFixSizeBuffer, Seek) {
10 size_t constexpr kSize { 64 };
11 std::vector<int32_t> memory( kSize );
12 utils::MemoryFixSizeBuffer buf(memory.data(), memory.size());
13 buf.Seek(utils::MemoryFixSizeBuffer::SeekEnd);
14 size_t end = buf.Tell();
15 ASSERT_EQ(end, kSize);
16 }
17 } // namespace rabit
0 #include "gtest/gtest.h"
1
2 int main(int argc, char** argv)
3 {
4 ::testing::InitGoogleTest(&argc, argv);
5 ::testing::FLAGS_gtest_death_test_style = "threadsafe";
6 return RUN_ALL_TESTS();
7 }
0 // this is a test case to test whether rabit can recover model when
1 // facing an exception
2 #include <rabit/rabit.h>
3 #include <cstdio>
4 #include <cstdlib>
5 #include <cmath>
6 using namespace rabit;
7
8 // dummy model
9 class Model : public rabit::Serializable {
10 public:
11 // iterations
12 std::vector<float> data;
13 // load from stream
14 virtual void Load(rabit::Stream *fi) {
15 fi->Read(&data);
16 }
17 /*! \brief save the model to the stream */
18 virtual void Save(rabit::Stream *fo) const {
19 fo->Write(data);
20 }
21 virtual void InitModel(size_t n) {
22 data.clear();
23 data.resize(n, 1.0f);
24 }
25 };
26
27 inline void TestMax(Model *model, int ntrial, int iter) {
28 int rank = rabit::GetRank();
29 int nproc = rabit::GetWorldSize();
30 const int z = iter + 111;
31
32 std::vector<float> ndata(model->data.size());
33 for (size_t i = 0; i < ndata.size(); ++i) {
34 ndata[i] = (i * (rank+1)) % z + model->data[i];
35 }
36 rabit::Allreduce<op::Max>(&ndata[0], ndata.size());
37
38 for (size_t i = 0; i < ndata.size(); ++i) {
39 float rmax = (i * 1) % z + model->data[i];
40 for (int r = 0; r < nproc; ++r) {
41 rmax = std::max(rmax, (float)((i * (r+1)) % z) + model->data[i]);
42 }
43 utils::Check(rmax == ndata[i], "[%d] TestMax check failurem i=%lu, rmax=%f, ndata=%f", rank, i, rmax, ndata[i]);
44 }
45 }
46
47 inline void TestSum(Model *model, int ntrial, int iter) {
48 int rank = rabit::GetRank();
49 int nproc = rabit::GetWorldSize();
50 const int z = 131 + iter;
51
52 std::vector<float> ndata(model->data.size());
53 for (size_t i = 0; i < ndata.size(); ++i) {
54 ndata[i] = (i * (rank+1)) % z + model->data[i];
55 }
56 Allreduce<op::Sum>(&ndata[0], ndata.size());
57
58 for (size_t i = 0; i < ndata.size(); ++i) {
59 float rsum = model->data[i] * nproc;
60 for (int r = 0; r < nproc; ++r) {
61 rsum += (float)((i * (r+1)) % z);
62 }
63 utils::Check(fabsf(rsum - ndata[i]) < 1e-5 ,
64 "[%d] TestSum check failure, local=%g, allreduce=%g", rank, rsum, ndata[i]);
65 }
66 model->data = ndata;
67 }
68
69 inline void TestBcast(size_t n, int root, int ntrial, int iter) {
70 int rank = rabit::GetRank();
71 std::string s; s.resize(n);
72 for (size_t i = 0; i < n; ++i) {
73 s[i] = char(i % 126 + 1);
74 }
75 std::string res;
76 if (root == rank) {
77 res = s;
78 rabit::Broadcast(&res, root);
79 } else {
80 rabit::Broadcast(&res, root);
81 }
82 utils::Check(res == s, "[%d] TestBcast fail", rank);
83 }
84
85 int main(int argc, char *argv[]) {
86 if (argc < 3) {
87 printf("Usage: <ndata> <config>\n");
88 return 0;
89 }
90 int n = atoi(argv[1]);
91 rabit::Init(argc, argv);
92 int rank = rabit::GetRank();
93 int nproc = rabit::GetWorldSize();
94 std::string name = rabit::GetProcessorName();
95 Model model;
96 srand(0);
97 int ntrial = 0;
98 for (int i = 1; i < argc; ++i) {
99 int n;
100 if (sscanf(argv[i], "rabit_num_trial=%d", &n) == 1) ntrial = n;
101 }
102 int iter = rabit::LoadCheckPoint(&model);
103 if (iter == 0) {
104 model.InitModel(n);
105 printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
106 } else {
107 printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
108 }
109 for (int r = iter; r < 3; ++r) {
110 TestMax(&model, ntrial, r);
111 printf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
112 int step = std::max(nproc / 3, 1);
113 for (int i = 0; i < nproc; i += step) {
114 TestBcast(n, i, ntrial, r);
115 }
116 printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
117 TestSum(&model, ntrial, r);
118 printf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
119 rabit::LazyCheckPoint(&model);
120 printf("[%d] !!!CheckPoint pass, iter=%d\n", rank, r);
121 }
122 rabit::Finalize();
123 return 0;
124 }
0 // this is a test case to test whether rabit can recover model when
1 // facing an exception
2 #include <rabit/rabit.h>
3 #include <cstdio>
4 #include <cstdlib>
5 #include <cmath>
6
7 using namespace rabit;
8
9 // dummy model
10 class Model : public rabit::Serializable {
11 public:
12 // iterations
13 std::vector<float> data;
14 // load from stream
15 virtual void Load(rabit::Stream *fi) {
16 fi->Read(&data);
17 }
18 /*! \brief save the model to the stream */
19 virtual void Save(rabit::Stream *fo) const {
20 fo->Write(data);
21 }
22 virtual void InitModel(size_t n, float v) {
23 data.clear();
24 data.resize(n, v);
25 }
26 };
27
28 inline void TestMax(Model *model, Model *local, int ntrial, int iter) {
29 int rank = rabit::GetRank();
30 int nproc = rabit::GetWorldSize();
31 const int z = iter + 111;
32 std::vector<float> ndata(model->data.size());
33 rabit::Allreduce<op::Max>(&ndata[0], ndata.size(),
34 [&]() {
35 // use lambda expression to prepare the data
36 for (size_t i = 0; i < ndata.size(); ++i) {
37 ndata[i] = (i * (rank+1)) % z + local->data[i];
38 }
39 });
40
41 for (size_t i = 0; i < ndata.size(); ++i) {
42 float rmax = (i * 1) % z + model->data[i];
43 for (int r = 0; r < nproc; ++r) {
44 rmax = std::max(rmax, (float)((i * (r+1)) % z) + model->data[i] + r);
45 }
46 utils::Check(rmax == ndata[i], "[%d] TestMax check failure", rank);
47 }
48 model->data = ndata;
49 local->data = ndata;
50 for (size_t i = 0; i < ndata.size(); ++i) {
51 local->data[i] = ndata[i] + rank;
52 }
53 }
54
55 inline void TestSum(Model *model, Model *local, int ntrial, int iter) {
56 int rank = rabit::GetRank();
57 int nproc = rabit::GetWorldSize();
58 const int z = 131 + iter;
59
60 std::vector<float> ndata(model->data.size());
61 for (size_t i = 0; i < ndata.size(); ++i) {
62 ndata[i] = (i * (rank+1)) % z + local->data[i];
63 }
64 Allreduce<op::Sum>(&ndata[0], ndata.size());
65
66 for (size_t i = 0; i < ndata.size(); ++i) {
67 float rsum = 0.0f;
68 for (int r = 0; r < nproc; ++r) {
69 rsum += (float)((i * (r+1)) % z) + model->data[i] + r;
70 }
71 utils::Check(fabsf(rsum - ndata[i]) < 1e-5 ,
72 "[%d] TestSum check failure, local=%g, allreduce=%g", rank, rsum, ndata[i]);
73 }
74 model->data = ndata;
75 for (size_t i = 0; i < ndata.size(); ++i) {
76 local->data[i] = ndata[i] + rank;
77 }
78 }
79
80 inline void TestBcast(size_t n, int root, int ntrial, int iter) {
81 int rank = rabit::GetRank();
82 std::string s; s.resize(n);
83 for (size_t i = 0; i < n; ++i) {
84 s[i] = char(i % 126 + 1);
85 }
86 std::string res;
87 if (root == rank) {
88 res = s;
89 rabit::Broadcast(&res, root);
90 } else {
91 rabit::Broadcast(&res, root);
92 }
93 utils::Check(res == s, "[%d] TestBcast fail", rank);
94 }
95
96 int main(int argc, char *argv[]) {
97 if (argc < 3) {
98 printf("Usage: <ndata>\n");
99 return 0;
100 }
101 int n = atoi(argv[1]);
102 rabit::Init(argc, argv);
103 int rank = rabit::GetRank();
104 int nproc = rabit::GetWorldSize();
105 std::string name = rabit::GetProcessorName();
106 Model model, local;
107 srand(0);
108 int ntrial = 0;
109 for (int i = 1; i < argc; ++i) {
110 int n;
111 if (sscanf(argv[i], "repeat=%d", &n) == 1) ntrial = n;
112 }
113 int iter = rabit::LoadCheckPoint(&model, &local);
114 if (iter == 0) {
115 model.InitModel(n, 1.0f);
116 local.InitModel(n, 1.0f + rank);
117 printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
118 } else {
119 printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
120 }
121 for (int r = iter; r < 3; ++r) {
122 TestMax(&model, &local, ntrial, r);
123 printf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
124 int step = std::max(nproc / 3, 1);
125 for (int i = 0; i < nproc; i += step) {
126 TestBcast(n, i, ntrial, r);
127 }
128 printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
129 TestSum(&model, &local, ntrial, r);
130 printf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
131 rabit::CheckPoint(&model, &local);
132 printf("[%d] !!!CheckPoint pass, iter=%d\n", rank, r);
133 }
134 rabit::Finalize();
135 return 0;
136 }
0 #!/usr/bin/env python3
1
2 from __future__ import print_function
3 from builtins import range
4
5 import sys
6 sys.path.append('../python')
7
8 import rabit
9 import numpy as np
10
11 rabit.init(lib='mock')
12 rank = rabit.get_rank()
13 n = 10
14 nround = 3
15 data = np.ones(n) * rank
16
17 version, model, local = rabit.load_checkpoint(True)
18 if version == 0:
19 model = np.zeros(n)
20 local = np.ones(n)
21 else:
22 print('[%d] restart from version %d' % (rank, version))
23
24 for i in range(version, nround):
25 res = rabit.allreduce(data + model+local, rabit.SUM)
26 print('[%d] iter=%d: %s' % (rank, i, str(res)))
27 model = res
28 local[:] = i
29 rabit.checkpoint(model, local)
30
31 rabit.finalize()
0 // this is a test case to test whether rabit can recover model when
1 // facing an exception
2 #include <rabit/rabit.h>
3 #include <cstdio>
4 #include <cstdlib>
5 #include <cmath>
6
7 using namespace rabit;
8
9 // dummy model
10 class Model : public rabit::Serializable {
11 public:
12 // iterations
13 std::vector<float> data;
14 // load from stream
15 virtual void Load(rabit::Stream *fi) {
16 fi->Read(&data);
17 }
18 /*! \brief save the model to the stream */
19 virtual void Save(rabit::Stream *fo) const {
20 fo->Write(data);
21 }
22 virtual void InitModel(size_t n) {
23 data.clear();
24 data.resize(n, 1.0f);
25 }
26 };
27
28 inline void TestMax(Model *model, int iter) {
29 int rank = rabit::GetRank();
30 int nproc = rabit::GetWorldSize();
31 const int z = iter + 111;
32
33 std::vector<float> ndata(model->data.size());
34 for (size_t i = 0; i < ndata.size(); ++i) {
35 ndata[i] = (i * (rank+1)) % z + model->data[i];
36 }
37 rabit::Allreduce<op::Max>(&ndata[0], ndata.size());
38
39 for (size_t i = 0; i < ndata.size(); ++i) {
40 float rmax = (i * 1) % z + model->data[i];
41 for (int r = 0; r < nproc; ++r) {
42 rmax = std::max(rmax, (float)((i * (r+1)) % z) + model->data[i]);
43 }
44 utils::Check(rmax == ndata[i], "[%d] TestMax check failurem i=%lu, rmax=%f, ndata=%f", rank, i, rmax, ndata[i]);
45 }
46 model->data = ndata;
47 }
48
49 inline void TestSum(Model *model, int iter) {
50 int rank = rabit::GetRank();
51 int nproc = rabit::GetWorldSize();
52 const int z = 131 + iter;
53
54 std::vector<float> ndata(model->data.size());
55 for (size_t i = 0; i < ndata.size(); ++i) {
56 ndata[i] = (i * (rank+1)) % z + model->data[i];
57 }
58 Allreduce<op::Sum>(&ndata[0], ndata.size());
59
60 for (size_t i = 0; i < ndata.size(); ++i) {
61 float rsum = model->data[i] * nproc;
62 for (int r = 0; r < nproc; ++r) {
63 rsum += (float)((i * (r+1)) % z);
64 }
65 utils::Check(fabsf(rsum - ndata[i]) < 1e-5 ,
66 "[%d] TestSum check failure, local=%g, allreduce=%g", rank, rsum, ndata[i]);
67 }
68 model->data = ndata;
69 }
70
71 inline void TestAllgather(Model *model, int iter) {
72 int rank = rabit::GetRank();
73 int nproc = rabit::GetWorldSize();
74 const int z = 131 + iter;
75
76 std::vector<float> ndata(model->data.size() * nproc);
77 size_t beginSlice = rank * model->data.size();
78 for (size_t i = 0; i < model->data.size(); ++i) {
79 ndata[beginSlice + i] = (i * (rank+1)) % z + model->data[i];
80 }
81 Allgather(&ndata[0], ndata.size(), beginSlice,
82 model->data.size(), model->data.size());
83
84 for (size_t i = 0; i < ndata.size(); ++i) {
85 int curRank = i / model->data.size();
86 int remainder = i % model->data.size();
87 float data = (remainder * (curRank+1)) % z + model->data[remainder];
88 utils::Check(fabsf(data - ndata[i]) < 1e-5 ,
89 "[%d] TestAllgather check failure, local=%g, allgatherring=%g", rank, data, ndata[i]);
90 }
91 model->data = ndata;
92 }
93
94 inline void TestBcast(size_t n, int root) {
95 int rank = rabit::GetRank();
96 std::string s; s.resize(n);
97 for (size_t i = 0; i < n; ++i) {
98 s[i] = char(i % 126 + 1);
99 }
100 std::string res;
101 if (root == rank) {
102 res = s;
103 }
104 rabit::Broadcast(&res, root);
105
106 utils::Check(res == s, "[%d] TestBcast fail", rank);
107 }
108
109 int main(int argc, char *argv[]) {
110 if (argc < 3) {
111 printf("Usage: <ndata> <config>\n");
112 return 0;
113 }
114 int n = atoi(argv[1]);
115 rabit::Init(argc, argv);
116 int rank = rabit::GetRank();
117 int nproc = rabit::GetWorldSize();
118 std::string name = rabit::GetProcessorName();
119
120 int max_rank = rank;
121 rabit::Allreduce<op::Max>(&max_rank, 1);
122 utils::Check(max_rank == nproc - 1, "max rank is world size-1");
123
124 Model model;
125 srand(0);
126 int ntrial = 0;
127 for (int i = 1; i < argc; ++i) {
128 int n;
129 if (sscanf(argv[i], "rabit_num_trial=%d", &n) == 1) ntrial = n;
130 }
131 int iter = rabit::LoadCheckPoint(&model);
132 if (iter == 0) {
133 model.InitModel(n);
134 }
135 printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
136
137 for (int r = iter; r < 3; ++r) {
138 TestMax(&model, r);
139 printf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
140 int step = std::max(nproc / 3, 1);
141 for (int i = 0; i < nproc; i += step) {
142 TestBcast(n, i);
143 }
144 printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
145
146 TestSum(&model, r);
147 printf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
148 TestAllgather(&model, r);
149 printf("[%d] !!!TestAllgather pass, iter=%d\n", rank, r);
150 rabit::CheckPoint(&model);
151 printf("[%d] !!!Checkpoint pass, iter=%d\n", rank, r);
152 }
153 rabit::Finalize();
154 return 0;
155 }
156
0 import os
1 import argparse
2 import sys
3
4 def main():
5 parser = argparse.ArgumentParser(description='TODO')
6 parser.add_argument('-ho', '--host_dir', required=True)
7 parser.add_argument('-s', '--submit_script', required=True)
8 parser.add_argument('-rex', '--rabit_exec', required=True)
9 parser.add_argument('-mpi', '--mpi_exec', required=True)
10 args = parser.parse_args()
11
12 ndata = [10**4, 10**5, 10**6, 10**7]
13 nrepeat = [10**4, 10**3, 10**2, 10]
14
15 machines = [2,4,8,16,31]
16
17 executables = [args.rabit_exec, args.mpi_exec]
18
19 for executable in executables:
20 sys.stderr.write('Executable %s' % executable)
21 sys.stderr.flush()
22 for i, data in enumerate(ndata):
23 for machine in machines:
24 host_file = os.path.join(args.host_dir, 'hosts%d' % machine)
25 cmd = 'python %s %d %s %s %d %d' % (args.submit_script, machine, host_file, executable, data, nrepeat[i])
26 sys.stderr.write('data=%d, repeat=%d, machine=%d\n' % (data, nrepeat[i], machine))
27 sys.stderr.flush()
28 os.system(cmd)
29 sys.stderr.write('\n')
30 sys.stderr.flush()
31
32 if __name__ == "__main__":
33 main()
0 // This program is used to test the speed of rabit API
1 #include <rabit/rabit.h>
2 #include <rabit/internal/timer.h>
3 #include <cstdio>
4 #include <cstdlib>
5 #include <cmath>
6 #include <time.h>
7
8 using namespace rabit;
9
10 double max_tdiff, sum_tdiff, bcast_tdiff, tot_tdiff;
11
12 inline void TestMax(size_t n) {
13 int rank = rabit::GetRank();
14 std::vector<float> ndata(n);
15 for (size_t i = 0; i < ndata.size(); ++i) {
16 ndata[i] = (i * (rank+1)) % 111;
17 }
18 double tstart = utils::GetTime();
19 rabit::Allreduce<op::Max>(&ndata[0], ndata.size());
20 max_tdiff += utils::GetTime() - tstart;
21 }
22
23 inline void TestSum(size_t n) {
24 int rank = rabit::GetRank();
25 const int z = 131;
26 std::vector<float> ndata(n);
27 for (size_t i = 0; i < ndata.size(); ++i) {
28 ndata[i] = (i * (rank+1)) % z;
29 }
30 double tstart = utils::GetTime();
31 rabit::Allreduce<op::Sum>(&ndata[0], ndata.size());
32 sum_tdiff += utils::GetTime() - tstart;
33 }
34
35 inline void TestBcast(size_t n, int root) {
36 int rank = rabit::GetRank();
37 std::string s; s.resize(n);
38 for (size_t i = 0; i < n; ++i) {
39 s[i] = char(i % 126 + 1);
40 }
41 std::string res;
42 res.resize(n);
43 if (root == rank) {
44 res = s;
45 }
46 double tstart = utils::GetTime();
47 rabit::Broadcast(&res[0], res.length(), root);
48 bcast_tdiff += utils::GetTime() - tstart;
49 }
50
51 inline void PrintStats(const char *name, double tdiff, int n, int nrep, size_t size) {
52 int nproc = rabit::GetWorldSize();
53 double tsum = tdiff;
54 rabit::Allreduce<op::Sum>(&tsum, 1);
55 double tavg = tsum / nproc;
56 double tsqr = tdiff - tavg;
57 tsqr *= tsqr;
58 rabit::Allreduce<op::Sum>(&tsqr, 1);
59 double tstd = sqrt(tsqr / nproc);
60 if (rabit::GetRank() == 0) {
61 rabit::TrackerPrintf("%s: mean=%g, std=%g sec\n", name, tavg, tstd);
62 double ndata = n;
63 ndata *= nrep * size;
64 if (n != 0) {
65 rabit::TrackerPrintf("%s-speed: %g MB/sec\n", name, (ndata / tavg) / 1024 / 1024 );
66 }
67 }
68 }
69
70 int main(int argc, char *argv[]) {
71 if (argc < 3) {
72 printf("Usage: <ndata> <nrepeat>\n");
73 return 0;
74 }
75 srand(0);
76 int n = atoi(argv[1]);
77 int nrep = atoi(argv[2]);
78 utils::Check(nrep >= 1, "need to at least repeat running once");
79 rabit::Init(argc, argv);
80 //int rank = rabit::GetRank();
81 int nproc = rabit::GetWorldSize();
82 std::string name = rabit::GetProcessorName();
83 max_tdiff = sum_tdiff = bcast_tdiff = 0;
84 double tstart = utils::GetTime();
85 for (int i = 0; i < nrep; ++i) {
86 TestMax(n);
87 TestSum(n);
88 TestBcast(n, rand() % nproc);
89 }
90 tot_tdiff = utils::GetTime() - tstart;
91 // use allreduce to get the sum and std of time
92 PrintStats("max_tdiff", max_tdiff, n, nrep, sizeof(float));
93 PrintStats("sum_tdiff", sum_tdiff, n, nrep, sizeof(float));
94 PrintStats("bcast_tdiff", bcast_tdiff, n, nrep, sizeof(char));
95 PrintStats("tot_tdiff", tot_tdiff, 0, nrep, sizeof(float));
96 rabit::Finalize();
97 return 0;
98 }
0 RABIT_BUILD_DMLC = 0
1
2 ifeq ($(RABIT_BUILD_DMLC),1)
3 DMLC=../dmlc-core
4 else
5 DMLC=../../dmlc-core
6 endif
7
8 # this is a makefile used to show testcases of rabit
9 .PHONY: all
10
11 all: model_recover_10_10k model_recover_10_10k_die_same model_recover_10_10k_die_hard local_recover_10_10k lazy_recover_10_10k_die_hard lazy_recover_10_10k_die_same ringallreduce_10_10k pylocal_recover_10_10k
12
13 # this experiment test recovery with actually process exit, use keepalive to keep program alive
14 model_recover_10_10k:
15 $(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 rabit_bootstrap_cache=true rabit_debug=true rabit_reduce_ring_mincount=1 rabit_timeout=true rabit_timeout_sec=5
16
17 model_recover_10_10k_die_same:
18 $(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 rabit_bootstrap_cache=1
19
20 model_recover_10_10k_die_hard:
21 $(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0 rabit_bootstrap_cache=1
22
23 local_recover_10_10k:
24 $(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 local_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=1,1,1,1
25
26 pylocal_recover_10_10k:
27 $(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 local_recover.py 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=1,1,1,1
28
29 lazy_recover_10_10k_die_hard:
30 $(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0
31
32 lazy_recover_10_10k_die_same:
33 $(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0
34
35 ringallreduce_10_10k:
36 $(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 rabit_reduce_ring_mincount=10