Codebase list rabit / 7c3bcbf5-ac8c-4921-8f2a-3d63b208b357/main python / rabit.py
7c3bcbf5-ac8c-4921-8f2a-3d63b208b357/main

Tree @7c3bcbf5-ac8c-4921-8f2a-3d63b208b357/main (Download .tar.gz)

rabit.py @7c3bcbf5-ac8c-4921-8f2a-3d63b208b357/mainraw · history · blame

"""
Reliable Allreduce and Broadcast Library.

Author: Tianqi Chen
"""
# pylint: disable=unused-argument,invalid-name,global-statement,dangerous-default-value,
import pickle
import ctypes
import os
import platform
import sys
import warnings
import numpy as np

# version information about the doc
__version__ = '1.0'

_LIB = None

def _find_lib_path(dll_name):
    """Find the rabit dynamic library files.

    Returns
    -------
    lib_path: list(string)
       List of all found library path to rabit
    """
    curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
    # make pythonpack hack: copy this directory one level upper for setup.py
    dll_path = [curr_path,
                os.path.join(curr_path, '../lib/'),
                os.path.join(curr_path, './lib/')]
    if os.name == 'nt':
        dll_path = [os.path.join(p, dll_name) for p in dll_path]
    else:
        dll_path = [os.path.join(p, dll_name) for p in dll_path]
    lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
    #From github issues, most of installation errors come from machines w/o compilers
    if len(lib_path) == 0 and not os.environ.get('XGBOOST_BUILD_DOC', False):
        raise RuntimeError(
            'Cannot find Rabit Libarary in the candicate path, ' +
            'did you install compilers and run build.sh in root path?\n'
            'List of candidates:\n' + ('\n'.join(dll_path)))
    return lib_path

# load in xgboost library
def _loadlib(lib='standard', lib_dll=None):
    """Load rabit library."""
    global _LIB
    if _LIB is not None:
        warnings.warn('rabit.int call was ignored because it has'\
                          ' already been initialized', level=2)
        return

    if lib_dll is not None:
        _LIB = lib_dll
        return

    if lib == 'standard':
        dll_name = 'librabit'
    else:
        dll_name = 'librabit_' + lib

    if os.name == 'nt':
        dll_name += '.dll'
    elif platform.system() == 'Darwin':
        dll_name += '.dylib'
    else:
        dll_name += '.so'

    _LIB = ctypes.cdll.LoadLibrary(_find_lib_path(dll_name)[0])
    _LIB.RabitGetRank.restype = ctypes.c_int
    _LIB.RabitGetWorldSize.restype = ctypes.c_int
    _LIB.RabitVersionNumber.restype = ctypes.c_int

def _unloadlib():
    """Unload rabit library."""
    global _LIB
    del _LIB
    _LIB = None

# reduction operators
MAX = 0
MIN = 1
SUM = 2
BITOR = 3

def init(args=None, lib='standard', lib_dll=None):
    """Intialize the rabit module, call this once before using anything.

    Parameters
    ----------
    args: list of str, optional
        The list of arguments used to initialized the rabit
        usually you need to pass in sys.argv.
        Defaults to sys.argv when it is None.
    lib: {'standard', 'mock', 'mpi'}, optional
        Type of library we want to load
        When cdll is specified
    lib_dll: ctypes.DLL, optional
        The DLL object used as lib.
        When this is presented argument lib will be ignored.
    """
    if args is None:
        args = []
    _loadlib(lib, lib_dll)
    arr = (ctypes.c_char_p * len(args))()

    arr[:] = args
    _LIB.RabitInit(len(args), arr)

def finalize():
    """Finalize the rabit engine.

    Call this function after you finished all jobs.
    """
    _LIB.RabitFinalize()
    _unloadlib()

def get_rank():
    """Get rank of current process.

    Returns
    -------
    rank : int
        Rank of current process.
    """
    ret = _LIB.RabitGetRank()
    return ret

def get_world_size():
    """Get total number workers.

    Returns
    -------
    n : int
        Total number of process.
    """
    ret = _LIB.RabitGetWorldSize()
    return ret

def tracker_print(msg):
    """Print message to the tracker.

    This function can be used to communicate the information of
    the progress to the tracker

    Parameters
    ----------
    msg : str
        The message to be printed to tracker.
    """
    if not isinstance(msg, str):
        msg = str(msg)
    _LIB.RabitTrackerPrint(ctypes.c_char_p(msg).encode('utf-8'))

def get_processor_name():
    """Get the processor name.

    Returns
    -------
    name : str
        the name of processor(host)
    """
    mxlen = 256
    length = ctypes.c_ulong()
    buf = ctypes.create_string_buffer(mxlen)
    _LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
    return buf.value

