1
1

automate the allreduce selection logic.

This commit was SVN r18484.
Этот коммит содержится в:
Rich Graham 2008-05-22 20:53:35 +00:00
родитель 8faeeab81a
Коммит f2a4b67809
5 изменённых файлов: 109 добавлений и 99 удалений

Просмотреть файл

@ -126,11 +126,16 @@ BEGIN_C_DECLS
int force_barrier;
/** MCA parameter: method to force a given reduce method to be used.
* 0 - FANIN_FAN_OUT_REDUCE_FN,
* 1 - REDUCE_SCATTER_GATHER_FN,
* 0 - FANIN_FAN_OUT_REDUCE_FN
* 1 - REDUCE_SCATTER_GATHER_FN
*/
int force_reduce;
/** MCA parameter: method to force a given allreduce method to be used.
* 0 - FANIN_FANOUT_ALLREDUCE_FN
* 1 - REDUCE_SCATTER_ALLGATHER_FN
*/
int force_allreduce;
};
@ -157,11 +162,23 @@ BEGIN_C_DECLS
N_REDUCE_FNS
};
enum{
SHORT_DATA_FN,
LONG_DATA_FN,
SHORT_DATA_FN_REDUCE,
LONG_DATA_FN_REDUCE,
N_REDUCE_FNS_USED
};
/* all-reduce */
enum{
FANIN_FANOUT_ALLREDUCE_FN,
REDUCE_SCATTER_ALLGATHER_FN,
N_ALLREDUCE_FNS
};
enum{
SHORT_DATA_FN_ALLREDUCE,
LONG_DATA_FN_ALLREDUCE,
N_ALLREDUCE_FNS_USED
};
/* enum for node type */
enum{
@ -463,6 +480,11 @@ BEGIN_C_DECLS
mca_coll_base_module_barrier_fn_t barrier_functions[N_BARRIER_FNS];
mca_coll_base_module_reduce_fn_t list_reduce_functions[N_REDUCE_FNS];
mca_coll_base_module_reduce_fn_t reduce_functions[N_REDUCE_FNS_USED];
mca_coll_base_module_allreduce_fn_t
list_allreduce_functions[N_ALLREDUCE_FNS];
mca_coll_base_module_allreduce_fn_t
allreduce_functions[N_ALLREDUCE_FNS_USED];
};
@ -584,6 +606,16 @@ BEGIN_C_DECLS
struct ompi_communicator_t *comm,
struct mca_coll_base_module_1_1_0_t *module);
int mca_coll_sm2_allreduce_intra_reducescatter_allgather(
void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype,
struct ompi_op_t *op, struct ompi_communicator_t *comm,
struct mca_coll_base_module_1_1_0_t *module);
int mca_coll_sm2_allreduce_intra_fanin_fanout(void *sbuf, void *rbuf,
int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op,
struct ompi_communicator_t *comm,
struct mca_coll_base_module_1_1_0_t *module);
/**
* Shared memory blocking reduce
*/

Просмотреть файл

