"""
Reliable Allreduce and Broadcast Library.
Author: Tianqi Chen
"""
# pylint: disable=unused-argument,invalid-name,global-statement,dangerous-default-value,
import pickle
import ctypes
import os
import platform
import sys
import warnings
import numpy as np
# version information about the doc
__version__ = '1.0'
_LIB = None
def _find_lib_path(dll_name):
"""Find the rabit dynamic library files.
Returns
-------
lib_path: list(string)
List of all found library path to rabit
"""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
# make pythonpack hack: copy this directory one level upper for setup.py
dll_path = [curr_path,
os.path.join(curr_path, '../lib/'),
os.path.join(curr_path, './lib/')]
if os.name == 'nt':
dll_path = [os.path.join(p, dll_name) for p in dll_path]
else:
dll_path = [os.path.join(p, dll_name) for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
#From github issues, most of installation errors come from machines w/o compilers
if len(lib_path) == 0 and not os.environ.get('XGBOOST_BUILD_DOC', False):
raise RuntimeError(
'Cannot find Rabit Libarary in the candicate path, ' +
'did you install compilers and run build.sh in root path?\n'
'List of candidates:\n' + ('\n'.join(dll_path)))
return lib_path
# load in xgboost library
def _loadlib(lib='standard', lib_dll=None):
"""Load rabit library."""
global _LIB
if _LIB is not None:
warnings.warn('rabit.int call was ignored because it has'\
' already been initialized', level=2)
return
if lib_dll is not None:
_LIB = lib_dll
return
if lib == 'standard':
dll_name = 'librabit'
else:
dll_name = 'librabit_' + lib
if os.name == 'nt':
dll_name += '.dll'
elif platform.system() == 'Darwin':
dll_name += '.dylib'
else:
dll_name += '.so'
_LIB = ctypes.cdll.LoadLibrary(_find_lib_path(dll_name)[0])
_LIB.RabitGetRank.restype = ctypes.c_int
_LIB.RabitGetWorldSize.restype = ctypes.c_int
_LIB.RabitVersionNumber.restype = ctypes.c_int
def _unloadlib():
"""Unload rabit library."""
global _LIB
del _LIB
_LIB = None
# reduction operators
MAX = 0
MIN = 1
SUM = 2
BITOR = 3
def init(args=None, lib='standard', lib_dll=None):
"""Intialize the rabit module, call this once before using anything.
Parameters
----------
args: list of str, optional
The list of arguments used to initialized the rabit
usually you need to pass in sys.argv.
Defaults to sys.argv when it is None.
lib: {'standard', 'mock', 'mpi'}, optional
Type of library we want to load
When cdll is specified
lib_dll: ctypes.DLL, optional
The DLL object used as lib.
When this is presented argument lib will be ignored.
"""
if args is None:
args = []
_loadlib(lib, lib_dll)
arr = (ctypes.c_char_p * len(args))()
arr[:] = args
_LIB.RabitInit(len(args), arr)
def finalize():
"""Finalize the rabit engine.
Call this function after you finished all jobs.
"""
_LIB.RabitFinalize()
_unloadlib()
def get_rank():
"""Get rank of current process.
Returns
-------
rank : int
Rank of current process.
"""
ret = _LIB.RabitGetRank()
return ret
def get_world_size():
"""Get total number workers.
Returns
-------
n : int
Total number of process.
"""
ret = _LIB.RabitGetWorldSize()
return ret
def tracker_print(msg):
"""Print message to the tracker.
This function can be used to communicate the information of
the progress to the tracker
Parameters
----------
msg : str
The message to be printed to tracker.
"""
if not isinstance(msg, str):
msg = str(msg)
_LIB.RabitTrackerPrint(ctypes.c_char_p(msg).encode('utf-8'))
def get_processor_name():
"""Get the processor name.
Returns
-------
name : str
the name of processor(host)
"""
mxlen = 256
length = ctypes.c_ulong()
buf = ctypes.create_string_buffer(mxlen)
_LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
return buf.value
def broadcast(data, root):
"""Broadcast object from one node to all other nodes.
Parameters
----------
data : any type that can be pickled
Input data, if current rank does not equal root, this can be None
root : int
Rank of the node to broadcast data from.
Returns
-------
object : int
the result of broadcast.
"""
rank = get_rank()
length = ctypes.c_ulong()
if root == rank:
assert data is not None, 'need to pass in data when broadcasting'
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
length.value = len(s)
# run first broadcast
_LIB.RabitBroadcast(ctypes.byref(length),
ctypes.sizeof(ctypes.c_ulong), root)
if root != rank:
dptr = (ctypes.c_char * length.value)()
# run second
_LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
length.value, root)
data = pickle.loads(dptr.raw)
del dptr
else:
_LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
length.value, root)
del s
return data
# enumeration of dtypes
DTYPE_ENUM__ = {
np.dtype('int8') : 0,
np.dtype('uint8') : 1,
np.dtype('int32') : 2,
np.dtype('uint32') : 3,
np.dtype('int64') : 4,
np.dtype('uint64') : 5,
np.dtype('float32') : 6,
np.dtype('float64') : 7
}
def allreduce(data, op, prepare_fun=None):
"""Perform allreduce, return the result.
Parameters
----------
data: numpy array
Input data.
op: int
Reduction operators, can be MIN, MAX, SUM, BITOR
prepare_fun: function
Lazy preprocessing function, if it is not None, prepare_fun(data)
will be called by the function before performing allreduce, to intialize the data
If the result of Allreduce can be recovered directly,
then prepare_fun will NOT be called
Returns
-------
result : array_like
The result of allreduce, have same shape as data
Notes
-----
This function is not thread-safe.
"""
if not isinstance(data, np.ndarray):
raise Exception('allreduce only takes in numpy.ndarray')
buf = data.ravel()
if buf.base is data.base:
buf = buf.copy()
if buf.dtype not in DTYPE_ENUM__:
raise Exception('data type %s not supported' % str(buf.dtype))
if prepare_fun is None:
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype],
op, None, None)
else:
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def pfunc(args):
"""prepare function."""
prepare_fun(data)
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype],
op, func_ptr(pfunc), None)
return buf
def _load_model(ptr, length):
"""
Internal function used by the module,
unpickle a model from a buffer specified by ptr, length
Arguments:
ptr: ctypes.POINTER(ctypes._char)
pointer to the memory region of buffer
length: int
the length of buffer
"""
data = (ctypes.c_char * length).from_address(ctypes.addressof(ptr.contents))
return pickle.loads(data.raw)
def load_checkpoint(with_local=False):
"""Load latest check point.
Parameters
----------
with_local: bool, optional
whether the checkpoint contains local model
Returns
-------
tuple : tuple
if with_local: return (version, gobal_model, local_model)
else return (version, gobal_model)
if returned version == 0, this means no model has been CheckPointed
and global_model, local_model returned will be None
"""
gptr = ctypes.POINTER(ctypes.c_char)()
global_len = ctypes.c_ulong()
if with_local:
lptr = ctypes.POINTER(ctypes.c_char)()
local_len = ctypes.c_ulong()
version = _LIB.RabitLoadCheckPoint(
ctypes.byref(gptr),
ctypes.byref(global_len),
ctypes.byref(lptr),
ctypes.byref(local_len))
if version == 0:
return (version, None, None)
return (version,
_load_model(gptr, global_len.value),
_load_model(lptr, local_len.value))
else:
version = _LIB.RabitLoadCheckPoint(
ctypes.byref(gptr),
ctypes.byref(global_len),
None, None)
if version == 0:
return (version, None)
return (version,
_load_model(gptr, global_len.value))
def checkpoint(global_model, local_model=None):
"""Checkpoint the model.
This means we finished a stage of execution.
Every time we call check point, there is a version number which will increase by one.
Parameters
----------
global_model: anytype that can be pickled
globally shared model/state when calling this function,
the caller need to gauranttees that global_model is the same in all nodes
local_model: anytype that can be pickled
Local model, that is specific to current node/rank.
This can be None when no local state is needed.
Notes
-----
local_model requires explicit replication of the model for fault-tolerance.
This will bring replication cost in checkpoint function.
while global_model do not need explicit replication.
It is recommended to use global_model if possible.
"""
sglobal = pickle.dumps(global_model)
if local_model is None:
_LIB.RabitCheckPoint(sglobal, len(sglobal), None, 0)
del sglobal
else:
slocal = pickle.dumps(local_model)
_LIB.RabitCheckPoint(sglobal, len(sglobal), slocal, len(slocal))
del slocal
del sglobal
def version_number():
"""Returns version number of current stored model.
This means how many calls to CheckPoint we made so far.
Returns
-------
version : int
Version number of currently stored model
"""
ret = _LIB.RabitVersionNumber()
return ret