Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCLomatic] Support migration of cusparse<T>csrgemm and cusparseXcsrgemmNnz #2065

Open
wants to merge 50 commits into
base: SYCLomatic
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
d68f2af
[SYCLomatic] Migrate cusparse<T>csrgemm
zhiweij1 Jun 17, 2024
4f67b79
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Jun 18, 2024
e60892e
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Jun 20, 2024
ebac48b
GPU Buffer test need debug
zhiweij1 Jun 21, 2024
aed64e0
Buffer GPU test flaky
zhiweij1 Jun 21, 2024
f97ab2d
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Jun 25, 2024
5b1da27
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Jun 26, 2024
665074e
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Jun 26, 2024
6297460
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Jul 5, 2024
a00980d
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Jul 5, 2024
8caaaed
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Jul 9, 2024
7f52f3c
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Jul 12, 2024
5b16a0d
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Jul 16, 2024
ce2907d
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Jul 18, 2024
ebdaeb3
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Jul 18, 2024
f62de9c
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Jul 22, 2024
98f654e
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Jul 24, 2024
e8f54a4
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Aug 1, 2024
c49014b
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Aug 7, 2024
cf52784
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Aug 8, 2024
9a04430
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Aug 16, 2024
d2fedb0
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Aug 22, 2024
3017987
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Aug 26, 2024
f1a3141
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Aug 26, 2024
977683e
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Aug 28, 2024
afb839b
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Aug 29, 2024
f1580db
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Sep 6, 2024
5e092e3
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Sep 10, 2024
504b095
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Sep 11, 2024
cf4b498
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Sep 12, 2024
c512d53
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Sep 19, 2024
73ae9e2
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Sep 24, 2024
cad2ffa
Fix issue
zhiweij1 Sep 24, 2024
ba9db5d
Fix
zhiweij1 Sep 26, 2024
ab58f1e
Fix bug
zhiweij1 Sep 27, 2024
e5851fc
Fix
zhiweij1 Sep 27, 2024
6757912
Fix
zhiweij1 Sep 27, 2024
de6ff00
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Oct 8, 2024
4df539f
Extend helper function signature
zhiweij1 Oct 9, 2024
0c95ea5
Fix
zhiweij1 Oct 9, 2024
2147b3d
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Oct 9, 2024
098740c
Revert back
zhiweij1 Oct 10, 2024
800bdf2
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Oct 16, 2024
fadf3f3
Refine
zhiweij1 Oct 17, 2024
02b780f
Fix lit
zhiweij1 Oct 18, 2024
c937c86
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Oct 18, 2024
ef36ab8
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Oct 18, 2024
c7d6c2e
Update lit
zhiweij1 Oct 18, 2024
6f6a5e9
Merge remote-tracking branch 'origin/SYCLomatic' into sparse_gemm
zhiweij1 Oct 21, 2024
3ff754a
Refine
zhiweij1 Oct 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions clang/lib/DPCT/APINamesCUSPARSE.inc
Original file line number Diff line number Diff line change
Expand Up @@ -531,3 +531,28 @@ ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
"cusparseSpSM_solve", CALL(MapNames::getDpctNamespace() + "sparse::spsm",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(1),
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7))))

ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
"cusparseScsrgemm",
CALL(MapNames::getDpctNamespace() + "sparse::csrgemm", ARG(0), ARG(1),
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(8), ARG(9), ARG(10),
ARG(11), ARG(13), ARG(14), ARG(15), ARG(16), ARG(17), ARG(18),
ARG(19))))
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
"cusparseDcsrgemm",
CALL(MapNames::getDpctNamespace() + "sparse::csrgemm", ARG(0), ARG(1),
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(8), ARG(9), ARG(10),
ARG(11), ARG(13), ARG(14), ARG(15), ARG(16), ARG(17), ARG(18),
ARG(19))))
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
"cusparseCcsrgemm",
CALL(MapNames::getDpctNamespace() + "sparse::csrgemm", ARG(0), ARG(1),
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(8), ARG(9), ARG(10),
ARG(11), ARG(13), ARG(14), ARG(15), ARG(16), ARG(17), ARG(18),
ARG(19))))
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
"cusparseZcsrgemm",
CALL(MapNames::getDpctNamespace() + "sparse::csrgemm", ARG(0), ARG(1),
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(8), ARG(9), ARG(10),
ARG(11), ARG(13), ARG(14), ARG(15), ARG(16), ARG(17), ARG(18),
ARG(19))))
10 changes: 5 additions & 5 deletions clang/lib/DPCT/APINames_cuSPARSE.inc
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,11 @@ ENTRY(cusparseScsrgeam2, cusparseScsrgeam2, false, NO_FLAG, P4, "comment")
ENTRY(cusparseDcsrgeam2, cusparseDcsrgeam2, false, NO_FLAG, P4, "comment")
ENTRY(cusparseCcsrgeam2, cusparseCcsrgeam2, false, NO_FLAG, P4, "comment")
ENTRY(cusparseZcsrgeam2, cusparseZcsrgeam2, false, NO_FLAG, P4, "comment")
ENTRY(cusparseXcsrgemmNnz, cusparseXcsrgemmNnz, false, NO_FLAG, P4, "comment")
ENTRY(cusparseScsrgemm, cusparseScsrgemm, false, NO_FLAG, P4, "comment")
ENTRY(cusparseDcsrgemm, cusparseDcsrgemm, false, NO_FLAG, P4, "comment")
ENTRY(cusparseCcsrgemm, cusparseCcsrgemm, false, NO_FLAG, P4, "comment")
ENTRY(cusparseZcsrgemm, cusparseZcsrgemm, false, NO_FLAG, P4, "comment")
ENTRY(cusparseXcsrgemmNnz, cusparseXcsrgemmNnz, true, NO_FLAG, P4, "DPCT1130")
ENTRY(cusparseScsrgemm, cusparseScsrgemm, true, NO_FLAG, P4, "comment")
ENTRY(cusparseDcsrgemm, cusparseDcsrgemm, true, NO_FLAG, P4, "comment")
ENTRY(cusparseCcsrgemm, cusparseCcsrgemm, true, NO_FLAG, P4, "comment")
ENTRY(cusparseZcsrgemm, cusparseZcsrgemm, true, NO_FLAG, P4, "comment")
ENTRY(cusparseScsrgemm2_bufferSizeExt, cusparseScsrgemm2_bufferSizeExt, false, NO_FLAG, P4, "comment")
ENTRY(cusparseDcsrgemm2_bufferSizeExt, cusparseDcsrgemm2_bufferSizeExt, false, NO_FLAG, P4, "comment")
ENTRY(cusparseCcsrgemm2_bufferSizeExt, cusparseCcsrgemm2_bufferSizeExt, false, NO_FLAG, P4, "comment")
Expand Down
101 changes: 99 additions & 2 deletions clang/lib/DPCT/ASTTraversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3792,8 +3792,9 @@ void SPBLASFunctionCallRule::registerMatcher(MatchFinder &MF) {
"cusparseCsrsv_solveEx",
/*level 3*/
"cusparseScsrmm", "cusparseDcsrmm", "cusparseCcsrmm", "cusparseZcsrmm",
"cusparseScsrmm2", "cusparseDcsrmm2", "cusparseCcsrmm2",
"cusparseZcsrmm2",
"cusparseScsrgemm", "cusparseDcsrgemm", "cusparseCcsrgemm",
"cusparseZcsrgemm", "cusparseXcsrgemmNnz", "cusparseScsrmm2",
"cusparseDcsrmm2", "cusparseCcsrmm2", "cusparseZcsrmm2",
/*Generic*/
"cusparseCreateCsr", "cusparseDestroySpMat", "cusparseCsrGet",
"cusparseSpMatGetFormat", "cusparseSpMatGetIndexBase",
Expand Down Expand Up @@ -3861,6 +3862,102 @@ void SPBLASFunctionCallRule::runRule(const MatchFinder::MatchResult &Result) {
EA.applyAllSubExprRepl();
return;
}
if (FuncName == "cusparseXcsrgemmNnz") {
std::vector<std::string> MigratedArgs;
for (const auto &Arg : CE->arguments()) {
MigratedArgs.push_back(ExprAnalysis::ref(Arg));
}
// We need find the next cusparse<T>csrgemm API call which is using the
// result of this API call, otherwise a warning will be emitted.
auto findOuterCS = [](const Stmt *Input) {
const CompoundStmt *CS = nullptr;
DpctGlobalInfo::findAncestor<Stmt>(
Input, [&](const DynTypedNode &Cur) -> bool {
if (Cur.get<DoStmt>() || Cur.get<ForStmt>() ||
Cur.get<WhileStmt>() || Cur.get<SwitchStmt>() ||
Cur.get<IfStmt>())
return true;
if (const CompoundStmt *S = Cur.get<CompoundStmt>())
CS = S;
return false;
});
return CS;
};
const CompoundStmt *CS1 = findOuterCS(CE);
// Find all the cusparse<T>csrgemm calls in this range.
using namespace clang::ast_matchers;
auto Matcher =
findAll(callExpr(callee(functionDecl(hasAnyName(
"cusparseScsrgemm", "cusparseDcsrgemm",
"cusparseCcsrgemm", "cusparseZcsrgemm"))))
.bind("CallExpr"));
auto CEResults = match(Matcher, *CS1, DpctGlobalInfo::getContext());
// Find the correct call
const CallExpr* CorrectCall = nullptr;
for (auto &Result : CEResults) {
const CallExpr *MatchedCE = Result.getNodeAs<CallExpr>("CallExpr");
if (MatchedCE) {
// 1. The context should be the same
const CompoundStmt *CS2 = findOuterCS(MatchedCE);
if (CS1 != CS2)
continue;
// 2. The args should be the same
std::vector<std::string> MatchedCEMigratedArgs;
for (const auto &Arg : MatchedCE->arguments()) {
MatchedCEMigratedArgs.push_back(ExprAnalysis::ref(Arg));
}
const static std::map<unsigned /*CE*/, unsigned /*MatchedCE*/> IdxMap =
{
{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 4}, {5, 5},
{6, 6}, {7, 7}, {8, 9}, {9, 10}, {10, 11}, {11, 12},
{12, 14}, {13, 15}, {14, 16}, {15, 18},
};
bool IsSame = true;
for (const auto &P : IdxMap) {
if (MigratedArgs[P.first] != MatchedCEMigratedArgs[P.second]) {
IsSame = false;
break;
}
}
if (IsSame) {
CorrectCall = MatchedCE;
break;
}
}
}
zhiweij1 marked this conversation as resolved.
Show resolved Hide resolved
if (!CorrectCall) {
// emit warning
zhiweij1 marked this conversation as resolved.
Show resolved Hide resolved
return;
}
const static std::map<unsigned /*CE*/, unsigned /*MatchedCE*/>
InsertBeforeIdxMap = {
{8, 8},
{12, 13},
};
std::string MigratedCall;
MigratedCall =
MapNames::getDpctNamespace() + "sparse::csrgemm_nnz(";
for (unsigned i = 0; i < MigratedArgs.size(); i++) {
if (InsertBeforeIdxMap.count(i)) {
zhiweij1 marked this conversation as resolved.
Show resolved Hide resolved
MigratedCall +=
(ExprAnalysis::ref(CorrectCall->getArg(InsertBeforeIdxMap.at(i))) +
", ");
}
MigratedCall += MigratedArgs[i];
if (i != MigratedArgs.size() - 1)
MigratedCall += ", ";
}
MigratedCall += ")";
auto DefRange = getDefinitionRange(CE->getBeginLoc(), CE->getEndLoc());
SourceLocation Begin = DefRange.getBegin();
SourceLocation End = DefRange.getEnd();
End = End.getLocWithOffset(
Lexer::MeasureTokenLength(End, DpctGlobalInfo::getSourceManager(),
DpctGlobalInfo::getContext().getLangOpts()));
emplaceTransformation(replaceText(Begin, End, std::move(MigratedCall),
DpctGlobalInfo::getSourceManager()));
return;
}
}

REGISTER_RULE(SPBLASFunctionCallRule, PassKind::PK_Migration,
Expand Down
Loading
Loading