Source code distributed/utils_perf.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
from collections import deque
import gc
import logging
import threading

from dask.utils import format_bytes

from .compatibility import PYPY
from .metrics import thread_time


logger = _logger = logging.getLogger(__name__)


class ThrottledGC:
    """Wrap gc.collect to protect against excessively repeated calls.

    Allows to run throttled garbage collection in the workers as a
    countermeasure to e.g.: https://github.com/dask/zict/issues/19

    collect() does nothing when repeated calls are so costly and so frequent
    that the thread would spend more than max_in_gc_frac doing GC.

    warn_if_longer is a duration in seconds (10s by default) that can be used
    to log a warning level message whenever an actual call to gc.collect()
    lasts too long.
    """

    def __init__(self, max_in_gc_frac=0.05, warn_if_longer=1, logger=None):
        self.max_in_gc_frac = max_in_gc_frac
        self.warn_if_longer = warn_if_longer
        self.last_collect = thread_time()
        self.last_gc_duration = 0
        self.logger = logger if logger is not None else _logger

    def collect(self):
        # In case of non-monotonicity in the clock, assume that any Python
        # operation lasts at least 1e-6 second.
        MIN_RUNTIME = 1e-6

        collect_start = thread_time()
        elapsed = max(collect_start - self.last_collect, MIN_RUNTIME)
        if self.last_gc_duration / elapsed < self.max_in_gc_frac:
            self.logger.debug(
                "Calling gc.collect(). %0.3fs elapsed since previous call.", elapsed
            )
            gc.collect()
            self.last_collect = collect_start
            self.last_gc_duration = max(thread_time() - collect_start, MIN_RUNTIME)
            if self.last_gc_duration > self.warn_if_longer:
                self.logger.warning(
                    "gc.collect() took %0.3fs. This is usually"
                    " a sign that some tasks handle too"
                    " many Python objects at the same time."
                    " Rechunking the work into smaller tasks"
                    " might help.",
                    self.last_gc_duration,
                )
            else:
                self.logger.debug("gc.collect() took %0.3fs", self.last_gc_duration)
        else:
            self.logger.debug(
                "gc.collect() lasts %0.3fs but only %0.3fs "
                "elapsed since last call: throttling.",
                self.last_gc_duration,
                elapsed,
            )


class FractionalTimer:
    """
    An object that measures runtimes, accumulates them and computes
    a running fraction of the recent runtimes over the corresponding
    elapsed time.
    """

    MULT = 1e9  # convert to nanoseconds

    def __init__(self, n_samples, timer=thread_time):
        self._timer = timer
        self._n_samples = n_samples
        self._start_stops = deque()
        self._durations = deque()
        self._cur_start = None
        self._running_sum = None
        self._running_fraction = None

    def _add_measurement(self, start, stop):
        start_stops = self._start_stops
        durations = self._durations
        if stop < start or (start_stops and start < start_stops[-1][1]):
            # Ignore if non-monotonic
            return

        # Use integers to ensure exact running sum computation
        duration = int((stop - start) * self.MULT)
        start_stops.append((start, stop))
        durations.append(duration)

        n = len(durations)
        assert n == len(start_stops)
        if n >= self._n_samples:
            if self._running_sum is None:
                assert n == self._n_samples
                self._running_sum = sum(durations)
            else:
                old_start, old_stop = start_stops.popleft()
                old_duration = durations.popleft()
                self._running_sum += duration - old_duration
                if stop >= old_start:
                    self._running_fraction = (
                        self._running_sum / (stop - old_stop) / self.MULT
                    )

    def start_timing(self):
        assert self._cur_start is None
        self._cur_start = self._timer()

    def stop_timing(self):
        stop = self._timer()
        start = self._cur_start
        self._cur_start = None
        assert start is not None
        self._add_measurement(start, stop)

    @property
    def running_fraction(self):
        return self._running_fraction


class GCDiagnosis:
    """
    An object that hooks itself into the gc callbacks to collect
    timing and memory statistics, and log interesting info.

    Don't instantiate this directly except for tests.
    Instead, use the global instance.
    """

    N_SAMPLES = 30

    def __init__(self, warn_over_frac=0.1, info_over_rss_win=10 * 1e6):
        self._warn_over_frac = warn_over_frac
        self._info_over_rss_win = info_over_rss_win
        self._enabled = False

    def enable(self):
        if PYPY:
            return
        assert not self._enabled
        self._fractional_timer = FractionalTimer(n_samples=self.N_SAMPLES)
        try:
            import psutil
        except ImportError:
            self._proc = None
        else:
            self._proc = psutil.Process()

        cb = self._gc_callback
        assert cb not in gc.callbacks
        # NOTE: a global ref to self is saved there so __del__ can't work
        gc.callbacks.append(cb)
        self._enabled = True

    def disable(self):
        if PYPY:
            return
        assert self._enabled
        gc.callbacks.remove(self._gc_callback)
        self._enabled = False

    @property
    def enabled(self):
        return self._enabled

    def __enter__(self):
        self.enable()
        return self

    def __exit__(self, *args):
        self.disable()

    def _gc_callback(self, phase, info):
        # Young generations are small and collected very often,
        # don't waste time measuring them
        if info["generation"] != 2:
            return
        if self._proc is not None:
            rss = self._proc.memory_info().rss
        else:
            rss = 0
        if phase == "start":
            self._fractional_timer.start_timing()
            self._gc_rss_before = rss
            return
        assert phase == "stop"
        self._fractional_timer.stop_timing()
        frac = self._fractional_timer.running_fraction
        if frac is not None and frac >= self._warn_over_frac:
            logger.warning(
                "full garbage collections took %d%% CPU time "
                "recently (threshold: %d%%)",
                100 * frac,
                100 * self._warn_over_frac,
            )
        rss_saved = self._gc_rss_before - rss
        if rss_saved >= self._info_over_rss_win:
            logger.info(
                "full garbage collection released %s "
                "from %d reference cycles (threshold: %s)",
                format_bytes(rss_saved),
                info["collected"],
                format_bytes(self._info_over_rss_win),
            )
        if info["uncollectable"] > 0:
            # This should ideally never happen on Python 3, but who knows?
            logger.warning(
                "garbage collector couldn't collect %d objects, "
                "please look in gc.garbage",
                info["uncollectable"],
            )


_gc_diagnosis = GCDiagnosis()
_gc_diagnosis_users = 0
_gc_diagnosis_lock = threading.Lock()


def enable_gc_diagnosis():
    """
    Ask to enable global GC diagnosis.
    """
    if PYPY:
        return
    global _gc_diagnosis_users
    with _gc_diagnosis_lock:
        if _gc_diagnosis_users == 0:
            _gc_diagnosis.enable()
        else:
            assert _gc_diagnosis.enabled
        _gc_diagnosis_users += 1


def disable_gc_diagnosis(force=False):
    """
    Ask to disable global GC diagnosis.
    """
    if PYPY:
        return
    global _gc_diagnosis_users
    with _gc_diagnosis_lock:
        if _gc_diagnosis_users > 0:
            _gc_diagnosis_users -= 1
            if _gc_diagnosis_users == 0:
                _gc_diagnosis.disable()
            elif force:
                _gc_diagnosis.disable()
                _gc_diagnosis_users = 0
            else:
                assert _gc_diagnosis.enabled