1
1

coll/base: Add MPI_Bcast based on a scatter followed by an allgather

Implements MPI_Bcast using a binomial tree scatter followed by
an recursive doubling allgather.

Signed-off-by: Mikhail Kurnosov <mkurnosov@gmail.com>
Этот коммит содержится в:
Mikhail Kurnosov 2018-06-20 20:29:15 +07:00 коммит произвёл Nathan Hjelm
родитель e305e80aff
Коммит c500739293
6 изменённых файлов: 199 добавлений и 10 удалений

Просмотреть файл

@ -713,3 +713,181 @@ int ompi_coll_base_bcast_intra_knomial(
return ompi_coll_base_bcast_intra_generic(buf, count, datatype, root, comm, module,
segcount, data->cached_kmtree);
}
/*
* ompi_coll_base_bcast_intra_scatter_allgather
*
* Function: Bcast using a binomial tree scatter followed by a recursive
* doubling allgather.
* Accepts: Same arguments as MPI_Bcast
* Returns: MPI_SUCCESS or error code
*
* Limitations: count >= comm_size
* Time complexity: O(\alpha\log(p) + \beta*m((p-1)/p))
* Binomial tree scatter: \alpha\log(p) + \beta*m((p-1)/p)
* Recursive doubling allgather: \alpha\log(p) + \beta*m((p-1)/p)
*
* Example, p=8, count=8, root=0
* Binomial tree scatter Recursive doubling allgather
* 0: --+ --+ --+ [0*******] <-+ [01******] <--+ [0123****] <--+
* 1: | 2| <-+ [*1******] <-+ [01******] <--|-+ [0123****] <--+-+
* 2: 4| <-+ --+ [**2*****] <-+ [**23****] <--+ | [0123****] <--+-+-+
* 3: | <-+ [***3****] <-+ [**23****] <----+ [0123****] <--+-+-+-+
* 4: <-+ --+ --+ [****4***] <-+ [****45**] <--+ [****4567] <--+ | | |
* 5: 2| <-+ [*****5**] <-+ [****45**] <--|-+ [****4567] <----+ | |
* 6: <-+ --+ [******6*] <-+ [******67] <--+ | [****4567] <------+ |
* 7: <-+ [*******7] <-+ [******67] <--|-+ [****4567] <--------+
*/
int ompi_coll_base_bcast_intra_scatter_allgather(
void *buf, int count, struct ompi_datatype_t *datatype, int root,
struct ompi_communicator_t *comm, mca_coll_base_module_t *module,
uint32_t segsize)
{
int err = MPI_SUCCESS;
ptrdiff_t lb, extent;
size_t datatype_size;
MPI_Status status;
ompi_datatype_get_extent(datatype, &lb, &extent);
ompi_datatype_type_size(datatype, &datatype_size);
int comm_size = ompi_comm_size(comm);
int rank = ompi_comm_rank(comm);
OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
"coll:base:bcast_intra_scatter_allgather: rank %d/%d",
rank, comm_size));
if (comm_size < 2 || datatype_size == 0)
return MPI_SUCCESS;
if (count < comm_size) {
OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
"coll:base:bcast_intra_scatter_allgather: rank %d/%d "
"count %d switching to basic linear bcast",
rank, comm_size, count));
return ompi_coll_base_bcast_intra_basic_linear(buf, count, datatype,
root, comm, module);
}
int vrank = (rank - root + comm_size) % comm_size;
int recv_count = 0, send_count = 0;
int scatter_count = (count + comm_size - 1) / comm_size; /* ceil(count / comm_size) */
int curr_count = (rank == root) ? count : 0;
/* Scatter by binomial tree: receive data from parent */
int mask = 0x1;
while (mask < comm_size) {
if (vrank & mask) {
int parent = (rank - mask + comm_size) % comm_size;
/* Compute an upper bound on recv block size */
recv_count = count - vrank * scatter_count;
if (recv_count <= 0) {
curr_count = 0;
} else {
/* Recv data from parent */
err = MCA_PML_CALL(recv((char *)buf + (ptrdiff_t)vrank * scatter_count * extent,
recv_count, datatype, parent,
MCA_COLL_BASE_TAG_BCAST, comm, &status));
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
/* Get received count */
curr_count = (int)(status._ucount / datatype_size);
}
break;
}
mask <<= 1;
}
/* Scatter by binomial tree: send data to child processes */
mask >>= 1;
while (mask > 0) {
if (vrank + mask < comm_size) {
send_count = curr_count - scatter_count * mask;
if (send_count > 0) {
int child = (rank + mask) % comm_size;
err = MCA_PML_CALL(send((char *)buf + (ptrdiff_t)scatter_count * (vrank + mask) * extent,
send_count, datatype, child,
MCA_COLL_BASE_TAG_BCAST,
MCA_PML_BASE_SEND_STANDARD, comm));
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
curr_count -= send_count;
}
}
mask >>= 1;
}
/*
* Allgather by recursive doubling
* Each process has the curr_count elems in the buf[vrank * scatter_count, ...]
*/
int rem_count = count - vrank * scatter_count;
curr_count = (scatter_count < rem_count) ? scatter_count : rem_count;
if (curr_count < 0)
curr_count = 0;
mask = 0x1;
while (mask < comm_size) {
int vremote = vrank ^ mask;
int remote = (vremote + root) % comm_size;
int vrank_tree_root = ompi_rounddown(vrank, mask);
int vremote_tree_root = ompi_rounddown(vremote, mask);
if (vremote < comm_size) {
ptrdiff_t send_offset = vrank_tree_root * scatter_count * extent;
ptrdiff_t recv_offset = vremote_tree_root * scatter_count * extent;
recv_count = count - vremote_tree_root * scatter_count;
if (recv_count < 0)
recv_count = 0;
err = ompi_coll_base_sendrecv((char *)buf + send_offset,
curr_count, datatype, remote,
MCA_COLL_BASE_TAG_BCAST,
(char *)buf + recv_offset,
recv_count, datatype, remote,
MCA_COLL_BASE_TAG_BCAST,
comm, &status, rank);
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
recv_count = (int)(status._ucount / datatype_size);
curr_count += recv_count;
}
/*
* Non-power-of-two case: if process did not have destination process
* to communicate with, we need to send him the current result.
* Recursive halving algorithm is used for search of process.
*/
if (vremote_tree_root + mask > comm_size) {
int nprocs_alldata = comm_size - vrank_tree_root - mask;
int offset = scatter_count * (vrank_tree_root + mask);
for (int rhalving_mask = mask >> 1; rhalving_mask > 0; rhalving_mask >>= 1) {
vremote = vrank ^ rhalving_mask;
remote = (vremote + root) % comm_size;
int tree_root = ompi_rounddown(vrank, rhalving_mask << 1);
/*
* Send only if:
* 1) current process has data: (vremote > vrank) && (vrank < tree_root + nprocs_alldata)
* 2) remote process does not have data at any step: vremote >= tree_root + nprocs_alldata
*/
if ((vremote > vrank) && (vrank < tree_root + nprocs_alldata)
&& (vremote >= tree_root + nprocs_alldata)) {
err = MCA_PML_CALL(send((char *)buf + (ptrdiff_t)offset * extent,
recv_count, datatype, remote,
MCA_COLL_BASE_TAG_BCAST,
MCA_PML_BASE_SEND_STANDARD, comm));
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
} else if ((vremote < vrank) && (vremote < tree_root + nprocs_alldata)
&& (vrank >= tree_root + nprocs_alldata)) {
err = MCA_PML_CALL(recv((char *)buf + (ptrdiff_t)offset * extent,
count - offset, datatype, remote,
MCA_COLL_BASE_TAG_BCAST,
comm, &status));
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
recv_count = (int)(status._ucount / datatype_size);
curr_count += recv_count;
}
}
}
mask <<= 1;
}
cleanup_and_return:
return err;
}

