notedeck

One damus client to rule them all
git clone git://jb55.com/notedeck
Log | Files | Refs | README | LICENSE

model.rs (19555B)


      1 use glam::{Vec3, Vec4};
      2 
      3 use crate::material::{MaterialGpu, MaterialUniform};
      4 use crate::texture::upload_rgba8_texture_2d;
      5 use std::collections::HashMap;
      6 use wgpu::util::DeviceExt;
      7 
      8 #[repr(C)]
      9 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
     10 pub struct Vertex {
     11     pub pos: [f32; 3],
     12     pub normal: [f32; 3],
     13     pub uv: [f32; 2],
     14     pub tangent: [f32; 4],
     15 }
     16 
     17 pub struct Mesh {
     18     pub num_indices: u32,
     19     pub vert_buf: wgpu::Buffer,
     20     pub ind_buf: wgpu::Buffer,
     21 }
     22 
     23 pub struct ModelDraw {
     24     pub mesh: Mesh,
     25     pub material_index: usize,
     26 }
     27 
     28 pub struct ModelData {
     29     pub draws: Vec<ModelDraw>,
     30     pub materials: Vec<MaterialGpu>,
     31     pub bounds: Aabb,
     32 }
     33 
     34 /// A model handle
     35 #[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Copy, Clone)]
     36 pub struct Model {
     37     pub id: u64,
     38 }
     39 
     40 struct GltfWgpuCache {
     41     samplers: Vec<Option<wgpu::Sampler>>,
     42     tex_views: HashMap<(usize, bool), wgpu::TextureView>,
     43 }
     44 
     45 impl GltfWgpuCache {
     46     pub fn new(doc: &gltf::Document) -> Self {
     47         Self {
     48             samplers: (0..doc.samplers().len()).map(|_| None).collect(),
     49             tex_views: HashMap::new(),
     50         }
     51     }
     52 
     53     fn ensure_sampler(
     54         &mut self,
     55         device: &wgpu::Device,
     56         sam: gltf::texture::Sampler<'_>,
     57     ) -> Option<usize> {
     58         let idx = sam.index()?;
     59         if self.samplers[idx].is_none() {
     60             let (min_f, mip_f) = map_min_filter(sam.min_filter());
     61             let samp = device.create_sampler(&wgpu::SamplerDescriptor {
     62                 label: Some("gltf_sampler"),
     63                 address_mode_u: map_wrap_mode(sam.wrap_s()),
     64                 address_mode_v: map_wrap_mode(sam.wrap_t()),
     65                 address_mode_w: wgpu::AddressMode::Repeat,
     66                 mag_filter: map_mag_filter(sam.mag_filter()),
     67                 min_filter: min_f,
     68                 mipmap_filter: mip_f,
     69                 ..Default::default()
     70             });
     71             self.samplers[idx] = Some(samp);
     72         }
     73         Some(idx)
     74     }
     75 
     76     fn sampler_ref(&self, idx: usize) -> &wgpu::Sampler {
     77         self.samplers[idx].as_ref().unwrap()
     78     }
     79 
     80     fn ensure_texture_view(
     81         &mut self,
     82         images: &[gltf::image::Data],
     83         device: &wgpu::Device,
     84         queue: &wgpu::Queue,
     85         tex: gltf::Texture<'_>,
     86         srgb: bool,
     87     ) -> (usize, bool) {
     88         let key = (tex.index(), srgb);
     89         self.tex_views.entry(key).or_insert_with(|| {
     90             let img = &images[tex.source().index()];
     91             let rgba8 = build_rgba(img);
     92 
     93             let format = if srgb {
     94                 wgpu::TextureFormat::Rgba8UnormSrgb
     95             } else {
     96                 wgpu::TextureFormat::Rgba8Unorm
     97             };
     98 
     99             upload_rgba8_texture_2d(
    100                 device, queue, img.width, img.height, &rgba8, format, "gltf_tex",
    101             )
    102         });
    103         key
    104     }
    105 
    106     fn view_ref(&self, key: (usize, bool)) -> &wgpu::TextureView {
    107         self.tex_views.get(&key).unwrap()
    108     }
    109 }
    110 
    111 impl Vertex {
    112     pub fn desc<'a>() -> wgpu::VertexBufferLayout<'a> {
    113         use std::mem;
    114         wgpu::VertexBufferLayout {
    115             array_stride: mem::size_of::<Vertex>() as wgpu::BufferAddress,
    116             step_mode: wgpu::VertexStepMode::Vertex,
    117             attributes: &[
    118                 // position
    119                 wgpu::VertexAttribute {
    120                     offset: 0,
    121                     shader_location: 0,
    122                     format: wgpu::VertexFormat::Float32x3,
    123                 },
    124                 // normal
    125                 wgpu::VertexAttribute {
    126                     offset: mem::size_of::<[f32; 3]>() as u64,
    127                     shader_location: 1,
    128                     format: wgpu::VertexFormat::Float32x3,
    129                 },
    130                 // uv
    131                 wgpu::VertexAttribute {
    132                     offset: (mem::size_of::<[f32; 3]>() + mem::size_of::<[f32; 3]>()) as u64,
    133                     shader_location: 2,
    134                     format: wgpu::VertexFormat::Float32x2,
    135                 },
    136                 // tangent
    137                 wgpu::VertexAttribute {
    138                     offset: (mem::size_of::<[f32; 3]>()
    139                         + mem::size_of::<[f32; 3]>()
    140                         + mem::size_of::<[f32; 2]>()) as u64, // 12+12+8 = 32
    141                     shader_location: 3,
    142                     format: wgpu::VertexFormat::Float32x4,
    143                 },
    144             ],
    145         }
    146     }
    147 }
    148 
    149 fn build_rgba(img: &gltf::image::Data) -> Vec<u8> {
    150     match img.format {
    151         gltf::image::Format::R8 => img.pixels.iter().flat_map(|&r| [r, r, r, 255]).collect(),
    152         gltf::image::Format::R8G8B8 => img
    153             .pixels
    154             .chunks_exact(3)
    155             .flat_map(|p| [p[0], p[1], p[2], 255])
    156             .collect(),
    157         gltf::image::Format::R8G8B8A8 => img.pixels.clone(),
    158         gltf::image::Format::R16 => {
    159             // super rare for your target; quick & dirty downconvert
    160             img.pixels
    161                 .chunks_exact(2)
    162                 .flat_map(|p| {
    163                     let r = p[0];
    164                     [r, r, r, 255]
    165                 })
    166                 .collect()
    167         }
    168         gltf::image::Format::R16G16B16 => img
    169             .pixels
    170             .chunks_exact(6)
    171             .flat_map(|p| {
    172                 let r = p[0];
    173                 let g = p[2];
    174                 let b = p[4];
    175                 [r, g, b, 255]
    176             })
    177             .collect(),
    178         gltf::image::Format::R16G16B16A16 => img
    179             .pixels
    180             .chunks_exact(8)
    181             .flat_map(|p| {
    182                 let r = p[0];
    183                 let g = p[2];
    184                 let b = p[4];
    185                 let a = p[6];
    186                 [r, g, b, a]
    187             })
    188             .collect(),
    189         _ => panic!("Unhandled image format {:?}", img.format),
    190     }
    191 }
    192 
    193 pub fn load_gltf_model(
    194     device: &wgpu::Device,
    195     queue: &wgpu::Queue,
    196     material_bgl: &wgpu::BindGroupLayout,
    197     path: impl AsRef<std::path::Path>,
    198 ) -> Result<ModelData, gltf::Error> {
    199     let path = path.as_ref();
    200 
    201     let (doc, buffers, images) = gltf::import(path)?;
    202 
    203     // --- default textures
    204     let default_sampler = make_default_sampler(device);
    205     let default_basecolor = upload_rgba8_texture_2d(
    206         device,
    207         queue,
    208         1,
    209         1,
    210         &[255, 255, 255, 255],
    211         wgpu::TextureFormat::Rgba8UnormSrgb,
    212         "basecolor_1x1",
    213     );
    214     let default_mr = upload_rgba8_texture_2d(
    215         device,
    216         queue,
    217         1,
    218         1,
    219         &[0, 255, 0, 255],
    220         wgpu::TextureFormat::Rgba8Unorm,
    221         "mr_1x1",
    222     );
    223     let default_normal = upload_rgba8_texture_2d(
    224         device,
    225         queue,
    226         1,
    227         1,
    228         &[128, 128, 255, 255],
    229         wgpu::TextureFormat::Rgba8Unorm,
    230         "normal_1x1",
    231     );
    232 
    233     let mut cache = GltfWgpuCache::new(&doc);
    234 
    235     let mut materials: Vec<MaterialGpu> = Vec::new();
    236 
    237     for mat in doc.materials() {
    238         let pbr = mat.pbr_metallic_roughness();
    239         let bc_factor = pbr.base_color_factor();
    240         let metallic_factor = pbr.metallic_factor();
    241         let roughness_factor = pbr.roughness_factor();
    242         let ao_strength = mat.occlusion_texture().map(|o| o.strength()).unwrap_or(1.0);
    243 
    244         let mut chosen_sampler_idx: Option<usize> = None;
    245 
    246         let basecolor_key = pbr.base_color_texture().map(|info| {
    247             let s_idx = cache.ensure_sampler(device, info.texture().sampler());
    248             if chosen_sampler_idx.is_none() {
    249                 chosen_sampler_idx = s_idx;
    250             }
    251             cache.ensure_texture_view(&images, device, queue, info.texture(), true)
    252         });
    253 
    254         let mr_key = pbr.metallic_roughness_texture().map(|info| {
    255             let s_idx = cache.ensure_sampler(device, info.texture().sampler());
    256             if chosen_sampler_idx.is_none() {
    257                 chosen_sampler_idx = s_idx;
    258             }
    259             cache.ensure_texture_view(&images, device, queue, info.texture(), false)
    260         });
    261 
    262         let normal_key = mat.normal_texture().map(|norm_tex| {
    263             let s_idx = cache.ensure_sampler(device, norm_tex.texture().sampler());
    264             if chosen_sampler_idx.is_none() {
    265                 chosen_sampler_idx = s_idx;
    266             }
    267             cache.ensure_texture_view(&images, device, queue, norm_tex.texture(), false)
    268         });
    269 
    270         let uniform = MaterialUniform {
    271             base_color_factor: Vec4::new(bc_factor[0], bc_factor[1], bc_factor[2], bc_factor[3]),
    272             metallic_factor,
    273             roughness_factor,
    274             ao_strength,
    275             _pad0: 0.0,
    276         };
    277 
    278         let chosen_sampler: &wgpu::Sampler = chosen_sampler_idx
    279             .map(|i| cache.sampler_ref(i))
    280             .unwrap_or(&default_sampler);
    281 
    282         let normal_view: &wgpu::TextureView = normal_key
    283             .map(|k| cache.view_ref(k))
    284             .unwrap_or(&default_normal);
    285 
    286         let basecolor_view: &wgpu::TextureView = basecolor_key
    287             .map(|k| cache.view_ref(k))
    288             .unwrap_or(&default_basecolor);
    289 
    290         let mr_view: &wgpu::TextureView = mr_key.map(|k| cache.view_ref(k)).unwrap_or(&default_mr);
    291 
    292         materials.push(make_material_gpu(
    293             device,
    294             queue,
    295             material_bgl,
    296             chosen_sampler,
    297             basecolor_view,
    298             mr_view,
    299             normal_view,
    300             uniform,
    301         ));
    302     }
    303 
    304     let default_material_index = materials.len();
    305     materials.push(make_material_gpu(
    306         device,
    307         queue,
    308         material_bgl,
    309         &default_sampler,
    310         &default_basecolor,
    311         &default_mr,
    312         &default_normal,
    313         MaterialUniform {
    314             base_color_factor: Vec4::ONE,
    315             metallic_factor: 0.0,
    316             roughness_factor: 1.0,
    317             ao_strength: 1.0,
    318             _pad0: 0.0,
    319         },
    320     ));
    321 
    322     let mut draws: Vec<ModelDraw> = Vec::new();
    323     let mut bounds = Aabb::empty();
    324 
    325     for mesh in doc.meshes() {
    326         for prim in mesh.primitives() {
    327             if prim.mode() != gltf::mesh::Mode::Triangles {
    328                 continue;
    329             }
    330 
    331             let reader = prim.reader(|b| Some(&buffers[b.index()]));
    332 
    333             let positions: Vec<[f32; 3]> = match reader.read_positions() {
    334                 Some(it) => it.collect(),
    335                 None => continue,
    336             };
    337 
    338             let normals: Vec<[f32; 3]> = reader
    339                 .read_normals()
    340                 .map(|it| it.collect())
    341                 .unwrap_or_else(|| vec![[0.0, 0.0, 1.0]; positions.len()]);
    342 
    343             let uvs: Vec<[f32; 2]> = reader
    344                 .read_tex_coords(0)
    345                 .map(|tc| tc.into_f32().collect())
    346                 .unwrap_or_else(|| vec![[0.0, 0.0]; positions.len()]);
    347 
    348             let indices: Vec<u32> = if let Some(read) = reader.read_indices() {
    349                 read.into_u32().collect()
    350             } else {
    351                 (0..positions.len() as u32).collect()
    352             };
    353 
    354             /*
    355             let tangents: Vec<[f32; 4]> = reader
    356                 .read_tangents()
    357                 .map(|it| it.collect())
    358                 .unwrap_or_else(|| vec![[1.0, 0.0, 0.0, 1.0]; positions.len()]);
    359                 */
    360 
    361             let mut verts: Vec<Vertex> = Vec::with_capacity(positions.len());
    362             for i in 0..positions.len() {
    363                 let pos = positions[i];
    364                 bounds.include_point(Vec3::new(pos[0], pos[1], pos[2]));
    365 
    366                 verts.push(Vertex {
    367                     pos,
    368                     normal: normals[i],
    369                     uv: uvs[i],
    370                     tangent: [0.0, 0.0, 0.0, 0.0],
    371                 })
    372             }
    373 
    374             compute_tangents(&mut verts, &indices);
    375 
    376             let vert_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
    377                 label: Some("gltf_vert_buf"),
    378                 contents: bytemuck::cast_slice(&verts),
    379                 usage: wgpu::BufferUsages::VERTEX,
    380             });
    381 
    382             let ind_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
    383                 label: Some("gltf_ind_buf"),
    384                 contents: bytemuck::cast_slice(&indices),
    385                 usage: wgpu::BufferUsages::INDEX,
    386             });
    387 
    388             let material_index = prim.material().index().unwrap_or(default_material_index);
    389 
    390             draws.push(ModelDraw {
    391                 mesh: Mesh {
    392                     num_indices: indices.len() as u32,
    393                     vert_buf,
    394                     ind_buf,
    395                 },
    396                 material_index,
    397             });
    398         }
    399     }
    400 
    401     Ok(ModelData {
    402         draws,
    403         materials,
    404         bounds,
    405     })
    406 }
    407 
    408 fn make_default_sampler(device: &wgpu::Device) -> wgpu::Sampler {
    409     device.create_sampler(&wgpu::SamplerDescriptor {
    410         label: Some("gltf_default_sampler"),
    411         address_mode_u: wgpu::AddressMode::Repeat,
    412         address_mode_v: wgpu::AddressMode::Repeat,
    413         address_mode_w: wgpu::AddressMode::Repeat,
    414         mag_filter: wgpu::FilterMode::Linear,
    415         min_filter: wgpu::FilterMode::Linear,
    416         mipmap_filter: wgpu::FilterMode::Nearest,
    417         ..Default::default()
    418     })
    419 }
    420 
    421 /// Keep wrap modes consistent when mapping glTF sampler -> wgpu sampler
    422 fn map_wrap_mode(wrap_mode: gltf::texture::WrappingMode) -> wgpu::AddressMode {
    423     match wrap_mode {
    424         gltf::texture::WrappingMode::ClampToEdge => wgpu::AddressMode::ClampToEdge,
    425         gltf::texture::WrappingMode::MirroredRepeat => wgpu::AddressMode::MirrorRepeat,
    426         gltf::texture::WrappingMode::Repeat => wgpu::AddressMode::Repeat,
    427     }
    428 }
    429 
    430 fn map_min_filter(f: Option<gltf::texture::MinFilter>) -> (wgpu::FilterMode, wgpu::FilterMode) {
    431     // (min, mipmap)
    432     match f {
    433         Some(gltf::texture::MinFilter::Nearest) => {
    434             (wgpu::FilterMode::Nearest, wgpu::FilterMode::Nearest)
    435         }
    436         Some(gltf::texture::MinFilter::Linear) => {
    437             (wgpu::FilterMode::Linear, wgpu::FilterMode::Nearest)
    438         }
    439 
    440         Some(gltf::texture::MinFilter::NearestMipmapNearest) => {
    441             (wgpu::FilterMode::Nearest, wgpu::FilterMode::Nearest)
    442         }
    443         Some(gltf::texture::MinFilter::LinearMipmapNearest) => {
    444             (wgpu::FilterMode::Linear, wgpu::FilterMode::Nearest)
    445         }
    446         Some(gltf::texture::MinFilter::NearestMipmapLinear) => {
    447             (wgpu::FilterMode::Nearest, wgpu::FilterMode::Linear)
    448         }
    449         Some(gltf::texture::MinFilter::LinearMipmapLinear) => {
    450             (wgpu::FilterMode::Linear, wgpu::FilterMode::Linear)
    451         }
    452 
    453         None => (wgpu::FilterMode::Linear, wgpu::FilterMode::Nearest),
    454     }
    455 }
    456 
    457 fn map_mag_filter(f: Option<gltf::texture::MagFilter>) -> wgpu::FilterMode {
    458     match f {
    459         Some(gltf::texture::MagFilter::Nearest) => wgpu::FilterMode::Nearest,
    460         Some(gltf::texture::MagFilter::Linear) => wgpu::FilterMode::Linear,
    461         None => wgpu::FilterMode::Linear,
    462     }
    463 }
    464 
    465 #[allow(clippy::too_many_arguments)]
    466 pub(crate) fn make_material_gpu(
    467     device: &wgpu::Device,
    468     queue: &wgpu::Queue,
    469     material_bgl: &wgpu::BindGroupLayout,
    470     sampler: &wgpu::Sampler,
    471     basecolor: &wgpu::TextureView,
    472     mr: &wgpu::TextureView,
    473     normal: &wgpu::TextureView,
    474     uniform: MaterialUniform,
    475 ) -> MaterialGpu {
    476     let buffer = device.create_buffer(&wgpu::BufferDescriptor {
    477         label: Some("material_ubo"),
    478         size: std::mem::size_of::<MaterialUniform>() as u64,
    479         usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
    480         mapped_at_creation: false,
    481     });
    482 
    483     // write uniform once
    484     queue.write_buffer(&buffer, 0, bytemuck::bytes_of(&uniform));
    485 
    486     let bindgroup = device.create_bind_group(&wgpu::BindGroupDescriptor {
    487         label: Some("material_bg"),
    488         layout: material_bgl,
    489         entries: &[
    490             wgpu::BindGroupEntry {
    491                 binding: 0,
    492                 resource: buffer.as_entire_binding(),
    493             },
    494             wgpu::BindGroupEntry {
    495                 binding: 1,
    496                 resource: wgpu::BindingResource::Sampler(sampler),
    497             },
    498             wgpu::BindGroupEntry {
    499                 binding: 2,
    500                 resource: wgpu::BindingResource::TextureView(basecolor),
    501             },
    502             wgpu::BindGroupEntry {
    503                 binding: 3,
    504                 resource: wgpu::BindingResource::TextureView(mr),
    505             },
    506             wgpu::BindGroupEntry {
    507                 binding: 4,
    508                 resource: wgpu::BindingResource::TextureView(normal),
    509             },
    510         ],
    511     });
    512 
    513     MaterialGpu {
    514         _uniform: uniform,
    515         _buffer: buffer,
    516         bindgroup,
    517     }
    518 }
    519 
    520 fn compute_tangents(verts: &mut [Vertex], indices: &[u32]) {
    521     use glam::{Vec2, Vec3};
    522 
    523     let n = verts.len();
    524     let mut tan1 = vec![Vec3::ZERO; n];
    525     let mut tan2 = vec![Vec3::ZERO; n];
    526 
    527     let to_v3 = |a: [f32; 3]| Vec3::new(a[0], a[1], a[2]);
    528     let to_v2 = |a: [f32; 2]| Vec2::new(a[0], a[1]);
    529 
    530     // Accumulate per-triangle tangents/bitangents
    531     for tri in indices.chunks_exact(3) {
    532         let i0 = tri[0] as usize;
    533         let i1 = tri[1] as usize;
    534         let i2 = tri[2] as usize;
    535 
    536         let p0 = to_v3(verts[i0].pos);
    537         let p1 = to_v3(verts[i1].pos);
    538         let p2 = to_v3(verts[i2].pos);
    539 
    540         let w0 = to_v2(verts[i0].uv);
    541         let w1 = to_v2(verts[i1].uv);
    542         let w2 = to_v2(verts[i2].uv);
    543 
    544         let e1 = p1 - p0;
    545         let e2 = p2 - p0;
    546 
    547         let d1 = w1 - w0;
    548         let d2 = w2 - w0;
    549 
    550         let denom = d1.x * d2.y - d1.y * d2.x;
    551         if denom.abs() < 1e-8 {
    552             continue; // degenerate UV mapping; skip
    553         }
    554         let r = 1.0 / denom;
    555 
    556         let sdir = (e1 * d2.y - e2 * d1.y) * r; // tangent direction
    557         let tdir = (e2 * d1.x - e1 * d2.x) * r; // bitangent direction
    558 
    559         tan1[i0] += sdir;
    560         tan1[i1] += sdir;
    561         tan1[i2] += sdir;
    562         tan2[i0] += tdir;
    563         tan2[i1] += tdir;
    564         tan2[i2] += tdir;
    565     }
    566 
    567     // Orthonormalize & store handedness in w
    568     for i in 0..n {
    569         let nrm = to_v3(verts[i].normal).normalize_or_zero();
    570         let t = tan1[i];
    571 
    572         // Gram–Schmidt: make T perpendicular to N
    573         let t_ortho = (t - nrm * nrm.dot(t)).normalize_or_zero();
    574 
    575         // Handedness: +1 or -1
    576         let w = if nrm.cross(t_ortho).dot(tan2[i]) < 0.0 {
    577             -1.0
    578         } else {
    579             1.0
    580         };
    581 
    582         verts[i].tangent = [t_ortho.x, t_ortho.y, t_ortho.z, w];
    583     }
    584 }
    585 
    586 #[derive(Debug, Copy, Clone)]
    587 pub struct Aabb {
    588     pub min: Vec3,
    589     pub max: Vec3,
    590 }
    591 
    592 impl Aabb {
    593     pub fn empty() -> Self {
    594         Self {
    595             min: Vec3::splat(f32::INFINITY),
    596             max: Vec3::splat(f32::NEG_INFINITY),
    597         }
    598     }
    599 
    600     pub fn include_point(&mut self, p: Vec3) {
    601         self.min = self.min.min(p);
    602         self.max = self.max.max(p);
    603     }
    604 
    605     pub fn center(&self) -> Vec3 {
    606         (self.min + self.max) * 0.5
    607     }
    608 
    609     pub fn half_extents(&self) -> Vec3 {
    610         (self.max - self.min) * 0.5
    611     }
    612 
    613     pub fn radius(&self) -> f32 {
    614         self.half_extents().length()
    615     }
    616 
    617     /// Clamp a point's XZ to the AABB's XZ extent. Y unchanged.
    618     pub fn clamp_xz(&self, p: Vec3) -> Vec3 {
    619         Vec3::new(
    620             p.x.clamp(self.min.x, self.max.x),
    621             p.y,
    622             p.z.clamp(self.min.z, self.max.z),
    623         )
    624     }
    625 
    626     /// Distance the point's XZ overshoots the AABB boundary. 0 if inside.
    627     pub fn xz_overshoot(&self, p: Vec3) -> f32 {
    628         let dx = (p.x - self.max.x).max(self.min.x - p.x).max(0.0);
    629         let dz = (p.z - self.max.z).max(self.min.z - p.z).max(0.0);
    630         (dx * dx + dz * dz).sqrt()
    631     }
    632 }