@ -26,7 +26,6 @@ extern uint64_t timers[7];
/**
* Shared memory blocking allreduce.
*/
static
int mca_coll_sm2_allreduce_intra_fanin_fanout(void *sbuf, void *rbuf, int count,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
@ -909,8 +908,10 @@ int mca_coll_sm2_allreduce_intra_recursive_doubling(void *sbuf, void *rbuf,
}
/* apply collective operation to first half of the data */
ompi_op_reduce(op,(void *)extra_rank_write_data_pointer,
(void *)my_write_pointer, n_my_count,dtype);
if( 0 < n_my_count ) {
ompi_op_reduce(op,(void *)extra_rank_write_data_pointer,
(void *)my_write_pointer, n_my_count,dtype);
}
/* wait for my partner to finish reducing the data */
@ -923,16 +924,16 @@ int mca_coll_sm2_allreduce_intra_recursive_doubling(void *sbuf, void *rbuf,
/* read my partner's data */
/* adjust read an write pointers */
extra_rank_write_data_pointer+=
(count_this_stripe-n_my_count)*dt_extent;
extra_rank_write_data_pointer+=(n_my_count*dt_extent);
rc=ompi_ddt_copy_content_same_ddt(dtype,
count_this_stripe-n_my_count,
(char *)(my_write_pointer+
(count_this_stripe-n_my_count)*dt_extent),
(char *)extra_rank_write_data_pointer);
if( 0 != rc ) {
return OMPI_ERROR;
if( 0 < (count_this_stripe-n_my_count) ) {
rc=ompi_ddt_copy_content_same_ddt(dtype,
count_this_stripe-n_my_count,
(char *)(my_write_pointer+n_my_count*dt_extent),
(char *)extra_rank_write_data_pointer);
if( 0 != rc ) {
return OMPI_ERROR;
}
}
/* now we are ready for the power of 2 portion of the
@ -963,9 +964,9 @@ int mca_coll_sm2_allreduce_intra_recursive_doubling(void *sbuf, void *rbuf,
sm_buffer_desc->proc_memory[extra_rank].data_segment;
/* offset into my half of the data */
extra_rank_write_data_pointer+=
(count_this_stripe/2)*dt_extent;
((count_this_stripe/2)*dt_extent);
my_extra_write_pointer=my_write_pointer+
(count_this_stripe/2)*dt_extent;
((count_this_stripe/2)*dt_extent);
/* wait until remote data is read */
while( extra_ctl_pointer->flag < tag ) {
@ -973,8 +974,10 @@ int mca_coll_sm2_allreduce_intra_recursive_doubling(void *sbuf, void *rbuf,
}
/* apply collective operation to second half of the data */
ompi_op_reduce(op,(void *)extra_rank_write_data_pointer,
(void *)my_extra_write_pointer, n_my_count,dtype);
if( 0 < n_my_count ) {
ompi_op_reduce(op,(void *)extra_rank_write_data_pointer,
(void *)my_extra_write_pointer, n_my_count,dtype);
}
/* signal that I am done, so my partner can read my data */
MB();
@ -1159,7 +1162,6 @@ Error:
/**
* Shared memory blocking allreduce.
*/
static
int mca_coll_sm2_allreduce_intra_reducescatter_allgather(void *sbuf, void *rbuf,
int count, struct ompi_datatype_t *dtype,
struct ompi_op_t *op, struct ompi_communicator_t *comm,
@ -1402,17 +1404,6 @@ int mca_coll_sm2_allreduce_intra_reducescatter_allgather(void *sbuf, void *rbuf,
* at the number of procs in the exchange, so a divide by two at each
* iteration will give the right number of proc for the given iteration
*/
/* debug
{ int *int_tmp=(int *)my_base_pointer;
int i;
fprintf(stderr," GGG my rank %d data in tmp :: ",my_rank);
for (i=0 ; i < count_this_stripe ; i++ ) {
fprintf(stderr," %d ",int_tmp[i]);
}
fprintf(stderr,"\n");
fflush(stderr);
}
end debug */
n_proc_data=my_exchange_node->n_largest_pow_2;
starting_proc=0;
for(exchange=my_exchange_node->n_exchanges-1;exchange>=0;exchange--) {
@ -1888,40 +1879,34 @@ int mca_coll_sm2_allreduce_intra(void *sbuf, void *rbuf, int count,
{
/* local variables */
int rc;
mca_coll_sm2_module_t *sm_module;
ptrdiff_t dt_extent;
size_t len_data_buffer;
#if 0
if( 0 != (op->o_flags & OMPI_OP_FLAGS_COMMUTE)) {
/* Commutative Operation */
rc= mca_coll_sm2_allreduce_intra_recursive_doubling(sbuf, rbuf, count,
dtype, op, comm, module);
if( OMPI_SUCCESS != rc ) {
goto Error;
}
#endif
rc= mca_coll_sm2_allreduce_intra_reducescatter_allgather(sbuf, rbuf, count,
dtype, op, comm, module);
if( OMPI_SUCCESS != rc ) {
goto Error;
}
#if 0
} else {
/* Non-Commutative Operation */
#endif
#if 0
rc= mca_coll_sm2_allreduce_intra_fanin_fanout_pipeline(
sbuf, rbuf, count,dtype, op, comm, module);
if( OMPI_SUCCESS != rc ) {
goto Error;
}
/* Non-Commutative Operation */
rc= mca_coll_sm2_allreduce_intra_fanin_fanout(sbuf, rbuf, count,
dtype, op, comm, module);
if( OMPI_SUCCESS != rc ) {
goto Error;
}
sm_module=(mca_coll_sm2_module_t *) module;
/* get size of data needed - same layout as user data, so that
* we can apply the reudction routines directly on these buffers
*/
rc=ompi_ddt_type_extent(dtype, &dt_extent);
if( OMPI_SUCCESS != rc ) {
goto Error;
}
#endif
len_data_buffer=count*dt_extent;
if( len_data_buffer <= sm_module->short_message_size) {
rc=sm_module->allreduce_functions[SHORT_DATA_FN_ALLREDUCE]
(sbuf, rbuf, count, dtype, op, comm, module);
}
else {
rc=sm_module->allreduce_functions[LONG_DATA_FN_ALLREDUCE]
(sbuf, rbuf, count, dtype, op, comm, module);
}
if( OMPI_SUCCESS != rc ) {
goto Error;
}
return OMPI_SUCCESS;

Просмотреть файл

@ -35,19 +35,6 @@
#include "ompi/mca/coll/base/base.h"
#include "orte/mca/rml/rml.h"
/* debug */
#include <signal.h>
extern int debug_print;
extern int my_debug_rank;
extern void debug_module(void);
void dbg_handler(int my_signal) {
/* debug_print=1; */
debug_module();
return;
}
/* end debug */
/*
* Public string showing the coll ompi_sm V2 component version number
@ -129,10 +116,6 @@ mca_coll_sm2_component_t mca_coll_sm2_component = {
*/
static int sm2_open(void)
{
/* debug */
int retVal;
struct sigaction new_sigact;
/* end debug */
/* local variables */
mca_coll_sm2_component_t *cs = &mca_coll_sm2_component;
@ -205,17 +188,8 @@ static int sm2_open(void)
mca_coll_sm2_param_register_int("force_barrier",(-1));
cs->force_reduce=
mca_coll_sm2_param_register_int("force_reduce",(-1));
/* debug */
/*
new_sigact.sa_handler=dbg_handler;
sigemptyset(&(new_sigact.sa_mask));
retVal=sigaction(SIGUSR2,&new_sigact,NULL);
*/
signal(SIGUSR2,dbg_handler);
/* end debug */
cs->force_allreduce=
mca_coll_sm2_param_register_int("force_allreduce",(-1));
return OMPI_SUCCESS;
}

Просмотреть файл

@ -31,6 +31,7 @@
#include "ompi/constants.h"
#include "ompi/communicator/communicator.h"
#include "ompi/mca/coll/coll.h"
#include "opal/util/show_help.h"
#include "coll_sm2.h"
#include "ompi/mca/coll/base/base.h"
#include "ompi/mca/dpm/dpm.h"
@ -748,17 +749,35 @@ mca_coll_sm2_comm_query(struct ompi_communicator_t *comm, int *priority)
mca_coll_sm2_reduce_intra_fanin;
sm_module->list_reduce_functions[REDUCE_SCATTER_GATHER_FN]=
mca_coll_sm2_reduce_intra_reducescatter_gather;
sm_module->reduce_functions[SHORT_DATA_FN]=
sm_module->reduce_functions[SHORT_DATA_FN_REDUCE]=
sm_module->list_reduce_functions[FANIN_REDUCE_FN];
sm_module->reduce_functions[LONG_DATA_FN]=
sm_module->reduce_functions[LONG_DATA_FN_REDUCE]=
sm_module->list_reduce_functions[REDUCE_SCATTER_GATHER_FN];
if( ( 0 <= mca_coll_sm2_component.force_reduce ) &&
( N_REDUCE_FNS > mca_coll_sm2_component.force_reduce ) ) {
/* set user specifed function */
mca_coll_base_module_barrier_fn_t tmp_fn=
sm_module->reduce_functions[mca_coll_sm2_component.force_reduce];
sm_module->reduce_functions[SHORT_DATA_FN]=tmp_fn;
sm_module->reduce_functions[LONG_DATA_FN]=tmp_fn;
mca_coll_base_module_reduce_fn_t tmp_fn=sm_module->
list_reduce_functions[mca_coll_sm2_component.force_reduce];
sm_module->reduce_functions[SHORT_DATA_FN_REDUCE]=tmp_fn;
sm_module->reduce_functions[LONG_DATA_FN_REDUCE]=tmp_fn;
}
/* allreduce */
sm_module->list_allreduce_functions[FANIN_FANOUT_ALLREDUCE_FN]=
mca_coll_sm2_allreduce_intra_fanin_fanout;
sm_module->list_allreduce_functions[REDUCE_SCATTER_ALLGATHER_FN]=
mca_coll_sm2_allreduce_intra_reducescatter_allgather;
sm_module->allreduce_functions[SHORT_DATA_FN_ALLREDUCE]=
sm_module->list_allreduce_functions[FANIN_FANOUT_ALLREDUCE_FN];
sm_module->allreduce_functions[LONG_DATA_FN_ALLREDUCE]=
sm_module->list_allreduce_functions[REDUCE_SCATTER_ALLGATHER_FN];
if( ( 0 <= mca_coll_sm2_component.force_allreduce ) &&
( N_ALLREDUCE_FNS > mca_coll_sm2_component.force_allreduce ) ) {
/* set user specifed function */
mca_coll_base_module_allreduce_fn_t tmp_fn=sm_module->
list_allreduce_functions[mca_coll_sm2_component.force_allreduce];
sm_module->allreduce_functions[SHORT_DATA_FN_ALLREDUCE]=tmp_fn;
sm_module->allreduce_functions[LONG_DATA_FN_ALLREDUCE]=tmp_fn;
}
/*

Просмотреть файл

@ -741,11 +741,11 @@ int mca_coll_sm2_reduce_intra(void *sbuf, void *rbuf, int count,
len_data_buffer=count*dt_extent;
if( len_data_buffer <= sm_module->short_message_size) {
rc=sm_module->reduce_functions[SHORT_DATA_FN]
rc=sm_module->reduce_functions[SHORT_DATA_FN_REDUCE]
(sbuf, rbuf, count, dtype, op, root, comm, module);
}
else {
rc=sm_module->reduce_functions[LONG_DATA_FN]
rc=sm_module->reduce_functions[LONG_DATA_FN_REDUCE]
(sbuf, rbuf, count, dtype, op, root, comm, module);
}