Просмотреть файл

@ -222,6 +222,7 @@ int ompi_coll_base_bcast_intra_binomial(BCAST_ARGS, uint32_t segsize);
int ompi_coll_base_bcast_intra_bintree(BCAST_ARGS, uint32_t segsize);
int ompi_coll_base_bcast_intra_split_bintree(BCAST_ARGS, uint32_t segsize);
int ompi_coll_base_bcast_intra_knomial(BCAST_ARGS, uint32_t segsize, int radix);
int ompi_coll_base_bcast_intra_scatter_allgather(BCAST_ARGS, uint32_t segsize);
/* Exscan */
int ompi_coll_base_exscan_intra_recursivedoubling(EXSCAN_ARGS);

Просмотреть файл

@ -108,15 +108,6 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
return err;
}
/*
* ompi_rounddown: Rounds a number down to nearest multiple.
* rounddown(10,4) = 8, rounddown(6,3) = 6, rounddown(14,3) = 12
*/
static int ompi_rounddown(int num, int factor)
{
num /= factor;
return num * factor; /* floor(num / factor) * factor */
}
/*
* ompi_coll_base_reduce_scatter_block_intra_recursivedoubling

Просмотреть файл

@ -93,3 +93,13 @@ unsigned int ompi_mirror_perm(unsigned int x, int nbits)
x = ((x >> 16) | (x << 16));
return x >> (sizeof(x) * CHAR_BIT - nbits);
}
/*
* ompi_rounddown: Rounds a number down to nearest multiple.
* rounddown(10,4) = 8, rounddown(6,3) = 6, rounddown(14,3) = 12
*/
int ompi_rounddown(int num, int factor)
{
num /= factor;
return num * factor; /* floor(num / factor) * factor */
}

