-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathtextureGenPipeline.py
More file actions
178 lines (150 loc) · 6.96 KB
/
textureGenPipeline.py
File metadata and controls
178 lines (150 loc) · 6.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import torch
import copy
import trimesh
import numpy as np
from PIL import Image
from typing import List
from DifferentiableRenderer.MeshRender import MeshRender
from utils.simplify_mesh_utils import remesh_mesh
from utils.multiview_utils import multiviewDiffusionNet
from utils.pipeline_utils import ViewProcessor
from utils.image_super_utils import imageSuperNet
from utils.uvwrap_utils import mesh_uv_wrap
from DifferentiableRenderer.mesh_utils import convert_obj_to_glb
import warnings
warnings.filterwarnings("ignore")
from diffusers.utils import logging as diffusers_logging
diffusers_logging.set_verbosity(50)
class MaterialMVPConfig:
def __init__(self, max_num_view, resolution):
self.device = "cuda"
self.multiview_cfg_path = "cfgs/v1.yaml"
self.multiview_pretrained_path = "tencent/Hunyuan3D-2.1"
self.dino_ckpt_path = "facebook/dinov2-giant"
self.realesrgan_ckpt_path = "ckpt/RealESRGAN_x4plus.pth"
self.raster_mode = "cr"
self.bake_mode = "back_sample"
self.render_size = 1024 * 2
self.texture_size = 1024 * 4
self.max_selected_view_num = max_num_view
self.resolution = resolution
self.bake_exp = 4
self.merge_method = "fast"
# view selection
self.candidate_camera_azims = [0, 90, 180, 270, 0, 180]
self.candidate_camera_elevs = [0, 0, 0, 0, 90, -90]
self.candidate_view_weights = [1, 0.1, 0.5, 0.1, 0.05, 0.05]
for azim in range(0, 360, 30):
self.candidate_camera_azims.append(azim)
self.candidate_camera_elevs.append(20)
self.candidate_view_weights.append(0.01)
self.candidate_camera_azims.append(azim)
self.candidate_camera_elevs.append(-20)
self.candidate_view_weights.append(0.01)
class MaterialMVPPipeline:
def __init__(self, config=None) -> None:
self.config = config if config is not None else MaterialMVPConfig()
self.models = {}
self.stats_logs = {}
self.render = MeshRender(
default_resolution=self.config.render_size,
texture_size=self.config.texture_size,
bake_mode=self.config.bake_mode,
raster_mode=self.config.raster_mode,
)
self.view_processor = ViewProcessor(self.config, self.render)
self.load_models()
def load_models(self):
torch.cuda.empty_cache()
self.models["super_model"] = imageSuperNet(self.config)
self.models["multiview_model"] = multiviewDiffusionNet(self.config)
print("Models Loaded.")
@torch.no_grad()
def __call__(self, mesh_path=None, image_path=None, output_mesh_path=None, use_remesh=True, save_glb=True):
"""Generate texture for 3D mesh using multiview diffusion"""
# Ensure image_prompt is a list
if isinstance(image_path, str):
image_prompt = Image.open(image_path)
elif isinstance(image_path, Image.Image):
image_prompt = image_path
if not isinstance(image_prompt, List):
image_prompt = [image_prompt]
else:
image_prompt = image_path
# Process mesh
path = os.path.dirname(mesh_path)
if use_remesh:
processed_mesh_path = os.path.join(path, "white_mesh_remesh.obj")
remesh_mesh(mesh_path, processed_mesh_path)
else:
processed_mesh_path = mesh_path
# Output path
if output_mesh_path is None:
output_mesh_path = os.path.join(path, f"textured_mesh.obj")
# Load mesh
mesh = trimesh.load(processed_mesh_path)
mesh = mesh_uv_wrap(mesh)
self.render.load_mesh(mesh=mesh)
########### View Selection #########
selected_camera_elevs, selected_camera_azims, selected_view_weights = self.view_processor.bake_view_selection(
self.config.candidate_camera_elevs,
self.config.candidate_camera_azims,
self.config.candidate_view_weights,
self.config.max_selected_view_num,
)
normal_maps = self.view_processor.render_normal_multiview(
selected_camera_elevs, selected_camera_azims, use_abs_coor=True
)
position_maps = self.view_processor.render_position_multiview(selected_camera_elevs, selected_camera_azims)
########## Style ###########
image_caption = "high quality"
image_style = []
for image in image_prompt:
image = image.resize((512, 512))
if image.mode == "RGBA":
white_bg = Image.new("RGB", image.size, (255, 255, 255))
white_bg.paste(image, mask=image.getchannel("A"))
image = white_bg
image_style.append(image)
image_style = [image.convert("RGB") for image in image_style]
########### Multiview ##########
multiviews_pbr = self.models["multiview_model"](
image_style,
normal_maps + position_maps,
prompt=image_caption,
custom_view_size=self.config.resolution,
resize_input=True,
)
########### Enhance ##########
enhance_images = {}
enhance_images["albedo"] = copy.deepcopy(multiviews_pbr["albedo"])
enhance_images["mr"] = copy.deepcopy(multiviews_pbr["mr"])
for i in range(len(enhance_images["albedo"])):
enhance_images["albedo"][i] = self.models["super_model"](enhance_images["albedo"][i])
enhance_images["mr"][i] = self.models["super_model"](enhance_images["mr"][i])
########### Bake ##########
for i in range(len(enhance_images)):
enhance_images["albedo"][i] = enhance_images["albedo"][i].resize(
(self.config.render_size, self.config.render_size)
)
enhance_images["mr"][i] = enhance_images["mr"][i].resize((self.config.render_size, self.config.render_size))
texture, mask = self.view_processor.bake_from_multiview(
enhance_images["albedo"], selected_camera_elevs, selected_camera_azims, selected_view_weights
)
mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
texture_mr, mask_mr = self.view_processor.bake_from_multiview(
enhance_images["mr"], selected_camera_elevs, selected_camera_azims, selected_view_weights
)
mask_mr_np = (mask_mr.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
########## inpaint ###########
texture = self.view_processor.texture_inpaint(texture, mask_np)
self.render.set_texture(texture, force_set=True)
if "mr" in enhance_images:
texture_mr = self.view_processor.texture_inpaint(texture_mr, mask_mr_np)
self.render.set_texture_mr(texture_mr)
self.render.save_mesh(output_mesh_path, downsample=True)
if save_glb:
convert_obj_to_glb(output_mesh_path, output_mesh_path.replace(".obj", ".glb"))
output_glb_path = output_mesh_path.replace(".obj", ".glb")
return output_mesh_path