1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260 | from collections import deque
import gc
import logging
import threading
from dask.utils import format_bytes
from .compatibility import PYPY
from .metrics import thread_time
logger = _logger = logging.getLogger(__name__)
class ThrottledGC:
"""Wrap gc.collect to protect against excessively repeated calls.
Allows to run throttled garbage collection in the workers as a
countermeasure to e.g.: https://github.com/dask/zict/issues/19
collect() does nothing when repeated calls are so costly and so frequent
that the thread would spend more than max_in_gc_frac doing GC.
warn_if_longer is a duration in seconds (10s by default) that can be used
to log a warning level message whenever an actual call to gc.collect()
lasts too long.
"""
def __init__(self, max_in_gc_frac=0.05, warn_if_longer=1, logger=None):
self.max_in_gc_frac = max_in_gc_frac
self.warn_if_longer = warn_if_longer
self.last_collect = thread_time()
self.last_gc_duration = 0
self.logger = logger if logger is not None else _logger
def collect(self):
# In case of non-monotonicity in the clock, assume that any Python
# operation lasts at least 1e-6 second.
MIN_RUNTIME = 1e-6
collect_start = thread_time()
elapsed = max(collect_start - self.last_collect, MIN_RUNTIME)
if self.last_gc_duration / elapsed < self.max_in_gc_frac:
self.logger.debug(
"Calling gc.collect(). %0.3fs elapsed since previous call.", elapsed
)
gc.collect()
self.last_collect = collect_start
self.last_gc_duration = max(thread_time() - collect_start, MIN_RUNTIME)
if self.last_gc_duration > self.warn_if_longer:
self.logger.warning(
"gc.collect() took %0.3fs. This is usually"
" a sign that some tasks handle too"
" many Python objects at the same time."
" Rechunking the work into smaller tasks"
" might help.",
self.last_gc_duration,
)
else:
self.logger.debug("gc.collect() took %0.3fs", self.last_gc_duration)
else:
self.logger.debug(
"gc.collect() lasts %0.3fs but only %0.3fs "
"elapsed since last call: throttling.",
self.last_gc_duration,
elapsed,
)
class FractionalTimer:
"""
An object that measures runtimes, accumulates them and computes
a running fraction of the recent runtimes over the corresponding
elapsed time.
"""
MULT = 1e9 # convert to nanoseconds
def __init__(self, n_samples, timer=thread_time):
self._timer = timer
self._n_samples = n_samples
self._start_stops = deque()
self._durations = deque()
self._cur_start = None
self._running_sum = None
self._running_fraction = None
def _add_measurement(self, start, stop):
start_stops = self._start_stops
durations = self._durations
if stop < start or (start_stops and start < start_stops[-1][1]):
# Ignore if non-monotonic
return
# Use integers to ensure exact running sum computation
duration = int((stop - start) * self.MULT)
start_stops.append((start, stop))
durations.append(duration)
n = len(durations)
assert n == len(start_stops)
if n >= self._n_samples:
if self._running_sum is None:
assert n == self._n_samples
self._running_sum = sum(durations)
else:
old_start, old_stop = start_stops.popleft()
old_duration = durations.popleft()
self._running_sum += duration - old_duration
if stop >= old_start:
self._running_fraction = (
self._running_sum / (stop - old_stop) / self.MULT
)
def start_timing(self):
assert self._cur_start is None
self._cur_start = self._timer()
def stop_timing(self):
stop = self._timer()
start = self._cur_start
self._cur_start = None
assert start is not None
self._add_measurement(start, stop)
@property
def running_fraction(self):
return self._running_fraction
class GCDiagnosis:
"""
An object that hooks itself into the gc callbacks to collect
timing and memory statistics, and log interesting info.
Don't instantiate this directly except for tests.
Instead, use the global instance.
"""
N_SAMPLES = 30
def __init__(self, warn_over_frac=0.1, info_over_rss_win=10 * 1e6):
self._warn_over_frac = warn_over_frac
self._info_over_rss_win = info_over_rss_win
self._enabled = False
def enable(self):
if PYPY:
return
assert not self._enabled
self._fractional_timer = FractionalTimer(n_samples=self.N_SAMPLES)
try:
import psutil
except ImportError:
self._proc = None
else:
self._proc = psutil.Process()
cb = self._gc_callback
assert cb not in gc.callbacks
# NOTE: a global ref to self is saved there so __del__ can't work
gc.callbacks.append(cb)
self._enabled = True
def disable(self):
if PYPY:
return
assert self._enabled
gc.callbacks.remove(self._gc_callback)
self._enabled = False
@property
def enabled(self):
return self._enabled
def __enter__(self):
self.enable()
return self
def __exit__(self, *args):
self.disable()
def _gc_callback(self, phase, info):
# Young generations are small and collected very often,
# don't waste time measuring them
if info["generation"] != 2:
return
if self._proc is not None:
rss = self._proc.memory_info().rss
else:
rss = 0
if phase == "start":
self._fractional_timer.start_timing()
self._gc_rss_before = rss
return
assert phase == "stop"
self._fractional_timer.stop_timing()
frac = self._fractional_timer.running_fraction
if frac is not None and frac >= self._warn_over_frac:
logger.warning(
"full garbage collections took %d%% CPU time "
"recently (threshold: %d%%)",
100 * frac,
100 * self._warn_over_frac,
)
rss_saved = self._gc_rss_before - rss
if rss_saved >= self._info_over_rss_win:
logger.info(
"full garbage collection released %s "
"from %d reference cycles (threshold: %s)",
format_bytes(rss_saved),
info["collected"],
format_bytes(self._info_over_rss_win),
)
if info["uncollectable"] > 0:
# This should ideally never happen on Python 3, but who knows?
logger.warning(
"garbage collector couldn't collect %d objects, "
"please look in gc.garbage",
info["uncollectable"],
)
_gc_diagnosis = GCDiagnosis()
_gc_diagnosis_users = 0
_gc_diagnosis_lock = threading.Lock()
def enable_gc_diagnosis():
"""
Ask to enable global GC diagnosis.
"""
if PYPY:
return
global _gc_diagnosis_users
with _gc_diagnosis_lock:
if _gc_diagnosis_users == 0:
_gc_diagnosis.enable()
else:
assert _gc_diagnosis.enabled
_gc_diagnosis_users += 1
def disable_gc_diagnosis(force=False):
"""
Ask to disable global GC diagnosis.
"""
if PYPY:
return
global _gc_diagnosis_users
with _gc_diagnosis_lock:
if _gc_diagnosis_users > 0:
_gc_diagnosis_users -= 1
if _gc_diagnosis_users == 0:
_gc_diagnosis.disable()
elif force:
_gc_diagnosis.disable()
_gc_diagnosis_users = 0
else:
assert _gc_diagnosis.enabled
|