Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions loopy/library/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@
from loopy.translation_unit import add_callable_to_table

# getting the callable 'max' from target
max_scalar_callable = target.get_device_ast_builder().known_callables["max"]
max_scalar_callable = target.known_device_callables["max"]

Check warning on line 371 in loopy/library/reduction.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "known_device_callables" is partially unknown   Type of "known_device_callables" is "dict[Unknown, Unknown]" (reportUnknownMemberType)

# type specialize the callable
max_scalar_callable, callables_table = max_scalar_callable.with_types(
Expand Down Expand Up @@ -404,7 +404,7 @@
from loopy.translation_unit import add_callable_to_table

# getting the callable 'min' from target
min_scalar_callable = target.get_device_ast_builder().known_callables["min"]
min_scalar_callable = target.known_device_callables["min"]

Check warning on line 407 in loopy/library/reduction.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "known_device_callables" is partially unknown   Type of "known_device_callables" is "dict[Unknown, Unknown]" (reportUnknownMemberType)

# type specialize the callable
min_scalar_callable, callables_table = min_scalar_callable.with_types(
Expand Down
27 changes: 24 additions & 3 deletions loopy/target/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,24 @@
"""
raise NotImplementedError()

@property
def known_host_callables(self):

Check warning on line 195 in loopy/target/__init__.py

View workflow job for this annotation

GitHub Actions / basedpyright

Return type, "dict[Unknown, Unknown]", is partially unknown (reportUnknownParameterType)
"""
Returns a mapping from function ids to corresponding
:class:`loopy.kernel.function_interface.InKernelCallable` for the
function ids known to *self* for host code generation.
"""
return {}

@property
def known_device_callables(self):

Check warning on line 204 in loopy/target/__init__.py

View workflow job for this annotation

GitHub Actions / basedpyright

Return type, "dict[Unknown, Unknown]", is partially unknown (reportUnknownParameterType)
"""
Returns a mapping from function ids to corresponding
:class:`loopy.kernel.function_interface.InKernelCallable` for the
function ids known to *self* for device code generation.
"""
return {}


@dataclass(frozen=True)
class ASTBuilderBase(ABC, Generic[ASTType]):
Expand All @@ -206,10 +224,9 @@
"""
Returns a mapping from function ids to corresponding
:class:`loopy.kernel.function_interface.InKernelCallable` for the
function ids known to *self.target*.
function ids known to *self.target* for device code generation.
"""
# FIXME: @inducer: Do we need to move this to TargetBase?
return {}
return dict(self.target.known_device_callables)

Check warning on line 229 in loopy/target/__init__.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument type is partially unknown   Argument corresponds to parameter "map" in function "__init__"   Argument type is "dict[Unknown, Unknown]" (reportUnknownArgumentType)

Check warning on line 229 in loopy/target/__init__.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "known_device_callables" is partially unknown   Type of "known_device_callables" is "dict[Unknown, Unknown]" (reportUnknownMemberType)

def symbol_manglers(self):
return []
Expand Down Expand Up @@ -351,6 +368,10 @@


class DummyHostASTBuilder(ASTBuilderBase[None]):
@property
def known_callables(self):

Check warning on line 372 in loopy/target/__init__.py

View workflow job for this annotation

GitHub Actions / basedpyright

Method "known_callables" is not marked as override but is overriding a method in class "ASTBuilderBase[None]" (reportImplicitOverride)

Check warning on line 372 in loopy/target/__init__.py

View workflow job for this annotation

GitHub Actions / basedpyright

Return type, "dict[Unknown, Unknown]", is partially unknown (reportUnknownParameterType)
return dict(self.target.known_host_callables)

Check warning on line 373 in loopy/target/__init__.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument type is partially unknown   Argument corresponds to parameter "map" in function "__init__"   Argument type is "dict[Unknown, Unknown]" (reportUnknownArgumentType)

Check warning on line 373 in loopy/target/__init__.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "known_host_callables" is partially unknown   Type of "known_host_callables" is "dict[Unknown, Unknown]" (reportUnknownMemberType)

def get_function_definition(self, codegen_state, codegen_result,
schedule_index, function_decl, function_body):
return function_body
Expand Down
32 changes: 21 additions & 11 deletions loopy/target/c/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,13 @@ def get_host_ast_builder(self):
def get_device_ast_builder(self):
return CFamilyASTBuilder(self)

@property
@override
def known_device_callables(self):
callables = super().known_device_callables
callables.update(get_c_callables())
return callables

# {{{ types

@memoize_method
Expand Down Expand Up @@ -890,13 +897,6 @@ def preamble_generators(self):
lambda preamble_info: _preamble_generator(
preamble_info, self.preamble_function_qualifier)])

@property
@override
def known_callables(self):
callables = super().known_callables
callables.update(get_c_callables())
return callables

# }}}

# {{{ code generation
Expand Down Expand Up @@ -1606,19 +1606,29 @@ class CWithGNULibcTarget(CTarget):
def get_device_ast_builder(self):
return CWithGNULibcASTBuilder(self)


class CWithGNULibcASTBuilder(CASTBuilder):
@property
def known_callables(self):
callables = super().known_callables
@override
def known_device_callables(self):
callables = super().known_device_callables
callables.update(get_gnu_libc_callables())
return callables


class CWithGNULibcASTBuilder(CASTBuilder):
pass


class ExecutableCWithGNULibcTarget(ExecutableCTarget):
def get_device_ast_builder(self):
return CWithGNULibcASTBuilder(self)

@property
@override
def known_device_callables(self):
callables = super().known_device_callables
callables.update(get_gnu_libc_callables())
return callables

# }}}

# vim: foldmethod=marker
16 changes: 6 additions & 10 deletions loopy/target/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,12 @@ def split_kernel_at_global_barriers(self):
def get_device_ast_builder(self):
return CUDACASTBuilder(self)

@property
def known_device_callables(self):
callables = super().known_device_callables
callables.update(get_cuda_callables())
return callables

# {{{ types

@memoize_method
Expand Down Expand Up @@ -330,16 +336,6 @@ class CUDACASTBuilder(CFamilyASTBuilder):

preamble_function_qualifier = "inline __device__"

# {{{ library

@property
def known_callables(self):
callables = super().known_callables
callables.update(get_cuda_callables())
return callables

# }}}

# {{{ top-level codegen

def get_function_declaration(
Expand Down
13 changes: 7 additions & 6 deletions loopy/target/opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,13 @@ def split_kernel_at_global_barriers(self):
def get_device_ast_builder(self):
return OpenCLCASTBuilder(self)

@property
@override
def known_device_callables(self):
callables = super().known_device_callables
callables.update(get_opencl_callables())
return callables

@memoize_method
def get_dtype_registry(self) -> DTypeRegistry:
from loopy.target.c.compyte.dtypes import (
Expand Down Expand Up @@ -673,12 +680,6 @@ def vector_dtype(self, base, count):
class OpenCLCASTBuilder(CFamilyASTBuilder):
# {{{ library

@property
def known_callables(self):
callables = super().known_callables
callables.update(get_opencl_callables())
return callables

def symbol_manglers(self):
return (
[*super().symbol_manglers(), opencl_symbol_mangler])
Expand Down
31 changes: 20 additions & 11 deletions loopy/target/pyopencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,26 @@ def get_host_ast_builder(self):
def get_device_ast_builder(self):
return PyOpenCLCASTBuilder(self)

@property
@override
def known_device_callables(self):
from loopy.library.random123 import get_random123_callables

# order matters: e.g. prefer our abs() over that of the
# superclass
callables = super().known_device_callables
callables.update(get_pyopencl_callables())
callables.update(get_random123_callables(self))
return callables

@property
@override
def known_host_callables(self):
from loopy.target.c import get_c_callables
callables = super().known_host_callables
callables.update(get_c_callables())
return callables

# {{{ types

@override
Expand Down Expand Up @@ -1224,17 +1244,6 @@ def get_function_declaration(

# {{{ library

@property
def known_callables(self):
from loopy.library.random123 import get_random123_callables

# order matters: e.g. prefer our abs() over that of the
# superclass
callables = super().known_callables
callables.update(get_pyopencl_callables())
callables.update(get_random123_callables(self.target))
return callables

def preamble_generators(self):
return ([pyopencl_preamble_generator, *super().preamble_generators()])

Expand Down
5 changes: 1 addition & 4 deletions loopy/target/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,7 @@ class PythonASTBuilderBase(ASTBuilderBase[Generable]):

@property
def known_callables(self):
from loopy.target.c import get_c_callables
callables = super().known_callables
callables.update(get_c_callables())
return callables
return dict(self.target.known_host_callables)

def preamble_generators(self):
return (
Expand Down
2 changes: 1 addition & 1 deletion loopy/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ def __init__(self, rule_mapping_context, callables_table, target):
@cached_property
def known_callables(self):
from loopy.kernel.function_interface import CallableKernel
return (frozenset(self.target.get_device_ast_builder().known_callables)
return (frozenset(self.target.known_device_callables)
| {name
for name, clbl in self.callables_table.items()
if isinstance(clbl, CallableKernel)})
Expand Down
2 changes: 1 addition & 1 deletion loopy/translation_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def resolve_callables(t_unit: TranslationUnit) -> TranslationUnit:
# get registered callables
known_callables = dict(t_unit.callables_table)
# get target specific callables
known_callables.update(t_unit.target.get_device_ast_builder().known_callables)
known_callables.update(t_unit.target.known_device_callables)
# get loopy specific callables
known_callables.update(get_loopy_callables())

Expand Down
Loading