-
Notifications
You must be signed in to change notification settings - Fork 882
/
Copy pathmodel_loader.py
192 lines (162 loc) · 5.54 KB
/
model_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
Model loader.
"""
import importlib
import json
import logging
import os
import uuid
from abc import ABCMeta, abstractmethod
from builtins import str
from typing import Optional
from ts.metrics.metric_cache_yaml_impl import MetricsCacheYamlImpl
from ts.service import Service
from .utils.util import list_classes_from_module
class ModelLoaderFactory(object):
"""
ModelLoaderFactory
"""
@staticmethod
def get_model_loader():
return TsModelLoader()
class ModelLoader(object):
"""
Base Model Loader class.
"""
__metaclass__ = ABCMeta
@abstractmethod
def load(
self,
model_name: str,
model_dir: str,
handler: Optional[str] = None,
gpu_id: Optional[int] = None,
batch_size: Optional[int] = None,
envelope: Optional[str] = None,
limit_max_image_pixels: Optional[bool] = True,
):
"""
Load model from file.
:param model_name:
:param model_dir:
:param handler:
:param gpu_id:
:param batch_size:
:param envelope:
:param limit_max_image_pixels:
:return: Model
"""
# pylint: disable=unnecessary-pass
pass
class TsModelLoader(ModelLoader):
"""
TorchServe 1.0 Model Loader
"""
def load(
self,
model_name: str,
model_dir: str,
handler: Optional[str] = None,
gpu_id: Optional[int] = None,
batch_size: Optional[int] = None,
envelope: Optional[str] = None,
limit_max_image_pixels: Optional[bool] = True,
metrics_cache: Optional[MetricsCacheYamlImpl] = None,
) -> Service:
"""
Load TorchServe 1.0 model from file.
:param model_name:
:param model_dir:
:param handler:
:param gpu_id:
:param batch_size:
:param envelope:
:param limit_max_image_pixels:
:param metrics_cache: MetricsCacheYamlImpl object
:return:
"""
logging.debug("Loading model - working dir: %s", os.getcwd())
# Backwards Compatibility with releases <=0.6.0
# Request ID is not set for model load requests
# TODO: UUID serves as a temporary request ID for model load requests
if metrics_cache is not None:
metrics_cache.set_request_ids(str(uuid.uuid4()))
manifest_file = os.path.join(model_dir, "MAR-INF", "MANIFEST.json")
manifest = None
if os.path.exists(manifest_file):
with open(manifest_file) as f:
manifest = json.load(f)
function_name = None
try:
module, function_name = self._load_handler_file(handler)
except ImportError:
module = self._load_default_handler(handler)
if module is None:
raise ValueError(
"Unable to load module {}, make sure it is added to python path".format(
handler
)
)
function_name = function_name or "handle"
if hasattr(module, function_name):
entry_point, initialize_fn = self._get_function_entry_point(
module, function_name
)
else:
entry_point, initialize_fn = self._get_class_entry_point(module)
if envelope is not None:
envelope_class = self._load_default_envelope(envelope)
if envelope_class is not None:
envelope_instance = envelope_class(entry_point)
entry_point = envelope_instance.handle
service = Service(
model_name,
model_dir,
manifest,
entry_point,
gpu_id,
batch_size,
limit_max_image_pixels,
metrics_cache,
)
initialize_fn(service.context)
return service
def _load_handler_file(self, handler):
temp = handler.split(":", 1)
module_name = temp[0]
if module_name.endswith(".py"):
module_name = module_name[:-3]
module_name = module_name.split("/")[-1]
module = importlib.import_module(module_name)
function_name = None if len(temp) == 1 else temp[1]
return module, function_name
def _load_default_handler(self, handler):
module_name = ".{0}".format(handler)
module = importlib.import_module(module_name, "ts.torch_handler")
return module
def _load_default_envelope(self, envelope):
module_name = ".{0}".format(envelope)
module = importlib.import_module(
module_name, "ts.torch_handler.request_envelope"
)
envelope_class = list_classes_from_module(module)[0]
return envelope_class
def _get_function_entry_point(self, module, function_name):
entry_point = getattr(module, function_name)
initialize_fn = lambda ctx: entry_point(None, ctx)
return entry_point, initialize_fn
def _get_class_entry_point(self, module):
model_class_definitions = list_classes_from_module(module)
if len(model_class_definitions) != 1:
raise ValueError(
"Expected only one class in custom service code or a function entry point {}".format(
model_class_definitions
)
)
model_class = model_class_definitions[0]
model_service = model_class()
if not hasattr(model_service, "handle"):
raise ValueError(
"Expect handle method in class {}".format(str(model_class))
)
return model_service.handle, model_service.initialize