Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix misc. issues surrounding ProfileStats #1121

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions aesara/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import logging
from copy import copy
from typing import Optional
from typing import Optional, Union

from aesara.compile.function.types import Function, UnusedInputError, orig_function
from aesara.compile.io import In, Out
Expand Down Expand Up @@ -282,7 +282,7 @@ def pfunc(
name=None,
rebuild_strict=True,
allow_input_downcast=None,
profile=None,
profile: Optional[Union[bool, str, ProfileStats]] = None,
on_unused_input=None,
output_keys=None,
fgraph: Optional[FunctionGraph] = None,
Expand Down Expand Up @@ -322,13 +322,13 @@ def pfunc(
general, or precise, type. None (default) is almost like
False, but allows downcasting of Python float scalars to
floatX.
profile : None, True, str, or ProfileStats instance
Accumulate profiling information into a given ProfileStats instance.
profile
Accumulate profiling information into a given `ProfileStats` instance.
None is the default, and means to use the value of config.profile.
If argument is `True` then a new ProfileStats instance will be used.
If argument is a string, a new ProfileStats instance will be created
If argument is ``True`` then a new `ProfileStats` instance will be used.
If argument is a string, a new `ProfileStats` instance will be created
with that string as its `message` attribute. This profiling object will
be available via self.profile.
be available via `Function.profile`.
on_unused_input : {'raise', 'warn','ignore', None}
What to do if a variable in the 'inputs' list is not used in the graph.
fgraph
Expand Down
10 changes: 6 additions & 4 deletions aesara/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import aesara.compile.profiling
from aesara.compile.io import In, SymbolicInput, SymbolicOutput
from aesara.compile.ops import deep_copy_op, view_op
from aesara.compile.profiling import ProfileStats
from aesara.configdefaults import config
from aesara.graph.basic import (
Constant,
Expand Down Expand Up @@ -731,10 +732,10 @@ def checkSV(sv_ori, sv_rpl):
message = name
else:
message = str(profile.message) + " copy"
profile = aesara.compile.profiling.ProfileStats(message=message)
profile = ProfileStats(message=message)
# profile -> object
elif isinstance(profile, str):
profile = aesara.compile.profiling.ProfileStats(message=profile)
profile = ProfileStats(message=profile)

f_cpy = maker.__class__(
inputs=ins,
Expand Down Expand Up @@ -1688,7 +1689,7 @@ def orig_function(
mode=None,
accept_inplace=False,
name=None,
profile=None,
profile: Optional[ProfileStats] = None,
on_unused_input=None,
output_keys=None,
fgraph: Optional[FunctionGraph] = None,
Expand All @@ -1712,7 +1713,8 @@ def orig_function(
accept_inplace : bool
True iff the graph can contain inplace operations prior to the
rewrite phase (default is False).
profile : None or ProfileStats instance
profile :
`ProfileStats` instance.
on_unused_input : {'raise', 'warn', 'ignore', None}
What to do if a variable in the 'inputs' list is not used in the graph.
output_keys
Expand Down
147 changes: 20 additions & 127 deletions aesara/compile/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,12 @@
# TODO: what to do about 'diff summary'? (ask Fred?)
#

import atexit
import copy
import logging
import operator
import sys
import time
from collections import defaultdict
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import numpy as np

Expand All @@ -41,107 +38,11 @@ def extended_open(filename, mode="r"):
yield f


logger = logging.getLogger("aesara.compile.profiling")

aesara_imported_time: float = time.time()
total_fct_exec_time: float = 0.0
total_graph_rewrite_time: float = 0.0
total_time_linker: float = 0.0

_atexit_print_list: List["ProfileStats"] = []
_atexit_registered: bool = False


def _atexit_print_fn():
"""Print `ProfileStat` objects in `_atexit_print_list` to `_atexit_print_file`."""
if config.profile:
to_sum = []

if config.profiling__destination == "stderr":
destination_file = "<stderr>"
elif config.profiling__destination == "stdout":
destination_file = "<stdout>"
else:
destination_file = config.profiling__destination

with extended_open(destination_file, mode="w"):

# Reverse sort in the order of compile+exec time
for ps in sorted(
_atexit_print_list, key=lambda a: a.compile_time + a.fct_call_time
)[::-1]:
if (
ps.fct_callcount >= 1
or ps.compile_time > 1
or getattr(ps, "callcount", 0) > 1
):
ps.summary(
file=destination_file,
n_ops_to_print=config.profiling__n_ops,
n_apply_to_print=config.profiling__n_apply,
)

if ps.show_sum:
to_sum.append(ps)
else:
# TODO print the name if there is one!
print("Skipping empty Profile")
if len(to_sum) > 1:
# Make a global profile
cum = copy.copy(to_sum[0])
msg = f"Sum of all({len(to_sum)}) printed profiles at exit."
cum.message = msg
for ps in to_sum[1:]:
for attr in [
"compile_time",
"fct_call_time",
"fct_callcount",
"vm_call_time",
"rewriter_time",
"linker_time",
"validate_time",
"import_time",
"linker_node_make_thunks",
]:
setattr(cum, attr, getattr(cum, attr) + getattr(ps, attr))

# merge dictionary
for attr in [
"apply_time",
"apply_callcount",
"apply_cimpl",
"variable_shape",
"variable_strides",
"variable_offset",
"linker_make_thunk_time",
]:
cum_attr = getattr(cum, attr)
for key, val in getattr(ps, attr.items()):
assert key not in cum_attr, (key, cum_attr)
cum_attr[key] = val

if cum.rewriter_profile and ps.rewriter_profile:
try:
merge = cum.rewriter_profile[0].merge_profile(
cum.rewriter_profile[1], ps.rewriter_profile[1]
)
assert len(merge) == len(cum.rewriter_profile[1])
cum.rewriter_profile = (cum.rewriter_profile[0], merge)
except Exception as e:
print(e)
cum.rewriter_profile = None
else:
cum.rewriter_profile = None

cum.summary(
file=destination_file,
n_ops_to_print=config.profiling__n_ops,
n_apply_to_print=config.profiling__n_apply,
)

if config.print_global_stats:
print_global_stats()


def print_global_stats():
"""
Expand Down Expand Up @@ -190,26 +91,12 @@ class ProfileStats:

Parameters
----------
atexit_print : bool
True means that this object will be printed to stderr (using .summary())
at the end of the program.
**kwargs : misc initializers
These should (but need not) match the names of the class vars declared
in this class.

"""

def reset(self):
"""Ignore previous function call"""
# self.compile_time = 0.
self.fct_call_time = 0.0
self.fct_callcount = 0
self.vm_call_time = 0.0
self.apply_time = {}
self.apply_callcount = {}
# self.apply_cimpl = None
# self.message = None

#
# Note on implementation:
# Class variables are used here so that each one can be
Expand Down Expand Up @@ -277,7 +164,7 @@ def reset(self):

linker_make_thunk_time: Dict = {}

line_width = config.profiling__output_line_width
line_width: int = config.profiling__output_line_width

nb_nodes: int = -1
# The number of nodes in the graph. We need the information separately in
Expand All @@ -289,7 +176,7 @@ def reset(self):

# param is called flag_time_thunks because most other attributes with time
# in the name are times *of* something, rather than configuration flags.
def __init__(self, atexit_print=True, flag_time_thunks=None, **kwargs):
def __init__(self, flag_time_thunks=None, message=None):
self.apply_callcount = {}
self.output_size = {}
# Keys are `(FunctionGraph, Variable)`
Expand All @@ -298,20 +185,25 @@ def __init__(self, atexit_print=True, flag_time_thunks=None, **kwargs):
self.variable_shape = {}
self.variable_strides = {}
self.variable_offset = {}
self.message = message
if flag_time_thunks is None:
self.flag_time_thunks = config.profiling__time_thunks
else:
self.flag_time_thunks = flag_time_thunks
self.__dict__.update(kwargs)
if atexit_print:
global _atexit_print_list
_atexit_print_list.append(self)
global _atexit_registered
if not _atexit_registered:
atexit.register(_atexit_print_fn)
_atexit_registered = True

self.ignore_first_call = config.profiling__ignore_first_call

def reset(self):
"""Ignore previous function call"""
# self.compile_time = 0.
self.fct_call_time = 0.0
self.fct_callcount = 0
self.vm_call_time = 0.0
self.apply_time = {}
self.apply_callcount = {}
self.apply_cimpl = None
self.message = None

def class_time(self):
"""
dict op -> total time on thunks
Expand Down Expand Up @@ -360,7 +252,7 @@ def class_impl(self):
rval = {}
for (fgraph, node) in self.apply_callcount:
typ = type(node.op)
if self.apply_cimpl[node]:
if self.apply_cimpl and self.apply_cimpl[node]:
impl = "C "
else:
impl = "Py"
Expand Down Expand Up @@ -438,7 +330,7 @@ def op_impl(self):
# timing is stored by node, we compute timing by Op on demand
rval = {}
for (fgraph, node) in self.apply_callcount:
if self.apply_cimpl[node]:
if self.apply_cimpl and self.apply_cimpl[node]:
rval[node.op] = "C "
else:
rval[node.op] = "Py"
Expand Down Expand Up @@ -785,7 +677,8 @@ def summary_nodes(self, file=sys.stderr, N=None):
def summary_function(self, file):
print("Function profiling", file=file)
print("==================", file=file)
print(f" Message: {self.message}", file=file)
if self.message:
print(f" Message: {self.message}", file=file)
print(
f" Time in {self.fct_callcount} calls to Function.__call__: {self.fct_call_time:e}s",
file=file,
Expand Down
12 changes: 6 additions & 6 deletions aesara/scan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,13 @@ def __init__(self, condition):


class ScanProfileStats(ProfileStats):
show_sum = False
callcount = 0
nbsteps = 0
call_time = 0.0
show_sum: bool = False
callcount: int = 0
nbsteps: int = 0
call_time: float = 0.0

def __init__(self, atexit_print=True, name=None, **kwargs):
super().__init__(atexit_print, **kwargs)
def __init__(self, name: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.name = name

def summary_globals(self, file):
Expand Down
27 changes: 19 additions & 8 deletions tests/compile/test_profiling.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Test of memory profiling


from io import StringIO

import numpy as np
import pytest

import aesara.tensor as at
from aesara.compile import ProfileStats
Expand All @@ -13,8 +11,13 @@
from aesara.tensor.type import fvector, scalars


pytestmark = pytest.mark.filterwarnings("error")


class TestProfiling:
# Test of Aesara profiling with min_peak_memory=True
"""
Test Aesara profiling with ``min_peak_memory=True``.
"""

def test_profiling(self):

Expand All @@ -32,14 +35,17 @@ def test_profiling(self):
z += [at.outer(x[i], x[i + 1]).sum(axis=1) for i in range(len(x) - 1)]
z += [x[i] + x[i + 1] for i in range(len(x) - 1)]

p = ProfileStats(False, gpu_checks=False)
p = ProfileStats()

if config.mode in ("DebugMode", "DEBUG_MODE", "FAST_COMPILE"):
m = "FAST_RUN"
else:
m = None

f = function(x, z, profile=p, name="test_profiling", mode=m)
with pytest.warns(
UserWarning, match=".*CVM does not support memory profiling.*"
):
f = function(x, z, profile=p, name="test_profiling", mode=m)

inp = [np.arange(1024, dtype="float32") + 1 for i in range(len(x))]
f(*inp)
Expand Down Expand Up @@ -87,14 +93,19 @@ def test_ifelse(self):

z = ifelse(at.lt(a, b), x * 2, y * 2)

p = ProfileStats(False, gpu_checks=False)
p = ProfileStats()

if config.mode in ("DebugMode", "DEBUG_MODE", "FAST_COMPILE"):
m = "FAST_RUN"
else:
m = None

f_ifelse = function([a, b, x, y], z, profile=p, name="test_ifelse", mode=m)
with pytest.warns(
UserWarning, match=".*CVM does not support memory profiling.*"
):
f_ifelse = function(
[a, b, x, y], z, profile=p, name="test_ifelse", mode=m
)

val1 = 0.0
val2 = 1.0
Expand Down
Loading