Skip to content

Commit 0d26c62

Browse files
committed
added a jax example
1 parent 48dbbbe commit 0d26c62

File tree

5 files changed

+261
-1
lines changed

5 files changed

+261
-1
lines changed

Quick_Deploy/JAX/README.md

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
<!--
2+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
-->
28+
29+
# Deploying a JAX Model
30+
31+
This README showcases how to deploy a simple ResNet model on Triton Inference Server. While Triton doesn't yet have a dedicated JAX backend, JAX/Flax models can be deployed using [Python Backend](https://github.com/triton-inference-server/python_backend). If you are new to Triton, it is recommended to watch this [getting started video](https://www.youtube.com/watch?v=NQDtfSi5QF4) and review [Part 1](https://github.com/triton-inference-server/tutorials/tree/main/Conceptual_Guide/Part_1-model_deployment) of the conceptual guide before proceeding. For the purposes of demonstration, we are using a pre-trained model provided by [flaxmodels](https://github.com/matthias-wright/flaxmodels).
32+
33+
## Step 1: Set Up Triton Inference Server
34+
35+
To use Triton, we need to build a model repository. The structure of the repository as follows:
36+
```
37+
model_repository
38+
|
39+
+-- resnet50
40+
|
41+
+-- config.pbtxt
42+
+-- 1
43+
|
44+
+-- model.py
45+
```
46+
For this example, we have pre-built the model repository. Next, we install the required dependencies and launch the Triton Inference Server.
47+
48+
```
49+
# Replace the yy.mm in the image name with the release year and month
50+
# of the Triton version needed, eg. 22.12
51+
docker run --gpus=all -it --shm-size=256m --rm -p8000:8000 -p8001:8001 -p8002:8002 -v$(pwd):/workspace/ -v/$(pwd)/model_repository:/models nvcr.io/nvidia/tritonserver:<yy.mm>-py3 bash
52+
53+
pip install --upgrade pip
54+
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
55+
pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git
56+
57+
```
58+
59+
## Step 2: Using a Triton Client to Query the Server
60+
61+
Let's breakdown the client application. First, we setup a connection with the Triton Inference Server.
62+
```
63+
client = httpclient.InferenceServerClient(url="localhost:8000")
64+
```
65+
Then we set the input and output arrays.
66+
```
67+
# Set Inputs
68+
input_tensors = [
69+
httpclient.InferInput("image", image.shape, datatype="FP32")
70+
]
71+
input_tensors[0].set_data_from_numpy(image)
72+
73+
# Set outputs
74+
outputs = [
75+
httpclient.InferRequestedOutput("fc_out")
76+
]
77+
```
78+
Lastly, we query send a request to the Triton Inference Server.
79+
80+
```
81+
# Query
82+
query_response = client.infer(model_name="resnet50",
83+
inputs=input_tensors,
84+
outputs=outputs)
85+
86+
# Output
87+
out = query_response.as_numpy("fc_out")
88+
```
89+

Quick_Deploy/JAX/client.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
import numpy as np
28+
from tritonclient.utils import *
29+
from PIL import Image
30+
import tritonclient.http as httpclient
31+
import requests
32+
33+
34+
def main():
35+
client = httpclient.InferenceServerClient(url="localhost:8000")
36+
37+
# Inputs
38+
url = "http://images.cocodataset.org/val2017/000000161642.jpg"
39+
image = np.asarray(Image.open(requests.get(url, stream=True).raw)).astype(np.float32)
40+
image = np.expand_dims(image, axis=0)
41+
42+
# Set Inputs
43+
input_tensors = [
44+
httpclient.InferInput("image", image.shape, datatype="FP32")
45+
]
46+
input_tensors[0].set_data_from_numpy(image)
47+
48+
# Set outputs
49+
outputs = [
50+
httpclient.InferRequestedOutput("fc_out")
51+
]
52+
53+
# Query
54+
query_response = client.infer(model_name="resnet50",
55+
inputs=input_tensors,
56+
outputs=outputs)
57+
58+
# Output
59+
out = query_response.as_numpy("fc_out")
60+
print(out.shape)
61+
62+
if __name__ == "__main__":
63+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
import triton_python_backend_utils as pb_utils
28+
import jax
29+
import jax.numpy as jnp
30+
import flaxmodels as fm
31+
32+
import numpy as np
33+
from flax.jax_utils import replicate
34+
35+
class TritonPythonModel:
36+
37+
def initialize(self, args):
38+
self.key = jax.random.PRNGKey(0)
39+
self.resnet18 = fm.ResNet18(output='logits', pretrained='imagenet')
40+
41+
42+
def execute(self, requests):
43+
responses = []
44+
for request in requests:
45+
inp = pb_utils.get_input_tensor_by_name(request, "image")
46+
input_image = inp.as_numpy()
47+
48+
params = self.resnet18.init(self.key, input_image)
49+
out = self.resnet18.apply(params, input_image, train=False)
50+
51+
inference_response = pb_utils.InferenceResponse(output_tensors=[
52+
pb_utils.Tensor(
53+
"fc_out",
54+
np.array(out),
55+
)
56+
])
57+
responses.append(inference_response)
58+
return responses
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
name: "resnet50"
28+
backend: "python"
29+
max_batch_size: 8
30+
31+
input [
32+
{
33+
name: "image"
34+
data_type: TYPE_FP32
35+
dims: [-1, -1, -1]
36+
}
37+
]
38+
output [
39+
{
40+
name: "fc_out"
41+
data_type: TYPE_FP32
42+
dims: [-1, -1]
43+
}
44+
]
45+
46+
instance_group [
47+
{
48+
kind: KIND_GPU
49+
}
50+
]

Quick_Deploy/ONNX/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ wget -O model_repository/densenet_onnx/1/model.onnx \
5353
docker run --gpus all --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ${PWD}/model_repository:/models nvcr.io/nvidia/tritonserver:<xx.yy>-py3 tritonserver --model-repository=/models
5454
```
5555

56-
## Step 3: Using a Triton Client to Query the Server
56+
## Step 2: Using a Triton Client to Query the Server
5757

5858
Install dependencies & download an example image to test inference.
5959

0 commit comments

Comments
 (0)