Просмотреть файл

@ -78,5 +78,11 @@ ompi_coll_base_sendrecv( void* sendbuf, size_t scount, ompi_datatype_t* sdatatyp
*/
unsigned int ompi_mirror_perm(unsigned int x, int nbits);
/*
* ompi_rounddown: Rounds a number down to nearest multiple.
* rounddown(10,4) = 8, rounddown(6,3) = 6, rounddown(14,3) = 12
*/
int ompi_rounddown(int num, int factor);
END_C_DECLS
#endif /* MCA_COLL_BASE_UTIL_EXPORT_H */

Просмотреть файл

@ -43,6 +43,7 @@ static mca_base_var_enum_value_t bcast_algorithms[] = {
{5, "binary_tree"},
{6, "binomial"},
{7, "knomial"},
{8, "scatter_allgather"},
{0, NULL}
};
@ -78,7 +79,7 @@ int ompi_coll_tuned_bcast_intra_check_forced_init (coll_tuned_force_algorithm_mc
mca_param_indices->algorithm_param_index =
mca_base_component_var_register(&mca_coll_tuned_component.super.collm_version,
"bcast_algorithm",
"Which bcast algorithm is used. Can be locked down to choice of: 0 ignore, 1 basic linear, 2 chain, 3: pipeline, 4: split binary tree, 5: binary tree, 6: binomial tree, 7: knomial tree.",
"Which bcast algorithm is used. Can be locked down to choice of: 0 ignore, 1 basic linear, 2 chain, 3: pipeline, 4: split binary tree, 5: binary tree, 6: binomial tree, 7: knomial tree, 8: scatter_allgather.",
MCA_BASE_VAR_TYPE_INT, new_enum, 0, MCA_BASE_VAR_FLAG_SETTABLE,
OPAL_INFO_LVL_5,
MCA_BASE_VAR_SCOPE_ALL,
@ -157,6 +158,8 @@ int ompi_coll_tuned_bcast_intra_do_this(void *buf, int count,
case (7):
return ompi_coll_base_bcast_intra_knomial(buf, count, dtype, root, comm, module,
segsize, coll_tuned_bcast_knomial_radix);
case (8):
return ompi_coll_base_bcast_intra_scatter_allgather(buf, count, dtype, root, comm, module, segsize);
} /* switch */
OPAL_OUTPUT((ompi_coll_tuned_stream,"coll:tuned:bcast_intra_do_this attempt to select algorithm %d when only 0-%d is valid?",
algorithm, ompi_coll_tuned_forced_max_algorithms[BCAST]));