Skip to content

Commit

Permalink
Add metadata support to TranslateBetweenGrid for Star VC (elemental#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
aj-prime authored Apr 26, 2023
1 parent bf3cd78 commit 2f7f309
Showing 1 changed file with 41 additions and 9 deletions.
50 changes: 41 additions & 9 deletions include/El/blas_like/level1/Copy/TranslateBetweenGrids.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3856,7 +3856,8 @@ void TranslateBetweenGrids(
EL_DEBUG_CSE;

/* Overview
We broadcast the size of A to all the ranks in B to make sure that
all ranks in B subgrid has the correct size of A.
Since we are using blocking communication, some care is required
to avoid deadlocks. Let's start with a naive algorithm for
[STAR,VC] matrices and optimize it in steps:
Expand All @@ -3883,21 +3884,53 @@ void TranslateBetweenGrids(
*/

// Matrix dimensions
const Int m = A.Height();
const Int n = A.Width();
Int m = A.Height();
Int n = A.Width();
Int strideA = A.RowStride();
Int ALDim = A.LDim();

// Create A metadata
Int recvMetaData[4];
Int metaData[4];

SyncInfo<El::Device::CPU> syncGeneralMetaData = SyncInfo<El::Device::CPU>();
mpi::Comm const& viewingCommB = B.Grid().ViewingComm();

const bool inAGrid = A.Participating();
const bool inBGrid = B.Participating();

if(inAGrid)
{
metaData[0] = m;
metaData[1] = n;
metaData[2] = strideA;
metaData[3] = ALDim;
}
else
{
metaData[0] = 0;
metaData[1] = 0;
metaData[2] = 0;
metaData[3] = 0;
}
const std::vector<Int> sendMetaData (metaData, metaData + 4);
mpi::AllReduce( sendMetaData.data(), recvMetaData, 4, mpi::MAX, viewingCommB, syncGeneralMetaData);
m = recvMetaData[0];
n = recvMetaData[1];
strideA = recvMetaData[2];
ALDim =recvMetaData[3];


B.Resize(m, n);
const Int nLocA = A.LocalWidth();
const Int nLocB = B.LocalWidth();

// Return immediately if there is no local data
const bool inAGrid = A.Participating();
const bool inBGrid = B.Participating();
if (!inAGrid && !inBGrid) {
return;
}

// Compute the number of messages to send/recv
const Int strideA = A.RowStride();
const Int strideB = B.RowStride();
const Int strideGCD = GCD(strideA, strideB);
const Int numSends = Min(strideB/strideGCD, nLocA);
Expand All @@ -3913,7 +3946,6 @@ void TranslateBetweenGrids(
// that we can match send/recv communicators. Since A's VC
// communicator is not necessarily defined on every process, we
// instead work with A's owning group.
mpi::Comm const& viewingCommB = B.Grid().ViewingComm();
mpi::Group owningGroupA = A.Grid().OwningGroup();
const int sizeA = A.Grid().Size();
vector<int> viewingRanksA(sizeA), owningRanksA(sizeA);
Expand Down Expand Up @@ -3976,15 +4008,15 @@ void TranslateBetweenGrids(
// Copy data locally
copy::util::InterleaveMatrix(
m, messageWidth,
A.LockedBuffer(0,jLocA), 1, numSends*A.LDim(),
A.LockedBuffer(0,jLocA), 1, numSends*ALDim,
B.Buffer(0,jLocB), 1, numRecvs*B.LDim(),
syncInfo);
}
else if (viewingRank == sendViewingRank) {
// Send data to other rank
copy::util::InterleaveMatrix(
m, messageWidth,
A.LockedBuffer(0,jLocA), 1, numSends*A.LDim(),
A.LockedBuffer(0,jLocA), 1, numSends*ALDim,
messageBuf.data(), 1, m,
syncInfo);
mpi::Send(
Expand Down

0 comments on commit 2f7f309

Please sign in to comment.