Skip to content

Commit

Permalink
Update sparse_finch notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Nov 28, 2024
1 parent e9b560b commit 40e95a8
Showing 1 changed file with 209 additions and 23 deletions.
232 changes: 209 additions & 23 deletions examples/sparse_finch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"import sparse\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import networkx as nx\n",
"\n",
"import numpy as np\n",
"import scipy.sparse as sps\n",
Expand Down Expand Up @@ -105,7 +106,7 @@
"metadata": {},
"outputs": [],
"source": [
"ITERS = 3\n",
"ITERS = 1\n",
"rng = np.random.default_rng(0)"
]
},
Expand Down Expand Up @@ -134,6 +135,13 @@
" return elapsed / ITERS"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## MTTKRP"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -146,26 +154,30 @@
"importlib.reload(sparse)\n",
"\n",
"configs = [\n",
" {\"I_\": 100, \"J_\": 25, \"K_\": 10, \"L_\": 10, \"DENSITY\": 0.001},\n",
" {\"I_\": 100, \"J_\": 25, \"K_\": 100, \"L_\": 10, \"DENSITY\": 0.001},\n",
" {\"I_\": 100, \"J_\": 25, \"K_\": 100, \"L_\": 100, \"DENSITY\": 0.001},\n",
" {\"I_\": 1000, \"J_\": 25, \"K_\": 100, \"L_\": 100, \"DENSITY\": 0.001},\n",
" {\"I_\": 1000, \"J_\": 25, \"K_\": 1000, \"L_\": 100, \"DENSITY\": 0.001},\n",
" {\"I_\": 1000, \"J_\": 25, \"K_\": 1000, \"L_\": 1000, \"DENSITY\": 0.001},\n",
"]\n",
"nonzeros = [10000, 100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000]\n",
"nonzeros = [100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000]\n",
"\n",
"if CI_MODE:\n",
" configs = configs[:1]\n",
" nonzeros = nonzeros[:1]\n",
"\n",
"finch_times = []\n",
"numba_times = []\n",
"finch_galley_times = []\n",
"\n",
"for config in configs:\n",
" B_sps = sparse.random((config[\"I_\"], config[\"K_\"], config[\"L_\"]), density=config[\"DENSITY\"], random_state=rng) * 10\n",
" D_sps = rng.random((config[\"L_\"], config[\"J_\"])) * 10\n",
" C_sps = rng.random((config[\"K_\"], config[\"J_\"])) * 10\n",
" B_sps = sparse.random(\n",
" (config[\"I_\"], config[\"K_\"], config[\"L_\"]),\n",
" density=config[\"DENSITY\"],\n",
" random_state=rng,\n",
" )\n",
" D_sps = rng.random((config[\"L_\"], config[\"J_\"]))\n",
" C_sps = rng.random((config[\"K_\"], config[\"J_\"]))\n",
"\n",
" # ======= Finch =======\n",
" os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
Expand All @@ -175,7 +187,7 @@
" D = sparse.asarray(np.array(D_sps, order=\"F\"))\n",
" C = sparse.asarray(np.array(C_sps, order=\"F\"))\n",
"\n",
" @sparse.compiled\n",
" @sparse.compiled(opt=\"default\")\n",
" def mttkrp_finch(B, D, C):\n",
" return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))\n",
"\n",
Expand All @@ -184,6 +196,23 @@
" # Benchmark\n",
" time_finch = benchmark(mttkrp_finch, info=\"Finch\", args=[B, D, C])\n",
"\n",
" # ======= Finch Galley =======\n",
" os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
" importlib.reload(sparse)\n",
"\n",
" B = sparse.asarray(B_sps.todense(), format=\"csf\")\n",
" D = sparse.asarray(np.array(D_sps, order=\"F\"))\n",
" C = sparse.asarray(np.array(C_sps, order=\"F\"))\n",
"\n",
" @sparse.compiled(opt=\"galley\")\n",
" def mttkrp_finch(B, D, C):\n",
" return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))\n",
"\n",
" # Compile\n",
" result_finch_galley = mttkrp_finch(B, D, C)\n",
" # Benchmark\n",
" time_finch_galley = benchmark(mttkrp_finch, info=\"Finch\", args=[B, D, C])\n",
"\n",
" # ======= Numba =======\n",
" os.environ[sparse._ENV_VAR_NAME] = \"Numba\"\n",
" importlib.reload(sparse)\n",
Expand All @@ -201,8 +230,10 @@
" time_numba = benchmark(mttkrp_numba, info=\"Numba\", args=[B, D, C])\n",
"\n",
" np.testing.assert_allclose(result_finch.todense(), result_numba.todense())\n",
"\n",
" finch_times.append(time_finch)\n",
" numba_times.append(time_numba)"
" numba_times.append(time_numba)\n",
" finch_galley_times.append(time_finch_galley)"
]
},
{
Expand All @@ -215,6 +246,7 @@
"\n",
"ax.plot(nonzeros, finch_times, \"o-\", label=\"Finch\")\n",
"ax.plot(nonzeros, numba_times, \"o-\", label=\"Numba\")\n",
"ax.plot(nonzeros, finch_galley_times, \"o-\", label=\"Finch - Galley\")\n",
"ax.grid(True)\n",
"ax.set_xlabel(\"no. of elements\")\n",
"ax.set_ylabel(\"time (sec)\")\n",
Expand All @@ -226,6 +258,13 @@
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SDDMM"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -235,15 +274,13 @@
"print(\"SDDMM Example:\\n\")\n",
"\n",
"configs = [\n",
" {\"LEN\": 10, \"DENSITY\": 0.1},\n",
" {\"LEN\": 50, \"DENSITY\": 0.05},\n",
" {\"LEN\": 100, \"DENSITY\": 0.01},\n",
" {\"LEN\": 500, \"DENSITY\": 0.005},\n",
" {\"LEN\": 1000, \"DENSITY\": 0.001},\n",
" {\"LEN\": 5000, \"DENSITY\": 0.00005},\n",
" {\"LEN\": 5000, \"DENSITY\": 0.00001},\n",
" {\"LEN\": 10000, \"DENSITY\": 0.00001},\n",
" {\"LEN\": 20000, \"DENSITY\": 0.00001},\n",
" {\"LEN\": 25000, \"DENSITY\": 0.00001},\n",
" {\"LEN\": 30000, \"DENSITY\": 0.00001},\n",
"]\n",
"size_n = [10, 50, 100, 500, 1000, 5000, 10000]\n",
"size_n = [5000, 10000, 20000, 25000, 30000]\n",
"\n",
"if CI_MODE:\n",
" configs = configs[:1]\n",
Expand All @@ -252,25 +289,27 @@
"finch_times = []\n",
"numba_times = []\n",
"scipy_times = []\n",
"finch_galley_times = []\n",
"\n",
"for config in configs:\n",
" LEN = config[\"LEN\"]\n",
" DENSITY = config[\"DENSITY\"]\n",
"\n",
" a_sps = rng.random((LEN, LEN)) * 10\n",
" b_sps = rng.random((LEN, LEN)) * 10\n",
" s_sps = sps.random(LEN, LEN, format=\"coo\", density=DENSITY, random_state=rng) * 10\n",
" a_sps = rng.random((LEN, LEN))\n",
" b_sps = rng.random((LEN, LEN))\n",
" s_sps = sps.random(LEN, LEN, format=\"coo\", density=DENSITY, random_state=rng)\n",
" s_sps.sum_duplicates()\n",
"\n",
" # ======= Finch =======\n",
" print(\"finch\")\n",
" os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
" importlib.reload(sparse)\n",
"\n",
" s = sparse.asarray(s_sps)\n",
" a = sparse.asarray(np.array(a_sps, order=\"F\"))\n",
" b = sparse.asarray(np.array(b_sps, order=\"C\"))\n",
"\n",
" @sparse.compiled\n",
" @sparse.compiled(opt=\"default\")\n",
" def sddmm_finch(s, a, b):\n",
" return sparse.sum(\n",
" s[:, :, None] * (a[:, None, :] * sparse.permute_dims(b, (1, 0))[None, :, :]),\n",
Expand All @@ -282,7 +321,30 @@
" # Benchmark\n",
" time_finch = benchmark(sddmm_finch, info=\"Finch\", args=[s, a, b])\n",
"\n",
" # ======= Finch Galley =======\n",
" print(\"finch galley\")\n",
" os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
" importlib.reload(sparse)\n",
"\n",
" s = sparse.asarray(s_sps)\n",
" a = sparse.asarray(np.array(a_sps, order=\"F\"))\n",
" b = sparse.asarray(np.array(b_sps, order=\"C\"))\n",
"\n",
" @sparse.compiled(opt=\"galley\")\n",
" def sddmm_finch(s, a, b):\n",
" # return s * (a @ b)\n",
" return sparse.sum(\n",
" s[:, :, None] * (a[:, None, :] * sparse.permute_dims(b, (1, 0))[None, :, :]),\n",
" axis=-1,\n",
" )\n",
"\n",
" # Compile\n",
" result_finch_galley = sddmm_finch(s, a, b)\n",
" # Benchmark\n",
" time_finch_galley = benchmark(sddmm_finch, info=\"Finch\", args=[s, a, b])\n",
"\n",
" # ======= Numba =======\n",
" print(\"numba\")\n",
" os.environ[sparse._ENV_VAR_NAME] = \"Numba\"\n",
" importlib.reload(sparse)\n",
"\n",
Expand All @@ -299,6 +361,8 @@
" time_numba = benchmark(sddmm_numba, info=\"Numba\", args=[s, a, b])\n",
"\n",
" # ======= SciPy =======\n",
" print(\"scipy\")\n",
"\n",
" def sddmm_scipy(s, a, b):\n",
" return s.multiply(a @ b)\n",
"\n",
Expand All @@ -312,7 +376,8 @@
"\n",
" finch_times.append(time_finch)\n",
" numba_times.append(time_numba)\n",
" scipy_times.append(time_scipy)"
" scipy_times.append(time_scipy)\n",
" finch_galley_times.append(time_finch_galley)"
]
},
{
Expand All @@ -326,13 +391,134 @@
"ax.plot(size_n, finch_times, \"o-\", label=\"Finch\")\n",
"ax.plot(size_n, numba_times, \"o-\", label=\"Numba\")\n",
"ax.plot(size_n, scipy_times, \"o-\", label=\"SciPy\")\n",
"ax.plot(size_n, finch_galley_times, \"o-\", label=\"Finch Galley\")\n",
"\n",
"ax.grid(True)\n",
"ax.set_xlabel(\"size N\")\n",
"ax.set_ylabel(\"time (sec)\")\n",
"ax.set_title(\"SDDMM\")\n",
"ax.set_xscale(\"log\")\n",
"# ax.set_yscale('log')\n",
"# ax.set_xscale(\"log\")\n",
"# ax.set_yscale(\"log\")\n",
"ax.legend(loc=\"best\", numpoints=1)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"Counting Triangles Example:\\n\")\n",
"\n",
"configs = [\n",
" {\"LEN\": 1000, \"DENSITY\": 0.1},\n",
" {\"LEN\": 2000, \"DENSITY\": 0.1},\n",
" {\"LEN\": 3000, \"DENSITY\": 0.1},\n",
" {\"LEN\": 4000, \"DENSITY\": 0.1},\n",
" {\"LEN\": 5000, \"DENSITY\": 0.1},\n",
"]\n",
"size_n = [1000, 2000, 3000, 4000, 5000]\n",
"\n",
"if CI_MODE:\n",
" configs = configs[:1]\n",
" size_n = size_n[:1]\n",
"\n",
"finch_times = []\n",
"finch_galley_times = []\n",
"networkx_times = []\n",
"scipy_times = []\n",
"\n",
"for config in configs:\n",
" LEN = config[\"LEN\"]\n",
" DENSITY = config[\"DENSITY\"]\n",
"\n",
" G = nx.gnp_random_graph(n=LEN, p=DENSITY)\n",
" a_sps = nx.to_scipy_sparse_array(G)\n",
"\n",
" # ======= Finch =======\n",
" print(\"finch\")\n",
" os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
" importlib.reload(sparse)\n",
"\n",
" a = sparse.asarray(a_sps)\n",
"\n",
" @sparse.compiled(opt=\"default\")\n",
" def ct_finch(a):\n",
" return sparse.sum(\n",
" a[:, :, None] * a[:, None, :] * sparse.permute_dims(a, (1, 0))[None, :, :],\n",
" ) / sparse.asarray(6)\n",
"\n",
" # Compile\n",
" result_finch = ct_finch(a)\n",
" # Benchmark\n",
" time_finch = benchmark(ct_finch, info=\"Finch\", args=[a])\n",
"\n",
" # ======= Finch Galley =======\n",
" print(\"finch galley\")\n",
" os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
" importlib.reload(sparse)\n",
"\n",
" a = sparse.asarray(a_sps)\n",
"\n",
" @sparse.compiled(opt=\"galley\")\n",
" def ct_finch(a):\n",
" return sparse.sum(\n",
" a[:, :, None] * a[:, None, :] * sparse.permute_dims(a, (1, 0))[None, :, :],\n",
" ) / sparse.asarray(6)\n",
"\n",
" # Compile\n",
" result_finch_galley = ct_finch(a)\n",
" # Benchmark\n",
" time_finch_galley = benchmark(ct_finch, info=\"Finch\", args=[a])\n",
"\n",
" # ======= SciPy =======\n",
" print(\"scipy\")\n",
"\n",
" def ct_scipy(a):\n",
" return (a @ a * a).sum() / 6\n",
"\n",
" a = a_sps\n",
"\n",
" # Benchmark\n",
" time_scipy = benchmark(ct_scipy, info=\"SciPy\", args=[a])\n",
"\n",
" # ======= NetworkX =======\n",
" print(\"networkx\")\n",
"\n",
" def ct_networkx(a):\n",
" return sum(nx.triangles(a).values()) / 3\n",
"\n",
" a = G\n",
"\n",
" time_networkx = benchmark(ct_networkx, info=\"SciPy\", args=[a])\n",
"\n",
" finch_times.append(time_finch)\n",
" finch_galley_times.append(time_finch_galley)\n",
" networkx_times.append(time_networkx)\n",
" scipy_times.append(time_scipy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(nrows=1, ncols=1)\n",
"\n",
"ax.plot(size_n, finch_times, \"o-\", label=\"Finch\")\n",
"ax.plot(size_n, networkx_times, \"o-\", label=\"NetworkX\")\n",
"ax.plot(size_n, scipy_times, \"o-\", label=\"SciPy\")\n",
"ax.plot(size_n, finch_galley_times, \"o-\", label=\"Finch Galley\")\n",
"\n",
"ax.grid(True)\n",
"ax.set_xlabel(\"size N\")\n",
"ax.set_ylabel(\"time (sec)\")\n",
"ax.set_title(\"Counting Triangles\")\n",
"# ax.set_xscale(\"log\")\n",
"# ax.set_yscale(\"log\")\n",
"ax.legend(loc=\"best\", numpoints=1)\n",
"\n",
"plt.show()"
Expand All @@ -355,7 +541,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 40e95a8

Please sign in to comment.