1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202 | import inspect
import logging
import os
import shutil
import sys
from typing import List
from types import ModuleType
import filecmp
from importlib import import_module
import click
from tornado.httpclient import AsyncHTTPClient
from dask.utils import tmpfile
from .utils import import_file
logger = logging.getLogger(__name__)
def validate_preload_argv(ctx, param, value):
"""Click option callback providing validation of preload subcommand arguments."""
if not value and not ctx.params.get("preload", None):
# No preload argv provided and no preload modules specified.
return value
if value and not ctx.params.get("preload", None):
# Report a usage error matching standard click error conventions.
unexpected_args = [v for v in value if v.startswith("-")]
for a in unexpected_args:
raise click.NoSuchOption(a)
raise click.UsageError(
"Got unexpected extra argument%s: (%s)"
% ("s" if len(value) > 1 else "", " ".join(value))
)
preload_modules = {
name: _import_module(name)
for name in ctx.params.get("preload")
if not is_webaddress(name)
}
preload_commands = [
getattr(m, "dask_setup", None)
for m in preload_modules.values()
if isinstance(getattr(m, "dask_setup", None), click.Command)
]
if len(preload_commands) > 1:
raise click.UsageError(
"Multiple --preload modules with click-configurable setup: %s"
% list(preload_modules.keys())
)
if value and not preload_commands:
raise click.UsageError(
"Unknown argument specified: %r Was click-configurable --preload target provided?"
)
if not preload_commands:
return value
else:
preload_command = preload_commands[0]
ctx = click.Context(preload_command, allow_extra_args=False)
preload_command.parse_args(ctx, list(value))
return value
def is_webaddress(s: str) -> bool:
return any(s.startswith(prefix) for prefix in ("http://", "https://"))
def _import_module(name, file_dir=None) -> ModuleType:
"""Imports module and extract preload interface functions.
Import modules specified by name and extract 'dask_setup'
and 'dask_teardown' if present.
Parameters
----------
name: str
Module name, file path, or text of module or script
file_dir: string
Path of a directory where files should be copied
Returns
-------
Nest dict of names to extracted module interface components if present
in imported module.
"""
if name.endswith(".py"):
# name is a file path
if file_dir is not None:
basename = os.path.basename(name)
copy_dst = os.path.join(file_dir, basename)
if os.path.exists(copy_dst):
if not filecmp.cmp(name, copy_dst):
logger.error("File name collision: %s", basename)
shutil.copy(name, copy_dst)
module = import_file(copy_dst)[0]
else:
module = import_file(name)[0]
elif " " not in name:
# name is a module name
if name not in sys.modules:
import_module(name)
module = sys.modules[name]
else:
# not a name, actually the text of the script
with tmpfile(extension=".py") as fn:
with open(fn, mode="w") as f:
f.write(name)
return _import_module(fn, file_dir=file_dir)
logger.info("Import preload module: %s", name)
return module
async def _download_module(url: str) -> ModuleType:
logger.info("Downloading preload at %s", url)
assert is_webaddress(url)
client = AsyncHTTPClient()
response = await client.fetch(url)
source = response.body.decode()
compiled = compile(source, url, "exec")
module = ModuleType(url)
exec(compiled, module.__dict__)
return module
class Preload:
"""
Manage state for setup/teardown of a preload module
Parameters
----------
dask_server: dask.distributed.Server
The Worker or Scheduler
name: str
module name, file name, or web address to load
argv: [string]
List of string arguments passed to click-configurable `dask_setup`.
file_dir: string
Path of a directory where files should be copied
"""
def __init__(self, dask_server, name: str, argv: List[str], file_dir: str):
self.dask_server = dask_server
self.name = name
self.argv = argv
self.file_dir = file_dir
if not is_webaddress(name):
self.module = _import_module(name, file_dir)
else:
self.module = None
async def start(self):
""" Run when the server finishes its start method """
if is_webaddress(self.name):
self.module = await _download_module(self.name)
dask_setup = getattr(self.module, "dask_setup", None)
if dask_setup:
if isinstance(dask_setup, click.Command):
context = dask_setup.make_context(
"dask_setup", list(self.argv), allow_extra_args=False
)
result = dask_setup.callback(
self.dask_server, *context.args, **context.params
)
if inspect.isawaitable(result):
await result
logger.info("Run preload setup click command: %s", self.name)
else:
future = dask_setup(self.dask_server)
if inspect.isawaitable(future):
await future
logger.info("Run preload setup function: %s", self.name)
async def teardown(self):
""" Run when the server starts its close method """
dask_teardown = getattr(self.module, "dask_teardown", None)
if dask_teardown:
future = dask_teardown(self.dask_server)
if inspect.isawaitable(future):
await future
def process_preloads(
dask_server, preload: List[str], preload_argv: List[List], file_dir: str = None
) -> List[Preload]:
if isinstance(preload, str):
preload = [preload]
return [Preload(dask_server, p, preload_argv, file_dir) for p in preload]
|