def broadcast(data, root):
    """Broadcast object from one node to all other nodes.

    Parameters
    ----------
    data : any type that can be pickled
        Input data, if current rank does not equal root, this can be None
    root : int
        Rank of the node to broadcast data from.

    Returns
    -------
    object : int
        the result of broadcast.
    """
    rank = get_rank()
    length = ctypes.c_ulong()
    if root == rank:
        assert data is not None, 'need to pass in data when broadcasting'
        s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
        length.value = len(s)
    # run first broadcast
    _LIB.RabitBroadcast(ctypes.byref(length),
                        ctypes.sizeof(ctypes.c_ulong), root)
    if root != rank:
        dptr = (ctypes.c_char * length.value)()
        # run second
        _LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
                            length.value, root)
        data = pickle.loads(dptr.raw)
        del dptr
    else:
        _LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
                            length.value, root)
        del s
    return data

# enumeration of dtypes
DTYPE_ENUM__ = {
    np.dtype('int8') : 0,
    np.dtype('uint8') : 1,
    np.dtype('int32') : 2,
    np.dtype('uint32') : 3,
    np.dtype('int64') : 4,
    np.dtype('uint64') : 5,
    np.dtype('float32') : 6,
    np.dtype('float64') : 7
}

def allreduce(data, op, prepare_fun=None):
    """Perform allreduce, return the result.

    Parameters
    ----------
    data: numpy array
        Input data.
    op: int
        Reduction operators, can be MIN, MAX, SUM, BITOR
    prepare_fun: function
        Lazy preprocessing function, if it is not None, prepare_fun(data)
        will be called by the function before performing allreduce, to intialize the data
        If the result of Allreduce can be recovered directly,
        then prepare_fun will NOT be called

    Returns
    -------
    result : array_like
        The result of allreduce, have same shape as data

    Notes
    -----
    This function is not thread-safe.
    """
    if not isinstance(data, np.ndarray):
        raise Exception('allreduce only takes in numpy.ndarray')
    buf = data.ravel()
    if buf.base is data.base:
        buf = buf.copy()
    if buf.dtype not in DTYPE_ENUM__:
        raise Exception('data type %s not supported' % str(buf.dtype))
    if prepare_fun is None:
        _LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
                            buf.size, DTYPE_ENUM__[buf.dtype],
                            op, None, None)
    else:
        func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
        def pfunc(args):
            """prepare function."""
            prepare_fun(data)
        _LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
                            buf.size, DTYPE_ENUM__[buf.dtype],
                            op, func_ptr(pfunc), None)
    return buf


def _load_model(ptr, length):
    """
    Internal function used by the module,
    unpickle a model from a buffer specified by ptr, length
    Arguments:
        ptr: ctypes.POINTER(ctypes._char)
            pointer to the memory region of buffer
        length: int
            the length of buffer
    """
    data = (ctypes.c_char * length).from_address(ctypes.addressof(ptr.contents))
    return pickle.loads(data.raw)

def load_checkpoint(with_local=False):
    """Load latest check point.

    Parameters
    ----------
    with_local: bool, optional
        whether the checkpoint contains local model

    Returns
    -------
    tuple : tuple
        if with_local: return (version, gobal_model, local_model)
        else return (version, gobal_model)
        if returned version == 0, this means no model has been CheckPointed
        and global_model, local_model returned will be None
    """
    gptr = ctypes.POINTER(ctypes.c_char)()
    global_len = ctypes.c_ulong()
    if with_local:
        lptr = ctypes.POINTER(ctypes.c_char)()
        local_len = ctypes.c_ulong()
        version = _LIB.RabitLoadCheckPoint(
            ctypes.byref(gptr),
            ctypes.byref(global_len),
            ctypes.byref(lptr),
            ctypes.byref(local_len))
        if version == 0:
            return (version, None, None)
        return (version,
                _load_model(gptr, global_len.value),
                _load_model(lptr, local_len.value))
    else:
        version = _LIB.RabitLoadCheckPoint(
            ctypes.byref(gptr),
            ctypes.byref(global_len),
            None, None)
        if version == 0:
            return (version, None)
        return (version,
                _load_model(gptr, global_len.value))

def checkpoint(global_model, local_model=None):
    """Checkpoint the model.

    This means we finished a stage of execution.
    Every time we call check point, there is a version number which will increase by one.

    Parameters
    ----------
    global_model: anytype that can be pickled
        globally shared model/state when calling this function,
        the caller need to gauranttees that global_model is the same in all nodes

    local_model: anytype that can be pickled
       Local model, that is specific to current node/rank.
       This can be None when no local state is needed.

    Notes
    -----
    local_model requires explicit replication of the model for fault-tolerance.
    This will bring replication cost in checkpoint function.
    while global_model do not need explicit replication.
    It is recommended to use global_model if possible.
    """
    sglobal = pickle.dumps(global_model)
    if local_model is None:
        _LIB.RabitCheckPoint(sglobal, len(sglobal), None, 0)
        del sglobal
    else:
        slocal = pickle.dumps(local_model)
        _LIB.RabitCheckPoint(sglobal, len(sglobal), slocal, len(slocal))
        del slocal
        del sglobal

def version_number():
    """Returns version number of current stored model.

    This means how many calls to CheckPoint we made so far.

    Returns
    -------
    version : int
        Version number of currently stored model
    """
    ret = _LIB.RabitVersionNumber()
    return ret