Skip to content

Commit

Permalink
Remove lambdas
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Mar 30, 2023
1 parent ec7f0b6 commit 391cbc1
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 61 deletions.
2 changes: 1 addition & 1 deletion jax_triton/pallas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Module for pallas, a jaxpr "dialect" for Triton."""
from jax_triton.pallas.core import BlockSpec
from jax_triton.pallas.core import Config
from jax_triton.pallas.core import KernelConfig
from jax_triton.pallas.pallas_call import pallas_call
from jax_triton.pallas.pallas_call import pallas_call_p
from jax_triton.pallas.primitives import atomic_add
Expand Down
14 changes: 13 additions & 1 deletion jax_triton/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import functools
from functools import partial

from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, Union

import jax.numpy as jnp
from jax._src import api_util
Expand Down Expand Up @@ -95,6 +95,18 @@ class GridSpec:

Platform = str


@dataclasses.dataclass
class KernelConfig:
in_specs: Optional[Sequence[Optional[BlockSpec]]] = None
out_specs: Optional[Sequence[Optional[BlockSpec]]] = None
grid: Optional[Union[Grid, int]] = None
meta: dict[str, Any] = dataclasses.field(default_factory=dict)
compiler_params: dict[Platform, dict[str, Any]] = dataclasses.field(default_factory=dict)

def replace(self, *args, **kwargs):
return dataclasses.replace(self, *args, **kwargs)

@dataclasses.dataclass
class Config:
meta: dict[str, Any]
Expand Down
113 changes: 71 additions & 42 deletions jax_triton/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,24 +306,16 @@ def _compute_spec(config: Config, spec: MaybeSpec,
spec = spec(**config.meta)
return spec

def specialize_kernel(config: Config,
def specialize_kernel(config: pallas_core.KernelConfig,
func: Callable,
grid: Optional[pallas_core.Grid],
name: Optional[str],
in_specs: Optional[list[Optional[BlockSpec]]],
out_specs: Optional[list[Optional[BlockSpec]]],
in_avals: tuple[jax_core.ShapedArray, ...],
out_avals: tuple[jax_core.ShapedArray, ...],
in_tree: tree_util.PyTreeDef,
compiler_params: dict[str, Any]
) -> tuple[SpecializedKernel, ...]:
specialized_grid = grid
if callable(specialized_grid):
specialized_grid = specialized_grid(**config.meta)
specialized_grid = pallas_core.preprocess_grid(specialized_grid)
specialized_in_specs = map(partial(_compute_spec, config), in_specs)
specialized_out_specs = map(partial(_compute_spec, config), out_specs)
if specialized_grid == ():
grid = config.grid
if grid == ():
in_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype)
for arg in in_avals]
out_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype)
Expand All @@ -333,42 +325,76 @@ def specialize_kernel(config: Config,
state.shaped_array_ref(
pallas_core.compute_shape_from_block_spec(block_spec, aval.shape),
aval.dtype)
for block_spec, aval in zip(specialized_in_specs, in_avals)]
for block_spec, aval in zip(config.in_specs, in_avals)]
out_ref_avals = [
state.shaped_array_ref(
pallas_core.compute_shape_from_block_spec(block_spec, aval.shape),
aval.dtype)
for block_spec, aval in zip(specialized_out_specs, out_avals)]
in_block_mappings = map(partial(pallas_core.convert_block_spec_to_block_mapping, specialized_grid),
specialized_in_specs)
out_block_mappings = map(partial(pallas_core.convert_block_spec_to_block_mapping, specialized_grid),
specialized_out_specs)
grid_spec = pallas_core.GridSpec(specialized_grid, (*in_block_mappings, *out_block_mappings), ())
for block_spec, aval in zip(config.out_specs, out_avals)]
in_block_mappings = map(
partial(pallas_core.convert_block_spec_to_block_mapping, grid),
config.in_specs)
out_block_mappings = map(
partial(pallas_core.convert_block_spec_to_block_mapping, grid),
config.out_specs)
grid_spec = pallas_core.GridSpec(grid, (*in_block_mappings, *out_block_mappings), ())
jaxpr, consts, out_tree = tracing_utils.initial_style_open_jaxpr(
func, in_tree, tuple((*in_ref_avals, *out_ref_avals)), "pallas_call", **config.meta)
return SpecializedKernel("foo", jaxpr, len(consts), grid_spec,
dict(compiler_params, **config.compiler_params)), consts, out_tree

def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False,
def _canonicalize_kernel_config(
maybe_kernel_config: Optional[pallas_core.KernelConfig],
in_avals: Sequence[jax_core.AbstractValue],
out_avals: Sequence[jax_core.AbstractValue],
in_specs: Optional[Sequence[Optional[BlockSpec]]],
out_specs: Optional[Sequence[Optional[BlockSpec]]],
grid: Optional[Union[Grid, int]],
) -> pallas_core.KernelConfig:
if not maybe_kernel_config:
config = pallas_core.KernelConfig(in_specs=in_specs, out_specs=out_specs, grid=grid)
else:
config = maybe_kernel_config
grid = maybe_kernel_config.grid
grid, in_specs, out_specs = config.grid, config.in_specs, config.out_specs
grid = pallas_core.preprocess_grid(grid)
if in_specs is not None and not isinstance(in_specs, (tuple, list)):
in_specs = (in_specs,)
if out_specs is not None and not isinstance(out_specs, (tuple, list)):
out_specs = (out_specs,)
if in_specs is None:
in_specs = [None] * len(in_avals)
if out_specs is None:
out_specs = [None] * len(out_avals)
return config.replace(grid=grid, in_specs=in_specs, out_specs=out_specs)

def pallas_call(f: Callable, out_shape: Any, *,
grid: Optional[Grid] = None,
config: Optional[pallas_core.KernelConfig] = None,
in_specs: Optional[Sequence[Optional[BlockSpec]]] = None,
out_specs: Optional[Sequence[Optional[BlockSpec]]] = None,
input_output_aliases: Dict[int, int] = {},
interpret: bool = False,
name: Optional[str] = None,
autotuning_configs: Optional[list[Config]] = None,
autotuning_configs: Optional[Sequence[pallas_core.KernelConfig]] = None,
debug: bool = False,
**compiler_params: Any):
if config is not None:
if grid is not None or in_specs is not None or out_specs is not None:
raise ValueError("Cannot specify both `config` and any of `grid`, "
"`in_specs`, or `out_specs`.")
if autotuning_configs is not None:
raise ValueError("Cannot specify both `config` and `autotuning_configs`")
if autotuning_configs is not None:
if grid is not None or in_specs is not None or out_specs is not None:
raise ValueError("Cannot specify both `autotuning_configs` and any of `grid`, "
"`in_specs`, or `out_specs`.")
singleton = False
if not isinstance(out_shape, (tuple, list)):
out_shape = (out_shape,)
singleton = True
if not isinstance(out_shape, tuple):
out_shape = tuple(out_shape)
if in_specs is not None and not isinstance(in_specs, (tuple, list)):
in_specs = (in_specs,)
if out_specs is not None and not isinstance(out_specs, (tuple, list)):
out_specs = (out_specs,)

if not name:
name = f.__name__ if hasattr(f, "__name__") else "unnamed"

Expand All @@ -382,29 +408,32 @@ def wrapped(*args):
for a in flat_args)
flat_out_avals = tuple(jax_core.ShapedArray(a.shape, a.dtype)
for a in flat_out_shapes)
canonicalized_configs = []
if autotuning_configs is None:
canonicalized_configs.append(_canonicalize_kernel_config(config,
flat_in_avals,
flat_out_avals,
in_specs,
out_specs,
grid))
else:
canonicalized_configs.extend(map(partial(_canonicalize_kernel_config,
in_avals=flat_in_avals,
out_avals=flat_out_avals,
in_specs=in_specs,
out_specs=out_specs,
grid=grid),
autotuning_configs))
kernels = []
flat_in_specs = in_specs
flat_out_specs = out_specs
if flat_in_specs is None:
flat_in_specs = [None] * len(flat_in_avals)
if flat_out_specs is None:
flat_out_specs = [None] * len(flat_out_avals)
all_consts = []
if autotuning_configs is None:
if len(canonicalized_configs) == 0:
raise ValueError("Cannot pass in empty autotuning configs")
for canonicalized_config in canonicalized_configs:
specialized_kernel, consts, jaxpr_out_tree = specialize_kernel(
Config({}, {}), f, grid, name, flat_in_specs, flat_out_specs, flat_in_avals,
canonicalized_config, f, name, flat_in_avals,
flat_out_avals, jaxpr_in_tree, compiler_params)
kernels.append(specialized_kernel)
all_consts.extend(consts)
else:
if len(autotuning_configs) == 0:
raise ValueError("Cannot pass in empty autotuning configs")
for config in autotuning_configs:
specialized_kernel, consts, jaxpr_out_tree = specialize_kernel(
config, f, grid, name, flat_in_specs, flat_out_specs, flat_in_avals, flat_out_avals,
jaxpr_in_tree, compiler_params)
kernels.append(specialized_kernel)
all_consts.extend(consts)
if all_consts:
raise NotImplementedError("Cannot handle consts.")
del jaxpr_out_tree
Expand Down
8 changes: 5 additions & 3 deletions jax_triton/pallas/triton_ir_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,9 +837,11 @@ def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes,
if debug:
print(kernel.jaxpr)
print(kernel.grid_spec)
compiler_params = kernel.compiler_params
num_warps = compiler_params.get("num_warps", 4)
num_stages = compiler_params.get("num_stages", 3)
compiler_params = dict(kernel.compiler_params)
num_warps = compiler_params.pop("num_warps", 4)
num_stages = compiler_params.pop("num_stages", 3)
if compiler_params:
raise ValueError(f"Invalid compiler params: {compiler_params}")
compilation_result = compile_jaxpr(kernel.jaxpr, kernel.num_consts,
tuple((*in_shapes, *out_shapes)),
kernel.grid_spec, kernel.name, num_warps, num_stages)
Expand Down
26 changes: 12 additions & 14 deletions tests/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,10 +857,10 @@ def test_basic_autotuning(self):

@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32),
grid=lambda block_size: 8 // block_size,
autotuning_configs=[
pl.Config(dict(block_size=2), {}),
pl.Config(dict(block_size=4), {}),
pl.KernelConfig(meta=dict(block_size=block_size),
grid=8 // block_size)
for block_size in [1, 2, 4, 8]
])
def add_one(x_ref, o_ref, *, block_size):
idx = pl.program_id(0) * block_size + jnp.arange(block_size)
Expand All @@ -873,18 +873,16 @@ def test_basic_autotuning_with_block_spec(self):

@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32),
grid=lambda block_size: 8 // block_size,
in_specs=[
lambda block_size: pl.BlockSpec(lambda i: i, (block_size,)),
],
out_specs=[
lambda block_size: pl.BlockSpec(lambda i: i, (block_size,)),
],
autotuning_configs=[
pl.Config(dict(block_size=1), {}),
pl.Config(dict(block_size=2), {}),
pl.Config(dict(block_size=4), {}),
pl.Config(dict(block_size=8), {}),
pl.KernelConfig(meta=dict(block_size=block_size),
in_specs=[
pl.BlockSpec(lambda i: i, (block_size,))
],
out_specs=[
pl.BlockSpec(lambda i: i, (block_size,))
],
grid=8 // block_size)
for block_size in [1, 2, 4, 8]
],
debug=True)
def add_one(x_ref, o_ref, *, block_size):
Expand Down

0 comments on commit 391cbc1

Please sign in to comment.