-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathllama_mesh_nodes.py
179 lines (152 loc) · 5.82 KB
/
llama_mesh_nodes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import gradio as gr
import os
# import spaces
from transformers import GemmaTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)
from trimesh.exchange.gltf import export_glb
import gradio as gr
import trimesh
import numpy as np
import tempfile
class ApplyGradientColor:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mesh_text": ("STRING", {"multiline": True}),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "apply_gradient_color"
CATEGORY = "LLaMA-Mesh"
def apply_gradient_color(self, mesh_text):
"""
Apply a gradient color to the mesh vertices based on the Y-axis and save as GLB.
Args:
mesh_text (str): The input mesh in OBJ format as a string.
Returns:
str: Path to the GLB file with gradient colors applied.
"""
# Load the mesh
temp_file = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
with open(temp_file+".obj", "w") as f:
f.write(mesh_text)
# return temp_file
mesh = trimesh.load_mesh(temp_file+".obj", file_type='obj')
# Get vertex coordinates
vertices = mesh.vertices
y_values = vertices[:, 1] # Y-axis values
# Normalize Y values to range [0, 1] for color mapping
y_normalized = (y_values - y_values.min()) / (y_values.max() - y_values.min())
# Generate colors: Map normalized Y values to RGB gradient (e.g., blue to red)
colors = np.zeros((len(vertices), 4)) # RGBA
colors[:, 0] = y_normalized # Red channel
colors[:, 2] = 1 - y_normalized # Blue channel
colors[:, 3] = 1.0 # Alpha channel (fully opaque)
# Attach colors to mesh vertices
mesh.visual.vertex_colors = colors
# Export to GLB format
glb_path = temp_file+".glb"
with open(glb_path, "wb") as f:
f.write(export_glb(mesh))
return glb_path
class VisualizeMesh:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mesh_text": ("STRING", {"multiline": True}),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "visualize_mesh"
CATEGORY = "LLaMA-Mesh"
def visualize_mesh(self, mesh_text):
"""
Convert the provided 3D mesh text into a visualizable format.
This function assumes the input is in OBJ format.
"""
temp_file = "temp_mesh.obj"
with open(temp_file, "w") as f:
f.write(mesh_text)
return temp_file
class ChatLLaMaMesh:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"message": ("STRING", {"multiline": True}),
"temperature": ("FLOAT", {
"default": 0.95,
"min": 0.0,
"max": 1.0,
"step": 0.1
}),
"max_new_tokens": ("INT", {
"default": 4096,
"min": 128,
"max": 8192,
"step": 1
}),
},
"optional": {
"history": ("CHAT_HISTORY", {"default": None}),
}
}
RETURN_TYPES = ("STRING", "CHAT_HISTORY", "STRING",)
RETURN_NAMES = ("response", "history", "mesh_text")
FUNCTION = "chat_llama3_8b"
CATEGORY = "LLaMA-Mesh"
def __init__(self):
# Load the tokenizer and model
model_path = "Zhengyi/LLaMA-Mesh"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
def chat_llama3_8b(self, message: str,
history: list,
temperature: float,
max_new_tokens: int
):
"""
Generate a streaming response using the llama3-8b model.
Args:
message (str): The input message.
history (list): The conversation history used by ChatInterface.
temperature (float): The temperature for generating the response.
max_new_tokens (int): The maximum number of new tokens to generate.
Returns:
str: The generated response.
"""
conversation = []
for user, assistant in history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt").to(self.model.device)
streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids= input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=self.terminators,
)
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
#print(outputs)
yield "".join(outputs)