Initial media depth project backup

This commit is contained in:
Codex
2026-05-20 12:25:12 +08:00
commit 4a0aebb2bd
358 changed files with 182095 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
[flake8]
max-line-length = 100
ignore = E203 E741 W503 E731

View File

View File

@@ -0,0 +1,36 @@
# Python cache
__pycache__/
*.py[cod]
# Distribution / packaging
workspace/
build/
dist/
*.egg-info/
.gradio/
# Test/coverage
.coverage
.pytest_cache/
htmlcov/
.tox/
gallery*/
debug*/
DA3HF*/
gradio_workspace/
eval_workspace/
FILTER*/
input_images*/
*.gradio/
# Jupyter notebooks
.ipynb_checkpoints
# OS files
.DS_Store
.vscode
src/debug_main.py
temp*.png
/outputs

View File

@@ -0,0 +1,59 @@
repos:
- repo: 'https://github.com/pre-commit/pre-commit-hooks'
rev: v4.5.0
hooks:
- id: check-added-large-files
args:
- '--maxkb=125'
- id: check-ast
- id: check-executables-have-shebangs
- id: check-merge-conflict
- id: check-symlinks
- id: check-toml
- id: check-yaml
- id: debug-statements
- id: detect-private-key
- id: end-of-file-fixer
- id: no-commit-to-branch
args:
- '--branch'
- 'master'
- id: pretty-format-json
exclude: '.*\.ipynb$'
args:
- '--autofix'
- '--indent'
- '4'
- id: trailing-whitespace
args:
- '--markdown-linebreak-ext=md'
- repo: 'https://github.com/pycqa/isort'
rev: 5.13.2
hooks:
- id: isort
args:
- '--settings-file'
- 'pyproject.toml'
- '--filter-files'
- repo: 'https://github.com/asottile/pyupgrade'
rev: v3.15.2
hooks:
- id: pyupgrade
args: [--py38-plus, --keep-runtime-typing]
- repo: 'https://github.com/psf/black.git'
rev: 24.3.0
hooks:
- id: black
args:
- '--config=pyproject.toml'
- repo: 'https://github.com/PyCQA/flake8'
rev: 7.0.0
hooks:
- id: flake8
args:
- '--config=.flake8'
- repo: 'https://github.com/myint/autoflake'
rev: v1.4
hooks:
- id: autoflake
args: [ '--remove-all-unused-imports', '--recursive', '--remove-unused-variables', '--in-place']

View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on
the same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2025 The Depth Anything 3 Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -0,0 +1,270 @@
<div align="center">
<h1 style="border-bottom: none; margin-bottom: 0px ">Depth Anything 3: Recovering the Visual Space from Any Views</h1>
<!-- <h2 style="border-top: none; margin-top: 3px;">Recovering the Visual Space from Any Views</h2> -->
[**Haotong Lin**](https://haotongl.github.io/)<sup>&ast;</sup> · [**Sili Chen**](https://github.com/SiliChen321)<sup>&ast;</sup> · [**Jun Hao Liew**](https://liewjunhao.github.io/)<sup>&ast;</sup> · [**Donny Y. Chen**](https://donydchen.github.io)<sup>&ast;</sup> · [**Zhenyu Li**](https://zhyever.github.io/) · [**Guang Shi**](https://scholar.google.com/citations?user=MjXxWbUAAAAJ&hl=en) · [**Jiashi Feng**](https://scholar.google.com.sg/citations?user=Q8iay0gAAAAJ&hl=en)
<br>
[**Bingyi Kang**](https://bingykang.github.io/)<sup>&ast;&dagger;</sup>
&dagger;project lead&emsp;&ast;Equal Contribution
<a href="https://arxiv.org/abs/2511.10647"><img src='https://img.shields.io/badge/arXiv-Depth Anything 3-red' alt='Paper PDF'></a>
<a href='https://depth-anything-3.github.io'><img src='https://img.shields.io/badge/Project_Page-Depth Anything 3-green' alt='Project Page'></a>
<a href='https://huggingface.co/spaces/depth-anything/Depth-Anything-3'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-blue'></a>
<!-- <a href='https://huggingface.co/datasets/depth-anything/VGB'><img src='https://img.shields.io/badge/Benchmark-VisGeo-yellow' alt='Benchmark'></a> -->
<!-- <a href='https://huggingface.co/datasets/depth-anything/data'><img src='https://img.shields.io/badge/Benchmark-xxx-yellow' alt='Data'></a> -->
</div>
This work presents **Depth Anything 3 (DA3)**, a model that predicts spatially consistent geometry from
arbitrary visual inputs, with or without known camera poses.
In pursuit of minimal modeling, DA3 yields two key insights:
- 💎 A **single plain transformer** (e.g., vanilla DINO encoder) is sufficient as a backbone without architectural specialization,
- ✨ A singular **depth-ray representation** obviates the need for complex multi-task learning.
🏆 DA3 significantly outperforms
[DA2](https://github.com/DepthAnything/Depth-Anything-V2) for monocular depth estimation,
and [VGGT](https://github.com/facebookresearch/vggt) for multi-view depth estimation and pose estimation.
All models are trained exclusively on **public academic datasets**.
<!-- <p align="center">
<img src="assets/images/da3_teaser.png" alt="Depth Anything 3" width="100%">
</p> -->
<p align="center">
<img src="assets/images/demo320-2.gif" alt="Depth Anything 3 - Left" width="70%">
</p>
<p align="center">
<img src="assets/images/da3_radar.png" alt="Depth Anything 3" width="100%">
</p>
## 📰 News
- **30-11-2025:** Add [`use_ray_pose`](#use-ray-pose) and [`ref_view_strategy`](docs/funcs/ref_view_strategy.md) (reference view selection for multi-view inputs).
- **25-11-2025:** Add [Awesome DA3 Projects](#-awesome-da3-projects), a community-driven section featuring DA3-based applications.
- **14-11-2025:** Paper, project page, code and models are all released.
## ✨ Highlights
### 🏆 Model Zoo
We release three series of models, each tailored for specific use cases in visual geometry.
- 🌟 **DA3 Main Series** (`DA3-Giant`, `DA3-Large`, `DA3-Base`, `DA3-Small`) These are our flagship foundation models, trained with a unified depth-ray representation. By varying the input configuration, a single model can perform a wide range of tasks:
+ 🌊 **Monocular Depth Estimation**: Predicts a depth map from a single RGB image.
+ 🌊 **Multi-View Depth Estimation**: Generates consistent depth maps from multiple images for high-quality fusion.
+ 🎯 **Pose-Conditioned Depth Estimation**: Achieves superior depth consistency when camera poses are provided as input.
+ 📷 **Camera Pose Estimation**: Estimates camera extrinsics and intrinsics from one or more images.
+ 🟡 **3D Gaussian Estimation**: Directly predicts 3D Gaussians, enabling high-fidelity novel view synthesis.
- 📐 **DA3 Metric Series** (`DA3Metric-Large`) A specialized model fine-tuned for metric depth estimation in monocular settings, ideal for applications requiring real-world scale.
- 🔍 **DA3 Monocular Series** (`DA3Mono-Large`). A dedicated model for high-quality relative monocular depth estimation. Unlike disparity-based models (e.g., [Depth Anything 2](https://github.com/DepthAnything/Depth-Anything-V2)), it directly predicts depth, resulting in superior geometric accuracy.
🔗 Leveraging these available models, we developed a **nested series** (`DA3Nested-Giant-Large`). This series combines a any-view giant model with a metric model to reconstruct visual geometry at a real-world metric scale.
### 🛠️ Codebase Features
Our repository is designed to be a powerful and user-friendly toolkit for both practical application and future research.
- 🎨 **Interactive Web UI & Gallery**: Visualize model outputs and compare results with an easy-to-use Gradio-based web interface.
-**Flexible Command-Line Interface (CLI)**: Powerful and scriptable CLI for batch processing and integration into custom workflows.
- 💾 **Multiple Export Formats**: Save your results in various formats, including `glb`, `npz`, depth images, `ply`, 3DGS videos, etc, to seamlessly connect with other tools.
- 🔧 **Extensible and Modular Design**: The codebase is structured to facilitate future research and the integration of new models or functionalities.
<!-- ### 🎯 Visual Geometry Benchmark
We introduce a new benchmark to rigorously evaluate geometry prediction models on three key tasks: pose estimation, 3D reconstruction, and visual rendering (novel view synthesis) quality.
- 🔄 **Broad Model Compatibility**: Our benchmark is designed to be versatile, supporting the evaluation of various models, including both monocular and multi-view depth estimation approaches.
- 🔬 **Robust Evaluation Pipeline**: We provide a standardized pipeline featuring RANSAC-based pose alignment, TSDF fusion for dense reconstruction, and a principled view selection strategy for novel view synthesis.
- 📊 **Standardized Metrics**: Performance is measured using established metrics: AUC for pose accuracy, F1-score and Chamfer Distance for reconstruction, and PSNR/SSIM/LPIPS for rendering quality.
- 🌍 **Diverse and Challenging Datasets**: The benchmark spans a wide range of scenes from datasets like HiRoom, ETH3D, DTU, 7Scenes, ScanNet++, DL3DV, Tanks and Temples, and MegaDepth. -->
## 🚀 Quick Start
### 📦 Installation
```bash
pip install xformers torch\>=2 torchvision
pip install -e . # Basic
pip install --no-build-isolation git+https://github.com/nerfstudio-project/gsplat.git@0b4dddf04cb687367602c01196913cde6a743d70 # for gaussian head
pip install -e ".[app]" # Gradio, python>=3.10
pip install -e ".[all]" # ALL
```
For detailed model information, please refer to the [Model Cards](#-model-cards) section below.
### 💻 Basic Usage
```python
import glob, os, torch
from depth_anything_3.api import DepthAnything3
device = torch.device("cuda")
model = DepthAnything3.from_pretrained("depth-anything/DA3NESTED-GIANT-LARGE")
model = model.to(device=device)
example_path = "assets/examples/SOH"
images = sorted(glob.glob(os.path.join(example_path, "*.png")))
prediction = model.inference(
images,
)
# prediction.processed_images : [N, H, W, 3] uint8 array
print(prediction.processed_images.shape)
# prediction.depth : [N, H, W] float32 array
print(prediction.depth.shape)
# prediction.conf : [N, H, W] float32 array
print(prediction.conf.shape)
# prediction.extrinsics : [N, 3, 4] float32 array # opencv w2c or colmap format
print(prediction.extrinsics.shape)
# prediction.intrinsics : [N, 3, 3] float32 array
print(prediction.intrinsics.shape)
```
```bash
export MODEL_DIR=depth-anything/DA3NESTED-GIANT-LARGE
# This can be a Hugging Face repository or a local directory
# If you encounter network issues, consider using the following mirror: export HF_ENDPOINT=https://hf-mirror.com
# Alternatively, you can download the model directly from Hugging Face
export GALLERY_DIR=workspace/gallery
mkdir -p $GALLERY_DIR
# CLI auto mode with backend reuse
da3 backend --model-dir ${MODEL_DIR} --gallery-dir ${GALLERY_DIR} # Cache model to gpu
da3 auto assets/examples/SOH \
--export-format glb \
--export-dir ${GALLERY_DIR}/TEST_BACKEND/SOH \
--use-backend
# CLI video processing with feature visualization
da3 video assets/examples/robot_unitree.mp4 \
--fps 15 \
--use-backend \
--export-dir ${GALLERY_DIR}/TEST_BACKEND/robo \
--export-format glb-feat_vis \
--feat-vis-fps 15 \
--process-res-method lower_bound_resize \
--export-feat "11,21,31"
# CLI auto mode without backend reuse
da3 auto assets/examples/SOH \
--export-format glb \
--export-dir ${GALLERY_DIR}/TEST_CLI/SOH \
--model-dir ${MODEL_DIR}
```
The model architecture is defined in [`DepthAnything3Net`](src/depth_anything_3/model/da3.py), and specified with a Yaml config file located at [`src/depth_anything_3/configs`](src/depth_anything_3/configs). The input and output processing are handled by [`DepthAnything3`](src/depth_anything_3/api.py). To customize the model architecture, simply create a new config file (*e.g.*, `path/to/new/config`) as:
```yaml
__object__:
path: depth_anything_3.model.da3
name: DepthAnything3Net
args: as_params
net:
__object__:
path: depth_anything_3.model.dinov2.dinov2
name: DinoV2
args: as_params
name: vitb
out_layers: [5, 7, 9, 11]
alt_start: 4
qknorm_start: 4
rope_start: 4
cat_token: True
head:
__object__:
path: depth_anything_3.model.dualdpt
name: DualDPT
args: as_params
dim_in: &head_dim_in 1536
output_dim: 2
features: &head_features 128
out_channels: &head_out_channels [96, 192, 384, 768]
```
Then, the model can be created with the following code snippet.
```python
from depth_anything_3.cfg import create_object, load_config
Model = create_object(load_config("path/to/new/config"))
```
## 📚 Useful Documentation
- 🖥️ [Command Line Interface](docs/CLI.md)
- 📑 [Python API](docs/API.md)
<!-- - 🏁 [Visual Geometry Benchmark](docs/BENCHMARK.md) -->
## 🗂️ Model Cards
Generally, you should observe that DA3-LARGE achieves comparable results to VGGT.
The Nested series uses an Any-view model to estimate pose and depth, and a monocular metric depth estimator for scaling.
| 🗃️ Model Name | 📏 Params | 📊 Rel. Depth | 📷 Pose Est. | 🧭 Pose Cond. | 🎨 GS | 📐 Met. Depth | ☁️ Sky Seg | 📄 License |
|-------------------------------|-----------|---------------|--------------|---------------|-------|---------------|-----------|----------------|
| **Nested** | | | | | | | | |
| [DA3NESTED-GIANT-LARGE](https://huggingface.co/depth-anything/DA3NESTED-GIANT-LARGE) | 1.40B | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | CC BY-NC 4.0 |
| **Any-view Model** | | | | | | | | |
| [DA3-GIANT](https://huggingface.co/depth-anything/DA3-GIANT) | 1.15B | ✅ | ✅ | ✅ | ✅ | | | CC BY-NC 4.0 |
| [DA3-LARGE](https://huggingface.co/depth-anything/DA3-LARGE) | 0.35B | ✅ | ✅ | ✅ | | | | CC BY-NC 4.0 |
| [DA3-BASE](https://huggingface.co/depth-anything/DA3-BASE) | 0.12B | ✅ | ✅ | ✅ | | | | Apache 2.0 |
| [DA3-SMALL](https://huggingface.co/depth-anything/DA3-SMALL) | 0.08B | ✅ | ✅ | ✅ | | | | Apache 2.0 |
| | | | | | | | | |
| **Monocular Metric Depth** | | | | | | | | |
| [DA3METRIC-LARGE](https://huggingface.co/depth-anything/DA3METRIC-LARGE) | 0.35B | ✅ | | | | ✅ | ✅ | Apache 2.0 |
| | | | | | | | | |
| **Monocular Depth** | | | | | | | | |
| [DA3MONO-LARGE](https://huggingface.co/depth-anything/DA3MONO-LARGE) | 0.35B | ✅ | | | | | ✅ | Apache 2.0 |
## ❓ FAQ
- **Monocular Metric Depth**: To obtain metric depth in meters from `DA3METRIC-LARGE`, use `metric_depth = focal * net_output / 300.`, where `focal` is the focal length in pixels (typically the average of fx and fy from the camera intrinsic matrix K). Note that the output from `DA3NESTED-GIANT-LARGE` is already in meters.
- <a id="use-ray-pose"></a>**Ray Head (`use_ray_pose`)**: Our API and CLI support `use_ray_pose` arg, which means that the model will derive camera pose from ray head, which is generally slightly slower, but more accurate. Note that the default is `False` for faster inference speed.
<details>
<summary>AUC3 Results for DA3NESTED-GIANT-LARGE</summary>
| Model | HiRoom | ETH3D | DTU | 7Scenes | ScanNet++ |
|-------|------|-------|-----|---------|-----------|
| `ray_head` | 84.4 | 52.6 | 93.9 | 29.5 | 89.4 |
| `cam_head` | 80.3 | 48.4 | 94.1 | 28.5 | 85.0 |
</details>
- **Older GPUs without XFormers support**: See [Issue #11](https://github.com/ByteDance-Seed/Depth-Anything-3/issues/11). Thanks to [@S-Mahoney](https://github.com/S-Mahoney) for the solution!
## 🏢 Awesome DA3 Projects
A community-curated list of Depth Anything 3 integrations across 3D tools, creative pipelines, robotics, and web/VR viewers, including but not limited to these. You are welcome to submit your DA3-based project via PR, and we will review and feature it if applicable.
- [DA3-blender](https://github.com/xy-gao/DA3-blender): Blender addon for DA3-based 3D reconstruction from a set of images.
- [ComfyUI-DepthAnythingV3](https://github.com/PozzettiAndrea/ComfyUI-DepthAnythingV3): ComfyUI nodes for Depth Anything 3, supporting single/multi-view and video-consistent depth with optional pointcloud export.
- [DA3-ROS2-Wrapper](https://github.com/GerdsenAI/GerdsenAI-Depth-Anything-3-ROS2-Wrapper): Real-time DA3 depth in ROS2 with multi-camera support.
- [VideoDepthViewer3D](https://github.com/amariichi/VideoDepthViewer3D): Streaming videos with DA3 metric depth to a Three.js/WebXR 3D viewer for VR/stereo playback.
## 📝 Citations
If you find Depth Anything 3 useful in your research or projects, please cite our work:
```
@article{depthanything3,
title={Depth Anything 3: Recovering the visual space from any views},
author={Haotong Lin and Sili Chen and Jun Hao Liew and Donny Y. Chen and Zhenyu Li and Guang Shi and Jiashi Feng and Bingyi Kang},
journal={arXiv preprint arXiv:2511.10647},
year={2025}
}
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 210 KiB

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6fd463e7e5fb4b30c14b5f618acf0de6c046d0b26d13e1eefc384f7b17e9dd3f
size 2252109

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 MiB

View File

@@ -0,0 +1,465 @@
# 📚 DepthAnything3 API Documentation
## 📑 Table of Contents
1. [📖 Overview](#overview)
2. [💡 Usage Examples](#usage-examples)
3. [🔧 Core API](#core-api)
- [DepthAnything3 Class](#depthanything3-class)
- [inference() Method](#inference-method)
4. [⚙️ Parameters](#parameters)
- [Input Parameters](#input-parameters)
- [Pose Alignment Parameters](#pose-alignment-parameters)
- [Feature Export Parameters](#feature-export-parameters)
- [Rendering Parameters](#rendering-parameters)
- [Processing Parameters](#processing-parameters)
- [Export Parameters](#export-parameters)
5. [📤 Export Formats](#export-formats)
6. [↩️ Return Value](#return-value)
## 📖 Overview
This documentation provides comprehensive API reference for DepthAnything3, including usage examples, parameter specifications, export formats, and advanced features. It covers both basic pose and depth estimation workflows and advanced pose-conditioned processing with multiple export capabilities.
## 💡 Usage Examples
Here are quick examples to get you started:
### 🚀 Basic Depth Estimation
```python
from depth_anything_3.api import DepthAnything3
# Initialize and run inference
model = DepthAnything3.from_pretrained("depth-anything/DA3NESTED-GIANT-LARGE").to("cuda")
prediction = model.inference(["image1.jpg", "image2.jpg"])
```
### 📷 Pose-Conditioned Depth Estimation
```python
import numpy as np
# With camera parameters for better consistency
prediction = model.inference(
image=["image1.jpg", "image2.jpg"],
extrinsics=extrinsics_array, # (N, 4, 4)
intrinsics=intrinsics_array # (N, 3, 3)
)
```
### 📤 Export Results
```python
# Export depth data and 3D visualization
prediction = model.inference(
image=image_paths,
export_dir="./output",
export_format="mini_npz-glb"
)
```
### 🔍 Feature Extraction
```python
# Export intermediate features from specific layers
prediction = model.inference(
image=image_paths,
export_dir="./output",
export_format="feat_vis",
export_feat_layers=[0, 1, 2] # Export features from layers 0, 1, 2
)
```
### ✨ Advanced Export with Gaussian Splatting
```python
# Export multiple formats including Gaussian Splatting
# Note: infer_gs=True requires da3-giant or da3nested-giant-large model
model = DepthAnything3(model_name="da3-giant").to("cuda")
prediction = model.inference(
image=image_paths,
extrinsics=extrinsics_array,
intrinsics=intrinsics_array,
export_dir="./output",
export_format="npz-glb-gs_ply-gs_video",
align_to_input_ext_scale=True,
infer_gs=True, # Required for gs_ply and gs_video exports
)
```
### 🎨 Advanced Export with Feature Visualization
```python
# Export with intermediate feature visualization
prediction = model.inference(
image=image_paths,
export_dir="./output",
export_format="mini_npz-glb-depth_vis-feat_vis",
export_feat_layers=[0, 5, 10, 15, 20],
feat_vis_fps=30,
)
```
### 📐 Using Ray-Based Pose Estimation
```python
# Use ray-based pose estimation instead of camera decoder
prediction = model.inference(
image=image_paths,
export_dir="./output",
export_format="glb",
use_ray_pose=True, # Enable ray-based pose estimation
)
```
### 🎯 Reference View Selection
```python
# For multi-view inputs, automatically select the best reference view
prediction = model.inference(
image=image_paths,
ref_view_strategy="saddle_balanced", # Default: balanced selection
)
# For video sequences, use middle frame as reference
prediction = model.inference(
image=video_frames,
ref_view_strategy="middle", # Good for temporally ordered inputs
)
```
## 🔧 Core API
### 🔨 DepthAnything3 Class
The main API class that provides depth estimation capabilities with optional pose conditioning.
#### 🎯 Initialization
```python
from depth_anything_3 import DepthAnything3
# Initialize the model with a model name
model = DepthAnything3(model_name="da3-large")
model = model.to("cuda") # Move to GPU
```
**Parameters:**
- `model_name` (str, default: "da3-large"): The name of the model preset to use.
- **Available models:**
- 🦾 `"da3-giant"` - 1.15B params, any-view model with GS support
-`"da3-large"` - 0.35B params, any-view model (recommended for most use cases)
- 📦 `"da3-base"` - 0.12B params, any-view model
- 🪶 `"da3-small"` - 0.08B params, any-view model
- 👁️ `"da3mono-large"` - 0.35B params, monocular depth only
- 📏 `"da3metric-large"` - 0.35B params, metric depth with sky segmentation
- 🎯 `"da3nested-giant-large"` - 1.40B params, nested model with all features
### 🚀 inference() Method
The primary inference method that processes images and returns depth predictions.
```python
prediction = model.inference(
image=image_list,
extrinsics=extrinsics_array, # Optional
intrinsics=intrinsics_array, # Optional
align_to_input_ext_scale=True, # Whether to align predicted poses to input scale
infer_gs=True, # Enable Gaussian branch for gs exports
use_ray_pose=False, # Use ray-based pose estimation instead of camera decoder
ref_view_strategy="saddle_balanced", # Reference view selection strategy
render_exts=render_extrinsics, # Optional renders for gs_video
render_ixts=render_intrinsics, # Optional renders for gs_video
render_hw=(height, width), # Optional renders for gs_video
process_res=504,
process_res_method="upper_bound_resize",
export_dir="output_directory", # Optional
export_format="mini_npz",
export_feat_layers=[], # List of layer indices to export features from
conf_thresh_percentile=40.0, # Confidence threshold percentile for depth map in GLB export
num_max_points=1_000_000, # Maximum number of points to export in GLB export
show_cameras=True, # Whether to show cameras in GLB export
feat_vis_fps=15, # Frames per second for feature visualization in feat_vis export
export_kwargs={} # Optional, additional arguments to export functions. export_format:key:val, see 'Parameters/Export Parameters' for details
)
```
## ⚙️ Parameters
### 📸 Input Parameters
#### `image` (required)
- **Type**: `List[Union[np.ndarray, Image.Image, str]]`
- **Description**: List of input images. Can be numpy arrays, PIL Images, or file paths.
- **Example**:
```python
# From file paths
image = ["image1.jpg", "image2.jpg", "image3.jpg"]
# From numpy arrays
image = [np.array(img1), np.array(img2)]
# From PIL Images
image = [Image.open("image1.jpg"), Image.open("image2.jpg")]
```
#### `extrinsics` (optional)
- **Type**: `Optional[np.ndarray]`
- **Shape**: `(N, 4, 4)` where N is the number of input images
- **Description**: Camera extrinsic matrices (world-to-camera transformation). When provided, enables pose-conditioned depth estimation mode.
- **Note**: If not provided, the model operates in standard depth estimation mode.
#### `intrinsics` (optional)
- **Type**: `Optional[np.ndarray]`
- **Shape**: `(N, 3, 3)` where N is the number of input images
- **Description**: Camera intrinsic matrices containing focal length and principal point information. When provided, enables pose-conditioned depth estimation mode.
### 🎯 Pose Alignment Parameters
#### `align_to_input_ext_scale` (default: True)
- **Type**: `bool`
- **Description**: When True the predicted extrinsics are replaced with the input
ones and the depth maps are rescaled to match their metric scale. When False the
function returns the internally aligned poses computed via Umeyama alignment.
#### `infer_gs` (default: False)
- **Type**: `bool`
- **Description**: Enable Gaussian Splatting branch for gaussian splatting exports. Required when using `gs_ply` or `gs_video` export formats.
#### `use_ray_pose` (default: False)
- **Type**: `bool`
- **Description**: Use ray-based pose estimation instead of camera decoder for pose prediction. When True, the model uses ray prediction heads to estimate camera poses; when False, it uses the camera decoder approach.
#### `ref_view_strategy` (default: "saddle_balanced")
- **Type**: `str`
- **Description**: Strategy for selecting the reference view from multiple input views. Options: `"first"`, `"middle"`, `"saddle_balanced"`, `"saddle_sim_range"`. Only applied when number of views ≥ 3. See [detailed documentation](funcs/ref_view_strategy.md) for strategy comparisons.
- **Available strategies**:
- `"saddle_balanced"`: Selects view with balanced features across multiple metrics (recommended default)
- `"saddle_sim_range"`: Selects view with largest similarity range
- `"first"`: Always uses first view (not recommended, equivalent to no reordering for views < 3)
- `"middle"`: Uses middle view (recommended for video sequences)
### 🔍 Feature Export Parameters
#### `export_feat_layers` (default: [])
- **Type**: `List[int]`
- **Description**: List of layer indices to export intermediate features from. Features are stored in the `aux` dictionary of the Prediction object with keys like `feat_layer_0`, `feat_layer_1`, etc.
### 🎥 Rendering Parameters
These arguments are only used when exporting Gaussian-splatting videos (include
`"gs_video"` in `export_format`). They describe an auxiliary camera trajectory
with ``M`` views.
#### `render_exts` (optional)
- **Type**: `Optional[np.ndarray]`
- **Shape**: `(M, 4, 4)`
- **Description**: Camera extrinsics for the synthesized trajectory. If omitted,
the exporter falls back to the predicted poses.
#### `render_ixts` (optional)
- **Type**: `Optional[np.ndarray]`
- **Shape**: `(M, 3, 3)`
- **Description**: Camera intrinsics for each rendered frame. Leave `None` to
reuse the input intrinsics.
#### `render_hw` (optional)
- **Type**: `Optional[Tuple[int, int]]`
- **Description**: Explicit output resolution `(height, width)` for the rendered
frames. Defaults to the input resolution when not provided.
### ⚡ Processing Parameters
#### `process_res` (default: 504)
- **Type**: `int`
- **Description**: Base resolution for processing. The model will resize images to this resolution for inference.
#### `process_res_method` (default: "upper_bound_resize")
- **Type**: `str`
- **Description**: Method for resizing images to the target resolution.
- **Options**:
- `"upper_bound_resize"`: Resize so that the specified dimension (504) becomes the longer side
- `"lower_bound_resize"`: Resize so that the specified dimension (504) becomes the shorter side
- **Example**:
- Input: 1200×1600 → Output: 378×504 (with `process_res=504`, `process_res_method="upper_bound_resize"`)
- Input: 504×672 → Output: 504×672 (no change needed)
### 📦 Export Parameters
#### `export_dir` (optional)
- **Type**: `Optional[str]`
- **Description**: Directory path where exported files will be saved. If not provided, no files will be exported.
#### `export_format` (default: "mini_npz")
- **Type**: `str`
- **Description**: Format for exporting results. Supports multiple formats separated by `-`.
- **Example**: `"mini_npz-glb"` exports both mini_npz and glb formats.
#### 🌐 GLB Export Parameters
These parameters are passed directly to the `inference()` method and only apply when `export_format` includes `"glb"`.
##### `conf_thresh_percentile` (default: 40.0)
- **Type**: `float`
- **Description**: Lower percentile for adaptive confidence threshold. Points below this confidence percentile will be filtered out from the point cloud.
##### `num_max_points` (default: 1,000,000)
- **Type**: `int`
- **Description**: Maximum number of points in the exported point cloud. If the point cloud exceeds this limit, it will be downsampled.
##### `show_cameras` (default: True)
- **Type**: `bool`
- **Description**: Whether to include camera wireframes in the exported GLB file for visualization.
#### 🎨 Feature Visualization Parameters
These parameters are passed directly to the `inference()` method and only apply when `export_format` includes `"feat_vis"`.
##### `feat_vis_fps` (default: 15)
- **Type**: `int`
- **Description**: Frame rate for the output video when visualizing features across multiple images.
#### ✨🎥 3DGS and 3DGS Video Parameters
These parameters are passed directly to the `inference()` method and only apply when `export_format` includes `"gs_ply"` or `"gs_video"`.
##### `export_kwargs` (default: `{}`)
- Type: `dict[str, dict[str, Any]]`
- Description: Per-format extra arguments passed to export functions, mainly for `"gs_ply"` and `"gs_video"`.
- Access pattern: `export_kwargs[export_format][key] = value`
- Example:
```python
{
"gs_ply": {
"gs_views_interval": 1,
},
"gs_video": {
"trj_mode": "interpolate_smooth",
"chunk_size": 1,
"vis_depth": None,
},
}
```
## 📤 Export Formats
The API supports multiple export formats for different use cases:
### 📊 `mini_npz`
- **Description**: Minimal NPZ format containing essential data
- **Contents**: `depth`, `conf`, `exts`, `ixts`
- **Use case**: Lightweight storage for depth data with camera parameters
### 📦 `npz`
- **Description**: Full NPZ format with comprehensive data
- **Contents**: `depth`, `conf`, `exts`, `ixts`, `image`, etc.
- **Use case**: Complete data export for advanced processing
### 🌐 `glb`
- **Description**: 3D visualization format with point cloud and camera poses
- **Contents**:
- Point cloud with colors from original images
- Camera wireframes for visualization
- Confidence-based filtering and downsampling
- **Use case**: 3D visualization, inspection, and analysis
- **Features**:
- Automatic sky depth handling
- Confidence threshold filtering
- Background filtering (black/white)
- Scene scale normalization
- **Parameters** (passed via `inference()` method directly):
- `conf_thresh_percentile` (float, default: 40.0): Lower percentile for adaptive confidence threshold. Points below this confidence percentile will be filtered out.
- `num_max_points` (int, default: 1,000,000): Maximum number of points in the exported point cloud. If exceeded, points will be downsampled.
- `show_cameras` (bool, default: True): Whether to include camera wireframes in the exported GLB file for visualization.
### ✨ `gs_ply`
- **Description**: Gaussian Splatting point cloud format
- **Contents**: 3DGS data in PLY format. Compatible with standard 3DGS viewers such as [SuperSplat](https://superspl.at/editor) (recommended), [SPARK](https://sparkjs.dev/viewer/).
- **Use case**: Gaussian Splatting reconstruction
- **Requirements**: Must set `infer_gs=True` when calling `inference()`. Only supported by `da3-giant` and `da3nested-giant-large` models.
- **Additional configs**, provided via `export_kwargs` (see [Export Parameters](#export-parameters)):
- `gs_views_interval`: Export to 3DGS every N views, default: `1`.
### 🎥 `gs_video`
- **Description**: Rasterized 3DGS to obtain videos
- **Contents**: A video of 3DGS-rasterized views using either provided viewpoints or a predefined camera trajectory.
- **Use case**: Video rendering for Gaussian Splatting
- **Requirements**: Must set `infer_gs=True` when calling `inference()`. Only supported by `da3-giant` and `da3nested-giant-large` models.
- **Note**: Can optionally use `render_exts`, `render_ixts`, and `render_hw` parameters in `inference()` method to specify novel viewpoints.
- **Additional configs**, provided via `export_kwargs` (see [Export Parameters](#export-parameters)):
- `extrinsics`: Optional world-to-camera poses for novel views. Falls back to the predicted poses of input views if not provided. (Alternatively, use `render_exts` parameter in `inference()`)
- `intrinsics`: Optional camera intrinsics for novel views. Falls back to the predicted intrinsics of input views if not provided. (Alternatively, use `render_ixts` parameter in `inference()`)
- `out_image_hw`: Optional output resolution `H x W`. Falls back to input resolution if not provided. (Alternatively, use `render_hw` parameter in `inference()`)
- `chunk_size`: Number of views rasterized per batch. Default: `8`.
- `trj_mode`: Predefined camera trajectory for novel-view rendering.
- `color_mode`: Same as `render_mode` in [gsplat](https://docs.gsplat.studio/main/apis/rasterization.html#gsplat.rasterization).
- `vis_depth`: How depth is combined with RGB. Default: `hcat` (horizontal concatenation).
- `enable_tqdm`: Whether to display a tqdm progress bar during rendering.
- `output_name`: File name of the rendered video.
- `video_quality`: Video quality to save. Default: `high`.
- `high`: High quality video (default)
- `medium`: Medium quality video (balance of storage space and quality)
- `low`: Low quality video (fewer storage space)
### 🔍 `feat_vis`
- **Description**: Feature visualization format
- **Contents**: PCA-visualized intermediate features from specified layers
- **Use case**: Model interpretability and feature analysis
- **Note**: Requires `export_feat_layers` to be specified
- **Parameters** (passed via `inference()` method directly):
- `feat_vis_fps` (int, default: 15): Frame rate for the output video when visualizing features across multiple images.
### 🎨 `depth_vis`
- **Description**: Depth visualization format
- **Contents**: Color-coded depth maps alongside original images
- **Use case**: Visual inspection of depth estimation quality
### 🔗 Multiple Format Export
You can export multiple formats simultaneously by separating them with `-`:
```python
# Export both mini_npz and glb formats
export_format = "mini_npz-glb"
# Export multiple formats
export_format = "npz-glb-gs_ply"
```
## ↩️ Return Value
The `inference()` method returns a `Prediction` object with the following attributes:
### 📊 Core Outputs
- **depth**: `np.ndarray` - Estimated depth maps with shape `(N, H, W)` where N is the number of images, H is height, and W is width.
- **conf**: `np.ndarray` - Confidence maps with shape `(N, H, W)` indicating prediction reliability (optional, depends on model).
### 📷 Camera Parameters
- **extrinsics**: `np.ndarray` - Camera extrinsic matrices with shape `(N, 3, 4)` representing world-to-camera transformations. Only present if camera poses were estimated or provided as input.
- **intrinsics**: `np.ndarray` - Camera intrinsic matrices with shape `(N, 3, 3)` containing focal length and principal point information. Only present if poses were estimated or provided as input.
### 🎁 Additional Outputs
- **processed_images**: `np.ndarray` - Preprocessed input images with shape `(N, H, W, 3)` in RGB format (0-255 uint8).
- **aux**: `dict` - Auxiliary outputs including:
- `feat_layer_X`: Intermediate features from layer X (if `export_feat_layers` was specified)
- `gaussians`: 3D Gaussian Splats data (if `infer_gs=True`)
### 💻 Usage Example
```python
prediction = model.inference(image=["img1.jpg", "img2.jpg"])
# Access depth maps
depth_maps = prediction.depth # shape: (2, H, W)
# Access confidence
if hasattr(prediction, 'conf'):
confidence = prediction.conf
# Access camera parameters (if available)
if hasattr(prediction, 'extrinsics'):
camera_poses = prediction.extrinsics # shape: (2, 4, 4)
if hasattr(prediction, 'intrinsics'):
camera_intrinsics = prediction.intrinsics # shape: (2, 3, 3)
# Access intermediate features (if export_feat_layers was set)
if hasattr(prediction, 'aux') and 'feat_layer_0' in prediction.aux:
features = prediction.aux['feat_layer_0']
```

View File

@@ -0,0 +1,654 @@
# 🚀 Depth Anything 3 Command Line Interface
## 📋 Table of Contents
- [📖 Overview](#overview)
- [⚡ Quick Start](#quick-start)
- [📚 Command Reference](#command-reference)
- [🤖 auto - Auto Mode](#auto---auto-mode)
- [🖼️ image - Single Image Processing](#image---single-image-processing)
- [🗂️ images - Image Directory Processing](#images---image-directory-processing)
- [🎬 video - Video Processing](#video---video-processing)
- [📐 colmap - COLMAP Dataset Processing](#colmap---colmap-dataset-processing)
- [🔧 backend - Backend Service](#backend---backend-service)
- [🎨 gradio - Gradio Application](#gradio---gradio-application)
- [🖼️ gallery - Gallery Server](#gallery---gallery-server)
- [⚙️ Parameter Details](#parameter-details)
- [💡 Usage Examples](#usage-examples)
## 📖 Overview
The Depth Anything 3 CLI provides a comprehensive command-line toolkit supporting image depth estimation, video processing, COLMAP dataset handling, and web applications.
The backend service enables cache model to GPU so that we do not need to reload model for each command.
## ⚡ Quick Start
The CLI can run fully offline or connect to the backend for cached weights and task scheduling:
```bash
# 🔧 Start backend service (optional, keeps model resident in GPU memory)
da3 backend --model-dir depth-anything/DA3NESTED-GIANT-LARGE
# 🚀 Use auto mode to process input
da3 auto path/to/input --export-dir ./workspace/scene001
# ♻️ Reuse backend for next job
da3 auto path/to/video.mp4 \
--export-dir ./workspace/scene002 \
--use-backend \
--backend-url http://localhost:8008
```
Each export directory contains `scene.glb`, `scene.jpg`, and optional extras such as `depth_vis/` or `gs_video/` depending on the requested format.
## 📚 Command Reference
### 🤖 auto - Auto Mode
Automatically detect input type and dispatch to the appropriate handler.
**Usage:**
```bash
da3 auto INPUT_PATH [OPTIONS]
```
**Input Type Detection:**
- 🖼️ Single image file (.jpg, .png, .jpeg, .webp, .bmp, .tiff, .tif)
- 📁 Image directory
- 🎬 Video file (.mp4, .avi, .mov, .mkv, .flv, .wmv, .webm, .m4v)
- 📐 COLMAP directory (containing `images/` and `sparse/` subdirectories)
**Parameters:**
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `INPUT_PATH` | str | Required | Input path (image, directory, video, or COLMAP) |
| `--model-dir` | str | Default model | Model directory path |
| `--export-dir` | str | `debug` | Export directory |
| `--export-format` | str | `glb` | Export format (supports `mini_npz`, `glb`, `feat_vis`, etc., can be combined with hyphens) |
| `--device` | str | `cuda` | Device to use |
| `--use-backend` | bool | `False` | Use backend service for inference |
| `--backend-url` | str | `http://localhost:8008` | Backend service URL |
| `--process-res` | int | `504` | Processing resolution |
| `--process-res-method` | str | `upper_bound_resize` | Processing resolution method |
| `--export-feat` | str | `""` | Export features from specified layers, comma-separated (e.g., `"0,1,2"`) |
| `--auto-cleanup` | bool | `False` | Automatically clean export directory without confirmation |
| `--fps` | float | `1.0` | [Video] Frame sampling FPS |
| `--sparse-subdir` | str | `""` | [COLMAP] Sparse reconstruction subdirectory (e.g., `"0"` for `sparse/0/`) |
| `--align-to-input-ext-scale` | bool | `True` | [COLMAP] Align prediction to input extrinsics scale |
| `--use-ray-pose` | bool | `False` | Use ray-based pose estimation instead of camera decoder |
| `--ref-view-strategy` | str | `saddle_balanced` | Reference view selection strategy: `first`, `middle`, `saddle_balanced`, `saddle_sim_range`. See [docs](funcs/ref_view_strategy.md) |
| `--conf-thresh-percentile` | float | `40.0` | [GLB] Lower percentile for adaptive confidence threshold |
| `--num-max-points` | int | `1000000` | [GLB] Maximum number of points in the point cloud |
| `--show-cameras` | bool | `True` | [GLB] Show camera wireframes in the exported scene |
| `--feat-vis-fps` | int | `15` | [FEAT_VIS] Frame rate for output video |
**Examples:**
```bash
# 🖼️ Auto-process an image
da3 auto path/to/image.jpg --export-dir ./output
# 🎬 Auto-process a video
da3 auto path/to/video.mp4 --fps 2.0 --export-dir ./output
# 🔧 Use backend service
da3 auto path/to/input \
--export-format mini_npz-glb \
--use-backend \
--backend-url http://localhost:8008 \
--export-dir ./output
```
---
### 🖼️ image - Single Image Processing
Process a single image for camera pose and depth estimation.
**Usage:**
```bash
da3 image IMAGE_PATH [OPTIONS]
```
**Parameters:**
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `IMAGE_PATH` | str | Required | Input image file path |
| `--model-dir` | str | Default model | Model directory path |
| `--export-dir` | str | `debug` | Export directory |
| `--export-format` | str | `glb` | Export format |
| `--device` | str | `cuda` | Device to use |
| `--use-backend` | bool | `False` | Use backend service for inference |
| `--backend-url` | str | `http://localhost:8008` | Backend service URL |
| `--process-res` | int | `504` | Processing resolution |
| `--process-res-method` | str | `upper_bound_resize` | Processing resolution method |
| `--export-feat` | str | `""` | Export feature layer indices (comma-separated) |
| `--auto-cleanup` | bool | `False` | Automatically clean export directory |
| `--use-ray-pose` | bool | `False` | Use ray-based pose estimation instead of camera decoder |
| `--ref-view-strategy` | str | `saddle_balanced` | Reference view selection strategy. See [docs](funcs/ref_view_strategy.md) |
| `--conf-thresh-percentile` | float | `40.0` | [GLB] Confidence threshold percentile |
| `--num-max-points` | int | `1000000` | [GLB] Maximum number of points |
| `--show-cameras` | bool | `True` | [GLB] Show cameras |
| `--feat-vis-fps` | int | `15` | [FEAT_VIS] Video frame rate |
**Examples:**
```bash
# ✨ Basic usage
da3 image path/to/image.png --export-dir ./output
# ⚡ With backend acceleration
da3 image path/to/image.png \
--use-backend \
--backend-url http://localhost:8008 \
--export-dir ./output
# 🔍 Export feature visualization
da3 image image.jpg \
--export-format feat_vis \
--export-feat "9,19,29,39" \
--export-dir ./results
```
---
### 🗂️ images - Image Directory Processing
Process a directory of images for batch depth estimation.
**Usage:**
```bash
da3 images IMAGES_DIR [OPTIONS]
```
**Parameters:**
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `IMAGES_DIR` | str | Required | Directory path containing images |
| `--image-extensions` | str | `png,jpg,jpeg` | Image file extensions to process (comma-separated) |
| `--model-dir` | str | Default model | Model directory path |
| `--export-dir` | str | `debug` | Export directory |
| `--export-format` | str | `glb` | Export format |
| `--device` | str | `cuda` | Device to use |
| `--use-backend` | bool | `False` | Use backend service for inference |
| `--backend-url` | str | `http://localhost:8008` | Backend service URL |
| `--process-res` | int | `504` | Processing resolution |
| `--process-res-method` | str | `upper_bound_resize` | Processing resolution method |
| `--export-feat` | str | `""` | Export feature layer indices |
| `--auto-cleanup` | bool | `False` | Automatically clean export directory |
| `--use-ray-pose` | bool | `False` | Use ray-based pose estimation instead of camera decoder |
| `--ref-view-strategy` | str | `saddle_balanced` | Reference view selection strategy. See [docs](funcs/ref_view_strategy.md) |
| `--conf-thresh-percentile` | float | `40.0` | [GLB] Confidence threshold percentile |
| `--num-max-points` | int | `1000000` | [GLB] Maximum number of points |
| `--show-cameras` | bool | `True` | [GLB] Show cameras |
| `--feat-vis-fps` | int | `15` | [FEAT_VIS] Video frame rate |
**Examples:**
```bash
# 📁 Process directory (defaults to png/jpg/jpeg)
da3 images ./image_folder --export-dir ./output
# 🎯 Custom extensions
da3 images ./dataset --image-extensions "png,jpg,webp" --export-dir ./output
# 🔧 Use backend service
da3 images ./dataset \
--use-backend \
--backend-url http://localhost:8008 \
--export-dir ./output
```
---
### 🎬 video - Video Processing
Process video by extracting frames for depth estimation.
**Usage:**
```bash
da3 video VIDEO_PATH [OPTIONS]
```
**Parameters:**
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `VIDEO_PATH` | str | Required | Input video file path |
| `--fps` | float | `1.0` | Frame extraction sampling FPS |
| `--model-dir` | str | Default model | Model directory path |
| `--export-dir` | str | `debug` | Export directory |
| `--export-format` | str | `glb` | Export format |
| `--device` | str | `cuda` | Device to use |
| `--use-backend` | bool | `False` | Use backend service for inference |
| `--backend-url` | str | `http://localhost:8008` | Backend service URL |
| `--process-res` | int | `504` | Processing resolution |
| `--process-res-method` | str | `upper_bound_resize` | Processing resolution method |
| `--export-feat` | str | `""` | Export feature layer indices |
| `--auto-cleanup` | bool | `False` | Automatically clean export directory |
| `--use-ray-pose` | bool | `False` | Use ray-based pose estimation instead of camera decoder |
| `--ref-view-strategy` | str | `saddle_balanced` | Reference view selection strategy. See [docs](funcs/ref_view_strategy.md) |
| `--conf-thresh-percentile` | float | `40.0` | [GLB] Confidence threshold percentile |
| `--num-max-points` | int | `1000000` | [GLB] Maximum number of points |
| `--show-cameras` | bool | `True` | [GLB] Show cameras |
| `--feat-vis-fps` | int | `15` | [FEAT_VIS] Video frame rate |
**Examples:**
```bash
# ✨ Basic video processing
da3 video path/to/video.mp4 --export-dir ./output
# ⚙️ Control frame sampling and resolution
da3 video path/to/video.mp4 \
--fps 2.0 \
--process-res 1024 \
--export-dir ./output
# 🔧 Use backend service
da3 video path/to/video.mp4 \
--use-backend \
--backend-url http://localhost:8008 \
--export-dir ./output
```
---
### 📐 colmap - COLMAP Dataset Processing
Run pose-conditioned depth estimation on COLMAP data.
**Usage:**
```bash
da3 colmap COLMAP_DIR [OPTIONS]
```
**Parameters:**
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `COLMAP_DIR` | str | Required | COLMAP directory containing `images/` and `sparse/` subdirectories |
| `--sparse-subdir` | str | `""` | Sparse reconstruction subdirectory (e.g., `"0"` for `sparse/0/`) |
| `--align-to-input-ext-scale` | bool | `True` | Align prediction to input extrinsics scale |
| `--model-dir` | str | Default model | Model directory path |
| `--export-dir` | str | `debug` | Export directory |
| `--export-format` | str | `glb` | Export format |
| `--device` | str | `cuda` | Device to use |
| `--use-backend` | bool | `False` | Use backend service for inference |
| `--backend-url` | str | `http://localhost:8008` | Backend service URL |
| `--process-res` | int | `504` | Processing resolution |
| `--process-res-method` | str | `upper_bound_resize` | Processing resolution method |
| `--export-feat` | str | `""` | Export feature layer indices |
| `--auto-cleanup` | bool | `False` | Automatically clean export directory |
| `--use-ray-pose` | bool | `False` | Use ray-based pose estimation instead of camera decoder |
| `--ref-view-strategy` | str | `saddle_balanced` | Reference view selection strategy. See [docs](funcs/ref_view_strategy.md) |
| `--conf-thresh-percentile` | float | `40.0` | [GLB] Confidence threshold percentile |
| `--num-max-points` | int | `1000000` | [GLB] Maximum number of points |
| `--show-cameras` | bool | `True` | [GLB] Show cameras |
| `--feat-vis-fps` | int | `15` | [FEAT_VIS] Video frame rate |
**Examples:**
```bash
# 📐 Process COLMAP dataset
da3 colmap ./colmap_dataset --export-dir ./output
# 🎯 Use specific sparse subdirectory and align scale
da3 colmap ./colmap_dataset \
--sparse-subdir 0 \
--align-to-input-ext-scale \
--export-dir ./output
# 🔧 Use backend service
da3 colmap ./colmap_dataset \
--use-backend \
--backend-url http://localhost:8008 \
--export-dir ./output
```
---
### 🔧 backend - Backend Service
Start model backend service with integrated gallery.
**Usage:**
```bash
da3 backend [OPTIONS]
```
**Parameters:**
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `--model-dir` | str | Default model | Model directory path |
| `--device` | str | `cuda` | Device to use |
| `--host` | str | `127.0.0.1` | Host address to bind to |
| `--port` | int | `8008` | Port number to bind to |
| `--gallery-dir` | str | Default gallery dir | Gallery directory path (optional) |
**Features:**
- 🎯 Keeps model resident in GPU memory
- 🔌 Provides REST inference API
- 📊 Integrated dashboard and status monitoring
- 🖼️ Optional gallery browser (if `--gallery-dir` is provided)
**Available Endpoints:**
- 🏠 `/` - Home page
- 📊 `/dashboard` - Dashboard
-`/status` - API status
- 🖼️ `/gallery/` - Gallery browser (if enabled)
**Examples:**
```bash
# 🚀 Basic backend service
da3 backend --model-dir depth-anything/DA3NESTED-GIANT-LARGE
# 🖼️ Backend with gallery
da3 backend \
--model-dir depth-anything/DA3NESTED-GIANT-LARGE \
--device cuda \
--host 0.0.0.0 \
--port 8008 \
--gallery-dir ./workspace
# 💻 Use CPU
da3 backend --model-dir depth-anything/DA3NESTED-GIANT-LARGE --device cpu
```
---
### 🎨 gradio - Gradio Application
Launch Depth Anything 3 Gradio interactive web application.
**Usage:**
```bash
da3 gradio [OPTIONS]
```
**Parameters:**
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `--model-dir` | str | Required | Model directory path |
| `--workspace-dir` | str | Required | Workspace directory path |
| `--gallery-dir` | str | Required | Gallery directory path |
| `--host` | str | `127.0.0.1` | Host address to bind to |
| `--port` | int | `7860` | Port number to bind to |
| `--share` | bool | `False` | Create a public link |
| `--debug` | bool | `False` | Enable debug mode |
| `--cache-examples` | bool | `False` | Pre-cache all example scenes at startup |
| `--cache-gs-tag` | str | `""` | Tag to match scene names for high-res+3DGS caching |
**Examples:**
```bash
# 🎨 Basic Gradio application
da3 gradio \
--model-dir depth-anything/DA3NESTED-GIANT-LARGE \
--workspace-dir ./workspace \
--gallery-dir ./gallery
# 🌐 Enable sharing and debug
da3 gradio \
--model-dir depth-anything/DA3NESTED-GIANT-LARGE \
--workspace-dir ./workspace \
--gallery-dir ./gallery \
--share \
--debug
# ⚡ Pre-cache examples
da3 gradio \
--model-dir depth-anything/DA3NESTED-GIANT-LARGE \
--workspace-dir ./workspace \
--gallery-dir ./gallery \
--cache-examples \
--cache-gs-tag "dl3dv"
```
---
### 🖼️ gallery - Gallery Server
Launch standalone Depth Anything 3 Gallery server.
**Usage:**
```bash
da3 gallery [OPTIONS]
```
**Parameters:**
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `--gallery-dir` | str | Default gallery dir | Gallery root directory |
| `--host` | str | `127.0.0.1` | Host address to bind to |
| `--port` | int | `8007` | Port number to bind to |
| `--open-browser` | bool | `False` | Open browser after launch |
**Note:**
The gallery expects each scene folder to contain at least `scene.glb` and `scene.jpg`, with optional subfolders such as `depth_vis/` or `gs_video/`.
**Examples:**
```bash
# 🖼️ Basic gallery server
da3 gallery --gallery-dir ./workspace
# 🌐 Custom host and port
da3 gallery \
--gallery-dir ./workspace \
--host 0.0.0.0 \
--port 8007
# 🚀 Auto-open browser
da3 gallery --gallery-dir ./workspace --open-browser
```
---
## ⚙️ Parameter Details
### 🔧 Common Parameters
- **`--export-dir`**: Output directory, defaults to `debug`
- **`--export-format`**: Export format, supports combining multiple formats with hyphens:
- 📦 `mini_npz`: Compressed NumPy format
- 🎨 `glb`: glTF binary format (3D scene)
- 🔍 `feat_vis`: Feature visualization
- Example: `mini_npz-glb` exports both formats
- **`--process-res`** / **`--process-res-method`**: Control preprocessing resolution strategy
- `process-res`: Target resolution (default 504)
- `process-res-method`: Resize method (default `upper_bound_resize`)
- **`--auto-cleanup`**: Remove existing export directory without confirmation
- **`--use-backend`** / **`--backend-url`**: Reuse running backend service
- ⚡ Reduces model loading time
- 🌐 Supports distributed processing
- **`--export-feat`**: Layer indices for exporting intermediate features (comma-separated)
- Example: `"9,19,29,39"`
### 🎨 GLB Export Parameters
- **`--conf-thresh-percentile`**: Lower percentile for adaptive confidence threshold (default 40.0)
- Used to filter low-confidence points
- **`--num-max-points`**: Maximum number of points in point cloud (default 1,000,000)
- Controls output file size and performance
- **`--show-cameras`**: Show camera wireframes in exported scene (default True)
### 🔍 Feature Visualization Parameters
- **`--feat-vis-fps`**: Frame rate for feature visualization output video (default 15)
### 🎬 Video-Specific Parameters
- **`--fps`**: Video frame extraction sampling rate (default 1.0 FPS)
- Higher values extract more frames
### 📐 COLMAP-Specific Parameters
- **`--sparse-subdir`**: Sparse reconstruction subdirectory
- Empty string uses `sparse/` directory
- `"0"` uses `sparse/0/` directory
- **`--align-to-input-ext-scale`**: Align prediction to input extrinsics scale (default True)
- Ensures depth estimation is consistent with COLMAP scale
---
## 💡 Usage Examples
### 1⃣ Basic Workflow
```bash
# 🔧 Start backend service
da3 backend --model-dir depth-anything/DA3NESTED-GIANT-LARGE --host 0.0.0.0 --port 8008
# 🖼️ Process single image
da3 image image.jpg --export-dir ./output1 --use-backend
# 🎬 Process video
da3 video video.mp4 --fps 2.0 --export-dir ./output2 --use-backend
# 📐 Process COLMAP dataset
da3 colmap ./colmap_data --export-dir ./output3 --use-backend
```
### 2⃣ Using Auto Mode
```bash
# 🤖 Auto-detect and process
da3 auto ./unknown_input --export-dir ./output
# ⚡ With backend acceleration
da3 auto ./unknown_input \
--use-backend \
--backend-url http://localhost:8008 \
--export-dir ./output
```
### 3⃣ Multi-Format Export
```bash
# 📦 Export both NPZ and GLB formats
da3 auto assets/examples/SOH \
--export-format mini_npz-glb \
--export-dir ./workspace/soh
# 🔍 Export feature visualization
da3 image image.jpg \
--export-format feat_vis \
--export-feat "9,19,29,39" \
--export-dir ./results
```
### 4⃣ Advanced Configuration
```bash
# ⚙️ Custom resolution and point cloud density
da3 image image.jpg \
--process-res 1024 \
--num-max-points 2000000 \
--conf-thresh-percentile 30.0 \
--export-dir ./output
# 📐 COLMAP advanced options
da3 colmap ./colmap_data \
--sparse-subdir 0 \
--align-to-input-ext-scale \
--process-res 756 \
--export-dir ./output
```
### 5⃣ Batch Processing Workflow
```bash
# 🔧 Start backend
da3 backend \
--model-dir depth-anything/DA3NESTED-GIANT-LARGE \
--device cuda \
--host 0.0.0.0 \
--port 8008 \
--gallery-dir ./workspace
# 🔄 Batch process multiple scenes
for scene in scene1 scene2 scene3; do
da3 auto ./data/$scene \
--export-dir ./workspace/$scene \
--use-backend \
--auto-cleanup
done
# 🖼️ Launch gallery to view results
da3 gallery --gallery-dir ./workspace --open-browser
```
### 6⃣ Web Applications
```bash
# 🎨 Launch Gradio application
da3 gradio \
--model-dir depth-anything/DA3NESTED-GIANT-LARGE \
--workspace-dir workspace/gradio \
--gallery-dir ./gallery \
--host 0.0.0.0 \
--port 7860 \
--share
```
### 7⃣ Transformer Feature Visualization
```bash
# 🔍 Export Transformer features
# 📦 Combined with numerical output
da3 auto video.mp4 \
--export-format glb-feat_vis \
--export-feat "11,21,31" \
--export-dir ./debug \
--use-backend
```
---
## 📝 Notes
1. **🔧 Backend Service**: Recommended for processing multiple tasks to improve efficiency
2. **💾 GPU Memory**: Be mindful of GPU memory usage when processing high-resolution inputs
3. **📁 Export Directory**: Use `--auto-cleanup` to avoid manual confirmation for deletion
4. **🔀 Format Combination**: Multiple export formats can be combined with hyphens (e.g., `mini_npz-glb-feat_vis`)
5. **📐 COLMAP Data**: Ensure COLMAP directory structure is correct (contains `images/` and `sparse/` subdirectories)
---
## ❓ Getting Help
View detailed help for any command:
```bash
# 📖 View main help
da3 --help
# 🔍 View specific command help
da3 auto --help
da3 image --help
da3 backend --help
```

View File

@@ -0,0 +1,183 @@
# 📐 Reference View Selection Strategy
## 📖 Overview
Reference view selection is a component in multi-view depth estimation. When processing multiple input views, the model needs to determine which view should serve as the primary reference frame for depth prediction, defining the world coordinate system.
Different reference view will leads to different reconstruction results. This is a known consideration in multi-view geometry and was analyzed in [PI3](https://arxiv.org/abs/2507.13347). The choice of reference view can affect the quality and consistency of depth predictions across the scene.
## 🚀 Our Simple Solution: Automatic Reference View Selection
DA3 provides a simple approach to address this through **automatic reference view selection** based on **class tokens**. Instead of relying on heuristics or manual selection, the model analyzes the class token features from all input views and intelligently selects the most suitable reference frame.
---
## 🎨 Available Strategies
### 1. ⚖️ `saddle_balanced` (Recommended, Default)
**Philosophy:**
Select a view that achieves balance across multiple feature metrics. This strategy looks for a "middle ground" view that is neither too similar nor too different from other views, making it a stable reference point.
**How it works:**
1. Extracts and normalizes class tokens from all views
2. Computes three complementary metrics for each view:
- **Similarity score**: Average cosine similarity with other views
- **Feature norm**: L2 norm of the original features
- **Feature variance**: Variance across feature dimensions
3. Normalizes each metric to [0, 1] range
4. Selects the view closest to 0.5 (median) across all three metrics
### 2. 🎢 `saddle_sim_range`
**Philosophy:**
Select a view with the largest similarity range to other views. This identifies "saddle point" views that are highly similar to some views but dissimilar to others, making them information-rich anchor points.
**How it works:**
1. Computes pairwise cosine similarity between all views
2. For each view, calculates the range (max - min) of similarities to other views
3. Selects the view with the maximum similarity range
---
### 3. 1⃣ `first` (Not Recommended)
**Philosophy:**
Always use the first view in the input sequence as the reference.
**How it works:**
Simply returns index 0.
**When to use:**
-**Not recommended** in general
- 🔧 Only use when you have manually pre-sorted your views and know the first view is optimal
- 🐛 Debugging or baseline comparisons
---
### 4. ⏸️ `middle`
**Philosophy:**
Select the view in the middle of the input sequence.
**How it works:**
Returns the view at index `S // 2` where S is the number of views.
**When to use:**
- ⏱️ **Only recommended when input images are temporally ordered**
- 🎬 Video sequences (e.g., **DA3-LONG** setting)
- 📹 Sequential captures where the middle frame likely has the most stable viewpoint
**Specific use case: DA3-LONG** 🎬
In video-based depth estimation scenarios (like DA3-LONG), where inputs are consecutive frames, `middle` is often the **optimal choice** because that it has maximum overlap with all other frames.
## 💻 Usage
### 🐍 Python API
```python
from depth_anything_3 import DepthAnything3
model = DepthAnything3.from_pretrained("depth-anything/DA3NESTED-GIANT-LARGE")
# Use default (saddle_balanced)
prediction = model.inference(
images,
ref_view_strategy="saddle_balanced"
)
# For video sequences, consider using middle
prediction = model.inference(
video_frames,
ref_view_strategy="middle" # Good for temporal sequences
)
# For complex scenes with wide baselines
prediction = model.inference(
images,
ref_view_strategy="saddle_sim_range"
)
```
### 🖥️ Command Line Interface
```bash
# Default (saddle_balanced)
da3 auto input/ --export-dir output/
# Explicitly specify strategy
da3 auto input/ --ref-view-strategy saddle_balanced
# For video processing
da3 video input.mp4 --ref-view-strategy middle
# For wide-baseline multi-view
da3 images captures/ --ref-view-strategy saddle_sim_range
```
---
### 🎯 When Selection Is Applied
Reference view selection is applied when:
- 3⃣ Number of views S ≥ 3
---
## 💡 Recommendations
### 📋 Quick Guide
| Scenario | Recommended Strategy | Rationale |
|----------|---------------------|-----------|
| **Default / Unknown** | `saddle_balanced` | Robust, balanced, works well across diverse scenarios |
| **Video frames** | `middle` | Temporal coherence, stable middle frame |
| **Wide-baseline multi-view** | `saddle_sim_range` | Maximizes information coverage |
| **Pre-sorted inputs** | `first` | Use only if you've manually optimized ordering |
| **Single image** | `first` | Automatically used (no reordering needed for S ≤ 2) |
### ✨ Best Practices
1. 🎯 **Start with defaults**: `saddle_balanced` works well in most cases
2. 🎬 **Consider your input type**: Use `middle` for videos, `saddle_balanced` for photos
3. 🔬 **Experiment if needed**: Try different strategies if results are suboptimal
4. 📊 **Monitor performance**: Check `glb` quality and consistency across views.
---
## 🔧 Technical Details
### 🎚️ Selection Threshold
The reference view selection is only triggered when:
```python
num_views >= 3 # At least 3 views required
```
For 1-2 views, no reordering is performed (equivalent to using `first`).
### ⚙️ Implementation
The selection happens at layer `alt_start - 1` in the vision transformer, before the first global attention layer. This ensures the selected reference view influences the entire depth prediction pipeline.
---
## ❓ FAQ
**Q: 🤔 Why is this feature provided?**
A: The model can handle any view order, but this feature provides automatic optimization for reference view selection, which can help improve depth prediction quality in multi-view scenarios.
**Q: ⏱️ Does this add computational cost?**
A: The overhead is totally negligible.
**Q: 🎮 Can I manually specify which view to use as reference?**
A: Not directly through this parameter. You can pre-sort your input images to place your preferred reference view first and use `ref_view_strategy="first"`.
**Q: ⚙️ What happens if I don't specify this parameter?**
A: The default `saddle_balanced` strategy is used automatically.
**Q: 📊 Is this feature used in the DA3 paper benchmarks?**
A: No, the paper used `first` as the default strategy for all multi-view experiments. The current default has been updated to `saddle_balanced` for better robustness.

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,94 @@
[build-system]
requires = ["hatchling>=1.25", "hatch-vcs>=0.4"]
build-backend = "hatchling.build"
[project]
name = "depth-anything-3"
version = "0.0.0"
description = "Depth Anything 3"
readme = "README.md"
requires-python = ">=3.9, <=3.13"
license = { text = "Apache-2.0" }
authors = [{ name = "Your Name" }]
dependencies = [
"pre-commit",
"trimesh",
"torch>=2",
"torchvision",
"einops",
"huggingface_hub",
"imageio",
"numpy<2",
"opencv-python",
"xformers",
"open3d",
"fastapi",
"uvicorn",
"requests",
"typer",
"pillow",
"omegaconf",
"evo",
"e3nn",
"moviepy",
"plyfile",
"pillow_heif",
"safetensors",
"uvicorn",
"moviepy==1.0.3",
"typer>=0.9.0",
"pycolmap",
]
[project.optional-dependencies]
app = ["gradio>=5", "pillow>=9.0"] # requires that python3>=3.10
gs = ["gsplat @ git+https://github.com/nerfstudio-project/gsplat.git@0b4dddf04cb687367602c01196913cde6a743d70"]
all = ["depth-anything-3[app,gs]"]
[project.scripts]
da3 = "depth_anything_3.cli:app"
[project.urls]
Homepage = "https://github.com/ByteDance-Seed/Depth-Anything-3"
[tool.hatch.version]
source = "vcs"
[tool.hatch.build.targets.wheel]
packages = ["src/depth_anything_3"]
[tool.hatch.build.targets.sdist]
include = [
"/README.md",
"/pyproject.toml",
"/src/depth_anything_3",
]
[tool.hatch.metadata]
allow-direct-references = true
[tool.mypy]
plugins = ["jaxtyping.mypy_plugin"]
[tool.black]
line-length = 99
target-version = ['py37', 'py38', 'py39', 'py310', 'py311']
include = '\.pyi?$'
exclude = '''
/(
| \.git
)/
'''
[tool.isort]
profile = "black"
multi_line_output = 3
include_trailing_comma = true
known_third_party = ["bson","cruise","cv2","dataloader","diffusers","omegaconf","tensorflow","torch","torchvision","transformers","gsplat"]
known_first_party = ["common", "data", "models", "projects"]
sections = ["FUTURE","STDLIB","THIRDPARTY","FIRSTPARTY","LOCALFOLDER"]
skip_gitignore = true
line_length = 99
no_lines_before="THIRDPARTY"

View File

@@ -0,0 +1,24 @@
pre-commit
trimesh
torch>=2
torchvision
einops
huggingface_hub
imageio
numpy<2
opencv-python
xformers
open3d
fastapi
uvicorn
requests
typer
pillow
omegaconf
evo
e3nn
moviepy
plyfile
pillow_heif
safetensors
pycolmap

View File

@@ -0,0 +1,446 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Depth Anything 3 API module.
This module provides the main API for Depth Anything 3, including model loading,
inference, and export capabilities. It supports both single and nested model architectures.
"""
from __future__ import annotations
import time
from typing import Optional, Sequence
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from PIL import Image
from depth_anything_3.cfg import create_object, load_config
from depth_anything_3.registry import MODEL_REGISTRY
from depth_anything_3.specs import Prediction
from depth_anything_3.utils.export import export
from depth_anything_3.utils.geometry import affine_inverse
from depth_anything_3.utils.io.input_processor import InputProcessor
from depth_anything_3.utils.io.output_processor import OutputProcessor
from depth_anything_3.utils.logger import logger
from depth_anything_3.utils.pose_align import align_poses_umeyama
torch.backends.cudnn.benchmark = False
# logger.info("CUDNN Benchmark Disabled")
SAFETENSORS_NAME = "model.safetensors"
CONFIG_NAME = "config.json"
class DepthAnything3(nn.Module, PyTorchModelHubMixin):
"""
Depth Anything 3 main API class.
This class provides a high-level interface for depth estimation using Depth Anything 3.
It supports both single and nested model architectures with metric scaling capabilities.
Features:
- Hugging Face Hub integration via PyTorchModelHubMixin
- Support for multiple model presets (vitb, vitg, nested variants)
- Automatic mixed precision inference
- Export capabilities for various formats (GLB, PLY, NPZ, etc.)
- Camera pose estimation and metric depth scaling
Usage:
# Load from Hugging Face Hub
model = DepthAnything3.from_pretrained("huggingface/model-name")
# Or create with specific preset
model = DepthAnything3(preset="vitg")
# Run inference
prediction = model.inference(images, export_dir="output", export_format="glb")
"""
_commit_hash: str | None = None # Set by mixin when loading from Hub
def __init__(self, model_name: str = "da3-large", **kwargs):
"""
Initialize DepthAnything3 with specified preset.
Args:
model_name: The name of the model preset to use.
Examples: 'da3-giant', 'da3-large', 'da3metric-large', 'da3nested-giant-large'.
**kwargs: Additional keyword arguments (currently unused).
"""
super().__init__()
self.model_name = model_name
# Build the underlying network
self.config = load_config(MODEL_REGISTRY[self.model_name])
self.model = create_object(self.config)
self.model.eval()
# Initialize processors
self.input_processor = InputProcessor()
self.output_processor = OutputProcessor()
# Device management (set by user)
self.device = None
@torch.inference_mode()
def forward(
self,
image: torch.Tensor,
extrinsics: torch.Tensor | None = None,
intrinsics: torch.Tensor | None = None,
export_feat_layers: list[int] | None = None,
infer_gs: bool = False,
use_ray_pose: bool = False,
ref_view_strategy: str = "saddle_balanced",
) -> dict[str, torch.Tensor]:
"""
Forward pass through the model.
Args:
image: Input batch with shape ``(B, N, 3, H, W)`` on the model device.
extrinsics: Optional camera extrinsics with shape ``(B, N, 4, 4)``.
intrinsics: Optional camera intrinsics with shape ``(B, N, 3, 3)``.
export_feat_layers: Layer indices to return intermediate features for.
infer_gs: Enable Gaussian Splatting branch.
use_ray_pose: Use ray-based pose estimation instead of camera decoder.
ref_view_strategy: Strategy for selecting reference view from multiple views.
Returns:
Dictionary containing model predictions
"""
# Determine optimal autocast dtype
autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
with torch.no_grad():
with torch.autocast(device_type=image.device.type, dtype=autocast_dtype):
return self.model(
image, extrinsics, intrinsics, export_feat_layers, infer_gs, use_ray_pose, ref_view_strategy
)
def inference(
self,
image: list[np.ndarray | Image.Image | str],
extrinsics: np.ndarray | None = None,
intrinsics: np.ndarray | None = None,
align_to_input_ext_scale: bool = True,
infer_gs: bool = False,
use_ray_pose: bool = False,
ref_view_strategy: str = "saddle_balanced",
render_exts: np.ndarray | None = None,
render_ixts: np.ndarray | None = None,
render_hw: tuple[int, int] | None = None,
process_res: int = 504,
process_res_method: str = "upper_bound_resize",
export_dir: str | None = None,
export_format: str = "mini_npz",
export_feat_layers: Sequence[int] | None = None,
# GLB export parameters
conf_thresh_percentile: float = 40.0,
num_max_points: int = 1_000_000,
show_cameras: bool = True,
# Feat_vis export parameters
feat_vis_fps: int = 15,
# Other export parameters, e.g., gs_ply, gs_video
export_kwargs: Optional[dict] = {},
) -> Prediction:
"""
Run inference on input images.
Args:
image: List of input images (numpy arrays, PIL Images, or file paths)
extrinsics: Camera extrinsics (N, 4, 4)
intrinsics: Camera intrinsics (N, 3, 3)
align_to_input_ext_scale: whether to align the input pose scale to the prediction
infer_gs: Enable the 3D Gaussian branch (needed for `gs_ply`/`gs_video` exports)
use_ray_pose: Use ray-based pose estimation instead of camera decoder (default: False)
ref_view_strategy: Strategy for selecting reference view from multiple views.
Options: "first", "middle", "saddle_balanced", "saddle_sim_range".
Default: "saddle_balanced". For single view input (S ≤ 2), no reordering is performed.
render_exts: Optional render extrinsics for Gaussian video export
render_ixts: Optional render intrinsics for Gaussian video export
render_hw: Optional render resolution for Gaussian video export
process_res: Processing resolution
process_res_method: Resize method for processing
export_dir: Directory to export results
export_format: Export format (mini_npz, npz, glb, ply, gs, gs_video)
export_feat_layers: Layer indices to export intermediate features from
conf_thresh_percentile: [GLB] Lower percentile for adaptive confidence threshold (default: 40.0) # noqa: E501
num_max_points: [GLB] Maximum number of points in the point cloud (default: 1,000,000)
show_cameras: [GLB] Show camera wireframes in the exported scene (default: True)
feat_vis_fps: [FEAT_VIS] Frame rate for output video (default: 15)
export_kwargs: additional arguments to export functions.
Returns:
Prediction object containing depth maps and camera parameters
"""
if "gs" in export_format:
assert infer_gs, "must set `infer_gs=True` to perform gs-related export."
if "colmap" in export_format:
assert isinstance(image[0], str), "`image` must be image paths for COLMAP export."
# Preprocess images
imgs_cpu, extrinsics, intrinsics = self._preprocess_inputs(
image, extrinsics, intrinsics, process_res, process_res_method
)
# Prepare tensors for model
imgs, ex_t, in_t = self._prepare_model_inputs(imgs_cpu, extrinsics, intrinsics)
# Normalize extrinsics
ex_t_norm = self._normalize_extrinsics(ex_t.clone() if ex_t is not None else None)
# Run model forward pass
export_feat_layers = list(export_feat_layers) if export_feat_layers is not None else []
raw_output = self._run_model_forward(
imgs, ex_t_norm, in_t, export_feat_layers, infer_gs, use_ray_pose, ref_view_strategy
)
# Convert raw output to prediction
prediction = self._convert_to_prediction(raw_output)
# Align prediction to extrinsincs
prediction = self._align_to_input_extrinsics_intrinsics(
extrinsics, intrinsics, prediction, align_to_input_ext_scale
)
# Add processed images for visualization
prediction = self._add_processed_images(prediction, imgs_cpu)
# Export if requested
if export_dir is not None:
if "gs" in export_format:
if infer_gs and "gs_video" not in export_format:
export_format = f"{export_format}-gs_video"
if "gs_video" in export_format:
if "gs_video" not in export_kwargs:
export_kwargs["gs_video"] = {}
export_kwargs["gs_video"].update(
{
"extrinsics": render_exts,
"intrinsics": render_ixts,
"out_image_hw": render_hw,
}
)
# Add GLB export parameters
if "glb" in export_format:
if "glb" not in export_kwargs:
export_kwargs["glb"] = {}
export_kwargs["glb"].update(
{
"conf_thresh_percentile": conf_thresh_percentile,
"num_max_points": num_max_points,
"show_cameras": show_cameras,
}
)
# Add Feat_vis export parameters
if "feat_vis" in export_format:
if "feat_vis" not in export_kwargs:
export_kwargs["feat_vis"] = {}
export_kwargs["feat_vis"].update(
{
"fps": feat_vis_fps,
}
)
# Add COLMAP export parameters
if "colmap" in export_format:
if "colmap" not in export_kwargs:
export_kwargs["colmap"] = {}
export_kwargs["colmap"].update(
{
"image_paths": image,
"conf_thresh_percentile": conf_thresh_percentile,
"process_res_method": process_res_method,
}
)
self._export_results(prediction, export_format, export_dir, **export_kwargs)
return prediction
def _preprocess_inputs(
self,
image: list[np.ndarray | Image.Image | str],
extrinsics: np.ndarray | None = None,
intrinsics: np.ndarray | None = None,
process_res: int = 504,
process_res_method: str = "upper_bound_resize",
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""Preprocess input images using input processor."""
start_time = time.time()
imgs_cpu, extrinsics, intrinsics = self.input_processor(
image,
extrinsics.copy() if extrinsics is not None else None,
intrinsics.copy() if intrinsics is not None else None,
process_res,
process_res_method,
)
end_time = time.time()
logger.info(
"Processed Images Done taking",
end_time - start_time,
"seconds. Shape: ",
imgs_cpu.shape,
)
return imgs_cpu, extrinsics, intrinsics
def _prepare_model_inputs(
self,
imgs_cpu: torch.Tensor,
extrinsics: torch.Tensor | None,
intrinsics: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""Prepare tensors for model input."""
device = self._get_model_device()
# Move images to model device
imgs = imgs_cpu.to(device, non_blocking=True)[None].float()
# Convert camera parameters to tensors
ex_t = (
extrinsics.to(device, non_blocking=True)[None].float()
if extrinsics is not None
else None
)
in_t = (
intrinsics.to(device, non_blocking=True)[None].float()
if intrinsics is not None
else None
)
return imgs, ex_t, in_t
def _normalize_extrinsics(self, ex_t: torch.Tensor | None) -> torch.Tensor | None:
"""Normalize extrinsics"""
if ex_t is None:
return None
transform = affine_inverse(ex_t[:, :1])
ex_t_norm = ex_t @ transform
c2ws = affine_inverse(ex_t_norm)
translations = c2ws[..., :3, 3]
dists = translations.norm(dim=-1)
median_dist = torch.median(dists)
median_dist = torch.clamp(median_dist, min=1e-1)
ex_t_norm[..., :3, 3] = ex_t_norm[..., :3, 3] / median_dist
return ex_t_norm
def _align_to_input_extrinsics_intrinsics(
self,
extrinsics: torch.Tensor | None,
intrinsics: torch.Tensor | None,
prediction: Prediction,
align_to_input_ext_scale: bool = True,
ransac_view_thresh: int = 10,
) -> Prediction:
"""Align depth map to input extrinsics"""
if extrinsics is None:
return prediction
prediction.intrinsics = intrinsics.numpy()
_, _, scale, aligned_extrinsics = align_poses_umeyama(
prediction.extrinsics,
extrinsics.numpy(),
ransac=len(extrinsics) >= ransac_view_thresh,
return_aligned=True,
random_state=42,
)
if align_to_input_ext_scale:
prediction.extrinsics = extrinsics[..., :3, :].numpy()
prediction.depth /= scale
else:
prediction.extrinsics = aligned_extrinsics
return prediction
def _run_model_forward(
self,
imgs: torch.Tensor,
ex_t: torch.Tensor | None,
in_t: torch.Tensor | None,
export_feat_layers: Sequence[int] | None = None,
infer_gs: bool = False,
use_ray_pose: bool = False,
ref_view_strategy: str = "saddle_balanced",
) -> dict[str, torch.Tensor]:
"""Run model forward pass."""
device = imgs.device
need_sync = device.type == "cuda"
if need_sync:
torch.cuda.synchronize(device)
start_time = time.time()
feat_layers = list(export_feat_layers) if export_feat_layers is not None else None
output = self.forward(imgs, ex_t, in_t, feat_layers, infer_gs, use_ray_pose, ref_view_strategy)
if need_sync:
torch.cuda.synchronize(device)
end_time = time.time()
logger.info(f"Model Forward Pass Done. Time: {end_time - start_time} seconds")
return output
def _convert_to_prediction(self, raw_output: dict[str, torch.Tensor]) -> Prediction:
"""Convert raw model output to Prediction object."""
start_time = time.time()
output = self.output_processor(raw_output)
end_time = time.time()
logger.info(f"Conversion to Prediction Done. Time: {end_time - start_time} seconds")
return output
def _add_processed_images(self, prediction: Prediction, imgs_cpu: torch.Tensor) -> Prediction:
"""Add processed images to prediction for visualization."""
# Convert from (N, 3, H, W) to (N, H, W, 3) and denormalize
processed_imgs = imgs_cpu.permute(0, 2, 3, 1).cpu().numpy() # (N, H, W, 3)
# Denormalize from ImageNet normalization
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
processed_imgs = processed_imgs * std + mean
processed_imgs = np.clip(processed_imgs, 0, 1)
processed_imgs = (processed_imgs * 255).astype(np.uint8)
prediction.processed_images = processed_imgs
return prediction
def _export_results(
self, prediction: Prediction, export_format: str, export_dir: str, **kwargs
) -> None:
"""Export results to specified format and directory."""
start_time = time.time()
export(prediction, export_format, export_dir, **kwargs)
end_time = time.time()
logger.info(f"Export Results Done. Time: {end_time - start_time} seconds")
def _get_model_device(self) -> torch.device:
"""
Get the device where the model is located.
Returns:
Device where the model parameters are located
Raises:
ValueError: If no tensors are found in the model
"""
if self.device is not None:
return self.device
# Find device from parameters
for param in self.parameters():
self.device = param.device
return param.device
# Find device from buffers
for buffer in self.buffers():
self.device = buffer.device
return buffer.device
raise ValueError("No tensor found in model")

View File

@@ -0,0 +1,594 @@
# flake8: noqa: E501
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
CSS and HTML content for the Depth Anything 3 Gradio application.
This module contains all the CSS styles and HTML content blocks
used in the Gradio interface.
"""
# CSS Styles for the Gradio interface
GRADIO_CSS = """
/* Add Font Awesome CDN with all styles including brands and colors */
@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css');
/* Add custom styles for colored icons */
.fa-color-blue {
color: #3b82f6;
}
.fa-color-purple {
color: #8b5cf6;
}
.fa-color-cyan {
color: #06b6d4;
}
.fa-color-green {
color: #10b981;
}
.fa-color-yellow {
color: #f59e0b;
}
.fa-color-red {
color: #ef4444;
}
.link-btn {
display: inline-flex;
align-items: center;
gap: 8px;
text-decoration: none;
padding: 12px 24px;
border-radius: 50px;
font-weight: 500;
transition: all 0.3s ease;
}
/* Dark mode tech theme */
@media (prefers-color-scheme: dark) {
html, body {
background: #1e293b;
color: #ffffff;
}
.gradio-container {
background: #1e293b;
color: #ffffff;
}
.link-btn {
background: rgba(255, 255, 255, 0.2);
color: white;
backdrop-filter: blur(10px);
border: 1px solid rgba(255, 255, 255, 0.3);
}
.link-btn:hover {
background: rgba(255, 255, 255, 0.3);
transform: translateY(-2px);
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.2);
}
.tech-bg {
background: linear-gradient(135deg, #0f172a, #1e293b); /* Darker colors */
position: relative;
overflow: hidden;
}
.tech-bg::before {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
background:
radial-gradient(circle at 20% 80%, rgba(59, 130, 246, 0.15) 0%, transparent 50%), /* Reduced opacity */
radial-gradient(circle at 80% 20%, rgba(139, 92, 246, 0.15) 0%, transparent 50%), /* Reduced opacity */
radial-gradient(circle at 40% 40%, rgba(18, 194, 233, 0.1) 0%, transparent 50%); /* Reduced opacity */
animation: techPulse 8s ease-in-out infinite;
}
.gradio-container .panel,
.gradio-container .block,
.gradio-container .form {
background: rgba(0, 0, 0, 0.3);
border: 1px solid rgba(59, 130, 246, 0.2);
border-radius: 10px;
}
.gradio-container * {
color: #ffffff;
}
.gradio-container label {
color: #e0e0e0;
}
.gradio-container .markdown {
color: #e0e0e0;
}
}
/* Light mode tech theme */
@media (prefers-color-scheme: light) {
html, body {
background: #ffffff;
color: #1e293b;
}
.gradio-container {
background: #ffffff;
color: #1e293b;
}
.tech-bg {
background: linear-gradient(135deg, #ffffff, #f1f5f9);
position: relative;
overflow: hidden;
}
.link-btn {
background: rgba(59, 130, 246, 0.15);
color: var(--body-text-color);
border: 1px solid rgba(59, 130, 246, 0.3);
}
.link-btn:hover {
background: rgba(59, 130, 246, 0.25);
transform: translateY(-2px);
box-shadow: 0 8px 25px rgba(59, 130, 246, 0.2);
}
.tech-bg::before {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
background:
radial-gradient(circle at 20% 80%, rgba(59, 130, 246, 0.1) 0%, transparent 50%),
radial-gradient(circle at 80% 20%, rgba(139, 92, 246, 0.1) 0%, transparent 50%),
radial-gradient(circle at 40% 40%, rgba(18, 194, 233, 0.08) 0%, transparent 50%);
animation: techPulse 8s ease-in-out infinite;
}
.gradio-container .panel,
.gradio-container .block,
.gradio-container .form {
background: rgba(255, 255, 255, 0.8);
border: 1px solid rgba(59, 130, 246, 0.3);
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.gradio-container * {
color: #1e293b;
}
.gradio-container label {
color: #334155;
}
.gradio-container .markdown {
color: #334155;
}
}
@keyframes techPulse {
0%, 100% { opacity: 0.5; }
50% { opacity: 0.8; }
}
/* Custom log with tech gradient */
.custom-log * {
font-style: italic;
font-size: 22px !important;
background: linear-gradient(135deg, #3b82f6, #8b5cf6);
background-size: 400% 400%;
-webkit-background-clip: text;
background-clip: text;
font-weight: bold !important;
color: transparent !important;
text-align: center !important;
animation: techGradient 3s ease infinite;
}
@keyframes techGradient {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
@keyframes metricPulse {
0%, 100% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
}
@keyframes pointcloudPulse {
0%, 100% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
}
@keyframes camerasPulse {
0%, 100% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
}
@keyframes gaussiansPulse {
0%, 100% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
}
/* Special colors for key terms - Global styles */
.metric-text {
background: linear-gradient(45deg, #ff6b6b, #ff8e53, #ff6b6b);
background-size: 200% 200%;
-webkit-background-clip: text;
background-clip: text;
color: transparent !important;
animation: metricPulse 2s ease-in-out infinite;
font-weight: 700;
text-shadow: 0 0 10px rgba(255, 107, 107, 0.5);
}
.pointcloud-text {
background: linear-gradient(45deg, #4ecdc4, #44a08d, #4ecdc4);
background-size: 200% 200%;
-webkit-background-clip: text;
background-clip: text;
color: transparent !important;
animation: pointcloudPulse 2.5s ease-in-out infinite;
font-weight: 700;
text-shadow: 0 0 10px rgba(78, 205, 196, 0.5);
}
.cameras-text {
background: linear-gradient(45deg, #667eea, #764ba2, #667eea);
background-size: 200% 200%;
-webkit-background-clip: text;
background-clip: text;
color: transparent !important;
animation: camerasPulse 3s ease-in-out infinite;
font-weight: 700;
text-shadow: 0 0 10px rgba(102, 126, 234, 0.5);
}
.gaussians-text {
background: linear-gradient(45deg, #f093fb, #f5576c, #f093fb);
background-size: 200% 200%;
-webkit-background-clip: text;
background-clip: text;
color: transparent !important;
animation: gaussiansPulse 2.2s ease-in-out infinite;
font-weight: 700;
text-shadow: 0 0 10px rgba(240, 147, 251, 0.5);
}
.example-log * {
font-style: italic;
font-size: 16px !important;
background: linear-gradient(135deg, #3b82f6, #8b5cf6);
-webkit-background-clip: text;
background-clip: text;
color: transparent !important;
}
#my_radio .wrap {
display: flex;
flex-wrap: nowrap;
justify-content: center;
align-items: center;
}
#my_radio .wrap label {
display: flex;
width: 50%;
justify-content: center;
align-items: center;
margin: 0;
padding: 10px 0;
box-sizing: border-box;
}
/* Align navigation buttons with dropdown bottom */
.navigation-row {
display: flex !important;
align-items: flex-end !important;
gap: 8px !important;
}
.navigation-row > div:nth-child(1),
.navigation-row > div:nth-child(3) {
align-self: flex-end !important;
}
.navigation-row > div:nth-child(2) {
flex: 1 !important;
}
/* Make thumbnails clickable with pointer cursor */
.clickable-thumbnail img {
cursor: pointer !important;
}
.clickable-thumbnail:hover img {
cursor: pointer !important;
opacity: 0.8;
transition: opacity 0.3s ease;
}
/* Make thumbnail containers narrower horizontally */
.clickable-thumbnail {
padding: 5px 2px !important;
margin: 0 2px !important;
}
.clickable-thumbnail .image-container {
margin: 0 !important;
padding: 0 !important;
}
.scene-info {
text-align: center !important;
padding: 5px 2px !important;
margin: 0 !important;
}
"""
def get_header_html(logo_base64=None):
"""
Generate the main header HTML with logo and title.
Args:
logo_base64 (str, optional): Base64 encoded logo image
Returns:
str: HTML string for the header
"""
return """
<div class="tech-bg" style="text-align: center; margin-bottom: 5px; padding: 40px 20px; border-radius: 15px; position: relative; overflow: hidden;">
<div style="position: relative; z-index: 2;">
<h1 style="margin: 0; font-size: 3.5em; font-weight: 700;
background: linear-gradient(135deg, #3b82f6, #8b5cf6);
background-size: 400% 400%;
-webkit-background-clip: text;
background-clip: text;
color: transparent;
animation: techGradient 3s ease infinite;
text-shadow: 0 0 30px rgba(59, 130, 246, 0.5);
letter-spacing: 2px;">
Depth Anything 3
</h1>
<p style="margin: 15px 0 0 0; font-size: 2.16em; font-weight: 300;" class="header-subtitle">
Recovering the Visual Space from Any Views
</p>
<div style="margin-top: 20px;">
<!-- Revert buttons to original inline styles -->
<a href="https://depth-anything-3.github.io" target="_blank" class="link-btn">
<i class="fas fa-globe" style="margin-right: 8px;"></i> Project Page
</a>
<a href="https://arxiv.org/abs/2406.09414" target="_blank" class="link-btn">
<i class="fas fa-file-pdf" style="margin-right: 8px;"></i> Paper
</a>
<a href="https://github.com/ByteDance-Seed/Depth-Anything-3" target="_blank" class="link-btn">
<i class="fab fa-github" style="margin-right: 8px;"></i> Code
</a>
</div>
</div>
</div>
<style>
/* Ensure tech-bg class is properly applied in dark mode */
@media (prefers-color-scheme: dark) {
.header-subtitle {
color: #cbd5e1;
}
/* Increase priority to ensure background color is properly applied */
.tech-bg {
background: linear-gradient(135deg, #0f172a, #1e293b) !important;
}
}
@media (prefers-color-scheme: light) {
.header-subtitle {
color: #475569;
}
/* Also add explicit background color for light mode */
.tech-bg {
background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%) !important;
}
}
</style>
"""
def get_description_html():
"""
Generate the main description and getting started HTML.
Returns:
str: HTML string for the description
"""
return """
<div class="description-container" style="padding: 25px; border-radius: 15px; margin: 0 0 20px 0;">
<h2 class="description-title" style="margin-top: 0; font-size: 1.6em; text-align: center;">
<i class="fas fa-bullseye fa-color-red" style="margin-right: 8px;"></i> What This Demo Does
</h2>
<div class="description-content" style="padding: 20px; border-radius: 10px; margin: 15px 0; text-align: center;">
<p class="description-main" style="line-height: 1.6; margin: 0; font-size: 1.45em;">
<strong>Upload images or videos</strong> → <strong>Get <span class="metric-text">Metric</span> <span class="pointcloud-text">Point Clouds</span>, <span class="cameras-text">Cameras</span> and <span class="gaussians-text">Novel Views</span></strong> → <strong>Explore in 3D</strong>
</p>
</div>
<div style="text-align: center; margin-top: 15px;">
<p class="description-tip" style="font-style: italic; margin: 0;">
<i class="fas fa-lightbulb fa-color-yellow" style="margin-right: 8px;"></i> <strong>Tip:</strong> Landscape-oriented images or videos are preferred for best 3D recovering.
</p>
</div>
</div>
<style>
@media (prefers-color-scheme: dark) {
.description-container {
background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%);
border: 1px solid rgba(59, 130, 246, 0.2);
}
.description-title { color: #3b82f6; }
.description-content { background: rgba(0, 0, 0, 0.3); }
.description-main { color: #e0e0e0; }
.description-text { color: #cbd5e1; }
.description-tip { color: #cbd5e1; }
}
@media (prefers-color-scheme: light) {
.description-container {
background: linear-gradient(135deg, rgba(59, 130, 246, 0.05) 0%, rgba(139, 92, 246, 0.05) 100%);
border: 1px solid rgba(59, 130, 246, 0.3);
}
.description-title { color: #3b82f6; }
.description-content { background: transparent; }
.description-main { color: #1e293b; }
.description-text { color: #475569; }
.description-tip { color: #475569; }
}
</style>
"""
def get_acknowledgements_html():
"""
Generate the acknowledgements section HTML.
Returns:
str: HTML string for the acknowledgements
"""
return """
<div style="background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%);
padding: 25px; border-radius: 15px; margin: 20px 0; border: 1px solid rgba(59, 130, 246, 0.2);">
<h3 style="color: #3b82f6; margin-top: 0; text-align: center; font-size: 1.4em;">
<i class="fas fa-trophy fa-color-yellow" style="margin-right: 8px;"></i> Research Credits & Acknowledgments
</h3>
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 20px; margin: 15px 0;">
<!-- Original Research Section (Left) -->
<div style="text-align: center;">
<h4 style="color: #8b5cf6; margin: 10px 0;"><i class="fas fa-flask fa-color-green" style="margin-right: 8px;"></i> Original Research</h4>
<p style="color: #e0e0e0; margin: 5px 0;">
<a href="https://depth-anything-3.github.io" target="_blank"
style="color: #3b82f6; text-decoration: none; font-weight: 600;">
Depth Anything 3
</a>
</p>
</div>
<!-- Previous Versions Section (Right) -->
<div style="text-align: center;">
<h4 style="color: #8b5cf6; margin: 10px 0;"><i class="fas fa-history fa-color-blue" style="margin-right: 8px;"></i> Previous Versions</h4>
<div style="display: flex; flex-direction: row; gap: 15px; justify-content: center; align-items: center;">
<p style="color: #e0e0e0; margin: 0;">
<a href="https://huggingface.co/spaces/LiheYoung/Depth-Anything" target="_blank"
style="color: #3b82f6; text-decoration: none; font-weight: 600;">
Depth-Anything
</a>
</p>
<span style="color: #e0e0e0;">•</span>
<p style="color: #e0e0e0; margin: 0;">
<a href="https://huggingface.co/spaces/depth-anything/Depth-Anything-V2" target="_blank"
style="color: #3b82f6; text-decoration: none; font-weight: 600;">
Depth-Anything-V2
</a>
</p>
</div>
</div>
</div>
<!-- HF Demo Adapted from - Centered at the bottom of the whole block -->
<div style="margin-top: 20px; padding-top: 15px; border-top: 1px solid rgba(59, 130, 246, 0.3); text-align: center;">
<p style="color: #a0a0a0; font-size: 0.9em; margin: 0;">
<i class="fas fa-code-branch fa-color-gray" style="margin-right: 5px;"></i> HF demo adapted from <a href="https://huggingface.co/spaces/facebook/map-anything" target="_blank" style="color: inherit; text-decoration: none;">Map Anything</a>
</p>
</div>
</div>
"""
def get_gradio_theme():
"""
Get the configured Gradio theme with adaptive tech colors.
Returns:
gr.themes.Base: Configured Gradio theme
"""
import gradio as gr
return gr.themes.Base(
primary_hue=gr.themes.Color(
c50="#eff6ff",
c100="#dbeafe",
c200="#bfdbfe",
c300="#93c5fd",
c400="#60a5fa",
c500="#3b82f6",
c600="#2563eb",
c700="#1d4ed8",
c800="#1e40af",
c900="#1e3a8a",
c950="#172554",
),
secondary_hue=gr.themes.Color(
c50="#f5f3ff",
c100="#ede9fe",
c200="#ddd6fe",
c300="#c4b5fd",
c400="#a78bfa",
c500="#8b5cf6",
c600="#7c3aed",
c700="#6d28d9",
c800="#5b21b6",
c900="#4c1d95",
c950="#2e1065",
),
neutral_hue=gr.themes.Color(
c50="#f8fafc",
c100="#f1f5f9",
c200="#e2e8f0",
c300="#cbd5e1",
c400="#94a3b8",
c500="#64748b",
c600="#475569",
c700="#334155",
c800="#1e293b",
c900="#0f172a",
c950="#020617",
),
)
# Measure tab instructions HTML
MEASURE_INSTRUCTIONS_HTML = """
### Click points on the image to compute distance.
> <i class="fas fa-triangle-exclamation fa-color-red" style="margin-right: 5px;"></i> Metric scale estimation is difficult on aerial/drone images.
"""

View File

@@ -0,0 +1,724 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Refactored Gradio App for Depth Anything 3.
This is the main application file that orchestrates all components.
The original functionality has been split into modular components for better maintainability.
"""
import argparse
import os
from typing import Any, Dict, List
import gradio as gr
from depth_anything_3.app.css_and_html import GRADIO_CSS, get_gradio_theme
from depth_anything_3.app.modules.event_handlers import EventHandlers
from depth_anything_3.app.modules.ui_components import UIComponents
# Set environment variables
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
class DepthAnything3App:
"""
Main application class for Depth Anything 3 Gradio app.
"""
def __init__(self, model_dir: str = None, workspace_dir: str = None, gallery_dir: str = None):
"""
Initialize the application.
Args:
model_dir: Path to the model directory
workspace_dir: Path to the workspace directory
gallery_dir: Path to the gallery directory
"""
self.model_dir = model_dir
self.workspace_dir = workspace_dir
self.gallery_dir = gallery_dir
# Set environment variables for directories
if self.model_dir:
os.environ["DA3_MODEL_DIR"] = self.model_dir
if self.workspace_dir:
os.environ["DA3_WORKSPACE_DIR"] = self.workspace_dir
if self.gallery_dir:
os.environ["DA3_GALLERY_DIR"] = self.gallery_dir
self.event_handlers = EventHandlers()
self.ui_components = UIComponents()
def cache_examples(
self,
show_cam: bool = True,
filter_black_bg: bool = False,
filter_white_bg: bool = False,
save_percentage: float = 20.0,
num_max_points: int = 1000,
cache_gs_tag: str = "",
gs_trj_mode: str = "smooth",
gs_video_quality: str = "low",
) -> None:
"""
Pre-cache all example scenes at startup.
Args:
show_cam: Whether to show camera in visualization
filter_black_bg: Whether to filter black background
filter_white_bg: Whether to filter white background
save_percentage: Filter percentage for point cloud
num_max_points: Maximum number of points
cache_gs_tag: Tag to match scene names for high-res+3DGS caching (e.g., "dl3dv")
gs_trj_mode: Trajectory mode for 3DGS
gs_video_quality: Video quality for 3DGS
"""
from depth_anything_3.app.modules.utils import get_scene_info
examples_dir = os.path.join(self.workspace_dir, "examples")
if not os.path.exists(examples_dir):
print(f"Examples directory not found: {examples_dir}")
return
scenes = get_scene_info(examples_dir)
if not scenes:
print("No example scenes found to cache.")
return
print(f"\n{'='*60}")
print(f"Caching {len(scenes)} example scenes...")
print(f"{'='*60}\n")
for i, scene in enumerate(scenes, 1):
scene_name = scene["name"]
# Check if scene name matches the gs tag for high-res+3DGS caching
use_high_res_gs = cache_gs_tag and cache_gs_tag.lower() in scene_name.lower()
if use_high_res_gs:
print(f"[{i}/{len(scenes)}] Caching scene: {scene_name} (HIGH-RES + 3DGS)")
print(f" - Number of images: {scene['num_images']}")
print(f" - Matched tag: '{cache_gs_tag}' - using high_res + 3DGS")
else:
print(f"[{i}/{len(scenes)}] Caching scene: {scene_name} (LOW-RES)")
print(f" - Number of images: {scene['num_images']}")
try:
# Load example scene
_, target_dir, _, _, _, _, _, _, _ = self.event_handlers.load_example_scene(
scene_name
)
if target_dir and target_dir != "None":
# Run reconstruction with appropriate settings
print(" - Running reconstruction...")
result = self.event_handlers.gradio_demo(
target_dir=target_dir,
show_cam=show_cam,
filter_black_bg=filter_black_bg,
filter_white_bg=filter_white_bg,
process_res_method="high_res" if use_high_res_gs else "low_res",
save_percentage=save_percentage,
num_max_points=num_max_points,
infer_gs=use_high_res_gs,
ref_view_strategy="saddle_balanced",
gs_trj_mode=gs_trj_mode,
gs_video_quality=gs_video_quality,
)
# Check if successful
if result[0] is not None: # reconstruction_output
print(f" ✓ Scene '{scene_name}' cached successfully")
else:
print(f" ✗ Scene '{scene_name}' caching failed: {result[1]}")
else:
print(f" ✗ Scene '{scene_name}' loading failed")
except Exception as e:
print(f" ✗ Error caching scene '{scene_name}': {str(e)}")
print()
print("=" * 60)
print("Example scene caching completed!")
print("=" * 60 + "\n")
def create_app(self) -> gr.Blocks:
"""
Create and configure the Gradio application.
Returns:
Configured Gradio Blocks interface
"""
# Initialize theme
def get_theme():
return get_gradio_theme()
with gr.Blocks(theme=get_theme(), css=GRADIO_CSS) as demo:
# State variables for the tabbed interface
is_example = gr.Textbox(label="is_example", visible=False, value="None")
processed_data_state = gr.State(value=None)
measure_points_state = gr.State(value=[])
selected_image_index_state = gr.State(value=0) # Track selected image index
# current_view_index = gr.State(value=0) # noqa: F841 Track current view index
# Header and description
self.ui_components.create_header_section()
self.ui_components.create_description_section()
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
# Main content area
with gr.Row():
with gr.Column(scale=2):
# Upload section
(
input_video,
s_time_interval,
input_images,
image_gallery,
) = self.ui_components.create_upload_section()
with gr.Column(scale=4):
with gr.Column():
# gr.Markdown("**Metric 3D Reconstruction (Point Cloud and Camera Poses)**")
# Reconstruction control section (buttons) - moved below tabs
log_output = gr.Markdown(
"Please upload a video or images, then click Reconstruct.",
elem_classes=["custom-log"],
)
# Tabbed interface
with gr.Tabs():
with gr.Tab("Point Cloud & Cameras"):
reconstruction_output = (
self.ui_components.create_3d_viewer_section()
)
with gr.Tab("Metric Depth"):
(
prev_measure_btn,
measure_view_selector,
next_measure_btn,
measure_image,
measure_depth_image,
measure_text,
) = self.ui_components.create_measure_section()
with gr.Tab("3DGS Rendered Novel Views"):
gs_video, gs_info = self.ui_components.create_nvs_video()
# Inference control section (before inference)
(process_res_method_dropdown, infer_gs, ref_view_strategy_dropdown) = (
self.ui_components.create_inference_control_section()
)
# Display control section - includes 3DGS options, buttons, and Visualization Options # noqa: E501
(
show_cam,
filter_black_bg,
filter_white_bg,
save_percentage,
num_max_points,
gs_trj_mode,
gs_video_quality,
submit_btn,
clear_btn,
) = self.ui_components.create_display_control_section()
# bind visibility of gs_trj_mode to infer_gs
infer_gs.change(
fn=lambda checked: (
gr.update(visible=checked),
gr.update(visible=checked),
gr.update(visible=checked),
gr.update(visible=(not checked)),
),
inputs=infer_gs,
outputs=[gs_trj_mode, gs_video_quality, gs_video, gs_info],
)
# Example scenes section
gr.Markdown("## Example Scenes")
scenes = self.ui_components.create_example_scenes_section()
scene_components = self.ui_components.create_example_scene_grid(scenes)
# Set up event handlers
self._setup_event_handlers(
demo,
is_example,
processed_data_state,
measure_points_state,
target_dir_output,
input_video,
input_images,
s_time_interval,
image_gallery,
reconstruction_output,
log_output,
show_cam,
filter_black_bg,
filter_white_bg,
process_res_method_dropdown,
save_percentage,
submit_btn,
clear_btn,
num_max_points,
infer_gs,
ref_view_strategy_dropdown,
selected_image_index_state,
measure_view_selector,
measure_image,
measure_depth_image,
measure_text,
prev_measure_btn,
next_measure_btn,
scenes,
scene_components,
gs_video,
gs_info,
gs_trj_mode,
gs_video_quality,
)
# Acknowledgements
self.ui_components.create_acknowledgements_section()
return demo
def _setup_event_handlers(
self,
demo: gr.Blocks,
is_example: gr.Textbox,
processed_data_state: gr.State,
measure_points_state: gr.State,
target_dir_output: gr.Textbox,
input_video: gr.Video,
input_images: gr.File,
s_time_interval: gr.Slider,
image_gallery: gr.Gallery,
reconstruction_output: gr.Model3D,
log_output: gr.Markdown,
show_cam: gr.Checkbox,
filter_black_bg: gr.Checkbox,
filter_white_bg: gr.Checkbox,
process_res_method_dropdown: gr.Dropdown,
save_percentage: gr.Slider,
submit_btn: gr.Button,
clear_btn: gr.ClearButton,
num_max_points: gr.Slider,
infer_gs: gr.Checkbox,
ref_view_strategy_dropdown: gr.Dropdown,
selected_image_index_state: gr.State,
measure_view_selector: gr.Dropdown,
measure_image: gr.Image,
measure_depth_image: gr.Image,
measure_text: gr.Markdown,
prev_measure_btn: gr.Button,
next_measure_btn: gr.Button,
scenes: List[Dict[str, Any]],
scene_components: List[gr.Image],
gs_video: gr.Video,
gs_info: gr.Markdown,
gs_trj_mode: gr.Dropdown,
gs_video_quality: gr.Dropdown,
) -> None:
"""
Set up all event handlers for the application.
Args:
demo: Gradio Blocks interface
All other arguments: Gradio components to connect
"""
# Configure clear button
clear_btn.add(
[
input_video,
input_images,
reconstruction_output,
log_output,
target_dir_output,
image_gallery,
gs_video,
]
)
# Main reconstruction button
submit_btn.click(
fn=self.event_handlers.clear_fields, inputs=[], outputs=[reconstruction_output]
).then(fn=self.event_handlers.update_log, inputs=[], outputs=[log_output]).then(
fn=self.event_handlers.gradio_demo,
inputs=[
target_dir_output,
show_cam,
filter_black_bg,
filter_white_bg,
process_res_method_dropdown,
save_percentage,
# pass num_max_points
num_max_points,
infer_gs,
ref_view_strategy_dropdown,
gs_trj_mode,
gs_video_quality,
],
outputs=[
reconstruction_output,
log_output,
processed_data_state,
measure_image,
measure_depth_image,
measure_text,
measure_view_selector,
gs_video,
gs_video, # gs_video visibility
gs_info, # gs_info visibility
],
).then(
fn=lambda: "False",
inputs=[],
outputs=[is_example], # set is_example to "False"
)
# Real-time visualization updates
self._setup_visualization_handlers(
show_cam,
filter_black_bg,
filter_white_bg,
process_res_method_dropdown,
target_dir_output,
is_example,
reconstruction_output,
log_output,
)
# File upload handlers
input_video.change(
fn=self.event_handlers.handle_uploads,
inputs=[input_video, input_images, s_time_interval],
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
input_images.change(
fn=self.event_handlers.handle_uploads,
inputs=[input_video, input_images, s_time_interval],
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
# Navigation handlers
self._setup_navigation_handlers(
prev_measure_btn,
next_measure_btn,
measure_view_selector,
measure_image,
measure_depth_image,
measure_points_state,
processed_data_state,
)
# Measurement handler
measure_image.select(
fn=self.event_handlers.measure,
inputs=[processed_data_state, measure_points_state, measure_view_selector],
outputs=[measure_image, measure_depth_image, measure_points_state, measure_text],
)
# Example scene handlers
self._setup_example_scene_handlers(
scenes,
scene_components,
reconstruction_output,
target_dir_output,
image_gallery,
log_output,
is_example,
processed_data_state,
measure_view_selector,
measure_image,
measure_depth_image,
gs_video,
gs_info,
)
def _setup_visualization_handlers(
self,
show_cam: gr.Checkbox,
filter_black_bg: gr.Checkbox,
filter_white_bg: gr.Checkbox,
process_res_method_dropdown: gr.Dropdown,
target_dir_output: gr.Textbox,
is_example: gr.Textbox,
reconstruction_output: gr.Model3D,
log_output: gr.Markdown,
) -> None:
"""Set up visualization update handlers."""
# Common inputs for visualization updates
viz_inputs = [
target_dir_output,
show_cam,
is_example,
filter_black_bg,
filter_white_bg,
process_res_method_dropdown,
]
# Set up change handlers for all visualization controls
for component in [show_cam, filter_black_bg, filter_white_bg]:
component.change(
fn=self.event_handlers.update_visualization,
inputs=viz_inputs,
outputs=[reconstruction_output, log_output],
)
def _setup_navigation_handlers(
self,
prev_measure_btn: gr.Button,
next_measure_btn: gr.Button,
measure_view_selector: gr.Dropdown,
measure_image: gr.Image,
measure_depth_image: gr.Image,
measure_points_state: gr.State,
processed_data_state: gr.State,
) -> None:
"""Set up navigation handlers for measure tab."""
# Measure tab navigation
prev_measure_btn.click(
fn=lambda processed_data, current_selector: self.event_handlers.navigate_measure_view(
processed_data, current_selector, -1
),
inputs=[processed_data_state, measure_view_selector],
outputs=[
measure_view_selector,
measure_image,
measure_depth_image,
measure_points_state,
],
)
next_measure_btn.click(
fn=lambda processed_data, current_selector: self.event_handlers.navigate_measure_view(
processed_data, current_selector, 1
),
inputs=[processed_data_state, measure_view_selector],
outputs=[
measure_view_selector,
measure_image,
measure_depth_image,
measure_points_state,
],
)
measure_view_selector.change(
fn=lambda processed_data, selector_value: (
self.event_handlers.update_measure_view(
processed_data, int(selector_value.split()[1]) - 1
)
if selector_value
else (None, None, [])
),
inputs=[processed_data_state, measure_view_selector],
outputs=[measure_image, measure_depth_image, measure_points_state],
)
def _setup_example_scene_handlers(
self,
scenes: List[Dict[str, Any]],
scene_components: List[gr.Image],
reconstruction_output: gr.Model3D,
target_dir_output: gr.Textbox,
image_gallery: gr.Gallery,
log_output: gr.Markdown,
is_example: gr.Textbox,
processed_data_state: gr.State,
measure_view_selector: gr.Dropdown,
measure_image: gr.Image,
measure_depth_image: gr.Image,
gs_video: gr.Video,
gs_info: gr.Markdown,
) -> None:
"""Set up example scene handlers."""
def load_and_update_measure(name):
result = self.event_handlers.load_example_scene(name)
# result = (reconstruction_output, target_dir, image_paths, log_message, processed_data, measure_view_selector, gs_video, gs_video_vis, gs_info_vis) # noqa: E501
# Update measure view if processed_data is available
measure_img = None
measure_depth = None
if result[4] is not None: # processed_data exists
measure_img, measure_depth, _ = (
self.event_handlers.visualization_handler.update_measure_view(result[4], 0)
)
return result + ("True", measure_img, measure_depth)
for i, scene in enumerate(scenes):
if i < len(scene_components):
scene_components[i].select(
fn=lambda name=scene["name"]: load_and_update_measure(name),
outputs=[
reconstruction_output,
target_dir_output,
image_gallery,
log_output,
processed_data_state,
measure_view_selector,
gs_video,
gs_video, # gs_video_visibility
gs_info, # gs_info_visibility
is_example,
measure_image,
measure_depth_image,
],
)
def launch(self, host: str = "127.0.0.1", port: int = 7860, **kwargs) -> None:
"""
Launch the application.
Args:
host: Host address to bind to
port: Port number to bind to
**kwargs: Additional arguments for demo.launch()
"""
demo = self.create_app()
demo.queue(max_size=20).launch(
show_error=True, ssr_mode=False, server_name=host, server_port=port, **kwargs
)
def main():
"""Main function to run the application."""
parser = argparse.ArgumentParser(
description="Depth Anything 3 Gradio Application",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Basic usage
python gradio_app.py --help
python gradio_app.py --host 0.0.0.0 --port 8080
python gradio_app.py --model-dir /path/to/model --workspace-dir /path/to/workspace
# Cache examples at startup (all low-res)
python gradio_app.py --cache-examples
# Cache with selective high-res+3DGS for scenes matching tag
python gradio_app.py --cache-examples --cache-gs-tag dl3dv
# This will use high-res + 3DGS for scenes containing "dl3dv" in their name,
# and low-res only for other scenes
""",
)
# Server configuration
parser.add_argument(
"--host", default="127.0.0.1", help="Host address to bind to (default: 127.0.0.1)"
)
parser.add_argument(
"--port", type=int, default=7860, help="Port number to bind to (default: 7860)"
)
# Directory configuration
parser.add_argument(
"--model-dir",
default="depth-anything/DA3NESTED-GIANT-LARGE",
help="Path to the model directory (default: depth-anything/DA3NESTED-GIANT-LARGE)",
)
parser.add_argument(
"--workspace-dir",
default="workspace/gradio", # noqa: E501
help="Path to the workspace directory (default: workspace/gradio)", # noqa: E501
)
parser.add_argument(
"--gallery-dir",
default="workspace/gallery",
help="Path to the gallery directory (default: workspace/gallery)", # noqa: E501
)
# Additional Gradio options
parser.add_argument("--share", action="store_true", help="Create a public link for the app")
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
# Example caching options
parser.add_argument(
"--cache-examples",
action="store_true",
help="Pre-cache all example scenes at startup for faster loading",
)
parser.add_argument(
"--cache-gs-tag",
type=str,
default="",
help="Tag to match scene names for high-res+3DGS caching (e.g., 'dl3dv'). Scenes containing this tag will use high_res and infer_gs=True; others will use low_res only.", # noqa: E501
)
args = parser.parse_args()
# Create directories if they don't exist
os.makedirs(args.workspace_dir, exist_ok=True)
os.makedirs(args.gallery_dir, exist_ok=True)
# Initialize and launch the application
app = DepthAnything3App(
model_dir=args.model_dir, workspace_dir=args.workspace_dir, gallery_dir=args.gallery_dir
)
# Prepare launch arguments
launch_kwargs = {"share": args.share, "debug": args.debug}
print("Starting Depth Anything 3 Gradio App...")
print(f"Host: {args.host}")
print(f"Port: {args.port}")
print(f"Model Directory: {args.model_dir}")
print(f"Workspace Directory: {args.workspace_dir}")
print(f"Gallery Directory: {args.gallery_dir}")
print(f"Share: {args.share}")
print(f"Debug: {args.debug}")
print(f"Cache Examples: {args.cache_examples}")
if args.cache_examples:
if args.cache_gs_tag:
print(
f"Cache GS Tag: '{args.cache_gs_tag}' (scenes matching this tag will use high-res + 3DGS)" # noqa: E501
) # noqa: E501
else:
print("Cache GS Tag: None (all scenes will use low-res only)")
# Pre-cache examples if requested
if args.cache_examples:
print("\n" + "=" * 60)
print("Pre-caching mode enabled")
if args.cache_gs_tag:
print(f"Scenes containing '{args.cache_gs_tag}' will use HIGH-RES + 3DGS")
print("Other scenes will use LOW-RES only")
else:
print("All scenes will use LOW-RES only")
print("=" * 60)
app.cache_examples(
show_cam=True,
filter_black_bg=False,
filter_white_bg=False,
save_percentage=5.0,
num_max_points=1000,
cache_gs_tag=args.cache_gs_tag,
gs_trj_mode="smooth",
gs_video_quality="low",
)
app.launch(host=args.host, port=args.port, **launch_kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,43 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Modules package for Depth Anything 3 Gradio app.
This package contains all the modular components for the Gradio application.
"""
from depth_anything_3.app.modules.event_handlers import EventHandlers
from depth_anything_3.app.modules.file_handlers import FileHandler
from depth_anything_3.app.modules.model_inference import ModelInference
from depth_anything_3.app.modules.ui_components import UIComponents
from depth_anything_3.app.modules.utils import (
create_depth_visualization,
get_logo_base64,
get_scene_info,
save_to_gallery_func,
)
from depth_anything_3.app.modules.visualization import VisualizationHandler
__all__ = [
"ModelInference",
"FileHandler",
"VisualizationHandler",
"EventHandlers",
"UIComponents",
"create_depth_visualization",
"save_to_gallery_func",
"get_scene_info",
"get_logo_base64",
]

View File

@@ -0,0 +1,619 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Event handling module for Depth Anything 3 Gradio app.
This module handles all event callbacks and user interactions.
"""
import os
import time
from glob import glob
from typing import Any, Dict, List, Optional, Tuple
import gradio as gr
import numpy as np
import torch
from depth_anything_3.app.modules.file_handlers import FileHandler
from depth_anything_3.app.modules.model_inference import ModelInference
from depth_anything_3.utils.memory import cleanup_cuda_memory
from depth_anything_3.app.modules.visualization import VisualizationHandler
class EventHandlers:
"""
Handles all event callbacks and user interactions for the Gradio app.
"""
def __init__(self):
"""Initialize the event handlers."""
self.model_inference = ModelInference()
self.file_handler = FileHandler()
self.visualization_handler = VisualizationHandler()
def clear_fields(self) -> None:
"""
Clears the 3D viewer, the stored target_dir, and empties the gallery.
"""
return None
def update_log(self) -> str:
"""
Display a quick log message while waiting.
"""
return "Loading and Reconstructing..."
def save_current_visualization(
self,
target_dir: str,
save_percentage: float,
show_cam: bool,
filter_black_bg: bool,
filter_white_bg: bool,
processed_data: Optional[Dict],
scene_name: str = "",
) -> str:
"""
Save current visualization results to gallery with specified save percentage.
Args:
target_dir: Directory containing results
save_percentage: Percentage of points to save (0-100)
show_cam: Whether to show cameras
filter_black_bg: Whether to filter black background
filter_white_bg: Whether to filter white background
processed_data: Processed data from reconstruction
Returns:
Status message
"""
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
return "No reconstruction available. Please run 'Reconstruct' first."
if processed_data is None:
return "No processed data available. Please run 'Reconstruct' first."
try:
# Add debug information
print("[DEBUG] save_current_visualization called with:")
print(f" target_dir: {target_dir}")
print(f" save_percentage: {save_percentage}")
print(f" show_cam: {show_cam}")
print(f" filter_black_bg: {filter_black_bg}")
print(f" filter_white_bg: {filter_white_bg}")
print(f" processed_data: {processed_data is not None}")
# Import the gallery save function
# Create gallery name with user input or auto-generated
import datetime
from .utils import save_to_gallery_func
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
if scene_name and scene_name.strip():
gallery_name = f"{scene_name.strip()}_{timestamp}_pct{save_percentage:.0f}"
else:
gallery_name = f"save_{timestamp}_pct{save_percentage:.0f}"
print(f"[DEBUG] Saving to gallery with name: {gallery_name}")
# Save entire process folder to gallery
success, message = save_to_gallery_func(
target_dir=target_dir, processed_data=processed_data, gallery_name=gallery_name
)
if success:
print(f"[DEBUG] Gallery save completed successfully: {message}")
return (
"Successfully saved to gallery!\n"
f"Gallery name: {gallery_name}\n"
f"Save percentage: {save_percentage}%\n"
f"Show cameras: {show_cam}\n"
f"Filter black bg: {filter_black_bg}\n"
f"Filter white bg: {filter_white_bg}\n\n"
f"{message}"
)
else:
print(f"[DEBUG] Gallery save failed: {message}")
return f"Failed to save to gallery: {message}"
except Exception as e:
return f"Error saving visualization: {str(e)}"
def gradio_demo(
self,
target_dir: str,
show_cam: bool = True,
filter_black_bg: bool = False,
filter_white_bg: bool = False,
process_res_method: str = "upper_bound_resize",
save_percentage: float = 30.0,
num_max_points: int = 1_000_000,
infer_gs: bool = False,
ref_view_strategy: str = "saddle_balanced",
gs_trj_mode: str = "extend",
gs_video_quality: str = "high",
) -> Tuple[
Optional[str],
str,
Optional[Dict],
Optional[np.ndarray],
Optional[np.ndarray],
str,
gr.Dropdown,
Optional[str], # gs video path
gr.update, # gs video visibility update
gr.update, # gs info visibility update
]:
"""
Perform reconstruction using the already-created target_dir/images.
Args:
target_dir: Directory containing images
show_cam: Whether to show camera
filter_black_bg: Whether to filter black background
filter_white_bg: Whether to filter white background
process_res_method: Method for resizing input images
save_percentage: Filter percentage for point cloud
num_max_points: Maximum number of points
infer_gs: Whether to infer 3D Gaussian Splatting
ref_view_strategy: Reference view selection strategy
Returns:
Tuple of reconstruction results
"""
if not os.path.isdir(target_dir) or target_dir == "None":
return (
None,
"No valid target directory found. Please upload first.",
None,
None,
None,
"",
None,
None,
gr.update(visible=False), # gs_video
gr.update(visible=True), # gs_info
)
start_time = time.time()
cleanup_cuda_memory()
# Get image files for logging
target_dir_images = os.path.join(target_dir, "images")
all_files = (
sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
)
print("Running DepthAnything3 model...")
print(f"Reference view strategy: {ref_view_strategy}")
with torch.no_grad():
prediction, processed_data = self.model_inference.run_inference(
target_dir,
process_res_method=process_res_method,
show_camera=show_cam,
save_percentage=save_percentage,
num_max_points=int(num_max_points * 1000), # Convert K to actual count
infer_gs=infer_gs,
ref_view_strategy=ref_view_strategy,
gs_trj_mode=gs_trj_mode,
gs_video_quality=gs_video_quality,
)
# The GLB file is already generated by the API
glbfile = os.path.join(target_dir, "scene.glb")
# Handle 3DGS video based on infer_gs flag
gsvideo_path = None
gs_video_visible = False
gs_info_visible = True
if infer_gs:
try:
gsvideo_path = sorted(glob(os.path.join(target_dir, "gs_video", "*.mp4")))[-1]
gs_video_visible = True
gs_info_visible = False
except IndexError:
gsvideo_path = None
print("3DGS video not found, but infer_gs was enabled")
# Cleanup
cleanup_cuda_memory()
end_time = time.time()
print(f"Total time: {end_time - start_time:.2f} seconds")
log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
# Populate visualization tabs with processed data
depth_vis, measure_img, measure_depth_vis, measure_pts = (
self.visualization_handler.populate_visualization_tabs(processed_data)
)
# Update view selectors based on available views
depth_selector, measure_selector = self.visualization_handler.update_view_selectors(
processed_data
)
return (
glbfile,
log_msg,
processed_data,
measure_img, # measure_image
measure_depth_vis, # measure_depth_image
"", # measure_text (empty initially)
measure_selector, # measure_view_selector
gsvideo_path,
gr.update(visible=gs_video_visible), # gs_video visibility
gr.update(visible=gs_info_visible), # gs_info visibility
)
def update_visualization(
self,
target_dir: str,
show_cam: bool,
is_example: str,
filter_black_bg: bool = False,
filter_white_bg: bool = False,
process_res_method: str = "upper_bound_resize",
) -> Tuple[gr.update, str]:
"""
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
and return it for the 3D viewer.
Args:
target_dir: Directory containing results
show_cam: Whether to show camera
is_example: Whether this is an example scene
filter_black_bg: Whether to filter black background
filter_white_bg: Whether to filter white background
process_res_method: Method for resizing input images
Returns:
Tuple of (glb_file, log_message)
"""
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
return (
gr.update(),
"No reconstruction available. Please click the Reconstruct button first.",
)
# Check if GLB exists (could be cached example or reconstructed scene)
glbfile = os.path.join(target_dir, "scene.glb")
if os.path.exists(glbfile):
return (
glbfile,
(
"Visualization loaded from cache."
if is_example == "True"
else "Visualization updated."
),
)
# If no GLB but it's an example that hasn't been reconstructed yet
if is_example == "True":
return (
gr.update(),
"No reconstruction available. Please click the Reconstruct button first.",
)
# For non-examples, check predictions.npz
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
error_message = (
f"No reconstruction available at {predictions_path}. "
"Please run 'Reconstruct' first."
)
return gr.update(), error_message
loaded = np.load(predictions_path, allow_pickle=True)
predictions = {key: loaded[key] for key in loaded.keys()} # noqa: F841
return (
glbfile,
"Visualization updated.",
)
def handle_uploads(
self,
input_video: Optional[str],
input_images: Optional[List],
s_time_interval: float = 10.0,
) -> Tuple[Optional[str], Optional[str], Optional[List], Optional[str]]:
"""
Handle file uploads and update gallery.
Args:
input_video: Path to input video file
input_images: List of input image files
s_time_interval: Sampling FPS (frames per second) for frame extraction
Returns:
Tuple of (reconstruction_output, target_dir, image_paths, log_message)
"""
return self.file_handler.update_gallery_on_upload(
input_video, input_images, s_time_interval
)
def load_example_scene(self, scene_name: str, examples_dir: str = None) -> Tuple[
Optional[str],
Optional[str],
Optional[List],
str,
Optional[Dict],
gr.Dropdown,
Optional[str],
gr.update,
gr.update,
]:
"""
Load a scene from examples directory.
Args:
scene_name: Name of the scene to load
examples_dir: Path to examples directory (if None, uses workspace_dir/examples)
Returns:
Tuple of (reconstruction_output, target_dir, image_paths, log_message, processed_data, measure_view_selector, gs_video, gs_video_vis, gs_info_vis) # noqa: E501
"""
if examples_dir is None:
# Get workspace directory from environment variable
workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
examples_dir = os.path.join(workspace_dir, "examples")
reconstruction_output, target_dir, image_paths, log_message = (
self.file_handler.load_example_scene(scene_name, examples_dir)
)
# Try to load cached processed data if available
processed_data = None
measure_view_selector = gr.Dropdown(choices=["View 1"], value="View 1")
gs_video_path = None
gs_video_visible = False
gs_info_visible = True
if target_dir and target_dir != "None":
predictions_path = os.path.join(target_dir, "predictions.npz")
if os.path.exists(predictions_path):
try:
# Load predictions from cache
loaded = np.load(predictions_path, allow_pickle=True)
predictions = {key: loaded[key] for key in loaded.keys()}
# Reconstruct processed_data structure
num_images = len(predictions.get("images", []))
processed_data = {}
for i in range(num_images):
processed_data[i] = {
"image": predictions["images"][i] if "images" in predictions else None,
"depth": predictions["depths"][i] if "depths" in predictions else None,
"depth_image": os.path.join(
target_dir, "depth_vis", f"{i:04d}.jpg" # Fixed: use .jpg not .png
),
"intrinsics": (
predictions["intrinsics"][i]
if "intrinsics" in predictions
and i < len(predictions["intrinsics"])
else None
),
"mask": None,
}
# Update measure view selector
choices = [f"View {i + 1}" for i in range(num_images)]
measure_view_selector = gr.Dropdown(choices=choices, value=choices[0])
except Exception as e:
print(f"Error loading cached data: {e}")
# Check for cached 3DGS video
gs_video_dir = os.path.join(target_dir, "gs_video")
if os.path.exists(gs_video_dir):
try:
from glob import glob
gs_videos = sorted(glob(os.path.join(gs_video_dir, "*.mp4")))
if gs_videos:
gs_video_path = gs_videos[-1]
gs_video_visible = True
gs_info_visible = False
print(f"Loaded cached 3DGS video: {gs_video_path}")
except Exception as e:
print(f"Error loading cached 3DGS video: {e}")
return (
reconstruction_output,
target_dir,
image_paths,
log_message,
processed_data,
measure_view_selector,
gs_video_path,
gr.update(visible=gs_video_visible),
gr.update(visible=gs_info_visible),
)
def navigate_depth_view(
self,
processed_data: Optional[Dict[int, Dict[str, Any]]],
current_selector: str,
direction: int,
) -> Tuple[str, Optional[str]]:
"""
Navigate depth view.
Args:
processed_data: Processed data dictionary
current_selector: Current selector value
direction: Direction to navigate
Returns:
Tuple of (new_selector_value, depth_vis)
"""
return self.visualization_handler.navigate_depth_view(
processed_data, current_selector, direction
)
def update_depth_view(
self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
) -> Optional[str]:
"""
Update depth view for a specific view index.
Args:
processed_data: Processed data dictionary
view_index: Index of the view to update
Returns:
Path to depth visualization image or None
"""
return self.visualization_handler.update_depth_view(processed_data, view_index)
def navigate_measure_view(
self,
processed_data: Optional[Dict[int, Dict[str, Any]]],
current_selector: str,
direction: int,
) -> Tuple[str, Optional[np.ndarray], Optional[np.ndarray], List]:
"""
Navigate measure view.
Args:
processed_data: Processed data dictionary
current_selector: Current selector value
direction: Direction to navigate
Returns:
Tuple of (new_selector_value, measure_image, depth_right_half, measure_points)
"""
return self.visualization_handler.navigate_measure_view(
processed_data, current_selector, direction
)
def update_measure_view(
self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List]:
"""
Update measure view for a specific view index.
Args:
processed_data: Processed data dictionary
view_index: Index of the view to update
Returns:
Tuple of (measure_image, depth_right_half, measure_points)
"""
return self.visualization_handler.update_measure_view(processed_data, view_index)
def measure(
self,
processed_data: Optional[Dict[int, Dict[str, Any]]],
measure_points: List,
current_view_selector: str,
event: gr.SelectData,
) -> List:
"""
Handle measurement on images.
Args:
processed_data: Processed data dictionary
measure_points: List of current measure points
current_view_selector: Current view selector value
event: Gradio select event
Returns:
List of [image, depth_right_half, measure_points, text]
"""
return self.visualization_handler.measure(
processed_data, measure_points, current_view_selector, event
)
def select_first_frame(
self, image_gallery: List, selected_index: int = 0
) -> Tuple[List, str, str]:
"""
Select the first frame from the image gallery.
Args:
image_gallery: List of images in the gallery
selected_index: Index of the selected image (default: 0)
Returns:
Tuple of (updated_image_gallery, log_message, selected_frame_path)
"""
try:
if not image_gallery or len(image_gallery) == 0:
return image_gallery, "No images available to select as first frame.", ""
# Handle None or invalid selected_index
if (
selected_index is None
or selected_index < 0
or selected_index >= len(image_gallery)
):
selected_index = 0
print(f"Invalid selected_index: {selected_index}, using default: 0")
# Get the selected image based on index
selected_image = image_gallery[selected_index]
print(f"Selected image index: {selected_index}")
print(f"Total images: {len(image_gallery)}")
# Extract the file path from the selected image
selected_frame_path = ""
print(f"Selected image type: {type(selected_image)}")
print(f"Selected image: {selected_image}")
if isinstance(selected_image, tuple):
# Gradio Gallery returns tuple (path, None)
selected_frame_path = selected_image[0]
elif isinstance(selected_image, str):
selected_frame_path = selected_image
elif hasattr(selected_image, "name"):
selected_frame_path = selected_image.name
elif isinstance(selected_image, dict):
if "name" in selected_image:
selected_frame_path = selected_image["name"]
elif "path" in selected_image:
selected_frame_path = selected_image["path"]
elif "src" in selected_image:
selected_frame_path = selected_image["src"]
else:
# Try to convert to string
selected_frame_path = str(selected_image)
print(f"Extracted path: {selected_frame_path}")
# Extract filename from the path for matching
import os
selected_filename = os.path.basename(selected_frame_path)
print(f"Selected filename: {selected_filename}")
# Move the selected image to the front
updated_gallery = [selected_image] + [
img for img in image_gallery if img != selected_image
]
log_message = (
f"Selected frame: {selected_filename}. "
f"Moved to first position. Total frames: {len(updated_gallery)}"
)
return updated_gallery, log_message, selected_filename
except Exception as e:
print(f"Error selecting first frame: {e}")
return image_gallery, f"Error selecting first frame: {e}", ""

View File

@@ -0,0 +1,304 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
File handling module for Depth Anything 3 Gradio app.
This module handles file uploads, video processing, and file operations.
"""
import os
import shutil
import time
from datetime import datetime
from typing import List, Optional, Tuple
import cv2
from PIL import Image
from pillow_heif import register_heif_opener
register_heif_opener()
class FileHandler:
"""
Handles file uploads and processing for the Gradio app.
"""
def __init__(self):
"""Initialize the file handler."""
def handle_uploads(
self,
input_video: Optional[str],
input_images: Optional[List],
s_time_interval: float = 10.0,
) -> Tuple[str, List[str]]:
"""
Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
images or extracted frames from video into it.
Args:
input_video: Path to input video file
input_images: List of input image files
s_time_interval: Sampling FPS (frames per second) for frame extraction
Returns:
Tuple of (target_dir, image_paths)
"""
start_time = time.time()
# Get workspace directory from environment variable or use default
workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
if not os.path.exists(workspace_dir):
os.makedirs(workspace_dir)
# Create input_images subdirectory
input_images_dir = os.path.join(workspace_dir, "input_images")
if not os.path.exists(input_images_dir):
os.makedirs(input_images_dir)
# Create a unique folder name within input_images
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
target_dir = os.path.join(input_images_dir, f"session_{timestamp}")
target_dir_images = os.path.join(target_dir, "images")
# Clean up if somehow that folder already exists
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir)
os.makedirs(target_dir_images)
image_paths = []
# Handle images
if input_images is not None:
image_paths.extend(self._process_images(input_images, target_dir_images))
# Handle video
if input_video is not None:
image_paths.extend(
self._process_video(input_video, target_dir_images, s_time_interval)
)
# Sort final images for gallery
image_paths = sorted(image_paths)
end_time = time.time()
print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
return target_dir, image_paths
def _process_images(self, input_images: List, target_dir_images: str) -> List[str]:
"""
Process uploaded images.
Args:
input_images: List of input image files
target_dir_images: Target directory for images
Returns:
List of processed image paths
"""
image_paths = []
for file_data in input_images:
if isinstance(file_data, dict) and "name" in file_data:
file_path = file_data["name"]
else:
file_path = file_data
# Check if the file is a HEIC image
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext in [".heic", ".heif"]:
# Convert HEIC to JPEG for better gallery compatibility
try:
with Image.open(file_path) as img:
# Convert to RGB if necessary (HEIC can have different color modes)
if img.mode not in ("RGB", "L"):
img = img.convert("RGB")
# Create JPEG filename
base_name = os.path.splitext(os.path.basename(file_path))[0]
dst_path = os.path.join(target_dir_images, f"{base_name}.jpg")
# Save as JPEG with high quality
img.save(dst_path, "JPEG", quality=95)
image_paths.append(dst_path)
print(
f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> "
f"{os.path.basename(dst_path)}"
)
except Exception as e:
print(f"Error converting HEIC file {file_path}: {e}")
# Fall back to copying as is
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
else:
# Regular image files - copy as is
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
return image_paths
def _process_video(
self, input_video: str, target_dir_images: str, s_time_interval: float
) -> List[str]:
"""
Process video file and extract frames.
Args:
input_video: Path to input video file
target_dir_images: Target directory for extracted frames
s_time_interval: Sampling FPS (frames per second) for frame extraction
Returns:
List of extracted frame paths
"""
image_paths = []
if isinstance(input_video, dict) and "name" in input_video:
video_path = input_video["name"]
else:
video_path = input_video
vs = cv2.VideoCapture(video_path)
fps = vs.get(cv2.CAP_PROP_FPS)
frame_interval = max(1, int(fps / s_time_interval)) # Convert FPS to frame interval
count = 0
video_frame_num = 0
while True:
gotit, frame = vs.read()
if not gotit:
break
count += 1
if count % frame_interval == 0:
image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
cv2.imwrite(image_path, frame)
image_paths.append(image_path)
video_frame_num += 1
return image_paths
def update_gallery_on_upload(
self,
input_video: Optional[str],
input_images: Optional[List],
s_time_interval: float = 10.0,
) -> Tuple[Optional[str], Optional[str], Optional[List], Optional[str]]:
"""
Handle file uploads and update gallery.
Args:
input_video: Path to input video file
input_images: List of input image files
s_time_interval: Sampling FPS (frames per second) for frame extraction
Returns:
Tuple of (reconstruction_output, target_dir, image_paths, log_message)
"""
if not input_video and not input_images:
return None, None, None, None
target_dir, image_paths = self.handle_uploads(input_video, input_images, s_time_interval)
return (
None,
target_dir,
image_paths,
"Upload complete. Click 'Reconstruct' to begin 3D processing.",
)
def load_example_scene(
self, scene_name: str, examples_dir: str = "examples"
) -> Tuple[Optional[str], Optional[str], Optional[List], str]:
"""
Load a scene from examples directory.
Args:
scene_name: Name of the scene to load
examples_dir: Path to examples directory
Returns:
Tuple of (reconstruction_output, target_dir, image_paths, log_message)
"""
from depth_anything_3.app.modules.utils import get_scene_info
scenes = get_scene_info(examples_dir)
# Find the selected scene
selected_scene = None
for scene in scenes:
if scene["name"] == scene_name:
selected_scene = scene
break
if selected_scene is None:
return None, None, None, "Scene not found"
# Use fixed directory name for examples (not timestamp-based)
workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
input_images_dir = os.path.join(workspace_dir, "input_images")
if not os.path.exists(input_images_dir):
os.makedirs(input_images_dir)
# Create a fixed folder name based on scene name
target_dir = os.path.join(input_images_dir, f"example_{scene_name}")
target_dir_images = os.path.join(target_dir, "images")
# Check if already cached (GLB file exists)
glb_path = os.path.join(target_dir, "scene.glb")
is_cached = os.path.exists(glb_path)
# Create directory if it doesn't exist
if not os.path.exists(target_dir):
os.makedirs(target_dir)
os.makedirs(target_dir_images)
# Copy images if directory is new or empty
if not os.path.exists(target_dir_images) or len(os.listdir(target_dir_images)) == 0:
os.makedirs(target_dir_images, exist_ok=True)
image_paths = []
for file_path in selected_scene["image_files"]:
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
else:
# Use existing images
image_paths = sorted(
[
os.path.join(target_dir_images, f)
for f in os.listdir(target_dir_images)
if f.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif"))
]
)
# Return cached GLB if available
if is_cached:
return (
glb_path, # Return cached reconstruction
target_dir, # Set target directory
image_paths, # Set gallery
f"Loaded cached scene '{scene_name}' with {selected_scene['num_images']} images.",
)
else:
return (
None, # No cached reconstruction
target_dir, # Set target directory
image_paths, # Set gallery
(
f"Loaded scene '{scene_name}' with {selected_scene['num_images']} images. "
"Click 'Reconstruct' to begin 3D processing."
),
)

View File

@@ -0,0 +1,260 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Model inference module for Depth Anything 3 Gradio app.
This module handles all model-related operations including inference,
data processing, and result preparation.
"""
import glob
import os
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
from depth_anything_3.api import DepthAnything3
from depth_anything_3.utils.memory import cleanup_cuda_memory
from depth_anything_3.utils.export.glb import export_to_glb
from depth_anything_3.utils.export.gs import export_to_gs_video
class ModelInference:
"""
Handles model inference and data processing for Depth Anything 3.
"""
def __init__(self):
"""Initialize the model inference handler."""
self.model = None
def initialize_model(self, device: str = "cuda") -> None:
"""
Initialize the DepthAnything3 model.
Args:
device: Device to load the model on
"""
if self.model is None:
# Get model directory from environment variable or use default
model_dir = os.environ.get(
"DA3_MODEL_DIR", "/dev/shm/da3_models/DA3HF-VITG-METRIC_VITL"
)
self.model = DepthAnything3.from_pretrained(model_dir)
self.model = self.model.to(device)
else:
self.model = self.model.to(device)
self.model.eval()
def run_inference(
self,
target_dir: str,
filter_black_bg: bool = False,
filter_white_bg: bool = False,
process_res_method: str = "upper_bound_resize",
show_camera: bool = True,
save_percentage: float = 30.0,
num_max_points: int = 1_000_000,
infer_gs: bool = False,
ref_view_strategy: str = "saddle_balanced",
gs_trj_mode: str = "extend",
gs_video_quality: str = "high",
) -> Tuple[Any, Dict[int, Dict[str, Any]]]:
"""
Run DepthAnything3 model inference on images.
Args:
target_dir: Directory containing images
filter_black_bg: Whether to filter black background
filter_white_bg: Whether to filter white background
process_res_method: Method for resizing input images
show_camera: Whether to show camera in 3D view
save_percentage: Percentage of points to save (0-100)
num_max_points: Maximum number of points in point cloud
infer_gs: Whether to infer 3D Gaussian Splatting
ref_view_strategy: Reference view selection strategy
gs_trj_mode: Trajectory mode for 3DGS
gs_video_quality: Video quality for 3DGS
Returns:
Tuple of (prediction, processed_data)
"""
print(f"Processing images from {target_dir}")
# Device check
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
# Initialize model if needed
self.initialize_model(device)
# Get image paths
print("Loading images...")
image_folder_path = os.path.join(target_dir, "images")
all_image_paths = sorted(glob.glob(os.path.join(image_folder_path, "*")))
# Filter for image files
image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"]
all_image_paths = [
path
for path in all_image_paths
if any(path.lower().endswith(ext) for ext in image_extensions)
]
print(f"Found {len(all_image_paths)} images")
print(f"All image paths: {all_image_paths}")
# Use sorted image order (reference view will be selected automatically)
image_paths = all_image_paths
print(f"Reference view selection strategy: {ref_view_strategy}")
if len(image_paths) == 0:
raise ValueError("No images found. Check your upload.")
# Map UI options to actual method names
method_mapping = {"high_res": "lower_bound_resize", "low_res": "upper_bound_resize"}
actual_method = method_mapping.get(process_res_method, "upper_bound_crop")
# Run model inference
print(f"Running inference with method: {actual_method}")
with torch.no_grad():
prediction = self.model.inference(
image_paths,
export_dir=None,
process_res_method=actual_method,
infer_gs=infer_gs,
ref_view_strategy=ref_view_strategy,
)
# num_max_points: int = 1_000_000,
export_to_glb(
prediction,
filter_black_bg=filter_black_bg,
filter_white_bg=filter_white_bg,
export_dir=target_dir,
show_cameras=show_camera,
conf_thresh_percentile=save_percentage,
num_max_points=int(num_max_points),
)
# export to gs video if needed
if infer_gs:
mode_mapping = {"extend": "extend", "smooth": "interpolate_smooth"}
print(f"GS mode: {gs_trj_mode}; Backend mode: {mode_mapping[gs_trj_mode]}")
export_to_gs_video(
prediction,
export_dir=target_dir,
chunk_size=4,
trj_mode=mode_mapping.get(gs_trj_mode, "extend"),
enable_tqdm=True,
vis_depth="hcat",
video_quality=gs_video_quality,
)
# Save predictions.npz for caching metric depth data
self._save_predictions_cache(target_dir, prediction)
# Process results
processed_data = self._process_results(target_dir, prediction, image_paths)
# Clean up using centralized memory utilities for consistency with backend
cleanup_cuda_memory()
return prediction, processed_data
def _save_predictions_cache(self, target_dir: str, prediction: Any) -> None:
"""
Save predictions data to predictions.npz for caching.
Args:
target_dir: Directory to save the cache
prediction: Model prediction object
"""
try:
output_file = os.path.join(target_dir, "predictions.npz")
# Build save dict with prediction data
save_dict = {}
# Save processed images if available
if prediction.processed_images is not None:
save_dict["images"] = prediction.processed_images
# Save depth data
if prediction.depth is not None:
save_dict["depths"] = np.round(prediction.depth, 6)
# Save confidence if available
if prediction.conf is not None:
save_dict["conf"] = np.round(prediction.conf, 2)
# Save camera parameters
if prediction.extrinsics is not None:
save_dict["extrinsics"] = prediction.extrinsics
if prediction.intrinsics is not None:
save_dict["intrinsics"] = prediction.intrinsics
# Save to file
np.savez_compressed(output_file, **save_dict)
print(f"Saved predictions cache to: {output_file}")
except Exception as e:
print(f"Warning: Failed to save predictions cache: {e}")
def _process_results(
self, target_dir: str, prediction: Any, image_paths: list
) -> Dict[int, Dict[str, Any]]:
"""
Process model results into structured data.
Args:
target_dir: Directory containing results
prediction: Model prediction object
image_paths: List of input image paths
Returns:
Dictionary containing processed data for each view
"""
processed_data = {}
# Read generated depth visualization files
depth_vis_dir = os.path.join(target_dir, "depth_vis")
if os.path.exists(depth_vis_dir):
depth_files = sorted(glob.glob(os.path.join(depth_vis_dir, "*.jpg")))
for i, depth_file in enumerate(depth_files):
# Use processed images directly from API
processed_image = None
if prediction.processed_images is not None and i < len(
prediction.processed_images
):
processed_image = prediction.processed_images[i]
processed_data[i] = {
"depth_image": depth_file,
"image": processed_image,
"original_image_path": image_paths[i] if i < len(image_paths) else None,
"depth": prediction.depth[i] if i < len(prediction.depth) else None,
"intrinsics": (
prediction.intrinsics[i]
if prediction.intrinsics is not None and i < len(prediction.intrinsics)
else None
),
"mask": None, # No mask information available
}
return processed_data
# cleanup() removed: call cleanup_cuda_memory() directly where needed.

View File

@@ -0,0 +1,477 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
UI components module for Depth Anything 3 Gradio app.
This module contains UI component definitions and layout functions.
"""
import os
from typing import Any, Dict, List, Tuple
import gradio as gr
from depth_anything_3.app.modules.utils import get_logo_base64, get_scene_info
class UIComponents:
"""
Handles UI component creation and layout for the Gradio app.
"""
def __init__(self):
"""Initialize the UI components handler."""
def create_upload_section(self) -> Tuple[gr.Video, gr.Slider, gr.File, gr.Gallery]:
"""
Create the upload section with video, images, and gallery components.
Returns:
A tuple of Gradio components: (input_video, s_time_interval, input_images, image_gallery).
"""
input_video = gr.Video(label="Upload Video", interactive=True)
s_time_interval = gr.Slider(
minimum=0.1,
maximum=60,
value=10,
step=0.1,
label="Sampling FPS (Frames Per Second)",
interactive=True,
visible=True,
)
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
image_gallery = gr.Gallery(
label="Preview",
columns=4,
height="300px",
show_download_button=True,
object_fit="contain",
preview=True,
interactive=False,
)
return input_video, s_time_interval, input_images, image_gallery
def create_3d_viewer_section(self) -> gr.Model3D:
"""
Create the 3D viewer component.
Returns:
3D model viewer component
"""
return gr.Model3D(
height=520,
zoom_speed=0.5,
pan_speed=0.5,
clear_color=[0.0, 0.0, 0.0, 0.0],
key="persistent_3d_viewer",
elem_id="reconstruction_3d_viewer",
)
def create_nvs_video(self) -> Tuple[gr.Video, gr.Markdown]:
"""
Create the 3DGS rendered video display component and info message.
Returns:
Tuple of (video component, info message component)
"""
with gr.Column():
gs_info = gr.Markdown(
(
"‼️ **3D Gaussian Splatting rendering is currently DISABLED.** <br><br><br>"
"To render novel views from 3DGS, "
"enable **Infer 3D Gaussian Splatting** below. <br>"
"Next, in **Visualization Options**, "
"*optionally* configure the **rendering trajectory** (default: smooth) "
"and **video quality** (default: low), "
"then click **Reconstruct**."
),
visible=True,
height=520,
)
gs_video = gr.Video(
height=520,
label="3DGS Rendered NVS Video (depth shown for reference only)",
interactive=False,
visible=False,
)
return gs_video, gs_info
def create_depth_section(self) -> Tuple[gr.Button, gr.Dropdown, gr.Button, gr.Image]:
"""
Create the depth visualization section.
Returns:
A tuple of (prev_depth_btn, depth_view_selector, next_depth_btn, depth_map)
"""
with gr.Row(elem_classes=["navigation-row"]):
prev_depth_btn = gr.Button("◀ Previous", size="sm", scale=1)
depth_view_selector = gr.Dropdown(
choices=["View 1"],
value="View 1",
label="Select View",
scale=2,
interactive=True,
allow_custom_value=True,
)
next_depth_btn = gr.Button("Next ▶", size="sm", scale=1)
depth_map = gr.Image(
type="numpy",
label="Colorized Depth Map",
format="png",
interactive=False,
)
return prev_depth_btn, depth_view_selector, next_depth_btn, depth_map
def create_measure_section(
self,
) -> Tuple[gr.Button, gr.Dropdown, gr.Button, gr.Image, gr.Image, gr.Markdown]:
"""
Create the measurement section.
Returns:
A tuple of (prev_measure_btn, measure_view_selector, next_measure_btn, measure_image,
measure_depth_image, measure_text)
"""
from depth_anything_3.app.css_and_html import MEASURE_INSTRUCTIONS_HTML
gr.Markdown(MEASURE_INSTRUCTIONS_HTML)
with gr.Row(elem_classes=["navigation-row"]):
prev_measure_btn = gr.Button("◀ Previous", size="sm", scale=1)
measure_view_selector = gr.Dropdown(
choices=["View 1"],
value="View 1",
label="Select View",
scale=2,
interactive=True,
allow_custom_value=True,
)
next_measure_btn = gr.Button("Next ▶", size="sm", scale=1)
with gr.Row():
measure_image = gr.Image(
type="numpy",
show_label=False,
format="webp",
interactive=False,
sources=[],
label="RGB Image",
scale=1,
height=275,
)
measure_depth_image = gr.Image(
type="numpy",
show_label=False,
format="webp",
interactive=False,
sources=[],
label="Depth Visualization (Right Half)",
scale=1,
height=275,
)
gr.Markdown(
"**Note:** Images have been adjusted to model processing size. "
"Click two points on the RGB image to measure distance."
)
measure_text = gr.Markdown("")
return (
prev_measure_btn,
measure_view_selector,
next_measure_btn,
measure_image,
measure_depth_image,
measure_text,
)
def create_inference_control_section(self) -> Tuple[gr.Dropdown, gr.Checkbox, gr.Dropdown]:
"""
Create the inference control section (before inference).
Returns:
Tuple of (process_res_method_dropdown, infer_gs, ref_view_strategy)
"""
with gr.Row():
process_res_method_dropdown = gr.Dropdown(
choices=["high_res", "low_res"],
value="low_res",
label="Image Processing Method",
info="low_res for much more images",
scale=1,
)
# Modify line 220, add color class
infer_gs = gr.Checkbox(
label="Infer 3D Gaussian Splatting",
value=False,
info=(
'Enable novel view rendering from 3DGS (<i class="fas fa-triangle-exclamation '
'fa-color-red"></i> requires extra processing time)'
),
scale=1,
)
ref_view_strategy = gr.Dropdown(
choices=["saddle_balanced", "saddle_sim_range", "first", "middle"],
value="saddle_balanced",
label="Reference View Strategy",
info="Strategy for selecting reference view from multiple inputs",
scale=1,
)
return (process_res_method_dropdown, infer_gs, ref_view_strategy)
def create_display_control_section(
self,
) -> Tuple[
gr.Checkbox,
gr.Checkbox,
gr.Checkbox,
gr.Slider,
gr.Slider,
gr.Dropdown,
gr.Dropdown,
gr.Button,
gr.ClearButton,
]:
"""
Create the display control section (options for visualization).
Returns:
Tuple of display control components including buttons
"""
with gr.Column():
# 3DGS options at the top
with gr.Row():
gs_trj_mode = gr.Dropdown(
choices=["smooth", "extend"],
value="smooth",
label=("Rendering trajectory for 3DGS viewpoints (requires n_views ≥ 2)"),
info=("'smooth' for view interpolation; 'extend' for longer trajectory"),
visible=False, # initially hidden
)
gs_video_quality = gr.Dropdown(
choices=["low", "medium", "high"],
value="low",
label=("Video quality for 3DGS rendered outputs"),
info=("'low' for faster loading speed; 'high' for better visual quality"),
visible=False, # initially hidden
)
# Reconstruct and Clear buttons (before Visualization Options)
with gr.Row():
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
clear_btn = gr.ClearButton(scale=1)
gr.Markdown("### Visualization Options: (Click Reconstruct to update)")
show_cam = gr.Checkbox(label="Show Camera", value=True)
filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
filter_white_bg = gr.Checkbox(label="Filter White Background", value=False)
save_percentage = gr.Slider(
minimum=0,
maximum=100,
value=10,
step=1,
label="Filter Percentage",
info="Confidence Threshold (%): Higher values filter more points.",
)
num_max_points = gr.Slider(
minimum=1000,
maximum=100000,
value=1000,
step=1000,
label="Max Points (K points)",
info="Maximum number of points to export to GLB (in thousands)",
)
return (
show_cam,
filter_black_bg,
filter_white_bg,
save_percentage,
num_max_points,
gs_trj_mode,
gs_video_quality,
submit_btn,
clear_btn,
)
def create_control_section(
self,
) -> Tuple[
gr.Button,
gr.ClearButton,
gr.Dropdown,
gr.Checkbox,
gr.Checkbox,
gr.Checkbox,
gr.Checkbox,
gr.Checkbox,
gr.Dropdown,
gr.Checkbox,
gr.Textbox,
]:
"""
Create the control section with buttons and options.
Returns:
Tuple of control components
"""
with gr.Row():
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
clear_btn = gr.ClearButton(
scale=1,
)
with gr.Row():
frame_filter = gr.Dropdown(
choices=["All"], value="All", label="Show Points from Frame"
)
with gr.Column():
gr.Markdown("### Visualization Option: (Click Reconstruct to update)")
show_cam = gr.Checkbox(label="Show Camera", value=True)
show_mesh = gr.Checkbox(label="Show Mesh", value=True)
filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
filter_white_bg = gr.Checkbox(label="Filter White Background", value=False)
gr.Markdown("### Reconstruction Options: (updated on next run)")
apply_mask_checkbox = gr.Checkbox(
label="Apply mask for predicted ambiguous depth classes & edges",
value=True,
)
process_res_method_dropdown = gr.Dropdown(
choices=[
"upper_bound_resize",
"upper_bound_crop",
"lower_bound_resize",
"lower_bound_crop",
],
value="upper_bound_resize",
label="Image Processing Method",
info="Method for resizing input images",
)
save_to_gallery_checkbox = gr.Checkbox(
label="Save to Gallery",
value=False,
info="Save current reconstruction results to gallery directory",
)
gallery_name_input = gr.Textbox(
label="Gallery Name",
placeholder="Enter a name for the gallery folder",
value="",
info="Leave empty for auto-generated name with timestamp",
)
return (
submit_btn,
clear_btn,
frame_filter,
show_cam,
show_mesh,
filter_black_bg,
filter_white_bg,
apply_mask_checkbox,
process_res_method_dropdown,
save_to_gallery_checkbox,
gallery_name_input,
)
def create_example_scenes_section(self) -> List[Dict[str, Any]]:
"""
Create the example scenes section.
Returns:
List of scene information dictionaries
"""
# Get workspace directory from environment variable
workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
examples_dir = os.path.join(workspace_dir, "examples")
# Get scene information
scenes = get_scene_info(examples_dir)
return scenes
def create_example_scene_grid(self, scenes: List[Dict[str, Any]]) -> List[gr.Image]:
"""
Create the example scene grid.
Args:
scenes: List of scene information dictionaries
Returns:
List of scene image components
"""
scene_components = []
if scenes:
for i in range(0, len(scenes), 4): # Process 4 scenes per row
with gr.Row():
for j in range(4):
scene_idx = i + j
if scene_idx < len(scenes):
scene = scenes[scene_idx]
with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]):
# Clickable thumbnail
scene_img = gr.Image(
value=scene["thumbnail"],
height=150,
interactive=False,
show_label=False,
elem_id=f"scene_thumb_{scene['name']}",
sources=[],
)
scene_components.append(scene_img)
# Scene name and image count as text below thumbnail
gr.Markdown(
f"**{scene['name']}** \n {scene['num_images']} images",
elem_classes=["scene-info"],
)
else:
# Empty column to maintain grid structure
with gr.Column(scale=1):
pass
return scene_components
def create_header_section(self) -> gr.HTML:
"""
Create the header section with logo and title.
Returns:
Header HTML component
"""
from depth_anything_3.app.css_and_html import get_header_html
return gr.HTML(get_header_html(get_logo_base64()))
def create_description_section(self) -> gr.HTML:
"""
Create the description section.
Returns:
Description HTML component
"""
from depth_anything_3.app.css_and_html import get_description_html
return gr.HTML(get_description_html())
def create_acknowledgements_section(self) -> gr.HTML:
"""
Create the acknowledgements section.
Returns:
Acknowledgements HTML component
"""
from depth_anything_3.app.css_and_html import get_acknowledgements_html
return gr.HTML(get_acknowledgements_html())

View File

@@ -0,0 +1,207 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility functions for Depth Anything 3 Gradio app.
This module contains helper functions for data processing, visualization,
and file operations.
"""
import json
import os
import shutil
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
def create_depth_visualization(depth: np.ndarray) -> Optional[np.ndarray]:
"""
Create a colored depth visualization.
Args:
depth: Depth array
Returns:
Colored depth visualization or None
"""
if depth is None:
return None
# Normalize depth to 0-1 range
depth_min = depth[depth > 0].min() if (depth > 0).any() else 0
depth_max = depth.max()
if depth_max <= depth_min:
return None
# Normalize depth
depth_norm = (depth - depth_min) / (depth_max - depth_min)
depth_norm = np.clip(depth_norm, 0, 1)
# Apply colormap (using matplotlib's viridis colormap)
import matplotlib.cm as cm
# Convert to colored image
depth_colored = cm.viridis(depth_norm)[:, :, :3] # Remove alpha channel
depth_colored = (depth_colored * 255).astype(np.uint8)
return depth_colored
def save_to_gallery_func(
target_dir: str, processed_data: Dict[int, Dict[str, Any]], gallery_name: Optional[str] = None
) -> Tuple[bool, str]:
"""
Save the current reconstruction results to the gallery directory.
Args:
target_dir: Source directory containing reconstruction results
processed_data: Processed data dictionary
gallery_name: Name for the gallery folder
Returns:
Tuple of (success, message)
"""
try:
# Get gallery directory from environment variable or use default
gallery_dir = os.environ.get(
"DA3_GALLERY_DIR",
"workspace/gallery",
)
if not os.path.exists(gallery_dir):
os.makedirs(gallery_dir)
# Use provided name or create a unique name
if gallery_name is None or gallery_name.strip() == "":
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
gallery_name = f"reconstruction_{timestamp}"
gallery_path = os.path.join(gallery_dir, gallery_name)
# Check if directory already exists
if os.path.exists(gallery_path):
return False, f"Save failed: folder '{gallery_name}' already exists"
# Create the gallery directory
os.makedirs(gallery_path, exist_ok=True)
# Copy GLB file
glb_source = os.path.join(target_dir, "scene.glb")
glb_dest = os.path.join(gallery_path, "scene.glb")
if os.path.exists(glb_source):
shutil.copy2(glb_source, glb_dest)
# Copy depth visualization images
depth_vis_dir = os.path.join(target_dir, "depth_vis")
if os.path.exists(depth_vis_dir):
gallery_depth_vis = os.path.join(gallery_path, "depth_vis")
shutil.copytree(depth_vis_dir, gallery_depth_vis)
# Copy original images
images_source = os.path.join(target_dir, "images")
if os.path.exists(images_source):
gallery_images = os.path.join(gallery_path, "images")
shutil.copytree(images_source, gallery_images)
scene_preview_source = os.path.join(target_dir, "scene.jpg")
scene_preview_dest = os.path.join(gallery_path, "scene.jpg")
shutil.copy2(scene_preview_source, scene_preview_dest)
# Save metadata
metadata = {
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
"num_images": len(processed_data) if processed_data else 0,
"gallery_name": gallery_name,
}
with open(os.path.join(gallery_path, "metadata.json"), "w") as f:
json.dump(metadata, f, indent=2)
print(f"Saved reconstruction to gallery: {gallery_path}")
return True, f"Save successful: saved to {gallery_path}"
except Exception as e:
print(f"Error saving to gallery: {e}")
return False, f"Save failed: {str(e)}"
def get_scene_info(examples_dir: str) -> List[Dict[str, Any]]:
"""
Get information about scenes in the examples directory.
Args:
examples_dir: Path to examples directory
Returns:
List of scene information dictionaries
"""
import glob
scenes = []
if not os.path.exists(examples_dir):
return scenes
for scene_folder in sorted(os.listdir(examples_dir)):
scene_path = os.path.join(examples_dir, scene_folder)
if os.path.isdir(scene_path):
# Find all image files in the scene folder
image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
image_files = []
for ext in image_extensions:
image_files.extend(glob.glob(os.path.join(scene_path, ext)))
image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
if image_files:
# Sort images and get the first one for thumbnail
image_files = sorted(image_files)
first_image = image_files[0]
num_images = len(image_files)
scenes.append(
{
"name": scene_folder,
"path": scene_path,
"thumbnail": first_image,
"num_images": num_images,
"image_files": image_files,
}
)
return scenes
# NOTE: cleanup was moved to a single canonical helper in
# `depth_anything_3.utils.memory.cleanup_cuda_memory`.
# Callers should import and call that directly instead of using this module.
def get_logo_base64() -> Optional[str]:
"""
Convert WAI logo to base64 for embedding in HTML.
Returns:
Base64 encoded logo string or None
"""
import base64
logo_path = "examples/WAI-Logo/wai_logo.png"
try:
with open(logo_path, "rb") as img_file:
img_data = img_file.read()
base64_str = base64.b64encode(img_data).decode()
return f"data:image/png;base64,{base64_str}"
except FileNotFoundError:
return None

View File

@@ -0,0 +1,434 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Visualization module for Depth Anything 3 Gradio app.
This module handles visualization updates, navigation, and measurement functionality.
"""
import os
from typing import Any, Dict, List, Optional, Tuple
import cv2
import gradio as gr
import numpy as np
class VisualizationHandler:
"""
Handles visualization updates and navigation for the Gradio app.
"""
def __init__(self):
"""Initialize the visualization handler."""
def update_view_selectors(
self, processed_data: Optional[Dict[int, Dict[str, Any]]]
) -> Tuple[gr.Dropdown, gr.Dropdown]:
"""
Update view selector dropdowns based on available views.
Args:
processed_data: Processed data dictionary
Returns:
Tuple of (depth_view_selector, measure_view_selector)
"""
if processed_data is None or len(processed_data) == 0:
choices = ["View 1"]
else:
num_views = len(processed_data)
choices = [f"View {i + 1}" for i in range(num_views)]
return (
gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector
gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector
)
def get_view_data_by_index(
self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
) -> Optional[Dict[str, Any]]:
"""
Get view data by index, handling bounds.
Args:
processed_data: Processed data dictionary
view_index: Index of the view to get
Returns:
View data dictionary or None
"""
if processed_data is None or len(processed_data) == 0:
return None
view_keys = list(processed_data.keys())
if view_index < 0 or view_index >= len(view_keys):
view_index = 0
return processed_data[view_keys[view_index]]
def update_depth_view(
self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
) -> Optional[str]:
"""
Update depth view for a specific view index.
Args:
processed_data: Processed data dictionary
view_index: Index of the view to update
Returns:
Path to depth visualization image or None
"""
view_data = self.get_view_data_by_index(processed_data, view_index)
if view_data is None or view_data.get("depth_image") is None:
return None
# Return the depth visualization image directly
return view_data["depth_image"]
def navigate_depth_view(
self,
processed_data: Optional[Dict[int, Dict[str, Any]]],
current_selector_value: str,
direction: int,
) -> Tuple[str, Optional[str]]:
"""
Navigate depth view (direction: -1 for previous, +1 for next).
Args:
processed_data: Processed data dictionary
current_selector_value: Current selector value
direction: Direction to navigate (-1 for previous, +1 for next)
Returns:
Tuple of (new_selector_value, depth_vis)
"""
if processed_data is None or len(processed_data) == 0:
return "View 1", None
# Parse current view number
try:
current_view = int(current_selector_value.split()[1]) - 1
except: # noqa
current_view = 0
num_views = len(processed_data)
new_view = (current_view + direction) % num_views
new_selector_value = f"View {new_view + 1}"
depth_vis = self.update_depth_view(processed_data, new_view)
return new_selector_value, depth_vis
def update_measure_view(
self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List]:
"""
Update measure view for a specific view index.
Args:
processed_data: Processed data dictionary
view_index: Index of the view to update
Returns:
Tuple of (measure_image, depth_right_half, measure_points)
"""
view_data = self.get_view_data_by_index(processed_data, view_index)
if view_data is None:
return None, None, [] # image, depth_right_half, measure_points
# Get the processed (resized) image
if "image" in view_data and view_data["image"] is not None:
image = view_data["image"].copy()
else:
return None, None, []
# Ensure image is in uint8 format
if image.dtype != np.uint8:
if image.max() <= 1.0:
image = (image * 255).astype(np.uint8)
else:
image = image.astype(np.uint8)
# Extract right half of the depth visualization (pure depth part)
depth_image_path = view_data.get("depth_image", None)
depth_right_half = None
if depth_image_path and os.path.exists(depth_image_path):
try:
# Load the combined depth visualization image
depth_combined = cv2.imread(depth_image_path)
depth_combined = cv2.cvtColor(depth_combined, cv2.COLOR_BGR2RGB)
if depth_combined is not None:
height, width = depth_combined.shape[:2]
# Extract right half (depth visualization part)
depth_right_half = depth_combined[:, width // 2 :]
except Exception as e:
print(f"Error extracting depth right half: {e}")
return image, depth_right_half, []
def navigate_measure_view(
self,
processed_data: Optional[Dict[int, Dict[str, Any]]],
current_selector_value: str,
direction: int,
) -> Tuple[str, Optional[np.ndarray], Optional[str], List]:
"""
Navigate measure view (direction: -1 for previous, +1 for next).
Args:
processed_data: Processed data dictionary
current_selector_value: Current selector value
direction: Direction to navigate (-1 for previous, +1 for next)
Returns:
Tuple of (new_selector_value, measure_image, depth_image_path, measure_points)
"""
if processed_data is None or len(processed_data) == 0:
return "View 1", None, None, []
# Parse current view number
try:
current_view = int(current_selector_value.split()[1]) - 1
except: # noqa
current_view = 0
num_views = len(processed_data)
new_view = (current_view + direction) % num_views
new_selector_value = f"View {new_view + 1}"
measure_image, depth_right_half, measure_points = self.update_measure_view(
processed_data, new_view
)
return new_selector_value, measure_image, depth_right_half, measure_points
def populate_visualization_tabs(
self, processed_data: Optional[Dict[int, Dict[str, Any]]]
) -> Tuple[Optional[str], Optional[np.ndarray], Optional[str], List]:
"""
Populate the depth and measure tabs with processed data.
Args:
processed_data: Processed data dictionary
Returns:
Tuple of (depth_vis, measure_img, depth_image_path, measure_points)
"""
if processed_data is None or len(processed_data) == 0:
return None, None, None, []
# Use update function to get depth visualization
depth_vis = self.update_depth_view(processed_data, 0)
measure_img, depth_right_half, _ = self.update_measure_view(processed_data, 0)
return depth_vis, measure_img, depth_right_half, []
def reset_measure(
self, processed_data: Optional[Dict[int, Dict[str, Any]]]
) -> Tuple[Optional[np.ndarray], List, str]:
"""
Reset measure points.
Args:
processed_data: Processed data dictionary
Returns:
Tuple of (image, measure_points, text)
"""
if processed_data is None or len(processed_data) == 0:
return None, [], ""
# Return the first view image
first_view = list(processed_data.values())[0]
return first_view["image"], [], ""
def measure(
self,
processed_data: Optional[Dict[int, Dict[str, Any]]],
measure_points: List,
current_view_selector: str,
event: gr.SelectData,
) -> List:
"""
Handle measurement on images.
Args:
processed_data: Processed data dictionary
measure_points: List of current measure points
current_view_selector: Current view selector value
event: Gradio select event
Returns:
List of [image, depth_right_half, measure_points, text]
"""
try:
print(f"Measure function called with selector: {current_view_selector}")
if processed_data is None or len(processed_data) == 0:
return [None, [], "No data available"]
# Use the currently selected view instead of always using the first view
try:
current_view_index = int(current_view_selector.split()[1]) - 1
except: # noqa
current_view_index = 0
print(f"Using view index: {current_view_index}")
# Get view data safely
if current_view_index < 0 or current_view_index >= len(processed_data):
current_view_index = 0
view_keys = list(processed_data.keys())
current_view = processed_data[view_keys[current_view_index]]
if current_view is None:
return [None, [], "No view data available"]
point2d = event.index[0], event.index[1]
print(f"Clicked point: {point2d}")
measure_points.append(point2d)
# Get image and depth visualization
image, depth_right_half, _ = self.update_measure_view(
processed_data, current_view_index
)
if image is None:
return [None, [], "No image available"]
image = image.copy()
# Ensure image is in uint8 format for proper cv2 operations
try:
if image.dtype != np.uint8:
if image.max() <= 1.0:
# Image is in [0, 1] range, convert to [0, 255]
image = (image * 255).astype(np.uint8)
else:
# Image is already in [0, 255] range
image = image.astype(np.uint8)
except Exception as e:
print(f"Image conversion error: {e}")
return [None, [], f"Image conversion error: {e}"]
# Draw circles for points
try:
for p in measure_points:
if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2)
except Exception as e:
print(f"Drawing error: {e}")
return [None, [], f"Drawing error: {e}"]
# Get depth information from processed_data
depth_text = ""
try:
for i, p in enumerate(measure_points):
if (
current_view["depth"] is not None
and 0 <= p[1] < current_view["depth"].shape[0]
and 0 <= p[0] < current_view["depth"].shape[1]
):
d = current_view["depth"][p[1], p[0]]
depth_text += f"- **P{i + 1} depth: {d:.2f}m**\n"
else:
depth_text += f"- **P{i + 1}: Click position ({p[0]}, {p[1]}) - No depth information**\n" # noqa: E501
except Exception as e:
print(f"Depth text error: {e}")
depth_text = f"Error computing depth: {e}\n"
if len(measure_points) == 2:
try:
point1, point2 = measure_points
# Draw line
if (
0 <= point1[0] < image.shape[1]
and 0 <= point1[1] < image.shape[0]
and 0 <= point2[0] < image.shape[1]
and 0 <= point2[1] < image.shape[0]
):
image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2)
# Compute 3D distance using depth information and camera intrinsics
distance_text = "- **Distance: Unable to calculate 3D distance**"
if (
current_view["depth"] is not None
and 0 <= point1[1] < current_view["depth"].shape[0]
and 0 <= point1[0] < current_view["depth"].shape[1]
and 0 <= point2[1] < current_view["depth"].shape[0]
and 0 <= point2[0] < current_view["depth"].shape[1]
):
try:
# Get depth values at the two points
d1 = current_view["depth"][point1[1], point1[0]]
d2 = current_view["depth"][point2[1], point2[0]]
# Convert 2D pixel coordinates to 3D world coordinates
if current_view["intrinsics"] is not None:
# Get camera intrinsics
K = current_view["intrinsics"] # 3x3 intrinsic matrix
fx, fy = K[0, 0], K[1, 1] # focal lengths
cx, cy = K[0, 2], K[1, 2] # principal point
# Convert pixel coordinates to normalized camera coordinates
# Point 1: (u1, v1) -> (x1, y1, z1)
u1, v1 = point1[0], point1[1]
x1 = (u1 - cx) * d1 / fx
y1 = (v1 - cy) * d1 / fy
z1 = d1
# Point 2: (u2, v2) -> (x2, y2, z2)
u2, v2 = point2[0], point2[1]
x2 = (u2 - cx) * d2 / fx
y2 = (v2 - cy) * d2 / fy
z2 = d2
# Calculate 3D Euclidean distance
p1_3d = np.array([x1, y1, z1])
p2_3d = np.array([x2, y2, z2])
distance_3d = np.linalg.norm(p1_3d - p2_3d)
distance_text = f"- **Distance: {distance_3d:.2f}m**"
else:
# Fallback to simplified calculation if no intrinsics
pixel_distance = np.sqrt(
(point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2
)
avg_depth = (d1 + d2) / 2
scale_factor = avg_depth / 1000 # Rough scaling factor
estimated_3d_distance = pixel_distance * scale_factor
distance_text = f"- **Distance: {estimated_3d_distance:.2f}m (estimated, no intrinsics)**" # noqa: E501
except Exception as e:
print(f"Distance computation error: {e}")
distance_text = f"- **Distance computation error: {e}**"
measure_points = []
text = depth_text + distance_text
print(f"Measurement complete: {text}")
return [image, depth_right_half, measure_points, text]
except Exception as e:
print(f"Final measurement error: {e}")
return [None, [], f"Measurement error: {e}"]
else:
print(f"Single point measurement: {depth_text}")
return [image, depth_right_half, measure_points, depth_text]
except Exception as e:
print(f"Overall measure function error: {e}")
return [None, [], f"Measure function error: {e}"]

View File

@@ -0,0 +1,144 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Configuration utility functions
"""
import importlib
from pathlib import Path
from typing import Any, Callable, List, Union
from omegaconf import DictConfig, ListConfig, OmegaConf
try:
OmegaConf.register_new_resolver("eval", eval)
except Exception as e:
# if eval is not available, we can just pass
print(f"Error registering eval resolver: {e}")
def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]:
"""
Load a configuration. Will resolve inheritance.
Supports both file paths and module paths (e.g., depth_anything_3.configs.giant).
"""
# Check if path is a module path (contains dots but no slashes and doesn't end with .yaml)
if "." in path and "/" not in path and not path.endswith(".yaml"):
# It's a module path, load from package resources
path_parts = path.split(".")[1:]
config_path = Path(__file__).resolve().parent
for part in path_parts:
config_path = config_path.joinpath(part)
config_path = config_path.with_suffix(".yaml")
config = OmegaConf.load(str(config_path))
else:
# It's a file path (absolute, relative, or with .yaml extension)
config = OmegaConf.load(path)
if argv is not None:
config_argv = OmegaConf.from_dotlist(argv)
config = OmegaConf.merge(config, config_argv)
config = resolve_recursive(config, resolve_inheritance)
return config
def resolve_recursive(
config: Any,
resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]],
) -> Any:
config = resolver(config)
if isinstance(config, DictConfig):
for k in config.keys():
v = config.get(k)
if isinstance(v, (DictConfig, ListConfig)):
config[k] = resolve_recursive(v, resolver)
if isinstance(config, ListConfig):
for i in range(len(config)):
v = config.get(i)
if isinstance(v, (DictConfig, ListConfig)):
config[i] = resolve_recursive(v, resolver)
return config
def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any:
"""
Recursively resolve inheritance if the config contains:
__inherit__: path/to/parent.yaml or a ListConfig of such paths.
"""
if isinstance(config, DictConfig):
inherit = config.pop("__inherit__", None)
if inherit:
inherit_list = inherit if isinstance(inherit, ListConfig) else [inherit]
parent_config = None
for parent_path in inherit_list:
assert isinstance(parent_path, str)
parent_config = (
load_config(parent_path)
if parent_config is None
else OmegaConf.merge(parent_config, load_config(parent_path))
)
if len(config.keys()) > 0:
config = OmegaConf.merge(parent_config, config)
else:
config = parent_config
return config
def import_item(path: str, name: str) -> Any:
"""
Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass
"""
return getattr(importlib.import_module(path), name)
def create_object(config: DictConfig) -> Any:
"""
Create an object from config.
The config is expected to contains the following:
__object__:
path: path.to.module
name: MyClass
args: as_config | as_params (default to as_config)
"""
config = DictConfig(config)
item = import_item(
path=config.__object__.path,
name=config.__object__.name,
)
args = config.__object__.get("args", "as_config")
if args == "as_config":
return item(config)
if args == "as_params":
config = OmegaConf.to_object(config)
config.pop("__object__")
return item(**config)
raise NotImplementedError(f"Unknown args type: {args}")
def create_dataset(path: str, *args, **kwargs) -> Any:
"""
Create a dataset. Requires the file to contain a "create_dataset" function.
"""
return import_item(path, "create_dataset")(*args, **kwargs)
def to_dict_recursive(config_obj):
if isinstance(config_obj, DictConfig):
return {k: to_dict_recursive(v) for k, v in config_obj.items()}
elif isinstance(config_obj, ListConfig):
return [to_dict_recursive(item) for item in config_obj]
return config_obj

View File

@@ -0,0 +1,803 @@
# flake8: noqa: E402
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Refactored Depth Anything 3 CLI
Clean, modular command-line interface
"""
from __future__ import annotations
import os
import typer
from depth_anything_3.services import start_server
from depth_anything_3.services.gallery import gallery as gallery_main
from depth_anything_3.services.inference_service import run_inference
from depth_anything_3.services.input_handlers import (
ColmapHandler,
ImageHandler,
ImagesHandler,
InputHandler,
VideoHandler,
parse_export_feat,
)
from depth_anything_3.utils.constants import (
DEFAULT_EXPORT_DIR,
DEFAULT_GALLERY_DIR,
DEFAULT_GRADIO_DIR,
DEFAULT_MODEL,
)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
app = typer.Typer(help="Depth Anything 3 - Video depth estimation CLI", add_completion=False)
# ============================================================================
# Input type detection utilities
# ============================================================================
# Supported file extensions
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff", ".tif"}
VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm", ".m4v"}
def detect_input_type(input_path: str) -> str:
"""
Detect input type from path.
Returns:
- "image": Single image file
- "images": Directory containing images
- "video": Video file
- "colmap": COLMAP directory structure
- "unknown": Cannot determine type
"""
if not os.path.exists(input_path):
return "unknown"
# Check if it's a file
if os.path.isfile(input_path):
ext = os.path.splitext(input_path)[1].lower()
if ext in IMAGE_EXTENSIONS:
return "image"
elif ext in VIDEO_EXTENSIONS:
return "video"
return "unknown"
# Check if it's a directory
if os.path.isdir(input_path):
# Check for COLMAP structure
images_dir = os.path.join(input_path, "images")
sparse_dir = os.path.join(input_path, "sparse")
if os.path.isdir(images_dir) and os.path.isdir(sparse_dir):
return "colmap"
# Check if directory contains image files
for item in os.listdir(input_path):
item_path = os.path.join(input_path, item)
if os.path.isfile(item_path):
ext = os.path.splitext(item)[1].lower()
if ext in IMAGE_EXTENSIONS:
return "images"
return "unknown"
return "unknown"
# ============================================================================
# Common parameters and configuration
# ============================================================================
# ============================================================================
# Inference commands
# ============================================================================
@app.command()
def auto(
input_path: str = typer.Argument(
..., help="Path to input (image, directory, video, or COLMAP)"
),
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
export_format: str = typer.Option("glb", help="Export format"),
device: str = typer.Option("cuda", help="Device to use"),
use_backend: bool = typer.Option(False, help="Use backend service for inference"),
backend_url: str = typer.Option(
"http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
),
process_res: int = typer.Option(504, help="Processing resolution"),
process_res_method: str = typer.Option(
"upper_bound_resize", help="Processing resolution method"
),
export_feat: str = typer.Option(
"",
help="[FEAT_VIS]Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
),
auto_cleanup: bool = typer.Option(
False, help="Automatically clean export directory if it exists (no prompt)"
),
# Video-specific options
fps: float = typer.Option(1.0, help="[Video] Sampling FPS for frame extraction"),
# COLMAP-specific options
sparse_subdir: str = typer.Option(
"", help="[COLMAP] Sparse reconstruction subdirectory (e.g., '0' for sparse/0/)"
),
align_to_input_ext_scale: bool = typer.Option(
True, help="[COLMAP] Align prediction to input extrinsics scale"
),
# Pose estimation options
use_ray_pose: bool = typer.Option(
False, help="Use ray-based pose estimation instead of camera decoder"
),
ref_view_strategy: str = typer.Option(
"saddle_balanced",
help="Reference view selection strategy: empty, first, middle, saddle_balanced, saddle_sim_range",
),
# GLB export options
conf_thresh_percentile: float = typer.Option(
40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
),
num_max_points: int = typer.Option(
1_000_000, help="[GLB] Maximum number of points in the point cloud"
),
show_cameras: bool = typer.Option(
True, help="[GLB] Show camera wireframes in the exported scene"
),
# Feat_vis export options
feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
):
"""
Automatically detect input type and run appropriate processing.
Supports:
- Single image file (.jpg, .png, etc.)
- Directory of images
- Video file (.mp4, .avi, etc.)
- COLMAP directory (with 'images' and 'sparse' subdirectories)
"""
# Detect input type
input_type = detect_input_type(input_path)
if input_type == "unknown":
typer.echo(f"❌ Error: Cannot determine input type for: {input_path}", err=True)
typer.echo("Supported inputs:", err=True)
typer.echo(" - Single image file (.jpg, .png, etc.)", err=True)
typer.echo(" - Directory containing images", err=True)
typer.echo(" - Video file (.mp4, .avi, etc.)", err=True)
typer.echo(" - COLMAP directory (with 'images/' and 'sparse/' subdirectories)", err=True)
raise typer.Exit(1)
# Display detected type
typer.echo(f"🔍 Detected input type: {input_type.upper()}")
typer.echo(f"📁 Input path: {input_path}")
typer.echo()
# Determine backend URL based on use_backend flag
final_backend_url = backend_url if use_backend else None
# Parse export_feat parameter
export_feat_layers = parse_export_feat(export_feat)
# Route to appropriate handler
if input_type == "image":
typer.echo("Processing single image...")
# Process input
image_files = ImageHandler.process(input_path)
# Handle export directory
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
# Run inference
run_inference(
image_paths=image_files,
export_dir=export_dir,
model_dir=model_dir,
device=device,
backend_url=final_backend_url,
export_format=export_format,
process_res=process_res,
process_res_method=process_res_method,
export_feat_layers=export_feat_layers,
use_ray_pose=use_ray_pose,
ref_view_strategy=ref_view_strategy,
conf_thresh_percentile=conf_thresh_percentile,
num_max_points=num_max_points,
show_cameras=show_cameras,
feat_vis_fps=feat_vis_fps,
)
elif input_type == "images":
typer.echo("Processing directory of images...")
# Process input - use default extensions
image_files = ImagesHandler.process(input_path, "png,jpg,jpeg")
# Handle export directory
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
# Run inference
run_inference(
image_paths=image_files,
export_dir=export_dir,
model_dir=model_dir,
device=device,
backend_url=final_backend_url,
export_format=export_format,
process_res=process_res,
process_res_method=process_res_method,
export_feat_layers=export_feat_layers,
use_ray_pose=use_ray_pose,
ref_view_strategy=ref_view_strategy,
conf_thresh_percentile=conf_thresh_percentile,
num_max_points=num_max_points,
show_cameras=show_cameras,
feat_vis_fps=feat_vis_fps,
)
elif input_type == "video":
typer.echo(f"Processing video with FPS={fps}...")
# Handle export directory
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
# Process input
image_files = VideoHandler.process(input_path, export_dir, fps)
# Run inference
run_inference(
image_paths=image_files,
export_dir=export_dir,
model_dir=model_dir,
device=device,
backend_url=final_backend_url,
export_format=export_format,
process_res=process_res,
process_res_method=process_res_method,
export_feat_layers=export_feat_layers,
use_ray_pose=use_ray_pose,
ref_view_strategy=ref_view_strategy,
conf_thresh_percentile=conf_thresh_percentile,
num_max_points=num_max_points,
show_cameras=show_cameras,
feat_vis_fps=feat_vis_fps,
)
elif input_type == "colmap":
typer.echo(
f"Processing COLMAP directory (sparse subdirectory: '{sparse_subdir or 'default'}')..."
)
# Process input
image_files, extrinsics, intrinsics = ColmapHandler.process(input_path, sparse_subdir)
# Handle export directory
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
# Run inference
run_inference(
image_paths=image_files,
export_dir=export_dir,
model_dir=model_dir,
device=device,
backend_url=final_backend_url,
export_format=export_format,
process_res=process_res,
process_res_method=process_res_method,
export_feat_layers=export_feat_layers,
extrinsics=extrinsics,
intrinsics=intrinsics,
align_to_input_ext_scale=align_to_input_ext_scale,
use_ray_pose=use_ray_pose,
ref_view_strategy=ref_view_strategy,
conf_thresh_percentile=conf_thresh_percentile,
num_max_points=num_max_points,
show_cameras=show_cameras,
feat_vis_fps=feat_vis_fps,
)
typer.echo()
typer.echo("✅ Processing completed successfully!")
@app.command()
def image(
image_path: str = typer.Argument(..., help="Path to input image file"),
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
export_format: str = typer.Option("glb", help="Export format"),
device: str = typer.Option("cuda", help="Device to use"),
use_backend: bool = typer.Option(False, help="Use backend service for inference"),
backend_url: str = typer.Option(
"http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
),
process_res: int = typer.Option(504, help="Processing resolution"),
process_res_method: str = typer.Option(
"upper_bound_resize", help="Processing resolution method"
),
export_feat: str = typer.Option(
"",
help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
),
auto_cleanup: bool = typer.Option(
False, help="Automatically clean export directory if it exists (no prompt)"
),
# Pose estimation options
use_ray_pose: bool = typer.Option(
False, help="Use ray-based pose estimation instead of camera decoder"
),
ref_view_strategy: str = typer.Option(
"saddle_balanced",
help="Reference view selection strategy: empty, first, middle, saddle_balanced, saddle_sim_range",
),
# GLB export options
conf_thresh_percentile: float = typer.Option(
40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
),
num_max_points: int = typer.Option(
1_000_000, help="[GLB] Maximum number of points in the point cloud"
),
show_cameras: bool = typer.Option(
True, help="[GLB] Show camera wireframes in the exported scene"
),
# Feat_vis export options
feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
):
"""Run camera pose and depth estimation on a single image."""
# Process input
image_files = ImageHandler.process(image_path)
# Handle export directory
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
# Parse export_feat parameter
export_feat_layers = parse_export_feat(export_feat)
# Determine backend URL based on use_backend flag
final_backend_url = backend_url if use_backend else None
# Run inference
run_inference(
image_paths=image_files,
export_dir=export_dir,
model_dir=model_dir,
device=device,
backend_url=final_backend_url,
export_format=export_format,
process_res=process_res,
process_res_method=process_res_method,
export_feat_layers=export_feat_layers,
use_ray_pose=use_ray_pose,
ref_view_strategy=ref_view_strategy,
conf_thresh_percentile=conf_thresh_percentile,
num_max_points=num_max_points,
show_cameras=show_cameras,
feat_vis_fps=feat_vis_fps,
)
@app.command()
def images(
images_dir: str = typer.Argument(..., help="Path to directory containing input images"),
image_extensions: str = typer.Option(
"png,jpg,jpeg", help="Comma-separated image file extensions to process"
),
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
export_format: str = typer.Option("glb", help="Export format"),
device: str = typer.Option("cuda", help="Device to use"),
use_backend: bool = typer.Option(False, help="Use backend service for inference"),
backend_url: str = typer.Option(
"http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
),
process_res: int = typer.Option(504, help="Processing resolution"),
process_res_method: str = typer.Option(
"upper_bound_resize", help="Processing resolution method"
),
export_feat: str = typer.Option(
"",
help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
),
auto_cleanup: bool = typer.Option(
False, help="Automatically clean export directory if it exists (no prompt)"
),
# Pose estimation options
use_ray_pose: bool = typer.Option(
False, help="Use ray-based pose estimation instead of camera decoder"
),
ref_view_strategy: str = typer.Option(
"saddle_balanced",
help="Reference view selection strategy: empty, first, middle, saddle_balanced, saddle_sim_range",
),
# GLB export options
conf_thresh_percentile: float = typer.Option(
40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
),
num_max_points: int = typer.Option(
1_000_000, help="[GLB] Maximum number of points in the point cloud"
),
show_cameras: bool = typer.Option(
True, help="[GLB] Show camera wireframes in the exported scene"
),
# Feat_vis export options
feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
):
"""Run camera pose and depth estimation on a directory of images."""
# Process input
image_files = ImagesHandler.process(images_dir, image_extensions)
# Handle export directory
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
# Parse export_feat parameter
export_feat_layers = parse_export_feat(export_feat)
# Determine backend URL based on use_backend flag
final_backend_url = backend_url if use_backend else None
# Run inference
run_inference(
image_paths=image_files,
export_dir=export_dir,
model_dir=model_dir,
device=device,
backend_url=final_backend_url,
export_format=export_format,
process_res=process_res,
process_res_method=process_res_method,
export_feat_layers=export_feat_layers,
use_ray_pose=use_ray_pose,
ref_view_strategy=ref_view_strategy,
conf_thresh_percentile=conf_thresh_percentile,
num_max_points=num_max_points,
show_cameras=show_cameras,
feat_vis_fps=feat_vis_fps,
)
@app.command()
def colmap(
colmap_dir: str = typer.Argument(
..., help="Path to COLMAP directory containing 'images' and 'sparse' subdirectories"
),
sparse_subdir: str = typer.Option(
"", help="Sparse reconstruction subdirectory (e.g., '0' for sparse/0/, empty for sparse/)"
),
align_to_input_ext_scale: bool = typer.Option(
True, help="Align prediction to input extrinsics scale"
),
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
export_format: str = typer.Option("glb", help="Export format"),
device: str = typer.Option("cuda", help="Device to use"),
use_backend: bool = typer.Option(False, help="Use backend service for inference"),
backend_url: str = typer.Option(
"http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
),
process_res: int = typer.Option(504, help="Processing resolution"),
process_res_method: str = typer.Option(
"upper_bound_resize", help="Processing resolution method"
),
export_feat: str = typer.Option(
"",
help="Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
),
auto_cleanup: bool = typer.Option(
False, help="Automatically clean export directory if it exists (no prompt)"
),
# Pose estimation options
use_ray_pose: bool = typer.Option(
False, help="Use ray-based pose estimation instead of camera decoder"
),
ref_view_strategy: str = typer.Option(
"saddle_balanced",
help="Reference view selection strategy: empty, first, middle, saddle_balanced, saddle_sim_range",
),
# GLB export options
conf_thresh_percentile: float = typer.Option(
40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
),
num_max_points: int = typer.Option(
1_000_000, help="[GLB] Maximum number of points in the point cloud"
),
show_cameras: bool = typer.Option(
True, help="[GLB] Show camera wireframes in the exported scene"
),
# Feat_vis export options
feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
):
"""Run pose conditioned depth estimation on COLMAP data."""
# Process input
image_files, extrinsics, intrinsics = ColmapHandler.process(colmap_dir, sparse_subdir)
# Handle export directory
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
# Parse export_feat parameter
export_feat_layers = parse_export_feat(export_feat)
# Determine backend URL based on use_backend flag
final_backend_url = backend_url if use_backend else None
# Run inference
run_inference(
image_paths=image_files,
export_dir=export_dir,
model_dir=model_dir,
device=device,
backend_url=final_backend_url,
export_format=export_format,
process_res=process_res,
process_res_method=process_res_method,
export_feat_layers=export_feat_layers,
extrinsics=extrinsics,
intrinsics=intrinsics,
align_to_input_ext_scale=align_to_input_ext_scale,
use_ray_pose=use_ray_pose,
ref_view_strategy=ref_view_strategy,
conf_thresh_percentile=conf_thresh_percentile,
num_max_points=num_max_points,
show_cameras=show_cameras,
feat_vis_fps=feat_vis_fps,
)
@app.command()
def video(
video_path: str = typer.Argument(..., help="Path to input video file"),
fps: float = typer.Option(1.0, help="Sampling FPS for frame extraction"),
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
export_format: str = typer.Option("glb", help="Export format"),
device: str = typer.Option("cuda", help="Device to use"),
use_backend: bool = typer.Option(False, help="Use backend service for inference"),
backend_url: str = typer.Option(
"http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
),
process_res: int = typer.Option(504, help="Processing resolution"),
process_res_method: str = typer.Option(
"upper_bound_resize", help="Processing resolution method"
),
export_feat: str = typer.Option(
"",
help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
),
auto_cleanup: bool = typer.Option(
False, help="Automatically clean export directory if it exists (no prompt)"
),
# Pose estimation options
use_ray_pose: bool = typer.Option(
False, help="Use ray-based pose estimation instead of camera decoder"
),
ref_view_strategy: str = typer.Option(
"saddle_balanced",
help="Reference view selection strategy: empty, first, middle, saddle_balanced, saddle_sim_range",
),
# GLB export options
conf_thresh_percentile: float = typer.Option(
40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
),
num_max_points: int = typer.Option(
1_000_000, help="[GLB] Maximum number of points in the point cloud"
),
show_cameras: bool = typer.Option(
True, help="[GLB] Show camera wireframes in the exported scene"
),
# Feat_vis export options
feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
):
"""Run depth estimation on video by extracting frames and processing them."""
# Handle export directory
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
# Process input
image_files = VideoHandler.process(video_path, export_dir, fps)
# Parse export_feat parameter
export_feat_layers = parse_export_feat(export_feat)
# Determine backend URL based on use_backend flag
final_backend_url = backend_url if use_backend else None
# Run inference
run_inference(
image_paths=image_files,
export_dir=export_dir,
model_dir=model_dir,
device=device,
backend_url=final_backend_url,
export_format=export_format,
process_res=process_res,
process_res_method=process_res_method,
export_feat_layers=export_feat_layers,
use_ray_pose=use_ray_pose,
ref_view_strategy=ref_view_strategy,
conf_thresh_percentile=conf_thresh_percentile,
num_max_points=num_max_points,
show_cameras=show_cameras,
feat_vis_fps=feat_vis_fps,
)
# ============================================================================
# Service management commands
# ============================================================================
@app.command()
def backend(
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
device: str = typer.Option("cuda", help="Device to use"),
host: str = typer.Option("127.0.0.1", help="Host to bind to"),
port: int = typer.Option(8008, help="Port to bind to"),
gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR, help="Gallery directory path (optional)"),
):
"""Start model backend service with integrated gallery."""
typer.echo("=" * 60)
typer.echo("🚀 Starting Depth Anything 3 Backend Server")
typer.echo("=" * 60)
typer.echo(f"Model directory: {model_dir}")
typer.echo(f"Device: {device}")
# Check if gallery directory exists
if gallery_dir and os.path.exists(gallery_dir):
typer.echo(f"Gallery directory: {gallery_dir}")
else:
gallery_dir = None # Disable gallery if directory doesn't exist
typer.echo()
typer.echo("📡 Server URLs (Ctrl/CMD+Click to open):")
typer.echo(f" 🏠 Home: http://{host}:{port}")
typer.echo(f" 📊 Dashboard: http://{host}:{port}/dashboard")
typer.echo(f" 📈 API Status: http://{host}:{port}/status")
if gallery_dir:
typer.echo(f" 🎨 Gallery: http://{host}:{port}/gallery/")
typer.echo("=" * 60)
try:
start_server(model_dir, device, host, port, gallery_dir)
except KeyboardInterrupt:
typer.echo("\n👋 Backend server stopped.")
except Exception as e:
typer.echo(f"❌ Failed to start backend: {e}")
raise typer.Exit(1)
# ============================================================================
# Application launch commands
# ============================================================================
@app.command()
def gradio(
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
workspace_dir: str = typer.Option(DEFAULT_GRADIO_DIR, help="Workspace directory path"),
gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR, help="Gallery directory path"),
host: str = typer.Option("127.0.0.1", help="Host address to bind to"),
port: int = typer.Option(7860, help="Port number to bind to"),
share: bool = typer.Option(False, help="Create a public link for the app"),
debug: bool = typer.Option(False, help="Enable debug mode"),
cache_examples: bool = typer.Option(
False, help="Pre-cache all example scenes at startup for faster loading"
),
cache_gs_tag: str = typer.Option(
"",
help="Tag to match scene names for high-res+3DGS caching (e.g., 'dl3dv'). Scenes containing this tag will use high_res and infer_gs=True; others will use low_res only.",
),
):
"""Launch Depth Anything 3 Gradio interactive web application"""
from depth_anything_3.app.gradio_app import DepthAnything3App
# Create necessary directories
os.makedirs(workspace_dir, exist_ok=True)
os.makedirs(gallery_dir, exist_ok=True)
typer.echo("Launching Depth Anything 3 Gradio application...")
typer.echo(f"Model directory: {model_dir}")
typer.echo(f"Workspace directory: {workspace_dir}")
typer.echo(f"Gallery directory: {gallery_dir}")
typer.echo(f"Host: {host}")
typer.echo(f"Port: {port}")
typer.echo(f"Share: {share}")
typer.echo(f"Debug mode: {debug}")
typer.echo(f"Cache examples: {cache_examples}")
if cache_examples:
if cache_gs_tag:
typer.echo(
f"Cache GS Tag: '{cache_gs_tag}' (scenes matching this tag will use high-res + 3DGS)"
)
else:
typer.echo(f"Cache GS Tag: None (all scenes will use low-res only)")
try:
# Initialize and launch application
app = DepthAnything3App(
model_dir=model_dir, workspace_dir=workspace_dir, gallery_dir=gallery_dir
)
# Pre-cache examples if requested
if cache_examples:
typer.echo("\n" + "=" * 60)
typer.echo("Pre-caching mode enabled")
if cache_gs_tag:
typer.echo(f"Scenes containing '{cache_gs_tag}' will use HIGH-RES + 3DGS")
typer.echo(f"Other scenes will use LOW-RES only")
else:
typer.echo(f"All scenes will use LOW-RES only")
typer.echo("=" * 60)
app.cache_examples(
show_cam=True,
filter_black_bg=False,
filter_white_bg=False,
save_percentage=20.0,
num_max_points=1000,
cache_gs_tag=cache_gs_tag,
gs_trj_mode="smooth",
gs_video_quality="low",
)
# Prepare launch arguments
launch_kwargs = {"share": share, "debug": debug}
app.launch(host=host, port=port, **launch_kwargs)
except KeyboardInterrupt:
typer.echo("\nGradio application stopped.")
except Exception as e:
typer.echo(f"Failed to launch Gradio application: {e}")
raise typer.Exit(1)
@app.command()
def gallery(
gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR, help="Gallery root directory"),
host: str = typer.Option("127.0.0.1", help="Host address to bind to"),
port: int = typer.Option(8007, help="Port number to bind to"),
open_browser: bool = typer.Option(False, help="Open browser after launch"),
):
"""Launch Depth Anything 3 Gallery server"""
# Validate gallery directory
if not os.path.exists(gallery_dir):
raise typer.BadParameter(f"Gallery directory not found: {gallery_dir}")
typer.echo("Launching Depth Anything 3 Gallery server...")
typer.echo(f"Gallery directory: {gallery_dir}")
typer.echo(f"Host: {host}")
typer.echo(f"Port: {port}")
typer.echo(f"Auto-open browser: {open_browser}")
try:
# Set command line arguments
import sys
sys.argv = ["gallery", "--dir", gallery_dir, "--host", host, "--port", str(port)]
if open_browser:
sys.argv.append("--open")
# Launch gallery server
gallery_main()
except KeyboardInterrupt:
typer.echo("\nGallery server stopped.")
except Exception as e:
typer.echo(f"Failed to launch Gallery server: {e}")
raise typer.Exit(1)
if __name__ == "__main__":
app()

View File

@@ -0,0 +1,45 @@
__object__:
path: depth_anything_3.model.da3
name: DepthAnything3Net
args: as_params
net:
__object__:
path: depth_anything_3.model.dinov2.dinov2
name: DinoV2
args: as_params
name: vitb
out_layers: [5, 7, 9, 11]
alt_start: 4
qknorm_start: 4
rope_start: 4
cat_token: True
head:
__object__:
path: depth_anything_3.model.dualdpt
name: DualDPT
args: as_params
dim_in: &head_dim_in 1536
output_dim: 2
features: &head_features 128
out_channels: &head_out_channels [96, 192, 384, 768]
cam_enc:
__object__:
path: depth_anything_3.model.cam_enc
name: CameraEnc
args: as_params
dim_out: 768
cam_dec:
__object__:
path: depth_anything_3.model.cam_dec
name: CameraDec
args: as_params
dim_in: 1536

View File

@@ -0,0 +1,71 @@
__object__:
path: depth_anything_3.model.da3
name: DepthAnything3Net
args: as_params
net:
__object__:
path: depth_anything_3.model.dinov2.dinov2
name: DinoV2
args: as_params
name: vitg
out_layers: [19, 27, 33, 39]
alt_start: 13
qknorm_start: 13
rope_start: 13
cat_token: True
head:
__object__:
path: depth_anything_3.model.dualdpt
name: DualDPT
args: as_params
dim_in: &head_dim_in 3072
output_dim: 2
features: &head_features 256
out_channels: &head_out_channels [256, 512, 1024, 1024]
cam_enc:
__object__:
path: depth_anything_3.model.cam_enc
name: CameraEnc
args: as_params
dim_out: 1536
cam_dec:
__object__:
path: depth_anything_3.model.cam_dec
name: CameraDec
args: as_params
dim_in: 3072
gs_head:
__object__:
path: depth_anything_3.model.gsdpt
name: GSDPT
args: as_params
dim_in: *head_dim_in
output_dim: 38 # should align with gs_adapter's setting, for gs params
features: *head_features
out_channels: *head_out_channels
gs_adapter:
__object__:
path: depth_anything_3.model.gs_adapter
name: GaussianAdapter
args: as_params
sh_degree: 2
pred_color: false # predict SH coefficient if false
pred_offset_depth: true
pred_offset_xy: true
gaussian_scale_min: 1e-5
gaussian_scale_max: 30.0

View File

@@ -0,0 +1,45 @@
__object__:
path: depth_anything_3.model.da3
name: DepthAnything3Net
args: as_params
net:
__object__:
path: depth_anything_3.model.dinov2.dinov2
name: DinoV2
args: as_params
name: vitl
out_layers: [11, 15, 19, 23]
alt_start: 8
qknorm_start: 8
rope_start: 8
cat_token: True
head:
__object__:
path: depth_anything_3.model.dualdpt
name: DualDPT
args: as_params
dim_in: &head_dim_in 2048
output_dim: 2
features: &head_features 256
out_channels: &head_out_channels [256, 512, 1024, 1024]
cam_enc:
__object__:
path: depth_anything_3.model.cam_enc
name: CameraEnc
args: as_params
dim_out: 1024
cam_dec:
__object__:
path: depth_anything_3.model.cam_dec
name: CameraDec
args: as_params
dim_in: 2048

View File

@@ -0,0 +1,45 @@
__object__:
path: depth_anything_3.model.da3
name: DepthAnything3Net
args: as_params
net:
__object__:
path: depth_anything_3.model.dinov2.dinov2
name: DinoV2
args: as_params
name: vits
out_layers: [5, 7, 9, 11]
alt_start: 4
qknorm_start: 4
rope_start: 4
cat_token: True
head:
__object__:
path: depth_anything_3.model.dualdpt
name: DualDPT
args: as_params
dim_in: &head_dim_in 768
output_dim: 2
features: &head_features 64
out_channels: &head_out_channels [48, 96, 192, 384]
cam_enc:
__object__:
path: depth_anything_3.model.cam_enc
name: CameraEnc
args: as_params
dim_out: 384
cam_dec:
__object__:
path: depth_anything_3.model.cam_dec
name: CameraDec
args: as_params
dim_in: 768

View File

@@ -0,0 +1,28 @@
__object__:
path: depth_anything_3.model.da3
name: DepthAnything3Net
args: as_params
net:
__object__:
path: depth_anything_3.model.dinov2.dinov2
name: DinoV2
args: as_params
name: vitl
out_layers: [4, 11, 17, 23]
alt_start: -1 # -1 means disable
qknorm_start: -1
rope_start: -1
cat_token: False
head:
__object__:
path: depth_anything_3.model.dpt
name: DPT
args: as_params
dim_in: 1024
output_dim: 1
features: 256
out_channels: [256, 512, 1024, 1024]

View File

@@ -0,0 +1,28 @@
__object__:
path: depth_anything_3.model.da3
name: DepthAnything3Net
args: as_params
net:
__object__:
path: depth_anything_3.model.dinov2.dinov2
name: DinoV2
args: as_params
name: vitl
out_layers: [4, 11, 17, 23]
alt_start: -1 # -1 means disable
qknorm_start: -1
rope_start: -1
cat_token: False
head:
__object__:
path: depth_anything_3.model.dpt
name: DPT
args: as_params
dim_in: 1024
output_dim: 1
features: 256
out_channels: [256, 512, 1024, 1024]

View File

@@ -0,0 +1,10 @@
__object__:
path: depth_anything_3.model.da3
name: NestedDepthAnything3Net
args: as_params
anyview:
__inherit__: depth_anything_3.configs.da3-giant
metric:
__inherit__: depth_anything_3.configs.da3metric-large

View File

@@ -0,0 +1,20 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from depth_anything_3.model.da3 import DepthAnything3Net, NestedDepthAnything3Net
__export__ = [
NestedDepthAnything3Net,
DepthAnything3Net,
]

View File

@@ -0,0 +1,45 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
class CameraDec(nn.Module):
def __init__(self, dim_in=1536):
super().__init__()
output_dim = dim_in
self.backbone = nn.Sequential(
nn.Linear(output_dim, output_dim),
nn.ReLU(),
nn.Linear(output_dim, output_dim),
nn.ReLU(),
)
self.fc_t = nn.Linear(output_dim, 3)
self.fc_qvec = nn.Linear(output_dim, 4)
self.fc_fov = nn.Sequential(nn.Linear(output_dim, 2), nn.ReLU())
def forward(self, feat, camera_encoding=None, *args, **kwargs):
B, N = feat.shape[:2]
feat = feat.reshape(B * N, -1)
feat = self.backbone(feat)
out_t = self.fc_t(feat.float()).reshape(B, N, 3)
if camera_encoding is None:
out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4)
out_fov = self.fc_fov(feat.float()).reshape(B, N, 2)
else:
out_qvec = camera_encoding[..., 3:7]
out_fov = camera_encoding[..., -2:]
pose_enc = torch.cat([out_t, out_qvec, out_fov], dim=-1)
return pose_enc

View File

@@ -0,0 +1,80 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
from depth_anything_3.model.utils.attention import Mlp
from depth_anything_3.model.utils.block import Block
from depth_anything_3.model.utils.transform import extri_intri_to_pose_encoding
from depth_anything_3.utils.geometry import affine_inverse
class CameraEnc(nn.Module):
"""
CameraHead predicts camera parameters from token representations using iterative refinement.
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
"""
def __init__(
self,
dim_out: int = 1024,
dim_in: int = 9,
trunk_depth: int = 4,
target_dim: int = 9,
num_heads: int = 16,
mlp_ratio: int = 4,
init_values: float = 0.01,
**kwargs,
):
super().__init__()
self.target_dim = target_dim
self.trunk_depth = trunk_depth
self.trunk = nn.Sequential(
*[
Block(
dim=dim_out,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
init_values=init_values,
)
for _ in range(trunk_depth)
]
)
self.token_norm = nn.LayerNorm(dim_out)
self.trunk_norm = nn.LayerNorm(dim_out)
self.pose_branch = Mlp(
in_features=dim_in,
hidden_features=dim_out // 2,
out_features=dim_out,
drop=0,
)
def forward(
self,
ext,
ixt,
image_size,
) -> tuple:
c2ws = affine_inverse(ext)
pose_encoding = extri_intri_to_pose_encoding(
c2ws,
ixt,
image_size,
)
pose_tokens = self.pose_branch(pose_encoding)
pose_tokens = self.token_norm(pose_tokens)
pose_tokens = self.trunk(pose_tokens)
pose_tokens = self.trunk_norm(pose_tokens)
return pose_tokens

View File

@@ -0,0 +1,442 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
import torch.nn as nn
from addict import Dict
from omegaconf import DictConfig, OmegaConf
from depth_anything_3.cfg import create_object
from depth_anything_3.model.utils.transform import pose_encoding_to_extri_intri
from depth_anything_3.utils.alignment import (
apply_metric_scaling,
compute_alignment_mask,
compute_sky_mask,
least_squares_scale_scalar,
sample_tensor_for_quantile,
set_sky_regions_to_max_depth,
)
from depth_anything_3.utils.geometry import affine_inverse, as_homogeneous, map_pdf_to_opacity
from depth_anything_3.utils.ray_utils import get_extrinsic_from_camray
def _wrap_cfg(cfg_obj):
return OmegaConf.create(cfg_obj)
class DepthAnything3Net(nn.Module):
"""
Depth Anything 3 network for depth estimation and camera pose estimation.
This network consists of:
- Backbone: DinoV2 feature extractor
- Head: DPT or DualDPT for depth prediction
- Optional camera decoders for pose estimation
- Optional GSDPT for 3DGS prediction
Args:
preset: Configuration preset containing network dimensions and settings
Returns:
Dictionary containing:
- depth: Predicted depth map (B, H, W)
- depth_conf: Depth confidence map (B, H, W)
- extrinsics: Camera extrinsics (B, N, 4, 4)
- intrinsics: Camera intrinsics (B, N, 3, 3)
- gaussians: 3D Gaussian Splats (world space), type: model.gs_adapter.Gaussians
- aux: Auxiliary features for specified layers
"""
# Patch size for feature extraction
PATCH_SIZE = 14
def __init__(self, net, head, cam_dec=None, cam_enc=None, gs_head=None, gs_adapter=None):
"""
Initialize DepthAnything3Net with given yaml-initialized configuration.
"""
super().__init__()
self.backbone = net if isinstance(net, nn.Module) else create_object(_wrap_cfg(net))
self.head = head if isinstance(head, nn.Module) else create_object(_wrap_cfg(head))
self.cam_dec, self.cam_enc = None, None
if cam_dec is not None:
self.cam_dec = (
cam_dec if isinstance(cam_dec, nn.Module) else create_object(_wrap_cfg(cam_dec))
)
self.cam_enc = (
cam_enc if isinstance(cam_enc, nn.Module) else create_object(_wrap_cfg(cam_enc))
)
self.gs_adapter, self.gs_head = None, None
if gs_head is not None and gs_adapter is not None:
self.gs_adapter = (
gs_adapter
if isinstance(gs_adapter, nn.Module)
else create_object(_wrap_cfg(gs_adapter))
)
gs_out_dim = self.gs_adapter.d_in + 1
if isinstance(gs_head, nn.Module):
assert (
gs_head.out_dim == gs_out_dim
), f"gs_head.out_dim should be {gs_out_dim}, got {gs_head.out_dim}"
self.gs_head = gs_head
else:
assert (
gs_head["output_dim"] == gs_out_dim
), f"gs_head output_dim should set to {gs_out_dim}, got {gs_head['output_dim']}"
self.gs_head = create_object(_wrap_cfg(gs_head))
def forward(
self,
x: torch.Tensor,
extrinsics: torch.Tensor | None = None,
intrinsics: torch.Tensor | None = None,
export_feat_layers: list[int] | None = [],
infer_gs: bool = False,
use_ray_pose: bool = False,
ref_view_strategy: str = "saddle_balanced",
) -> Dict[str, torch.Tensor]:
"""
Forward pass through the network.
Args:
x: Input images (B, N, 3, H, W)
extrinsics: Camera extrinsics (B, N, 4, 4)
intrinsics: Camera intrinsics (B, N, 3, 3)
feat_layers: List of layer indices to extract features from
infer_gs: Enable Gaussian Splatting branch
use_ray_pose: Use ray-based pose estimation
ref_view_strategy: Strategy for selecting reference view
Returns:
Dictionary containing predictions and auxiliary features
"""
# Extract features using backbone
if extrinsics is not None:
with torch.autocast(device_type=x.device.type, enabled=False):
cam_token = self.cam_enc(extrinsics, intrinsics, x.shape[-2:])
else:
cam_token = None
feats, aux_feats = self.backbone(
x, cam_token=cam_token, export_feat_layers=export_feat_layers, ref_view_strategy=ref_view_strategy
)
# feats = [[item for item in feat] for feat in feats]
H, W = x.shape[-2], x.shape[-1]
# Process features through depth head
with torch.autocast(device_type=x.device.type, enabled=False):
output = self._process_depth_head(feats, H, W)
if use_ray_pose:
output = self._process_ray_pose_estimation(output, H, W)
else:
output = self._process_camera_estimation(feats, H, W, output)
if infer_gs:
output = self._process_gs_head(feats, H, W, output, x, extrinsics, intrinsics)
output = self._process_mono_sky_estimation(output)
# Extract auxiliary features if requested
output.aux = self._extract_auxiliary_features(aux_feats, export_feat_layers, H, W)
return output
def _process_mono_sky_estimation(
self, output: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""Process mono sky estimation."""
if "sky" not in output:
return output
non_sky_mask = compute_sky_mask(output.sky, threshold=0.3)
if non_sky_mask.sum() <= 10:
return output
if (~non_sky_mask).sum() <= 10:
return output
non_sky_depth = output.depth[non_sky_mask]
if non_sky_depth.numel() > 100000:
idx = torch.randint(0, non_sky_depth.numel(), (100000,), device=non_sky_depth.device)
sampled_depth = non_sky_depth[idx]
else:
sampled_depth = non_sky_depth
non_sky_max = torch.quantile(sampled_depth, 0.99)
# Set sky regions to maximum depth and high confidence
output.depth, _ = set_sky_regions_to_max_depth(
output.depth, None, non_sky_mask, max_depth=non_sky_max
)
return output
def _process_ray_pose_estimation(
self, output: Dict[str, torch.Tensor], height: int, width: int
) -> Dict[str, torch.Tensor]:
"""Process ray pose estimation if ray pose decoder is available."""
if "ray" in output and "ray_conf" in output:
pred_extrinsic, pred_focal_lengths, pred_principal_points = get_extrinsic_from_camray(
output.ray,
output.ray_conf,
output.ray.shape[-3],
output.ray.shape[-2],
)
pred_extrinsic = affine_inverse(pred_extrinsic) # w2c -> c2w
pred_extrinsic = pred_extrinsic[:, :, :3, :]
pred_intrinsic = torch.eye(3, 3)[None, None].repeat(pred_extrinsic.shape[0], pred_extrinsic.shape[1], 1, 1).clone().to(pred_extrinsic.device)
pred_intrinsic[:, :, 0, 0] = pred_focal_lengths[:, :, 0] / 2 * width
pred_intrinsic[:, :, 1, 1] = pred_focal_lengths[:, :, 1] / 2 * height
pred_intrinsic[:, :, 0, 2] = pred_principal_points[:, :, 0] * width * 0.5
pred_intrinsic[:, :, 1, 2] = pred_principal_points[:, :, 1] * height * 0.5
del output.ray
del output.ray_conf
output.extrinsics = pred_extrinsic
output.intrinsics = pred_intrinsic
return output
def _process_depth_head(
self, feats: list[torch.Tensor], H: int, W: int
) -> Dict[str, torch.Tensor]:
"""Process features through the depth prediction head."""
return self.head(feats, H, W, patch_start_idx=0)
def _process_camera_estimation(
self, feats: list[torch.Tensor], H: int, W: int, output: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""Process camera pose estimation if camera decoder is available."""
if self.cam_dec is not None:
pose_enc = self.cam_dec(feats[-1][1])
# Remove ray information as it's not needed for pose estimation
if "ray" in output:
del output.ray
if "ray_conf" in output:
del output.ray_conf
# Convert pose encoding to extrinsics and intrinsics
c2w, ixt = pose_encoding_to_extri_intri(pose_enc, (H, W))
output.extrinsics = affine_inverse(c2w)
output.intrinsics = ixt
return output
def _process_gs_head(
self,
feats: list[torch.Tensor],
H: int,
W: int,
output: Dict[str, torch.Tensor],
in_images: torch.Tensor,
extrinsics: torch.Tensor | None = None,
intrinsics: torch.Tensor | None = None,
) -> Dict[str, torch.Tensor]:
"""Process 3DGS parameters estimation if 3DGS head is available."""
if self.gs_head is None or self.gs_adapter is None:
return output
assert output.get("depth", None) is not None, "must provide MV depth for the GS head."
# The depth is defined in the DA3 model's camera space,
# so even with provided GT camera poses,
# we instead use the predicted camera poses for better alignment.
ctx_extr = output.get("extrinsics", None)
ctx_intr = output.get("intrinsics", None)
assert (
ctx_extr is not None and ctx_intr is not None
), "must process camera info first if GT is not available"
gt_extr = extrinsics
# homo the extr if needed
ctx_extr = as_homogeneous(ctx_extr)
if gt_extr is not None:
gt_extr = as_homogeneous(gt_extr)
# forward through the gs_dpt head to get 'camera space' parameters
gs_outs = self.gs_head(
feats=feats,
H=H,
W=W,
patch_start_idx=0,
images=in_images,
)
raw_gaussians = gs_outs.raw_gs
densities = gs_outs.raw_gs_conf
# convert to 'world space' 3DGS parameters; ready to export and render
# gt_extr could be None, and will be used to align the pose scale if available
gs_world = self.gs_adapter(
extrinsics=ctx_extr,
intrinsics=ctx_intr,
depths=output.depth,
opacities=map_pdf_to_opacity(densities),
raw_gaussians=raw_gaussians,
image_shape=(H, W),
gt_extrinsics=gt_extr,
)
output.gaussians = gs_world
return output
def _extract_auxiliary_features(
self, feats: list[torch.Tensor], feat_layers: list[int], H: int, W: int
) -> Dict[str, torch.Tensor]:
"""Extract auxiliary features from specified layers."""
aux_features = Dict()
assert len(feats) == len(feat_layers)
for feat, feat_layer in zip(feats, feat_layers):
# Reshape features to spatial dimensions
feat_reshaped = feat.reshape(
[
feat.shape[0],
feat.shape[1],
H // self.PATCH_SIZE,
W // self.PATCH_SIZE,
feat.shape[-1],
]
)
aux_features[f"feat_layer_{feat_layer}"] = feat_reshaped
return aux_features
class NestedDepthAnything3Net(nn.Module):
"""
Nested Depth Anything 3 network with metric scaling capabilities.
This network combines two DepthAnything3Net branches:
- Main branch: Standard depth estimation
- Metric branch: Metric depth estimation for scaling alignment
The network performs depth alignment using least squares scaling
and handles sky region masking for improved depth estimation.
Args:
preset: Configuration for the main depth estimation branch
second_preset: Configuration for the metric depth branch
"""
def __init__(self, anyview: DictConfig, metric: DictConfig):
"""
Initialize NestedDepthAnything3Net with two branches.
Args:
preset: Configuration for main depth estimation branch
second_preset: Configuration for metric depth branch
"""
super().__init__()
self.da3 = create_object(anyview)
self.da3_metric = create_object(metric)
def forward(
self,
x: torch.Tensor,
extrinsics: torch.Tensor | None = None,
intrinsics: torch.Tensor | None = None,
export_feat_layers: list[int] | None = [],
infer_gs: bool = False,
use_ray_pose: bool = False,
ref_view_strategy: str = "saddle_balanced",
) -> Dict[str, torch.Tensor]:
"""
Forward pass through both branches with metric scaling alignment.
Args:
x: Input images (B, N, 3, H, W)
extrinsics: Camera extrinsics (B, N, 4, 4) - unused
intrinsics: Camera intrinsics (B, N, 3, 3) - unused
feat_layers: List of layer indices to extract features from
infer_gs: Enable Gaussian Splatting branch
use_ray_pose: Use ray-based pose estimation
ref_view_strategy: Strategy for selecting reference view
Returns:
Dictionary containing aligned depth predictions and camera parameters
"""
# Get predictions from both branches
output = self.da3(
x, extrinsics, intrinsics, export_feat_layers=export_feat_layers, infer_gs=infer_gs, use_ray_pose=use_ray_pose, ref_view_strategy=ref_view_strategy
)
metric_output = self.da3_metric(x)
# Apply metric scaling and alignment
output = self._apply_metric_scaling(output, metric_output)
output = self._apply_depth_alignment(output, metric_output)
output = self._handle_sky_regions(output, metric_output)
return output
def _apply_metric_scaling(
self, output: Dict[str, torch.Tensor], metric_output: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""Apply metric scaling to the metric depth output."""
# Scale metric depth based on camera intrinsics
metric_output.depth = apply_metric_scaling(
metric_output.depth,
output.intrinsics,
)
return output
def _apply_depth_alignment(
self, output: Dict[str, torch.Tensor], metric_output: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""Apply depth alignment using least squares scaling."""
# Compute non-sky mask
non_sky_mask = compute_sky_mask(metric_output.sky, threshold=0.3)
# Ensure we have enough non-sky pixels
assert non_sky_mask.sum() > 10, "Insufficient non-sky pixels for alignment"
# Sample depth confidence for quantile computation
depth_conf_ns = output.depth_conf[non_sky_mask]
depth_conf_sampled = sample_tensor_for_quantile(depth_conf_ns, max_samples=100000)
median_conf = torch.quantile(depth_conf_sampled, 0.5)
# Compute alignment mask
align_mask = compute_alignment_mask(
output.depth_conf, non_sky_mask, output.depth, metric_output.depth, median_conf
)
# Compute scale factor using least squares
valid_depth = output.depth[align_mask]
valid_metric_depth = metric_output.depth[align_mask]
scale_factor = least_squares_scale_scalar(valid_metric_depth, valid_depth)
# Apply scaling to depth and extrinsics
output.depth *= scale_factor
output.extrinsics[:, :, :3, 3] *= scale_factor
output.is_metric = 1
output.scale_factor = scale_factor.item()
return output
def _handle_sky_regions(
self,
output: Dict[str, torch.Tensor],
metric_output: Dict[str, torch.Tensor],
sky_depth_def: float = 200.0,
) -> Dict[str, torch.Tensor]:
"""Handle sky regions by setting them to maximum depth."""
non_sky_mask = compute_sky_mask(metric_output.sky, threshold=0.3)
# Compute maximum depth for non-sky regions
# Use sampling to safely compute quantile on large tensors
non_sky_depth = output.depth[non_sky_mask]
if non_sky_depth.numel() > 100000:
idx = torch.randint(0, non_sky_depth.numel(), (100000,), device=non_sky_depth.device)
sampled_depth = non_sky_depth[idx]
else:
sampled_depth = non_sky_depth
non_sky_max = min(torch.quantile(sampled_depth, 0.99), sky_depth_def)
# Set sky regions to maximum depth and high confidence
output.depth, output.depth_conf = set_sky_regions_to_max_depth(
output.depth, output.depth_conf, non_sky_mask, max_depth=non_sky_max
)
return output

View File

@@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
from typing import List
import torch.nn as nn
from depth_anything_3.model.dinov2.vision_transformer import (
vit_base,
vit_giant2,
vit_large,
vit_small,
)
class DinoV2(nn.Module):
def __init__(
self,
name: str,
out_layers: List[int],
alt_start: int = -1,
qknorm_start: int = -1,
rope_start: int = -1,
cat_token: bool = True,
**kwargs,
):
super().__init__()
assert name in {"vits", "vitb", "vitl", "vitg"}
self.name = name
self.out_layers = out_layers
self.alt_start = alt_start
self.qknorm_start = qknorm_start
self.rope_start = rope_start
self.cat_token = cat_token
encoder_map = {
"vits": vit_small,
"vitb": vit_base,
"vitl": vit_large,
"vitg": vit_giant2,
}
encoder_fn = encoder_map[self.name]
ffn_layer = "swiglufused" if self.name == "vitg" else "mlp"
self.pretrained = encoder_fn(
img_size=518,
patch_size=14,
ffn_layer=ffn_layer,
alt_start=alt_start,
qknorm_start=qknorm_start,
rope_start=rope_start,
cat_token=cat_token,
)
def forward(self, x, **kwargs):
return self.pretrained.get_intermediate_layers(
x,
self.out_layers,
**kwargs,
)

View File

@@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# from .attention import MemEffAttention
from .block import Block
from .layer_scale import LayerScale
from .mlp import Mlp
from .patch_embed import PatchEmbed
from .rope import PositionGetter, RotaryPositionEmbedding2D
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
__all__ = [
Mlp,
PatchEmbed,
SwiGLUFFN,
SwiGLUFFNFused,
Block,
# MemEffAttention,
LayerScale,
PositionGetter,
RotaryPositionEmbedding2D,
]

View File

@@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import logging
import torch.nn.functional as F
from torch import Tensor, nn
logger = logging.getLogger("dinov2")
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
qk_norm: bool = False,
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
rope=None,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.fused_attn = fused_attn
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.rope = rope
def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor:
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2]
q, k = self.q_norm(q), self.k_norm(k)
if self.rope is not None and pos is not None:
q = self.rope(q, pos)
k = self.rope(k, pos)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
attn_mask=(
(attn_mask)[:, None].repeat(1, self.num_heads, 1, 1)
if attn_mask is not None
else None
),
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _forward(self, x: Tensor) -> Tensor:
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x

View File

@@ -0,0 +1,143 @@
# flake8: noqa: F821
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
import logging
from typing import Callable, Optional
import torch
from torch import Tensor, nn
from .attention import Attention
from .drop_path import DropPath
from .layer_scale import LayerScale
from .mlp import Mlp
logger = logging.getLogger("dinov2")
XFORMERS_AVAILABLE = True
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
rope=None,
ln_eps: float = 1e-6,
) -> None:
super().__init__()
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
self.norm1 = norm_layer(dim, eps=ln_eps)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
qk_norm=qk_norm,
rope=rope,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim, eps=ln_eps)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor:
def attn_residual_func(x: Tensor, pos=None, attn_mask=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.1:
# the overhead is compensated only for a drop path rate larger than 0.1
x = drop_add_residual_stochastic_depth(
x,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
pos=pos,
)
x = drop_add_residual_stochastic_depth(
x,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
elif self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x, pos=pos, attn_mask=attn_mask))
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
else:
x = x + attn_residual_func(x, pos=pos, attn_mask=attn_mask)
x = x + ffn_residual_func(x)
return x
def drop_add_residual_stochastic_depth(
x: Tensor,
residual_func: Callable[[Tensor], Tensor],
sample_drop_ratio: float = 0.0,
pos: Optional[Tensor] = None,
) -> Tensor:
# 1) extract subset using permutation
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
x_subset = x[brange]
# 2) apply residual_func to get residual
if pos is not None:
# if necessary, apply rope to the subset
pos = pos[brange]
residual = residual_func(x_subset, pos=pos)
else:
residual = residual_func(x_subset)
x_flat = x.flatten(1)
residual = residual.flatten(1)
residual_scale_factor = b / sample_subset_size
# 3) add the residual
x_plus_residual = torch.index_add(
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
)
return x_plus_residual.view_as(x)
def get_branges_scales(x, sample_drop_ratio=0.0):
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
residual_scale_factor = b / sample_subset_size
return brange, residual_scale_factor

View File

@@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
from torch import nn
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0:
random_tensor.div_(keep_prob)
output = x * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)

View File

@@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 # noqa: E501
from typing import Union
import torch
from torch import Tensor, nn
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: Union[float, Tensor] = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.dim = dim
self.inplace = inplace
self.init_values = init_values
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
def extra_repr(self) -> str:
return f"{self.dim}, init_values={self.init_values}, inplace={self.inplace}"

View File

@@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
from typing import Callable, Optional
from torch import Tensor, nn
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x

View File

@@ -0,0 +1,94 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
from typing import Callable, Optional, Tuple, Union
import torch.nn as nn
from torch import Tensor
def make_2tuple(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
assert isinstance(x, int)
return (x, x)
class PatchEmbed(nn.Module):
"""
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
Args:
img_size: Image size.
patch_size: Patch token size.
in_chans: Number of input image channels.
embed_dim: Number of linear projection output channels.
norm_layer: Normalization layer.
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
image_HW = make_2tuple(img_size)
patch_HW = make_2tuple(patch_size)
patch_grid_size = (
image_HW[0] // patch_HW[0],
image_HW[1] // patch_HW[1],
)
self.img_size = image_HW
self.patch_size = patch_HW
self.patches_resolution = patch_grid_size
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.flatten_embedding = flatten_embedding
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
_, _, H, W = x.shape
patch_H, patch_W = self.patch_size
assert (
H % patch_H == 0
), f"Input image height {H} is not a multiple of patch height {patch_H}"
assert (
W % patch_W == 0
), f"Input image width {W} is not a multiple of patch width: {patch_W}"
x = self.proj(x) # B C H W
H, W = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2) # B HW C
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
return x
def flops(self) -> float:
Ho, Wo = self.patches_resolution
flops = (
Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
)
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops

View File

@@ -0,0 +1,200 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# Implementation of 2D Rotary Position Embeddings (RoPE).
# This module provides a clean implementation of 2D Rotary Position Embeddings,
# which extends the original RoPE concept to handle 2D spatial positions.
# Inspired by:
# https://github.com/meta-llama/codellama/blob/main/llama/model.py
# https://github.com/naver-ai/rope-vit
from typing import Dict, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class PositionGetter:
"""Generates and caches 2D spatial positions for patches in a grid.
This class efficiently manages the generation of spatial coordinates for patches
in a 2D grid, caching results to avoid redundant computations.
Attributes:
position_cache: Dictionary storing precomputed position tensors for different
grid dimensions.
"""
def __init__(self):
"""Initializes the position generator with an empty cache."""
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
def __call__(
self, batch_size: int, height: int, width: int, device: torch.device
) -> torch.Tensor:
"""Generates spatial positions for a batch of patches.
Args:
batch_size: Number of samples in the batch.
height: Height of the grid in patches.
width: Width of the grid in patches.
device: Target device for the position tensor.
Returns:
Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
for each position in the grid, repeated for each batch item.
"""
if (height, width) not in self.position_cache:
y_coords = torch.arange(height, device=device)
x_coords = torch.arange(width, device=device)
positions = torch.cartesian_prod(y_coords, x_coords)
self.position_cache[height, width] = positions
cached_positions = self.position_cache[height, width]
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
class RotaryPositionEmbedding2D(nn.Module):
"""2D Rotary Position Embedding implementation.
This module applies rotary position embeddings to input tokens based on their
2D spatial positions. It handles the position-dependent rotation of features
separately for vertical and horizontal dimensions.
Args:
frequency: Base frequency for the position embeddings. Default: 100.0
scaling_factor: Scaling factor for frequency computation. Default: 1.0
Attributes:
base_frequency: Base frequency for computing position embeddings.
scaling_factor: Factor to scale the computed frequencies.
frequency_cache: Cache for storing precomputed frequency components.
"""
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
"""Initializes the 2D RoPE module."""
super().__init__()
self.base_frequency = frequency
self.scaling_factor = scaling_factor
self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
def _compute_frequency_components(
self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes frequency components for rotary embeddings.
Args:
dim: Feature dimension (must be even).
seq_len: Maximum sequence length.
device: Target device for computations.
dtype: Data type for the computed tensors.
Returns:
Tuple of (cosine, sine) tensors for frequency components.
"""
cache_key = (dim, seq_len, device, dtype)
if cache_key not in self.frequency_cache:
# Compute frequency bands
exponents = torch.arange(0, dim, 2, device=device).float() / dim
inv_freq = 1.0 / (self.base_frequency**exponents)
# Generate position-dependent frequencies
positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
angles = torch.einsum("i,j->ij", positions, inv_freq)
# Compute and cache frequency components
angles = angles.to(dtype)
angles = torch.cat((angles, angles), dim=-1)
cos_components = angles.cos().to(dtype)
sin_components = angles.sin().to(dtype)
self.frequency_cache[cache_key] = (cos_components, sin_components)
return self.frequency_cache[cache_key]
@staticmethod
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
"""Performs feature rotation by splitting and recombining feature dimensions.
Args:
x: Input tensor to rotate.
Returns:
Rotated feature tensor.
"""
feature_dim = x.shape[-1]
x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_1d_rope(
self,
tokens: torch.Tensor,
positions: torch.Tensor,
cos_comp: torch.Tensor,
sin_comp: torch.Tensor,
) -> torch.Tensor:
"""Applies 1D rotary position embeddings along one dimension.
Args:
tokens: Input token features.
positions: Position indices.
cos_comp: Cosine components for rotation.
sin_comp: Sine components for rotation.
Returns:
Tokens with applied rotary position embeddings.
"""
# Embed positions with frequency components
cos = F.embedding(positions, cos_comp)[:, None, :, :]
sin = F.embedding(positions, sin_comp)[:, None, :, :]
# Apply rotation
return (tokens * cos) + (self._rotate_features(tokens) * sin)
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
"""Applies 2D rotary position embeddings to input tokens.
Args:
tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
The feature dimension (dim) must be divisible by 4.
positions: Position tensor of shape (batch_size, n_tokens, 2) containing
the y and x coordinates for each token.
Returns:
Tensor of same shape as input with applied 2D rotary position embeddings.
Raises:
AssertionError: If input dimensions are invalid or positions are malformed.
"""
# Validate inputs
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
assert (
positions.ndim == 3 and positions.shape[-1] == 2
), "Positions must have shape (batch_size, n_tokens, 2)"
# Compute feature dimension for each spatial direction
feature_dim = tokens.size(-1) // 2
# Get frequency components
max_position = int(positions.max()) + 1
cos_comp, sin_comp = self._compute_frequency_components(
feature_dim, max_position, tokens.device, tokens.dtype
)
# Split features for vertical and horizontal processing
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
# Apply RoPE separately for each dimension
vertical_features = self._apply_1d_rope(
vertical_features, positions[..., 0], cos_comp, sin_comp
)
horizontal_features = self._apply_1d_rope(
horizontal_features, positions[..., 1], cos_comp, sin_comp
)
# Combine processed features
return torch.cat((vertical_features, horizontal_features), dim=-1)

View File

@@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional
import torch.nn.functional as F
from torch import Tensor, nn
class SwiGLUFFN(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
def forward(self, x: Tensor) -> Tensor:
x12 = self.w12(x)
x1, x2 = x12.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
return self.w3(hidden)
try:
from xformers.ops import SwiGLU
XFORMERS_AVAILABLE = True
except ImportError:
SwiGLU = SwiGLUFFN
XFORMERS_AVAILABLE = False
class SwiGLUFFNFused(SwiGLU):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
out_features = out_features or in_features
hidden_features = hidden_features or in_features
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
super().__init__(
in_features=in_features,
hidden_features=hidden_features,
out_features=out_features,
bias=bias,
)

View File

@@ -0,0 +1,456 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import math
from typing import Callable, List, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
from einops import rearrange
from depth_anything_3.utils.logger import logger
from .layers import LayerScale # noqa: F401
from .layers import Mlp # noqa: F401
from .layers import ( # noqa: F401
Block,
PatchEmbed,
PositionGetter,
RotaryPositionEmbedding2D,
SwiGLUFFNFused,
)
from depth_anything_3.model.reference_view_selector import (
RefViewStrategy,
select_reference_view,
reorder_by_reference,
restore_original_order,
)
from depth_anything_3.utils.constants import THRESH_FOR_REF_SELECTION
# logger = logging.getLogger("dinov2")
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=float)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def named_apply(
fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = ".".join((name, child_name)) if name else child_name
named_apply(
fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True
)
if depth_first and include_root:
fn(module=module, name=name)
return module
class BlockChunk(nn.ModuleList):
def forward(self, x):
for b in self:
x = b(x)
return x
class DinoVisionTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
ffn_bias=True,
proj_bias=True,
drop_path_rate=0.0,
drop_path_uniform=False,
init_values=1.0, # for layerscale: None or 0 => no layerscale
embed_layer=PatchEmbed,
act_layer=nn.GELU,
block_fn=Block,
ffn_layer="mlp",
block_chunks=1,
num_register_tokens=0,
interpolate_antialias=False,
interpolate_offset=0.1,
alt_start=-1,
qknorm_start=-1,
rope_start=-1,
rope_freq=100,
plus_cam_token=False,
cat_token=True,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
proj_bias (bool): enable bias for proj in attn if True
ffn_bias (bool): enable bias for ffn if True
weight_init (str): weight init scheme
init_values (float): layer-scale init values
embed_layer (nn.Module): patch embedding layer
act_layer (nn.Module): MLP activation layer
block_fn (nn.Module): transformer block class
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating
positional embeddings
interpolate_offset: (float) work-around offset to apply when interpolating
positional embeddings
"""
super().__init__()
self.patch_start_idx = 1
norm_layer = nn.LayerNorm
self.num_features = self.embed_dim = (
embed_dim # num_features for consistency with other models
)
self.alt_start = alt_start
self.qknorm_start = qknorm_start
self.rope_start = rope_start
self.cat_token = cat_token
self.num_tokens = 1
self.n_blocks = depth
self.num_heads = num_heads
self.patch_size = patch_size
self.num_register_tokens = num_register_tokens
self.interpolate_antialias = interpolate_antialias
self.interpolate_offset = interpolate_offset
self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if self.alt_start != -1:
self.camera_token = nn.Parameter(torch.randn(1, 2, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
assert num_register_tokens >= 0
self.register_tokens = (
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
if num_register_tokens
else None
)
if drop_path_uniform is True:
dpr = [drop_path_rate] * depth
else:
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
if ffn_layer == "mlp":
logger.info("using MLP layer as FFN")
ffn_layer = Mlp
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
logger.info("using SwiGLU layer as FFN")
ffn_layer = SwiGLUFFNFused
elif ffn_layer == "identity":
logger.info("using Identity layer as FFN")
def f(*args, **kwargs):
return nn.Identity()
ffn_layer = f
else:
raise NotImplementedError
if self.rope_start != -1:
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
self.position_getter = PositionGetter() if self.rope is not None else None
else:
self.rope = None
blocks_list = [
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
ffn_layer=ffn_layer,
init_values=init_values,
qk_norm=i >= qknorm_start if qknorm_start != -1 else False,
rope=self.rope if i >= rope_start and rope_start != -1 else None,
)
for i in range(depth)
]
self.blocks = nn.ModuleList(blocks_list)
self.norm = norm_layer(embed_dim)
def interpolate_pos_encoding(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
assert N == M * M
kwargs = {}
if self.interpolate_offset:
# Historical kludge: add a small number to avoid floating point error in the
# interpolation, see https://github.com/facebookresearch/dino/issues/8
# Note: still needed for backward-compatibility, the underlying operators are using
# both output size and scale factors
sx = float(w0 + self.interpolate_offset) / M
sy = float(h0 + self.interpolate_offset) / M
kwargs["scale_factor"] = (sx, sy)
else:
# Simply specify an output size instead of a scale factor
kwargs["size"] = (w0, h0)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
mode="bicubic",
antialias=self.interpolate_antialias,
**kwargs,
)
assert (w0, h0) == patch_pos_embed.shape[-2:]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
def prepare_cls_token(self, B, S):
cls_token = self.cls_token.expand(B, S, -1)
cls_token = cls_token.reshape(B * S, -1, self.embed_dim)
return cls_token
def prepare_tokens_with_masks(self, x, masks=None, cls_token=None, **kwargs):
B, S, nc, w, h = x.shape
x = rearrange(x, "b s c h w -> (b s) c h w")
x = self.patch_embed(x)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
cls_token = self.prepare_cls_token(B, S)
x = torch.cat((cls_token, x), dim=1)
x = x + self.interpolate_pos_encoding(x, w, h)
if self.register_tokens is not None:
x = torch.cat(
(
x[:, :1],
self.register_tokens.expand(x.shape[0], -1, -1),
x[:, 1:],
),
dim=1,
)
x = rearrange(x, "(b s) n c -> b s n c", b=B, s=S)
return x
def _prepare_rope(self, B, S, H, W, device):
pos = None
pos_nodiff = None
if self.rope is not None:
pos = self.position_getter(
B * S, H // self.patch_size, W // self.patch_size, device=device
)
pos = rearrange(pos, "(b s) n c -> b s n c", b=B)
pos_nodiff = torch.zeros_like(pos).to(pos.dtype)
if self.patch_start_idx > 0:
pos = pos + 1
pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(device).to(pos.dtype)
pos_special = rearrange(pos_special, "(b s) n c -> b s n c", b=B)
pos = torch.cat([pos_special, pos], dim=2)
pos_nodiff = pos_nodiff + 1
pos_nodiff = torch.cat([pos_special, pos_nodiff], dim=2)
return pos, pos_nodiff
def _get_intermediate_layers_not_chunked(self, x, n=1, export_feat_layers=[], **kwargs):
B, S, _, H, W = x.shape
x = self.prepare_tokens_with_masks(x)
output, total_block_len, aux_output = [], len(self.blocks), []
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
pos, pos_nodiff = self._prepare_rope(B, S, H, W, x.device)
for i, blk in enumerate(self.blocks):
if i < self.rope_start or self.rope is None:
g_pos, l_pos = None, None
else:
g_pos = pos_nodiff
l_pos = pos
if self.alt_start != -1 and (i == self.alt_start - 1) and x.shape[1] >= THRESH_FOR_REF_SELECTION:
# Select reference view using configured strategy
strategy = kwargs.get("ref_view_strategy", "saddle_balanced")
logger.info(f"Selecting reference view using strategy: {strategy}")
b_idx = select_reference_view(x, strategy=strategy)
# Reorder views to place reference view first
x = reorder_by_reference(x, b_idx)
local_x = reorder_by_reference(local_x, b_idx)
if self.alt_start != -1 and i == self.alt_start:
if kwargs.get("cam_token", None) is not None:
logger.info("Using camera conditions provided by the user")
cam_token = kwargs.get("cam_token")
else:
ref_token = self.camera_token[:, :1].expand(B, -1, -1)
src_token = self.camera_token[:, 1:].expand(B, S - 1, -1)
cam_token = torch.cat([ref_token, src_token], dim=1)
x[:, :, 0] = cam_token
if self.alt_start != -1 and i >= self.alt_start and i % 2 == 1:
x = self.process_attention(
x, blk, "global", pos=g_pos, attn_mask=kwargs.get("attn_mask", None)
)
else:
x = self.process_attention(x, blk, "local", pos=l_pos)
local_x = x
if i in blocks_to_take:
out_x = torch.cat([local_x, x], dim=-1) if self.cat_token else x
# Restore original view order if reordering was applied
if x.shape[1] >= THRESH_FOR_REF_SELECTION and self.alt_start != -1 and 'b_idx' in locals():
out_x = restore_original_order(out_x, b_idx)
output.append((out_x[:, :, 0], out_x))
if i in export_feat_layers:
aux_output.append(x)
return output, aux_output
def process_attention(self, x, block, attn_type="global", pos=None, attn_mask=None):
b, s, n = x.shape[:3]
if attn_type == "local":
x = rearrange(x, "b s n c -> (b s) n c")
if pos is not None:
pos = rearrange(pos, "b s n c -> (b s) n c")
elif attn_type == "global":
x = rearrange(x, "b s n c -> b (s n) c")
if pos is not None:
pos = rearrange(pos, "b s n c -> b (s n) c")
else:
raise ValueError(f"Invalid attention type: {attn_type}")
x = block(x, pos=pos, attn_mask=attn_mask)
if attn_type == "local":
x = rearrange(x, "(b s) n c -> b s n c", b=b, s=s)
elif attn_type == "global":
x = rearrange(x, "b (s n) c -> b s n c", b=b, s=s)
return x
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1, # Layers or n last layers to take
export_feat_layers: List[int] = [],
**kwargs,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
outputs, aux_outputs = self._get_intermediate_layers_not_chunked(
x, n, export_feat_layers=export_feat_layers, **kwargs
)
camera_tokens = [out[0] for out in outputs]
if outputs[0][1].shape[-1] == self.embed_dim:
outputs = [self.norm(out[1]) for out in outputs]
elif outputs[0][1].shape[-1] == (self.embed_dim * 2):
outputs = [
torch.cat(
[out[1][..., : self.embed_dim], self.norm(out[1][..., self.embed_dim :])],
dim=-1,
)
for out in outputs
]
else:
raise ValueError(f"Invalid output shape: {outputs[0][1].shape}")
aux_outputs = [self.norm(out) for out in aux_outputs]
outputs = [out[..., 1 + self.num_register_tokens :, :] for out in outputs]
aux_outputs = [out[..., 1 + self.num_register_tokens :, :] for out in aux_outputs]
return tuple(zip(outputs, camera_tokens)), aux_outputs
def vit_small(patch_size=16, num_register_tokens=0, depth=12, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=384,
depth=depth,
num_heads=6,
mlp_ratio=4,
# block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_base(patch_size=16, num_register_tokens=0, depth=12, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=768,
depth=depth,
num_heads=12,
mlp_ratio=4,
# block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_large(patch_size=16, num_register_tokens=0, depth=24, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1024,
depth=depth,
num_heads=16,
mlp_ratio=4,
# block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_giant2(patch_size=16, num_register_tokens=0, depth=40, **kwargs):
"""
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
"""
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1536,
depth=depth,
num_heads=24,
mlp_ratio=4,
num_register_tokens=num_register_tokens,
**kwargs,
)
return model

View File

@@ -0,0 +1,458 @@
# flake8: noqa E501
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict as TyDict
from typing import List, Sequence, Tuple
import torch
import torch.nn as nn
from addict import Dict
from einops import rearrange
from depth_anything_3.model.utils.head_utils import (
Permute,
create_uv_grid,
custom_interpolate,
position_grid_to_embed,
)
class DPT(nn.Module):
"""
DPT for dense prediction (main head + optional sky head, sky always 1 channel).
Returns:
- Main head:
* If output_dim>1: { head_name, f"{head_name}_conf" }
* If output_dim==1: { head_name }
- Sky head (if use_sky_head=True): { sky_name } # [B, S, 1, H/down_ratio, W/down_ratio]
"""
def __init__(
self,
dim_in: int,
*,
patch_size: int = 14,
output_dim: int = 1,
activation: str = "exp",
conf_activation: str = "expp1",
features: int = 256,
out_channels: Sequence[int] = (256, 512, 1024, 1024),
pos_embed: bool = False,
down_ratio: int = 1,
head_name: str = "depth",
# ---- sky head (fixed 1 channel) ----
use_sky_head: bool = True,
sky_name: str = "sky",
sky_activation: str = "relu", # 'sigmoid' / 'relu' / 'linear'
use_ln_for_heads: bool = False, # If needed, apply LayerNorm on intermediate features of both heads
norm_type: str = "idt", # use to match legacy GS-DPT head, "idt" / "layer"
fusion_block_inplace: bool = False,
) -> None:
super().__init__()
# -------------------- configuration --------------------
self.patch_size = patch_size
self.activation = activation
self.conf_activation = conf_activation
self.pos_embed = pos_embed
self.down_ratio = down_ratio
# Names
self.head_main = head_name
self.sky_name = sky_name
# Main head: output dimension and confidence switch
self.out_dim = output_dim
self.has_conf = output_dim > 1
# Sky head parameters (always 1 channel)
self.use_sky_head = use_sky_head
self.sky_activation = sky_activation
# Fixed 4 intermediate outputs
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
# -------------------- token pre-norm + per-stage projection --------------------
if norm_type == "layer":
self.norm = nn.LayerNorm(dim_in)
elif norm_type == "idt":
self.norm = nn.Identity()
else:
raise Exception(f"Unknown norm_type {norm_type}, should be 'layer' or 'idt'.")
self.projects = nn.ModuleList(
[nn.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
)
# -------------------- Spatial re-size (align to common scale before fusion) --------------------
# Design consistent with original: relative to patch grid (x4, x2, x1, /2)
self.resize_layers = nn.ModuleList(
[
nn.ConvTranspose2d(
out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0
),
nn.ConvTranspose2d(
out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0
),
nn.Identity(),
nn.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1),
]
)
# -------------------- scratch: stage adapters + main fusion chain --------------------
self.scratch = _make_scratch(list(out_channels), features, expand=False)
# Main fusion chain
self.scratch.refinenet1 = _make_fusion_block(features, inplace=fusion_block_inplace)
self.scratch.refinenet2 = _make_fusion_block(features, inplace=fusion_block_inplace)
self.scratch.refinenet3 = _make_fusion_block(features, inplace=fusion_block_inplace)
self.scratch.refinenet4 = _make_fusion_block(
features, has_residual=False, inplace=fusion_block_inplace
)
# Heads (shared neck1; then split into two heads)
head_features_1 = features
head_features_2 = 32
self.scratch.output_conv1 = nn.Conv2d(
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
)
ln_seq = (
[Permute((0, 2, 3, 1)), nn.LayerNorm(head_features_2), Permute((0, 3, 1, 2))]
if use_ln_for_heads
else []
)
# Main head
self.scratch.output_conv2 = nn.Sequential(
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
*ln_seq,
nn.ReLU(inplace=True),
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
)
# Sky head (fixed 1 channel)
if self.use_sky_head:
self.scratch.sky_output_conv2 = nn.Sequential(
nn.Conv2d(
head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1
),
*ln_seq,
nn.ReLU(inplace=True),
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
)
# -------------------------------------------------------------------------
# Public forward (supports frame chunking to save memory)
# -------------------------------------------------------------------------
def forward(
self,
feats: List[torch.Tensor],
H: int,
W: int,
patch_start_idx: int,
chunk_size: int = 8,
**kwargs,
) -> Dict:
"""
Args:
feats: List of 4 entries, each entry is a tensor like [B, S, T, C] (or the 0th element of tuple/list is that tensor).
H, W: Original image dimensions
patch_start_idx: Starting index of patch tokens in sequence (for cropping non-patch tokens)
chunk_size: Chunk size along time dimension S
Returns:
Dict[str, Tensor]
"""
B, S, N, C = feats[0][0].shape
feats = [feat[0].reshape(B * S, N, C) for feat in feats]
# update image info, used by the GS-DPT head
extra_kwargs = {}
if "images" in kwargs:
extra_kwargs.update({"images": rearrange(kwargs["images"], "B S ... -> (B S) ...")})
if chunk_size is None or chunk_size >= S:
out_dict = self._forward_impl(feats, H, W, patch_start_idx, **extra_kwargs)
out_dict = {k: v.view(B, S, *v.shape[1:]) for k, v in out_dict.items()}
return Dict(out_dict)
out_dicts: List[TyDict[str, torch.Tensor]] = []
for s0 in range(0, S, chunk_size):
s1 = min(s0 + chunk_size, S)
kw = {}
if "images" in extra_kwargs:
kw.update({"images": extra_kwargs["images"][s0:s1]})
out_dicts.append(
self._forward_impl([f[s0:s1] for f in feats], H, W, patch_start_idx, **kw)
)
out_dict = {k: torch.cat([od[k] for od in out_dicts], dim=0) for k in out_dicts[0].keys()}
out_dict = {k: v.view(B, S, *v.shape[1:]) for k, v in out_dict.items()}
return Dict(out_dict)
# -------------------------------------------------------------------------
# Internal forward (single chunk)
# -------------------------------------------------------------------------
def _forward_impl(
self,
feats: List[torch.Tensor],
H: int,
W: int,
patch_start_idx: int,
) -> TyDict[str, torch.Tensor]:
B, _, C = feats[0].shape
ph, pw = H // self.patch_size, W // self.patch_size
resized_feats = []
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
x = feats[take_idx][:, patch_start_idx:] # [B*S, N_patch, C]
x = self.norm(x)
# permute -> contiguous before reshape to keep conv input contiguous
x = x.permute(0, 2, 1).contiguous().reshape(B, C, ph, pw) # [B*S, C, ph, pw]
x = self.projects[stage_idx](x)
if self.pos_embed:
x = self._add_pos_embed(x, W, H)
x = self.resize_layers[stage_idx](x) # Align scale
resized_feats.append(x)
# 2) Fusion pyramid (main branch only)
fused = self._fuse(resized_feats)
# 3) Upsample to target resolution, optionally add position encoding again
h_out = int(ph * self.patch_size / self.down_ratio)
w_out = int(pw * self.patch_size / self.down_ratio)
fused = self.scratch.output_conv1(fused)
fused = custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True)
if self.pos_embed:
fused = self._add_pos_embed(fused, W, H)
# 4) Shared neck1
feat = fused
# 5) Main head: logits -> activation
main_logits = self.scratch.output_conv2(feat)
outs: TyDict[str, torch.Tensor] = {}
if self.has_conf:
fmap = main_logits.permute(0, 2, 3, 1)
pred = self._apply_activation_single(fmap[..., :-1], self.activation)
conf = self._apply_activation_single(fmap[..., -1], self.conf_activation)
outs[self.head_main] = pred.squeeze(1)
outs[f"{self.head_main}_conf"] = conf.squeeze(1)
else:
outs[self.head_main] = self._apply_activation_single(
main_logits, self.activation
).squeeze(1)
# 6) Sky head (fixed 1 channel)
if self.use_sky_head:
sky_logits = self.scratch.sky_output_conv2(feat)
outs[self.sky_name] = self._apply_sky_activation(sky_logits).squeeze(1)
return outs
# -------------------------------------------------------------------------
# Subroutines
# -------------------------------------------------------------------------
def _fuse(self, feats: List[torch.Tensor]) -> torch.Tensor:
"""
4-layer top-down fusion, returns finest scale features (after fusion, before neck1).
"""
l1, l2, l3, l4 = feats
l1_rn = self.scratch.layer1_rn(l1)
l2_rn = self.scratch.layer2_rn(l2)
l3_rn = self.scratch.layer3_rn(l3)
l4_rn = self.scratch.layer4_rn(l4)
# 4 -> 3 -> 2 -> 1
out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:])
out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:])
out = self.scratch.refinenet1(out, l1_rn)
return out
def _apply_activation_single(
self, x: torch.Tensor, activation: str = "linear"
) -> torch.Tensor:
"""
Apply activation to single channel output, maintaining semantic consistency with value branch in multi-channel case.
Supports: exp / relu / sigmoid / softplus / tanh / linear / expp1
"""
act = activation.lower() if isinstance(activation, str) else activation
if act == "exp":
return torch.exp(x)
if act == "expp1":
return torch.exp(x) + 1
if act == "expm1":
return torch.expm1(x)
if act == "relu":
return torch.relu(x)
if act == "sigmoid":
return torch.sigmoid(x)
if act == "softplus":
return torch.nn.functional.softplus(x)
if act == "tanh":
return torch.tanh(x)
# Default linear
return x
def _apply_sky_activation(self, x: torch.Tensor) -> torch.Tensor:
"""
Sky head activation (fixed 1 channel):
* 'sigmoid' -> Sigmoid probability map
* 'relu' -> ReLU positive domain output
* 'linear' -> Original value (logits)
"""
act = (
self.sky_activation.lower()
if isinstance(self.sky_activation, str)
else self.sky_activation
)
if act == "sigmoid":
return torch.sigmoid(x)
if act == "relu":
return torch.relu(x)
# 'linear'
return x
def _add_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
"""Simple UV position encoding directly added to feature map."""
pw, ph = x.shape[-1], x.shape[-2]
pe = create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
pe = position_grid_to_embed(pe, x.shape[1]) * ratio
pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
return x + pe
# -----------------------------------------------------------------------------
# Building blocks (preserved, consistent with original)
# -----------------------------------------------------------------------------
def _make_fusion_block(
features: int,
size: Tuple[int, int] = None,
has_residual: bool = True,
groups: int = 1,
inplace: bool = False,
) -> nn.Module:
return FeatureFusionBlock(
features=features,
activation=nn.ReLU(inplace=inplace),
deconv=False,
bn=False,
expand=False,
align_corners=True,
size=size,
has_residual=has_residual,
groups=groups,
)
def _make_scratch(
in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False
) -> nn.Module:
scratch = nn.Module()
# Optional expansion by stage
c1 = out_shape
c2 = out_shape * (2 if expand else 1)
c3 = out_shape * (4 if expand else 1)
c4 = out_shape * (8 if expand else 1)
scratch.layer1_rn = nn.Conv2d(in_shape[0], c1, 3, 1, 1, bias=False, groups=groups)
scratch.layer2_rn = nn.Conv2d(in_shape[1], c2, 3, 1, 1, bias=False, groups=groups)
scratch.layer3_rn = nn.Conv2d(in_shape[2], c3, 3, 1, 1, bias=False, groups=groups)
scratch.layer4_rn = nn.Conv2d(in_shape[3], c4, 3, 1, 1, bias=False, groups=groups)
return scratch
class ResidualConvUnit(nn.Module):
"""Lightweight residual convolution block for fusion"""
def __init__(self, features: int, activation: nn.Module, bn: bool, groups: int = 1) -> None:
super().__init__()
self.bn = bn
self.groups = groups
self.conv1 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups)
self.conv2 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups)
self.norm1 = None
self.norm2 = None
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override]
out = self.activation(x)
out = self.conv1(out)
if self.norm1 is not None:
out = self.norm1(out)
out = self.activation(out)
out = self.conv2(out)
if self.norm2 is not None:
out = self.norm2(out)
return self.skip_add.add(out, x)
class FeatureFusionBlock(nn.Module):
"""Top-down fusion block: (optional) residual merge + upsampling + 1x1 contraction"""
def __init__(
self,
features: int,
activation: nn.Module,
deconv: bool = False,
bn: bool = False,
expand: bool = False,
align_corners: bool = True,
size: Tuple[int, int] = None,
has_residual: bool = True,
groups: int = 1,
) -> None:
super().__init__()
self.align_corners = align_corners
self.size = size
self.has_residual = has_residual
self.resConfUnit1 = (
ResidualConvUnit(features, activation, bn, groups=groups) if has_residual else None
)
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=groups)
out_features = (features // 2) if expand else features
self.out_conv = nn.Conv2d(features, out_features, 1, 1, 0, bias=True, groups=groups)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, *xs: torch.Tensor, size: Tuple[int, int] = None) -> torch.Tensor: # type: ignore[override]
"""
xs:
- xs[0]: Top branch input
- xs[1]: Lateral input (can do residual addition with top branch)
"""
y = xs[0]
if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None:
y = self.skip_add.add(y, self.resConfUnit1(xs[1]))
y = self.resConfUnit2(y)
# Upsampling
if (size is None) and (self.size is None):
up_kwargs = {"scale_factor": 2}
elif size is None:
up_kwargs = {"size": self.size}
else:
up_kwargs = {"size": size}
y = custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners)
y = self.out_conv(y)
return y

View File

@@ -0,0 +1,488 @@
# flake8: noqa E501
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Sequence, Tuple
import torch
import torch.nn as nn
from addict import Dict
from depth_anything_3.model.dpt import _make_fusion_block, _make_scratch
from depth_anything_3.model.utils.head_utils import (
Permute,
create_uv_grid,
custom_interpolate,
position_grid_to_embed,
)
class DualDPT(nn.Module):
"""
Dual-head DPT for dense prediction with an always-on auxiliary head.
Architectural notes:
- Sky/object branches are removed.
- `intermediate_layer_idx` is fixed to (0, 1, 2, 3).
- Auxiliary head has its **own** fusion blocks (no fusion_inplace / no sharing).
- Auxiliary head is internally multi-level; **only the final level** is returned.
- Returns a **dict** with keys from `head_names`, e.g.:
{ main_name, f"{main_name}_conf", aux_name, f"{aux_name}_conf" }
- `feature_only` is fixed to False.
"""
def __init__(
self,
dim_in: int,
*,
patch_size: int = 14,
output_dim: int = 2,
activation: str = "exp",
conf_activation: str = "expp1",
features: int = 256,
out_channels: Sequence[int] = (256, 512, 1024, 1024),
pos_embed: bool = True,
down_ratio: int = 1,
aux_pyramid_levels: int = 4,
aux_out1_conv_num: int = 5,
head_names: Tuple[str, str] = ("depth", "ray"),
) -> None:
super().__init__()
# -------------------- configuration --------------------
self.patch_size = patch_size
self.activation = activation
self.conf_activation = conf_activation
self.pos_embed = pos_embed
self.down_ratio = down_ratio
self.aux_levels = aux_pyramid_levels
self.aux_out1_conv_num = aux_out1_conv_num
# names ONLY come from config (no hard-coded strings elsewhere)
self.head_main, self.head_aux = head_names
# Always expect 4 scales; enforce intermediate idx = (0, 1, 2, 3)
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
# -------------------- token pre-norm + per-stage projection --------------------
self.norm = nn.LayerNorm(dim_in)
self.projects = nn.ModuleList(
[nn.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
)
# -------------------- spatial re-sizers (align to common scale before fusion) --------------------
# design: stage strides (x4, x2, x1, /2) relative to patch grid to align to a common pivot scale
self.resize_layers = nn.ModuleList(
[
nn.ConvTranspose2d(
out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0
),
nn.ConvTranspose2d(
out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0
),
nn.Identity(),
nn.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1),
]
)
# -------------------- scratch: stage adapters + fusion (main & aux are separate) --------------------
self.scratch = _make_scratch(list(out_channels), features, expand=False)
# Main fusion chain (independent)
self.scratch.refinenet1 = _make_fusion_block(features)
self.scratch.refinenet2 = _make_fusion_block(features)
self.scratch.refinenet3 = _make_fusion_block(features)
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
# Primary head neck + head (independent)
head_features_1 = features
head_features_2 = 32
self.scratch.output_conv1 = nn.Conv2d(
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
)
self.scratch.output_conv2 = nn.Sequential(
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
)
# Auxiliary fusion chain (completely separate; no sharing, i.e., "fusion_inplace=False")
self.scratch.refinenet1_aux = _make_fusion_block(features)
self.scratch.refinenet2_aux = _make_fusion_block(features)
self.scratch.refinenet3_aux = _make_fusion_block(features)
self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False)
# Aux pre-head per level (we will only *return final level*)
self.scratch.output_conv1_aux = nn.ModuleList(
[self._make_aux_out1_block(head_features_1) for _ in range(self.aux_levels)]
)
# Aux final projection per level
use_ln = True
ln_seq = (
[Permute((0, 2, 3, 1)), nn.LayerNorm(head_features_2), Permute((0, 3, 1, 2))]
if use_ln
else []
)
self.scratch.output_conv2_aux = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(
head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1
),
*ln_seq,
nn.ReLU(inplace=True),
nn.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0),
)
for _ in range(self.aux_levels)
]
)
# -------------------------------------------------------------------------
# Public forward (supports frame chunking for memory)
# -------------------------------------------------------------------------
def forward(
self,
feats: List[torch.Tensor],
H: int,
W: int,
patch_start_idx: int,
chunk_size: int = 8,
) -> Dict[str, torch.Tensor]:
"""
Args:
aggregated_tokens_list: List of 4 tensors [B, S, T, C] from transformer.
images: [B, S, 3, H, W], in [0, 1].
patch_start_idx: Patch-token start in the token sequence (to drop non-patch tokens).
frames_chunk_size: Optional chunking along S for memory.
Returns:
Dict[str, Tensor] with keys based on `head_names`, e.g.:
self.head_main, f"{self.head_main}_conf",
self.head_aux, f"{self.head_aux}_conf"
Shapes:
main: [B, S, out_dim, H/down_ratio, W/down_ratio]
main_cf: [B, S, 1, H/down_ratio, W/down_ratio]
aux: [B, S, 7, H/down_ratio, W/down_ratio]
aux_cf: [B, S, 1, H/down_ratio, W/down_ratio]
"""
B, S, N, C = feats[0][0].shape
feats = [feat[0].reshape(B * S, N, C) for feat in feats]
if chunk_size is None or chunk_size >= S:
out_dict = self._forward_impl(feats, H, W, patch_start_idx)
out_dict = {k: v.reshape(B, S, *v.shape[1:]) for k, v in out_dict.items()}
return Dict(out_dict)
out_dicts = []
for s0 in range(0, S, chunk_size):
s1 = min(s0 + chunk_size, S)
out_dict = self._forward_impl(
[feat[s0:s1] for feat in feats],
H,
W,
patch_start_idx,
)
out_dicts.append(out_dict)
out_dict = {
k: torch.cat([out_dict[k] for out_dict in out_dicts], dim=0)
for k in out_dicts[0].keys()
}
out_dict = {k: v.view(B, S, *v.shape[1:]) for k, v in out_dict.items()}
return Dict(out_dict)
# -------------------------------------------------------------------------
# Internal forward (single chunk)
# -------------------------------------------------------------------------
def _forward_impl(
self,
feats: List[torch.Tensor],
H: int,
W: int,
patch_start_idx: int,
) -> Dict[str, torch.Tensor]:
B, _, C = feats[0].shape
ph, pw = H // self.patch_size, W // self.patch_size
resized_feats = []
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
x = feats[take_idx][:, patch_start_idx:]
x = self.norm(x)
x = x.permute(0, 2, 1).reshape(B, C, ph, pw) # [B*S, C, ph, pw]
x = self.projects[stage_idx](x)
if self.pos_embed:
x = self._add_pos_embed(x, W, H)
x = self.resize_layers[stage_idx](x) # align scales
resized_feats.append(x)
# 2) Fuse pyramid (main & aux are completely independent)
fused_main, fused_aux_pyr = self._fuse(resized_feats)
# 3) Upsample to target resolution and (optional) add pos-embed again
h_out = int(ph * self.patch_size / self.down_ratio)
w_out = int(pw * self.patch_size / self.down_ratio)
fused_main = custom_interpolate(
fused_main, (h_out, w_out), mode="bilinear", align_corners=True
)
if self.pos_embed:
fused_main = self._add_pos_embed(fused_main, W, H)
# Primary head: conv1 -> conv2 -> activate
# fused_main = self.scratch.output_conv1(fused_main)
main_logits = self.scratch.output_conv2(fused_main)
fmap = main_logits.permute(0, 2, 3, 1)
main_pred = self._apply_activation_single(fmap[..., :-1], self.activation)
main_conf = self._apply_activation_single(fmap[..., -1], self.conf_activation)
# Auxiliary head (multi-level inside) -> only last level returned (after activation)
last_aux = fused_aux_pyr[-1]
if self.pos_embed:
last_aux = self._add_pos_embed(last_aux, W, H)
# neck (per-level pre-conv) then final projection (only for last level)
# last_aux = self.scratch.output_conv1_aux[-1](last_aux)
last_aux_logits = self.scratch.output_conv2_aux[-1](last_aux)
fmap_last = last_aux_logits.permute(0, 2, 3, 1)
aux_pred = self._apply_activation_single(fmap_last[..., :-1], "linear")
aux_conf = self._apply_activation_single(fmap_last[..., -1], self.conf_activation)
return {
self.head_main: main_pred.squeeze(-1),
f"{self.head_main}_conf": main_conf,
self.head_aux: aux_pred,
f"{self.head_aux}_conf": aux_conf,
}
# -------------------------------------------------------------------------
# Subroutines
# -------------------------------------------------------------------------
def _fuse(self, feats: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Feature pyramid fusion.
Returns:
fused_main: Tensor at finest scale (after refinenet1)
aux_pyr: List of aux tensors at each level (pre out_conv1_aux)
"""
l1, l2, l3, l4 = feats
l1_rn = self.scratch.layer1_rn(l1)
l2_rn = self.scratch.layer2_rn(l2)
l3_rn = self.scratch.layer3_rn(l3)
l4_rn = self.scratch.layer4_rn(l4)
# level 4 -> 3
out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
aux_out = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:])
aux_list: List[torch.Tensor] = []
if self.aux_levels >= 4:
aux_list.append(aux_out)
# level 3 -> 2
out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:])
aux_out = self.scratch.refinenet3_aux(aux_out, l3_rn, size=l2_rn.shape[2:])
if self.aux_levels >= 3:
aux_list.append(aux_out)
# level 2 -> 1
out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:])
aux_out = self.scratch.refinenet2_aux(aux_out, l2_rn, size=l1_rn.shape[2:])
if self.aux_levels >= 2:
aux_list.append(aux_out)
# level 1 (final)
out = self.scratch.refinenet1(out, l1_rn)
aux_out = self.scratch.refinenet1_aux(aux_out, l1_rn)
aux_list.append(aux_out)
out = self.scratch.output_conv1(out)
aux_list = [self.scratch.output_conv1_aux[i](aux) for i, aux in enumerate(aux_list)]
return out, aux_list
def _add_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
"""Simple UV positional embedding added to feature maps."""
pw, ph = x.shape[-1], x.shape[-2]
pe = create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
pe = position_grid_to_embed(pe, x.shape[1]) * ratio
pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
return x + pe
def _make_aux_out1_block(self, in_ch: int) -> nn.Sequential:
"""Factory for the aux pre-head stack before the final 1x1 projection."""
if self.aux_out1_conv_num == 5:
return nn.Sequential(
nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1),
nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1),
nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
)
if self.aux_out1_conv_num == 3:
return nn.Sequential(
nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1),
nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
)
if self.aux_out1_conv_num == 1:
return nn.Sequential(nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1))
raise ValueError(f"aux_out1_conv_num {self.aux_out1_conv_num} not supported")
def _apply_activation_single(
self, x: torch.Tensor, activation: str = "linear"
) -> torch.Tensor:
"""
Apply activation to single channel output, maintaining semantic consistency with value branch in multi-channel case.
Supports: exp / relu / sigmoid / softplus / tanh / linear / expp1
"""
act = activation.lower() if isinstance(activation, str) else activation
if act == "exp":
return torch.exp(x)
if act == "expm1":
return torch.expm1(x)
if act == "expp1":
return torch.exp(x) + 1
if act == "relu":
return torch.relu(x)
if act == "sigmoid":
return torch.sigmoid(x)
if act == "softplus":
return torch.nn.functional.softplus(x)
if act == "tanh":
return torch.tanh(x)
# Default linear
return x
# # -----------------------------------------------------------------------------
# # Building blocks (tidy)
# # -----------------------------------------------------------------------------
# def _make_fusion_block(
# features: int,
# size: Tuple[int, int] = None,
# has_residual: bool = True,
# groups: int = 1,
# inplace: bool = False, # <- activation uses inplace=True by default; not related to "fusion_inplace"
# ) -> nn.Module:
# return FeatureFusionBlock(
# features=features,
# activation=nn.ReLU(inplace=inplace),
# deconv=False,
# bn=False,
# expand=False,
# align_corners=True,
# size=size,
# has_residual=has_residual,
# groups=groups,
# )
# def _make_scratch(
# in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False
# ) -> nn.Module:
# scratch = nn.Module()
# # optionally expand widths by stage
# c1 = out_shape
# c2 = out_shape * (2 if expand else 1)
# c3 = out_shape * (4 if expand else 1)
# c4 = out_shape * (8 if expand else 1)
# scratch.layer1_rn = nn.Conv2d(in_shape[0], c1, 3, 1, 1, bias=False, groups=groups)
# scratch.layer2_rn = nn.Conv2d(in_shape[1], c2, 3, 1, 1, bias=False, groups=groups)
# scratch.layer3_rn = nn.Conv2d(in_shape[2], c3, 3, 1, 1, bias=False, groups=groups)
# scratch.layer4_rn = nn.Conv2d(in_shape[3], c4, 3, 1, 1, bias=False, groups=groups)
# return scratch
# class ResidualConvUnit(nn.Module):
# """Lightweight residual conv block used within fusion."""
# def __init__(self, features: int, activation: nn.Module, bn: bool, groups: int = 1) -> None:
# super().__init__()
# self.bn = bn
# self.groups = groups
# self.conv1 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups)
# self.conv2 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups)
# self.norm1 = None
# self.norm2 = None
# self.activation = activation
# self.skip_add = nn.quantized.FloatFunctional()
# def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override]
# out = self.activation(x)
# out = self.conv1(out)
# if self.norm1 is not None:
# out = self.norm1(out)
# out = self.activation(out)
# out = self.conv2(out)
# if self.norm2 is not None:
# out = self.norm2(out)
# return self.skip_add.add(out, x)
# class FeatureFusionBlock(nn.Module):
# """Top-down fusion block: (optional) residual merge + upsample + 1x1 shrink."""
# def __init__(
# self,
# features: int,
# activation: nn.Module,
# deconv: bool = False,
# bn: bool = False,
# expand: bool = False,
# align_corners: bool = True,
# size: Tuple[int, int] = None,
# has_residual: bool = True,
# groups: int = 1,
# ) -> None:
# super().__init__()
# self.align_corners = align_corners
# self.size = size
# self.has_residual = has_residual
# self.resConfUnit1 = (
# ResidualConvUnit(features, activation, bn, groups=groups) if has_residual else None
# )
# self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=groups)
# out_features = (features // 2) if expand else features
# self.out_conv = nn.Conv2d(features, out_features, 1, 1, 0, bias=True, groups=groups)
# self.skip_add = nn.quantized.FloatFunctional()
# def forward(self, *xs: torch.Tensor, size: Tuple[int, int] = None) -> torch.Tensor: # type: ignore[override]
# """
# xs:
# - xs[0]: top input
# - xs[1]: (optional) lateral (to be added with residual)
# """
# y = xs[0]
# if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None:
# y = self.skip_add.add(y, self.resConfUnit1(xs[1]))
# y = self.resConfUnit2(y)
# # upsample
# if (size is None) and (self.size is None):
# up_kwargs = {"scale_factor": 2}
# elif size is None:
# up_kwargs = {"size": self.size}
# else:
# up_kwargs = {"size": size}
# y = custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners)
# y = self.out_conv(y)
# return y

View File

@@ -0,0 +1,200 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from einops import einsum, rearrange, repeat
from torch import nn
from depth_anything_3.model.utils.transform import cam_quat_xyzw_to_world_quat_wxyz
from depth_anything_3.specs import Gaussians
from depth_anything_3.utils.geometry import affine_inverse, get_world_rays, sample_image_grid
from depth_anything_3.utils.pose_align import batch_align_poses_umeyama
from depth_anything_3.utils.sh_helpers import rotate_sh
class GaussianAdapter(nn.Module):
def __init__(
self,
sh_degree: int = 0,
pred_color: bool = False,
pred_offset_depth: bool = False,
pred_offset_xy: bool = True,
gaussian_scale_min: float = 1e-5,
gaussian_scale_max: float = 30.0,
):
super().__init__()
self.sh_degree = sh_degree
self.pred_color = pred_color
self.pred_offset_depth = pred_offset_depth
self.pred_offset_xy = pred_offset_xy
self.gaussian_scale_min = gaussian_scale_min
self.gaussian_scale_max = gaussian_scale_max
# Create a mask for the spherical harmonics coefficients. This ensures that at
# initialization, the coefficients are biased towards having a large DC
# component and small view-dependent components.
if not pred_color:
self.register_buffer(
"sh_mask",
torch.ones((self.d_sh,), dtype=torch.float32),
persistent=False,
)
for degree in range(1, sh_degree + 1):
self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * 0.25**degree
def forward(
self,
extrinsics: torch.Tensor, # "*#batch 4 4"
intrinsics: torch.Tensor, # "*#batch 3 3"
depths: torch.Tensor, # "*#batch"
opacities: torch.Tensor, # "*#batch" | "*#batch _"
raw_gaussians: torch.Tensor, # "*#batch _"
image_shape: tuple[int, int],
eps: float = 1e-8,
gt_extrinsics: Optional[torch.Tensor] = None, # "*#batch 4 4"
**kwargs,
) -> Gaussians:
device = extrinsics.device
dtype = raw_gaussians.dtype
H, W = image_shape
b, v = raw_gaussians.shape[:2]
# get cam2worlds and intr_normed to adapt to 3DGS codebase
cam2worlds = affine_inverse(extrinsics)
intr_normed = intrinsics.clone().detach()
intr_normed[..., 0, :] /= W
intr_normed[..., 1, :] /= H
# 1. compute 3DGS means
# 1.1) offset the predicted depth if needed
if self.pred_offset_depth:
gs_depths = depths + raw_gaussians[..., -1]
raw_gaussians = raw_gaussians[..., :-1]
else:
gs_depths = depths
# 1.2) align predicted poses with GT if needed
if gt_extrinsics is not None and not torch.equal(extrinsics, gt_extrinsics):
try:
_, _, pose_scales = batch_align_poses_umeyama(
gt_extrinsics.detach().float(),
extrinsics.detach().float(),
)
except Exception:
pose_scales = torch.ones_like(extrinsics[:, 0, 0, 0])
pose_scales = torch.clamp(pose_scales, min=1 / 3.0, max=3.0)
cam2worlds[:, :, :3, 3] = cam2worlds[:, :, :3, 3] * rearrange(
pose_scales, "b -> b () ()"
) # [b, i, j]
gs_depths = gs_depths * rearrange(pose_scales, "b -> b () () ()") # [b, v, h, w]
# 1.3) casting xy in image space
xy_ray, _ = sample_image_grid((H, W), device)
xy_ray = xy_ray[None, None, ...].expand(b, v, -1, -1, -1) # b v h w xy
# offset xy if needed
if self.pred_offset_xy:
pixel_size = 1 / torch.tensor((W, H), dtype=xy_ray.dtype, device=device)
offset_xy = raw_gaussians[..., :2]
xy_ray = xy_ray + offset_xy * pixel_size
raw_gaussians = raw_gaussians[..., 2:] # skip the offset_xy
# 1.4) unproject depth + xy to world ray
origins, directions = get_world_rays(
xy_ray,
repeat(cam2worlds, "b v i j -> b v h w i j", h=H, w=W),
repeat(intr_normed, "b v i j -> b v h w i j", h=H, w=W),
)
gs_means_world = origins + directions * gs_depths[..., None]
gs_means_world = rearrange(gs_means_world, "b v h w d -> b (v h w) d")
# 2. compute other GS attributes
scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1)
# 2.1) 3DGS scales
# make the scale invarient to resolution
scale_min = self.gaussian_scale_min
scale_max = self.gaussian_scale_max
scales = scale_min + (scale_max - scale_min) * scales.sigmoid()
pixel_size = 1 / torch.tensor((W, H), dtype=dtype, device=device)
multiplier = self.get_scale_multiplier(intr_normed, pixel_size)
gs_scales = scales * gs_depths[..., None] * multiplier[..., None, None, None]
gs_scales = rearrange(gs_scales, "b v h w d -> b (v h w) d")
# 2.2) 3DGS quaternion (world space)
# due to historical issue, assume quaternion in order xyzw, not wxyz
# Normalize the quaternion features to yield a valid quaternion.
rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps)
# rotate them to world space
cam_quat_xyzw = rearrange(rotations, "b v h w c -> b (v h w) c")
c2w_mat = repeat(
cam2worlds,
"b v i j -> b (v h w) i j",
h=H,
w=W,
)
world_quat_wxyz = cam_quat_xyzw_to_world_quat_wxyz(cam_quat_xyzw, c2w_mat)
gs_rotations_world = world_quat_wxyz # b (v h w) c
# 2.3) 3DGS color / SH coefficient (world space)
sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3)
if not self.pred_color:
sh = sh * self.sh_mask
if self.pred_color or self.sh_degree == 0:
# predict pre-computed color or predict only DC band, no need to transform
gs_sh_world = sh
else:
gs_sh_world = rotate_sh(sh, cam2worlds[:, :, None, None, None, :3, :3])
gs_sh_world = rearrange(gs_sh_world, "b v h w xyz d_sh -> b (v h w) xyz d_sh")
# 2.4) 3DGS opacity
gs_opacities = rearrange(opacities, "b v h w ... -> b (v h w) ...")
return Gaussians(
means=gs_means_world,
harmonics=gs_sh_world,
opacities=gs_opacities,
scales=gs_scales,
rotations=gs_rotations_world,
)
def get_scale_multiplier(
self,
intrinsics: torch.Tensor, # "*#batch 3 3"
pixel_size: torch.Tensor, # "*#batch 2"
multiplier: float = 0.1,
) -> torch.Tensor: # " *batch"
xy_multipliers = multiplier * einsum(
intrinsics[..., :2, :2].float().inverse().to(intrinsics),
pixel_size,
"... i j, j -> ... i",
)
return xy_multipliers.sum(dim=-1)
@property
def d_sh(self) -> int:
return 1 if self.pred_color else (self.sh_degree + 1) ** 2
@property
def d_in(self) -> int:
# provided as reference to the gs_dpt output dim
raw_gs_dim = 0
if self.pred_offset_xy:
raw_gs_dim += 2
raw_gs_dim += 3 # scales
raw_gs_dim += 4 # quaternion
raw_gs_dim += 3 * self.d_sh # color
if self.pred_offset_depth:
raw_gs_dim += 1
return raw_gs_dim

View File

@@ -0,0 +1,133 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict as TyDict
from typing import List, Sequence
import torch
import torch.nn as nn
from depth_anything_3.model.dpt import DPT
from depth_anything_3.model.utils.head_utils import activate_head_gs, custom_interpolate
class GSDPT(DPT):
def __init__(
self,
dim_in: int,
patch_size: int = 14,
output_dim: int = 4,
activation: str = "linear",
conf_activation: str = "sigmoid",
features: int = 256,
out_channels: Sequence[int] = (256, 512, 1024, 1024),
pos_embed: bool = True,
feature_only: bool = False,
down_ratio: int = 1,
conf_dim: int = 1,
norm_type: str = "idt", # use to match legacy GS-DPT head, "idt" / "layer"
fusion_block_inplace: bool = False,
) -> None:
super().__init__(
dim_in=dim_in,
patch_size=patch_size,
output_dim=output_dim,
activation=activation,
conf_activation=conf_activation,
features=features,
out_channels=out_channels,
pos_embed=pos_embed,
down_ratio=down_ratio,
head_name="raw_gs",
use_sky_head=False,
norm_type=norm_type,
fusion_block_inplace=fusion_block_inplace,
)
self.conf_dim = conf_dim
if conf_dim and conf_dim > 1:
assert (
conf_activation == "linear"
), "use linear prediction when using view-dependent opacity"
merger_out_dim = features if feature_only else features // 2
self.images_merger = nn.Sequential(
nn.Conv2d(3, merger_out_dim // 4, 3, 1, 1), # fewer channels first
nn.GELU(),
nn.Conv2d(merger_out_dim // 4, merger_out_dim // 2, 3, 1, 1),
nn.GELU(),
nn.Conv2d(merger_out_dim // 2, merger_out_dim, 3, 1, 1),
nn.GELU(),
)
# -------------------------------------------------------------------------
# Internal forward (single chunk)
# -------------------------------------------------------------------------
def _forward_impl(
self,
feats: List[torch.Tensor],
H: int,
W: int,
patch_start_idx: int,
images: torch.Tensor,
) -> TyDict[str, torch.Tensor]:
B, _, C = feats[0].shape
ph, pw = H // self.patch_size, W // self.patch_size
resized_feats = []
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
x = feats[take_idx][:, patch_start_idx:] # [B*S, N_patch, C]
x = self.norm(x)
x = x.permute(0, 2, 1).reshape(B, C, ph, pw) # [B*S, C, ph, pw]
x = self.projects[stage_idx](x)
if self.pos_embed:
x = self._add_pos_embed(x, W, H)
x = self.resize_layers[stage_idx](x) # Align scale
resized_feats.append(x)
# 2) Fusion pyramid (main branch only)
fused = self._fuse(resized_feats)
fused = self.scratch.output_conv1(fused)
# 3) Upsample to target resolution, optionally add position encoding again
h_out = int(ph * self.patch_size / self.down_ratio)
w_out = int(pw * self.patch_size / self.down_ratio)
fused = custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True)
# inject the image information here
fused = fused + self.images_merger(images)
if self.pos_embed:
fused = self._add_pos_embed(fused, W, H)
# 4) Shared neck1
# feat = self.scratch.output_conv1(fused)
feat = fused
# 5) Main head: logits -> activate_head or single channel activation
main_logits = self.scratch.output_conv2(feat)
outs: TyDict[str, torch.Tensor] = {}
if self.has_conf:
pred, conf = activate_head_gs(
main_logits,
activation=self.activation,
conf_activation=self.conf_activation,
conf_dim=self.conf_dim,
)
outs[self.head_main] = pred.squeeze(1)
outs[f"{self.head_main}_conf"] = conf.squeeze(1)
else:
outs[self.head_main] = self._apply_activation_single(main_logits).squeeze(1)
return outs

View File

@@ -0,0 +1,223 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reference View Selection Strategies
This module provides different strategies for selecting a reference view
from multiple input views in multi-view depth estimation.
"""
import torch
from typing import Literal
RefViewStrategy = Literal["first", "middle", "saddle_balanced", "saddle_sim_range"]
def select_reference_view(
x: torch.Tensor,
strategy: RefViewStrategy = "saddle_balanced",
) -> torch.Tensor:
"""
Select a reference view from multiple views using the specified strategy.
Args:
x: Input tensor of shape (B, S, N, C) where
B = batch size
S = number of views
N = number of tokens
C = channel dimension
strategy: Selection strategy, one of:
- "first": Always select the first view
- "middle": Select the middle view
- "saddle_balanced": Select view with balanced features across multiple metrics
- "saddle_sim_range": Select view with largest similarity range
Returns:
b_idx: Tensor of shape (B,) containing the selected view index for each batch
"""
B, S, N, C = x.shape
# For single view, no reordering needed
if S <= 1:
return torch.zeros(B, dtype=torch.long, device=x.device)
# Simple position-based strategies
if strategy == "first":
return torch.zeros(B, dtype=torch.long, device=x.device)
elif strategy == "middle":
return torch.full((B,), S // 2, dtype=torch.long, device=x.device)
# Feature-based strategies require normalized class tokens
# Extract and normalize class tokens (first token of each view)
img_class_feat = x[:, :, 0] / x[:, :, 0].norm(dim=-1, keepdim=True) # B S C
if strategy == "saddle_balanced":
# Select view with balanced features across multiple metrics
# Compute similarity matrix
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # B S S
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
sim_score = sim_no_diag.sum(dim=-1) / (S - 1) # B S
feat_norm = x[:, :, 0].norm(dim=-1) # B S
feat_var = img_class_feat.var(dim=-1) # B S
# Normalize all metrics to [0, 1]
def normalize_metric(metric):
min_val = metric.min(dim=1, keepdim=True).values
max_val = metric.max(dim=1, keepdim=True).values
return (metric - min_val) / (max_val - min_val + 1e-8)
sim_score_norm = normalize_metric(sim_score)
norm_norm = normalize_metric(feat_norm)
var_norm = normalize_metric(feat_var)
# Select view closest to the median (0.5) across all metrics
balance_score = (
(sim_score_norm - 0.5).abs() +
(norm_norm - 0.5).abs() +
(var_norm - 0.5).abs()
)
b_idx = balance_score.argmin(dim=1)
elif strategy == "saddle_sim_range":
# Select view with largest similarity range (max - min)
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # B S S
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
sim_max = sim_no_diag.max(dim=-1).values # B S
sim_min = sim_no_diag.min(dim=-1).values # B S
sim_range = sim_max - sim_min
b_idx = sim_range.argmax(dim=1)
else:
raise ValueError(
f"Unknown reference view selection strategy: {strategy}. "
f"Must be one of: 'first', 'middle', 'saddle_balanced', 'saddle_sim_range'"
)
return b_idx
def reorder_by_reference(
x: torch.Tensor,
b_idx: torch.Tensor,
) -> torch.Tensor:
"""
Reorder views to place the selected reference view first.
Args:
x: Input tensor of shape (B, S, N, C)
b_idx: Reference view indices of shape (B,)
Returns:
Reordered tensor with reference view at position 0
Example:
If b_idx = [2] and S = 5 (views [0,1,2,3,4]),
result order is [2,0,1,3,4] (ref_idx first, then others in order)
"""
B, S = x.shape[0], x.shape[1]
# For single view, no reordering needed
if S <= 1:
return x
# Create position indices: (B, S) where each row is [0, 1, 2, ..., S-1]
positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1) # B S
# For each position, determine which original index it should take
# Position 0 gets ref_idx
# Position 1 to ref_idx gets indices 0 to ref_idx-1
# Position ref_idx+1 to S-1 gets indices ref_idx+1 to S-1
b_idx_expanded = b_idx.unsqueeze(1) # B 1
# Create the reordering indices
# For positions 1 to ref_idx: map to indices 0 to ref_idx-1 (shift by -1)
# For positions > ref_idx: keep the same
reorder_indices = positions.clone()
reorder_indices = torch.where(
(positions > 0) & (positions <= b_idx_expanded),
positions - 1,
positions
)
# Set position 0 to ref_idx
reorder_indices[:, 0] = b_idx
# Gather using advanced indexing
batch_indices = torch.arange(B, device=x.device).unsqueeze(1) # B 1
x_reordered = x[batch_indices, reorder_indices]
return x_reordered
def restore_original_order(
x: torch.Tensor,
b_idx: torch.Tensor,
) -> torch.Tensor:
"""
Restore original view order after processing.
Args:
x: Reordered tensor of shape (B, S, ...)
b_idx: Original reference view indices of shape (B,)
Returns:
Tensor with original view order restored
Example:
If original order was [0, 1, 2, 3, 4] and b_idx=2,
reordered becomes [2, 0, 1, 3, 4] (reference at position 0),
restore should return [0, 1, 2, 3, 4] (original order).
"""
B, S = x.shape[0], x.shape[1]
# For single view, no restoration needed
if S <= 1:
return x
# Create target position indices: (B, S) where each row is [0, 1, 2, ..., S-1]
target_positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1) # B S
# For each target position, determine which current position it comes from
# Target position 0 to ref_idx-1 <- Current position 1 to ref_idx (shift by +1)
# Target position ref_idx <- Current position 0
# Target position ref_idx+1 to S-1 <- Current position ref_idx+1 to S-1 (no change)
b_idx_expanded = b_idx.unsqueeze(1) # B 1
# Create the restore indices
restore_indices = torch.where(
target_positions < b_idx_expanded,
target_positions + 1, # Positions before ref_idx come from current position + 1
target_positions # Positions after ref_idx stay the same
)
# Target position = ref_idx comes from current position 0
# Use scatter to set specific positions
restore_indices = torch.scatter(
restore_indices,
dim=1,
index=b_idx_expanded,
src=torch.zeros_like(b_idx_expanded)
)
# Gather using advanced indexing
batch_indices = torch.arange(B, device=x.device).unsqueeze(1) # B 1
x_restored = x[batch_indices, restore_indices]
return x_restored

View File

@@ -0,0 +1,109 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 # noqa
from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
qk_norm: bool = False,
rope=None,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.rope = rope
def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor:
# Debug breakpoint removed for production
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
q = self.rope(q, pos) if self.rope is not None else q
k = self.rope(k, pos) if self.rope is not None else k
x = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
attn_mask=attn_mask,
)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: Union[float, Tensor] = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x

View File

@@ -0,0 +1,81 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable
from torch import Tensor, nn
from .attention import Attention, LayerScale, Mlp
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
rope=None,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
qk_norm=qk_norm,
rope=rope,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.sample_drop_ratio = 0.0 # Equivalent to always having drop_path=0
def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor:
def attn_residual_func(x: Tensor, pos=None, attn_mask=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
# drop_path is always 0, so always take the else branch
x = x + attn_residual_func(x, pos=pos, attn_mask=attn_mask)
x = x + ffn_residual_func(x)
return x

View File

@@ -0,0 +1,340 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from math import isqrt
from typing import Literal, Optional
import torch
from einops import rearrange, repeat
from tqdm import tqdm
from depth_anything_3.specs import Gaussians
from depth_anything_3.utils.camera_trj_helpers import (
interpolate_extrinsics,
interpolate_intrinsics,
render_dolly_zoom_path,
render_stabilization_path,
render_wander_path,
render_wobble_inter_path,
)
from depth_anything_3.utils.geometry import affine_inverse, as_homogeneous, get_fov
from depth_anything_3.utils.logger import logger
try:
from gsplat import rasterization
except ImportError:
logger.warn(
"Dependency `gsplat` is required for rendering 3DGS. "
"Install via: pip install git+https://github.com/nerfstudio-project/"
"gsplat.git@0b4dddf04cb687367602c01196913cde6a743d70"
)
def render_3dgs(
extrinsics: torch.Tensor, # "batch_views 4 4", w2c
intrinsics: torch.Tensor, # "batch_views 3 3", normalized
image_shape: tuple[int, int],
gaussian: Gaussians,
background_color: Optional[torch.Tensor] = None, # "batch_views 3"
use_sh: bool = True,
num_view: int = 1,
color_mode: Literal["RGB+D", "RGB+ED"] = "RGB+D",
**kwargs,
) -> tuple[
torch.Tensor, # "batch_views 3 height width"
torch.Tensor, # "batch_views height width"
]:
# extract gaussian params
gaussian_means = gaussian.means
gaussian_scales = gaussian.scales
gaussian_quats = gaussian.rotations
gaussian_opacities = gaussian.opacities
gaussian_sh_coefficients = gaussian.harmonics
b, _, _ = extrinsics.shape
if background_color is None:
background_color = repeat(torch.tensor([0.0, 0.0, 0.0]), "c -> b c", b=b).to(
gaussian_sh_coefficients
)
if use_sh:
_, _, _, n = gaussian_sh_coefficients.shape
degree = isqrt(n) - 1
shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous()
else: # use color
shs = (
gaussian_sh_coefficients.squeeze(-1).sigmoid().contiguous()
) # (b, g, c), normed to (0, 1)
h, w = image_shape
fov_x, fov_y = get_fov(intrinsics).unbind(dim=-1)
tan_fov_x = (0.5 * fov_x).tan()
tan_fov_y = (0.5 * fov_y).tan()
focal_length_x = w / (2 * tan_fov_x)
focal_length_y = h / (2 * tan_fov_y)
view_matrix = extrinsics.float()
all_images = []
all_radii = []
all_depths = []
# render view in a batch based, each batch contains one scene
# assume the Gaussian parameters are originally repeated along the view dim
batch_scene = b // num_view
def index_i_gs_attr(full_attr, idx):
# return rearrange(full_attr, "(b v) ... -> b v ...", v=num_view)[idx, 0]
return full_attr[idx]
for i in range(batch_scene):
K = repeat(
torch.tensor(
[
[0, 0, w / 2.0],
[0, 0, h / 2.0],
[0, 0, 1],
]
),
"i j -> v i j",
v=num_view,
).to(gaussian_means)
K[:, 0, 0] = focal_length_x.reshape(batch_scene, num_view)[i]
K[:, 1, 1] = focal_length_y.reshape(batch_scene, num_view)[i]
i_means = index_i_gs_attr(gaussian_means, i) # [N, 3]
i_scales = index_i_gs_attr(gaussian_scales, i)
i_quats = index_i_gs_attr(gaussian_quats, i)
i_opacities = index_i_gs_attr(gaussian_opacities, i) # [N,]
i_colors = index_i_gs_attr(shs, i) # [N, K, 3]
i_viewmats = rearrange(view_matrix, "(b v) ... -> b v ...", v=num_view)[i] # [v, 4, 4]
i_backgrounds = rearrange(background_color, "(b v) ... -> b v ...", v=num_view)[
i
] # [v, 3]
render_colors, render_alphas, info = rasterization(
means=i_means,
quats=i_quats, # [N, 4]
scales=i_scales, # [N, 3]
opacities=i_opacities,
colors=i_colors,
viewmats=i_viewmats, # [v, 4, 4]
Ks=K, # [v, 3, 3]
backgrounds=i_backgrounds,
render_mode=color_mode,
width=w,
height=h,
packed=False,
sh_degree=degree if use_sh else None,
)
depth = render_colors[..., -1].unbind(dim=0)
image = rearrange(render_colors[..., :3], "v h w c -> v c h w").unbind(dim=0)
radii = info["radii"].unbind(dim=0)
try:
info["means2d"].retain_grad() # [1, N, 2]
except Exception:
pass
all_images.extend(image)
all_depths.extend(depth)
all_radii.extend(radii)
return torch.stack(all_images), torch.stack(all_depths)
def run_renderer_in_chunk_w_trj_mode(
gaussians: Gaussians,
extrinsics: torch.Tensor, # world2cam, "batch view 4 4" | "batch view 3 4"
intrinsics: torch.Tensor, # unnormed intrinsics, "batch view 3 3"
image_shape: tuple[int, int],
chunk_size: Optional[int] = 8,
trj_mode: Literal[
"original",
"smooth",
"interpolate",
"interpolate_smooth",
"wander",
"dolly_zoom",
"extend",
"wobble_inter",
] = "smooth",
input_shape: Optional[tuple[int, int]] = None,
enable_tqdm: Optional[bool] = False,
**kwargs,
) -> tuple[
torch.Tensor, # color, "batch view 3 height width"
torch.Tensor, # depth, "batch view height width"
]:
cam2world = affine_inverse(as_homogeneous(extrinsics))
if input_shape is not None:
in_h, in_w = input_shape
else:
in_h, in_w = image_shape
intr_normed = intrinsics.clone().detach()
intr_normed[..., 0, :] /= in_w
intr_normed[..., 1, :] /= in_h
if extrinsics.shape[1] <= 1:
assert trj_mode in [
"wander",
"dolly_zoom",
], "Please set trj_mode to 'wander' or 'dolly_zoom' when n_views=1"
def _smooth_trj_fn_batch(raw_c2ws, k_size=50):
try:
smooth_c2ws = torch.stack(
[render_stabilization_path(c2w_i, k_size) for c2w_i in raw_c2ws],
dim=0,
)
except Exception as e:
print(f"[DEBUG] Path smoothing failed with error: {e}.")
smooth_c2ws = raw_c2ws
return smooth_c2ws
# get rendered trj
if trj_mode == "original":
tgt_c2w = cam2world
tgt_intr = intr_normed
elif trj_mode == "smooth":
tgt_c2w = _smooth_trj_fn_batch(cam2world)
tgt_intr = intr_normed
elif trj_mode in ["interpolate", "interpolate_smooth", "extend"]:
inter_len = 8
total_len = (cam2world.shape[1] - 1) * inter_len
if total_len > 24 * 18: # no more than 18s
inter_len = max(1, 24 * 10 // (cam2world.shape[1] - 1))
if total_len < 24 * 2: # no less than 2s
inter_len = max(1, 24 * 2 // (cam2world.shape[1] - 1))
if inter_len > 2:
t = torch.linspace(0, 1, inter_len, dtype=torch.float32, device=cam2world.device)
t = (torch.cos(torch.pi * (t + 1)) + 1) / 2
tgt_c2w_b = []
tgt_intr_b = []
for b_idx in range(cam2world.shape[0]):
tgt_c2w = []
tgt_intr = []
for cur_idx in range(cam2world.shape[1] - 1):
tgt_c2w.append(
interpolate_extrinsics(
cam2world[b_idx, cur_idx], cam2world[b_idx, cur_idx + 1], t
)[(0 if cur_idx == 0 else 1) :]
)
tgt_intr.append(
interpolate_intrinsics(
intr_normed[b_idx, cur_idx], intr_normed[b_idx, cur_idx + 1], t
)[(0 if cur_idx == 0 else 1) :]
)
tgt_c2w_b.append(torch.cat(tgt_c2w))
tgt_intr_b.append(torch.cat(tgt_intr))
tgt_c2w = torch.stack(tgt_c2w_b) # b v 4 4
tgt_intr = torch.stack(tgt_intr_b) # b v 3 3
else:
tgt_c2w = cam2world
tgt_intr = intr_normed
if trj_mode in ["interpolate_smooth", "extend"]:
tgt_c2w = _smooth_trj_fn_batch(tgt_c2w)
if trj_mode == "extend":
# apply dolly_zoom and wander in the middle frame
assert cam2world.shape[0] == 1, "extend only supports for batch_size=1 currently."
mid_idx = tgt_c2w.shape[1] // 2
c2w_wd, intr_wd = render_wander_path(
tgt_c2w[0, mid_idx],
tgt_intr[0, mid_idx],
h=in_h,
w=in_w,
num_frames=max(36, min(60, mid_idx // 2)),
max_disp=24.0,
)
c2w_dz, intr_dz = render_dolly_zoom_path(
tgt_c2w[0, mid_idx],
tgt_intr[0, mid_idx],
h=in_h,
w=in_w,
num_frames=max(36, min(60, mid_idx // 2)),
)
tgt_c2w = torch.cat(
[
tgt_c2w[:, :mid_idx],
c2w_wd.unsqueeze(0),
c2w_dz.unsqueeze(0),
tgt_c2w[:, mid_idx:],
],
dim=1,
)
tgt_intr = torch.cat(
[
tgt_intr[:, :mid_idx],
intr_wd.unsqueeze(0),
intr_dz.unsqueeze(0),
tgt_intr[:, mid_idx:],
],
dim=1,
)
elif trj_mode in ["wander", "dolly_zoom"]:
if trj_mode == "wander":
render_fn = render_wander_path
extra_kwargs = {"max_disp": 24.0}
else:
render_fn = render_dolly_zoom_path
extra_kwargs = {"D_focus": 30.0, "max_disp": 2.0}
tgt_c2w = []
tgt_intr = []
for b_idx in range(cam2world.shape[0]):
c2w_i, intr_i = render_fn(
cam2world[b_idx, 0], intr_normed[b_idx, 0], h=in_h, w=in_w, **extra_kwargs
)
tgt_c2w.append(c2w_i)
tgt_intr.append(intr_i)
tgt_c2w = torch.stack(tgt_c2w)
tgt_intr = torch.stack(tgt_intr)
elif trj_mode == "wobble_inter":
tgt_c2w, tgt_intr = render_wobble_inter_path(
cam2world=cam2world,
intr_normed=intr_normed,
inter_len=10,
n_skip=3,
)
else:
raise Exception(f"trj mode [{trj_mode}] is not implemented.")
_, v = tgt_c2w.shape[:2]
tgt_extr = affine_inverse(tgt_c2w)
if chunk_size is None:
chunk_size = v
chunk_size = min(v, chunk_size)
all_colors = []
all_depths = []
for chunk_idx in tqdm(
range(math.ceil(v / chunk_size)),
desc="Rendering novel views",
disable=(not enable_tqdm),
leave=False,
):
s = int(chunk_idx * chunk_size)
e = int((chunk_idx + 1) * chunk_size)
cur_n_view = tgt_extr[:, s:e].shape[1]
color, depth = render_3dgs(
extrinsics=rearrange(tgt_extr[:, s:e], "b v ... -> (b v) ..."), # w2c
intrinsics=rearrange(tgt_intr[:, s:e], "b v ... -> (b v) ..."), # normed
image_shape=image_shape,
gaussian=gaussians,
num_view=cur_n_view,
**kwargs,
)
all_colors.append(rearrange(color, "(b v) ... -> b v ...", v=cur_n_view))
all_depths.append(rearrange(depth, "(b v) ... -> b v ...", v=cur_n_view))
all_colors = torch.cat(all_colors, dim=1)
all_depths = torch.cat(all_depths, dim=1)
return all_colors, all_depths

View File

@@ -0,0 +1,230 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
# -----------------------------------------------------------------------------
# Activation functions
# -----------------------------------------------------------------------------
def activate_head_gs(out, activation="norm_exp", conf_activation="expp1", conf_dim=None):
"""
Process network output to extract GS params and density values.
Density could be view-dependent as SH coefficient
Args:
out: Network output tensor (B, C, H, W)
activation: Activation type for 3D points
conf_activation: Activation type for confidence values
Returns:
Tuple of (3D points tensor, confidence tensor)
"""
# Move channels from last dim to the 4th dimension => (B, H, W, C)
fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
# Split into xyz (first C-1 channels) and confidence (last channel)
conf_dim = 1 if conf_dim is None else conf_dim
xyz = fmap[:, :, :, :-conf_dim]
conf = fmap[:, :, :, -1] if conf_dim == 1 else fmap[:, :, :, -conf_dim:]
if activation == "norm_exp":
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
xyz_normed = xyz / d
pts3d = xyz_normed * torch.expm1(d)
elif activation == "norm":
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
elif activation == "exp":
pts3d = torch.exp(xyz)
elif activation == "relu":
pts3d = F.relu(xyz)
elif activation == "sigmoid":
pts3d = torch.sigmoid(xyz)
elif activation == "linear":
pts3d = xyz
else:
raise ValueError(f"Unknown activation: {activation}")
if conf_activation == "expp1":
conf_out = 1 + conf.exp()
elif conf_activation == "expp0":
conf_out = conf.exp()
elif conf_activation == "sigmoid":
conf_out = torch.sigmoid(conf)
elif conf_activation == "linear":
conf_out = conf
else:
raise ValueError(f"Unknown conf_activation: {conf_activation}")
return pts3d, conf_out
# -----------------------------------------------------------------------------
# Other utilities
# -----------------------------------------------------------------------------
class Permute(nn.Module):
"""nn.Module wrapper around Tensor.permute for cleaner nn.Sequential usage."""
dims: Tuple[int, ...]
def __init__(self, dims: Tuple[int, ...]) -> None:
super().__init__()
self.dims = dims
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override]
return x.permute(*self.dims)
def position_grid_to_embed(
pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100
) -> torch.Tensor:
"""
Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
Args:
pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
embed_dim: Output channel dimension for embeddings
Returns:
Tensor of shape (H, W, embed_dim) with positional embeddings
"""
H, W, grid_dim = pos_grid.shape
assert grid_dim == 2
pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
# Process x and y coordinates separately
emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
# Combine and reshape
emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
return emb.view(H, W, embed_dim) # [H, W, D]
def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
"""
This function generates a 1D positional embedding from a given grid using sine and cosine functions. # noqa
Args:
- embed_dim: The embedding dimension.
- pos: The position to generate the embedding from.
Returns:
- emb: The generated 1D positional embedding.
"""
assert embed_dim % 2 == 0
omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
omega /= embed_dim / 2.0
omega = 1.0 / omega_0**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2)
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
return emb.float()
# Inspired by https://github.com/microsoft/moge
def create_uv_grid(
width: int,
height: int,
aspect_ratio: float = None,
dtype: torch.dtype = None,
device: torch.device = None,
) -> torch.Tensor:
"""
Create a normalized UV grid of shape (width, height, 2).
The grid spans horizontally and vertically according to an aspect ratio,
ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
corner is at (x_span, y_span), normalized by the diagonal of the plane.
Args:
width (int): Number of points horizontally.
height (int): Number of points vertically.
aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
dtype (torch.dtype, optional): Data type of the resulting tensor.
device (torch.device, optional): Device on which the tensor is created.
Returns:
torch.Tensor: A (width, height, 2) tensor of UV coordinates.
"""
# Derive aspect ratio if not explicitly provided
if aspect_ratio is None:
aspect_ratio = float(width) / float(height)
# Compute normalized spans for X and Y
diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
span_x = aspect_ratio / diag_factor
span_y = 1.0 / diag_factor
# Establish the linspace boundaries
left_x = -span_x * (width - 1) / width
right_x = span_x * (width - 1) / width
top_y = -span_y * (height - 1) / height
bottom_y = span_y * (height - 1) / height
# Generate 1D coordinates
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
# Create 2D meshgrid (width x height) and stack into UV
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
uv_grid = torch.stack((uu, vv), dim=-1)
return uv_grid
# -----------------------------------------------------------------------------
# Interpolation (safe interpolation, avoid INT_MAX overflow)
# -----------------------------------------------------------------------------
def custom_interpolate(
x: torch.Tensor,
size: Union[Tuple[int, int], None] = None,
scale_factor: Union[float, None] = None,
mode: str = "bilinear",
align_corners: bool = True,
) -> torch.Tensor:
"""
Safe interpolation implementation to avoid INT_MAX overflow in torch.nn.functional.interpolate.
"""
if size is None:
assert scale_factor is not None, "Either size or scale_factor must be provided."
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
INT_MAX = 1610612736
total = size[0] * size[1] * x.shape[0] * x.shape[1]
if total > INT_MAX:
chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0)
outs = [
nn.functional.interpolate(c, size=size, mode=mode, align_corners=align_corners)
for c in chunks
]
return torch.cat(outs, dim=0).contiguous()
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)

View File

@@ -0,0 +1,208 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
def extri_intri_to_pose_encoding(
extrinsics,
intrinsics,
image_size_hw=None,
):
"""Convert camera extrinsics and intrinsics to a compact pose encoding."""
# extrinsics: BxSx3x4
# intrinsics: BxSx3x3
R = extrinsics[:, :, :3, :3] # BxSx3x3
T = extrinsics[:, :, :3, 3] # BxSx3
quat = mat_to_quat(R)
# Note the order of h and w here
H, W = image_size_hw
fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
return pose_encoding
def pose_encoding_to_extri_intri(
pose_encoding,
image_size_hw=None,
):
"""Convert a pose encoding back to camera extrinsics and intrinsics."""
T = pose_encoding[..., :3]
quat = pose_encoding[..., 3:7]
fov_h = pose_encoding[..., 7]
fov_w = pose_encoding[..., 8]
R = quat_to_mat(quat)
extrinsics = torch.cat([R, T[..., None]], dim=-1)
H, W = image_size_hw
fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6)
fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6)
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
intrinsics[..., 0, 0] = fx
intrinsics[..., 1, 1] = fy
intrinsics[..., 0, 2] = W / 2
intrinsics[..., 1, 2] = H / 2
intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
return extrinsics, intrinsics
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
"""
Quaternion Order: XYZW or say ijkr, scalar-last
Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part last,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
i, j, k, r = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part last, as tensor of shape (..., 4).
Quaternion Order: XYZW or say ijkr, scalar-last
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
quat_by_rijk = torch.stack(
[
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(
batch_dim + (4,)
)
out = out[..., [1, 2, 3, 0]]
out = standardize_quaternion(out)
return out
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
if torch.is_grad_enabled():
ret[positive_mask] = torch.sqrt(x[positive_mask])
else:
ret = torch.where(positive_mask, torch.sqrt(x), ret)
return ret
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part last,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
def cam_quat_xyzw_to_world_quat_wxyz(cam_quat_xyzw, c2w):
# cam_quat_xyzw: (b, n, 4) in xyzw
# c2w: (b, n, 4, 4)
b, n = cam_quat_xyzw.shape[:2]
# 1. xyzw -> wxyz
cam_quat_wxyz = torch.cat(
[
cam_quat_xyzw[..., 3:4], # w
cam_quat_xyzw[..., 0:1], # x
cam_quat_xyzw[..., 1:2], # y
cam_quat_xyzw[..., 2:3], # z
],
dim=-1,
)
# 2. Quaternion to matrix
cam_quat_wxyz_flat = cam_quat_wxyz.reshape(-1, 4)
rotmat_cam = quat_to_mat(cam_quat_wxyz_flat).reshape(b, n, 3, 3)
# 3. Transform to world space
rotmat_c2w = c2w[..., :3, :3]
rotmat_world = torch.matmul(rotmat_c2w, rotmat_cam)
# 4. Matrix to quaternion (wxyz)
rotmat_world_flat = rotmat_world.reshape(-1, 3, 3)
world_quat_wxyz_flat = mat_to_quat(rotmat_world_flat)
world_quat_wxyz = world_quat_wxyz_flat.reshape(b, n, 4)
return world_quat_wxyz

View File

@@ -0,0 +1,50 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from pathlib import Path
def get_all_models() -> OrderedDict:
"""
Scans all YAML files in the configs directory and returns a sorted dictionary where:
- Keys are model names (YAML filenames without the .yaml extension)
- Values are absolute paths to the corresponding YAML files
"""
# Get path to the configs directory within the da3 package
# Works both in development and after pip installation
# configs_dir = files("depth_anything_3").joinpath("configs")
configs_dir = Path(__file__).resolve().parent / "configs"
# Ensure path is a Path object for consistent cross-platform handling
configs_dir = Path(configs_dir)
model_entries = []
# Iterate through all items in the configs directory
for item in configs_dir.iterdir():
# Filter for YAML files (excluding directories)
if item.is_file() and item.suffix == ".yaml":
# Extract model name (filename without .yaml extension)
model_name = item.stem
# Get absolute path (resolve() handles symlinks)
file_abs_path = str(item.resolve())
model_entries.append((model_name, file_abs_path))
# Sort entries by model name and convert to OrderedDict
sorted_entries = sorted(model_entries, key=lambda x: x[0])
return OrderedDict(sorted_entries)
# Global registry for external imports
MODEL_REGISTRY = get_all_models()

View File

@@ -0,0 +1,24 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Services module for Depth Anything 3.
"""
from depth_anything_3.services.backend import create_app, start_server
__all__ = [
start_server,
create_app,
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,806 @@
#!/usr/bin/env python3
# flake8: noqa: E501
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Depth Anything 3 Gallery Server (two-level, single-file)
Now supports paginated depth preview (4 per page).
"""
import argparse
import json
import mimetypes
import os
import posixpath
import sys
from functools import partial
from http import HTTPStatus
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
from urllib.parse import quote, unquote
# ------------------------------ Embedded HTML ------------------------------ #
HTML_PAGE = r"""<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Depth Anything 3 Gallery</title>
<meta name="viewport" content="width=device-width, initial-scale=1" />
<link rel="icon" href="https://i.postimg.cc/rFSzGJ7J/light-icon.jpg" media="(prefers-color-scheme: light)">
<link rel="icon" href="https://i.postimg.cc/P5gZfJsf/dark-icon.jpg" media="(prefers-color-scheme: dark)">
<script type="module" src="https://unpkg.com/@google/model-viewer/dist/model-viewer.min.js"></script>
<style>
:root {
--gap:16px; --card-radius:16px; --shadow:0 8px 24px rgba(0,0,0,.12);
--maxW:1036px; --maxH:518px;
--tech-blue: #00d4ff;
--tech-cyan: #00ffcc;
--tech-purple: #7877c6;
}
*{ box-sizing:border-box }
/* Dark mode tech theme */
@media (prefers-color-scheme: dark) {
body{
margin:0; font:16px/1.5 system-ui,-apple-system,Segoe UI,Roboto,sans-serif;
background: linear-gradient(135deg, #0a0a0a 0%, #1a1a2e 50%, #16213e 100%);
color:#e8eaed;
position: relative;
overflow-x: hidden;
}
body::before {
content: '';
position: fixed;
top: 0;
left: 0;
right: 0;
bottom: 0;
background:
radial-gradient(circle at 20% 80%, rgba(120, 119, 198, 0.3) 0%, transparent 50%),
radial-gradient(circle at 80% 20%, rgba(255, 119, 198, 0.3) 0%, transparent 50%),
radial-gradient(circle at 40% 40%, rgba(120, 219, 255, 0.2) 0%, transparent 50%);
animation: techPulse 8s ease-in-out infinite;
z-index: -1;
}
}
/* Light mode tech theme */
@media (prefers-color-scheme: light) {
body{
margin:0; font:16px/1.5 system-ui,-apple-system,Segoe UI,Roboto,sans-serif;
background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 50%, #cbd5e1 100%);
color:#1e293b;
position: relative;
overflow-x: hidden;
}
body::before {
content: '';
position: fixed;
top: 0;
left: 0;
right: 0;
bottom: 0;
background:
radial-gradient(circle at 20% 80%, rgba(0, 212, 255, 0.1) 0%, transparent 50%),
radial-gradient(circle at 80% 20%, rgba(0, 102, 255, 0.1) 0%, transparent 50%),
radial-gradient(circle at 40% 40%, rgba(0, 255, 204, 0.08) 0%, transparent 50%);
animation: techPulse 8s ease-in-out infinite;
z-index: -1;
}
}
@keyframes techPulse {
0%, 100% { opacity: 0.5; }
50% { opacity: 0.8; }
}
@keyframes techGradient {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
/* Dark mode header */
@media (prefers-color-scheme: dark) {
header{
padding:20px 24px; position:sticky; top:0;
background:linear-gradient(180deg,rgba(10,10,10,0.9) 60%,rgba(10,10,10,0));
z-index:2; border-bottom:1px solid rgba(0, 212, 255, 0.2);
backdrop-filter: blur(10px);
}
h1{
margin:0; font-size:22px;
background: linear-gradient(45deg, var(--tech-blue), var(--tech-cyan), var(--tech-purple));
background-size: 400% 400%;
-webkit-background-clip: text;
background-clip: text;
color: transparent;
animation: techGradient 3s ease infinite;
text-shadow: 0 0 30px rgba(0, 212, 255, 0.5);
}
.muted{ opacity:.7; font-size:13px; color: #a0a0a0; }
#backBtn{
display:none; padding:6px 10px; border-radius:10px;
border:1px solid rgba(0, 212, 255, 0.3);
background:rgba(0, 0, 0, 0.3);
color:#e8eaed; cursor:pointer;
transition: all 0.3s ease;
}
#backBtn:hover {
border-color: var(--tech-blue);
box-shadow: 0 0 10px rgba(0, 212, 255, 0.3);
}
#search{
flex:1 1 260px; min-width:240px; max-width:520px;
padding:10px 14px; border-radius:12px;
border:1px solid rgba(0, 212, 255, 0.3);
background:rgba(0, 0, 0, 0.3);
color:#e8eaed; outline:none;
transition: all 0.3s ease;
}
#search:focus {
border-color: var(--tech-blue);
box-shadow: 0 0 10px rgba(0, 212, 255, 0.3);
}
}
/* Light mode header */
@media (prefers-color-scheme: light) {
header{
padding:20px 24px; position:sticky; top:0;
background:linear-gradient(180deg,rgba(248,250,252,0.9) 60%,rgba(248,250,252,0));
z-index:2; border-bottom:1px solid rgba(0, 212, 255, 0.3);
backdrop-filter: blur(10px);
}
h1{
margin:0; font-size:22px;
background: linear-gradient(45deg, #0066ff, #00d4ff, #00ffcc);
background-size: 400% 400%;
-webkit-background-clip: text;
background-clip: text;
color: transparent;
animation: techGradient 3s ease infinite;
text-shadow: 0 0 20px rgba(0, 102, 255, 0.3);
}
.muted{ opacity:.7; font-size:13px; color: #64748b; }
#backBtn{
display:none; padding:6px 10px; border-radius:10px;
border:1px solid rgba(0, 212, 255, 0.4);
background:rgba(255, 255, 255, 0.8);
color:#1e293b; cursor:pointer;
transition: all 0.3s ease;
}
#backBtn:hover {
border-color: #0066ff;
box-shadow: 0 0 10px rgba(0, 102, 255, 0.3);
}
#search{
flex:1 1 260px; min-width:240px; max-width:520px;
padding:10px 14px; border-radius:12px;
border:1px solid rgba(0, 212, 255, 0.4);
background:rgba(255, 255, 255, 0.8);
color:#1e293b; outline:none;
transition: all 0.3s ease;
}
#search:focus {
border-color: #0066ff;
box-shadow: 0 0 10px rgba(0, 102, 255, 0.3);
}
}
.row{ display:flex; gap:12px; align-items:center; flex-wrap:wrap; justify-content:center; }
main{ padding:16px 24px 24px; display:grid; place-items:center; }
.group-wrap{ width:min(900px,100%); }
.group-list{ list-style:none; margin:0; padding:0; display:grid; gap:10px; }
/* Dark mode cards */
@media (prefers-color-scheme: dark) {
.group-item{
display:flex; align-items:center; gap:12px; padding:12px 14px;
background:rgba(0, 0, 0, 0.3); border:1px solid rgba(0, 212, 255, 0.2); border-radius:14px; cursor:pointer;
transition: all 0.3s ease;
backdrop-filter: blur(10px);
}
.group-item:hover{
transform: translateY(-1px);
border-color:var(--tech-blue);
box-shadow: 0 4px 15px rgba(0, 212, 255, 0.2);
}
.card{
background:rgba(0, 0, 0, 0.3); border:1px solid rgba(0, 212, 255, 0.2); border-radius:var(--card-radius);
overflow:hidden; box-shadow:var(--shadow);
transition:all 0.3s ease; cursor:pointer; display:flex; flex-direction:column; max-width:var(--maxW);
backdrop-filter: blur(10px);
}
.card:hover{
transform:translateY(-2px);
border-color:var(--tech-blue);
box-shadow: 0 8px 25px rgba(0, 212, 255, 0.2);
}
.thumb-box{
position:relative; width:100%; aspect-ratio:2/1;
background:linear-gradient(135deg, #0e121b 0%, #1a1a2e 100%);
display:grid; place-items:center; overflow:hidden;
border-bottom: 1px solid rgba(0, 212, 255, 0.1);
}
.open{
font-size:12px; opacity:.7; padding:6px 8px;
border:1px solid rgba(0, 212, 255, 0.3);
border-radius:10px;
background:rgba(0, 212, 255, 0.1);
transition: all 0.3s ease;
}
.open:hover {
background:rgba(0, 212, 255, 0.2);
border-color: var(--tech-blue);
}
}
/* Light mode cards */
@media (prefers-color-scheme: light) {
.group-item{
display:flex; align-items:center; gap:12px; padding:12px 14px;
background:rgba(255, 255, 255, 0.8); border:1px solid rgba(0, 212, 255, 0.3); border-radius:14px; cursor:pointer;
transition: all 0.3s ease;
backdrop-filter: blur(10px);
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
}
.group-item:hover{
transform: translateY(-1px);
border-color:#0066ff;
box-shadow: 0 4px 15px rgba(0, 102, 255, 0.2);
}
.card{
background:rgba(255, 255, 255, 0.8); border:1px solid rgba(0, 212, 255, 0.3); border-radius:var(--card-radius);
overflow:hidden; box-shadow:0 4px 6px rgba(0, 0, 0, 0.1);
transition:all 0.3s ease; cursor:pointer; display:flex; flex-direction:column; max-width:var(--maxW);
backdrop-filter: blur(10px);
}
.card:hover{
transform:translateY(-2px);
border-color:#0066ff;
box-shadow: 0 8px 25px rgba(0, 102, 255, 0.2);
}
.thumb-box{
position:relative; width:100%; aspect-ratio:2/1;
background:linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%);
display:grid; place-items:center; overflow:hidden;
border-bottom: 1px solid rgba(0, 212, 255, 0.2);
}
.open{
font-size:12px; opacity:.7; padding:6px 8px;
border:1px solid rgba(0, 212, 255, 0.4);
border-radius:10px;
background:rgba(0, 212, 255, 0.1);
transition: all 0.3s ease;
}
.open:hover {
background:rgba(0, 212, 255, 0.2);
border-color: #0066ff;
}
}
.gname{ font-weight:600; overflow:hidden; text-overflow:ellipsis; white-space:nowrap; width:100%; }
.grid{
width:min(1200px,100%);
display:grid;
grid-template-columns:repeat(auto-fill,minmax(260px,1fr));
gap:var(--gap);
align-items:start;
justify-items:stretch;
margin: 0 auto;
padding: 0 20px;
}
.thumb{ max-width:100%; max-height:100%; object-fit:contain; display:block; }
.meta{ padding:12px 14px; display:flex; justify-content:space-between; align-items:center; gap:8px; }
.title{ font-weight:600; font-size:14px; overflow:hidden; text-overflow:ellipsis; white-space:nowrap; }
.empty{ opacity:.6; padding:40px 0; text-align:center; }
.crumb{ font-size:13px; opacity:.8; }
.overlay{ position:fixed; inset:0; background:rgba(0,0,0,.6); display:none; place-items:center; padding:20px; z-index:10; }
.overlay.show{ display:grid; }
/* Dark mode viewer */
@media (prefers-color-scheme: dark) {
.viewer{
inline-size:min(92vw,var(--maxW));
block-size:min(82vh,var(--maxH));
background:#0e121b; border:1px solid rgba(0, 212, 255, 0.3); border-radius:18px; overflow:hidden; position:relative; box-shadow:0 12px 36px rgba(0,0,0,.35);
display:grid;
}
.chip{ background:rgba(0,0,0,.45); border:1px solid rgba(0, 212, 255, 0.3); color:#e8eaed; padding:6px 10px; border-radius:12px; font-size:12px; max-width:60%; overflow:hidden; text-overflow:ellipsis; white-space:nowrap; }
.btn{ margin-left:auto; background:rgba(0, 0, 0, 0.3); color:#e8eaed; border:1px solid rgba(0, 212, 255, 0.3); border-radius:10px; padding:6px 10px; cursor:pointer; transition: all 0.3s ease; }
.btn:hover { border-color: var(--tech-blue); box-shadow: 0 0 10px rgba(0, 212, 255, 0.3); }
.mv-box{ width:100%; aspect-ratio:1036/518; background:#0b0d12; border:1px solid rgba(0, 212, 255, 0.2); border-radius:12px; overflow:hidden; }
.mv-box model-viewer{ width:100%; height:100%; background:#0b0d12; }
.res-cell{ position:relative; width:100%; aspect-ratio:2/1; background:#0e121b; border:1px solid rgba(0, 212, 255, 0.2); border-radius:12px; overflow:hidden; display:grid; place-items:center; }
.res-empty{ position:absolute; inset:0; display:grid; place-items:center; opacity:.55; font-size:12px; color:#9aa0a6; }
.download-icon{ background:rgba(0, 0, 0, 0.6); border:1px solid rgba(0, 212, 255, 0.3); color:#e8eaed; box-shadow:0 4px 12px rgba(0,0,0,0.3); }
.download-icon:hover{ background:rgba(0, 212, 255, 0.2); border-color:var(--tech-blue); box-shadow:0 0 20px rgba(0, 212, 255, 0.4); transform:scale(1.05); }
}
/* Light mode viewer */
@media (prefers-color-scheme: light) {
.viewer{
inline-size:min(92vw,var(--maxW));
block-size:min(82vh,var(--maxH));
background:#f8fafc; border:1px solid rgba(0, 212, 255, 0.4); border-radius:18px; overflow:hidden; position:relative; box-shadow:0 12px 36px rgba(0,0,0,.15);
display:grid;
}
.chip{ background:rgba(255,255,255,0.8); border:1px solid rgba(0, 212, 255, 0.4); color:#1e293b; padding:6px 10px; border-radius:12px; font-size:12px; max-width:60%; overflow:hidden; text-overflow:ellipsis; white-space:nowrap; }
.btn{ margin-left:auto; background:rgba(255, 255, 255, 0.8); color:#1e293b; border:1px solid rgba(0, 212, 255, 0.4); border-radius:10px; padding:6px 10px; cursor:pointer; transition: all 0.3s ease; }
.btn:hover { border-color: #0066ff; box-shadow: 0 0 10px rgba(0, 102, 255, 0.3); }
.mv-box{ width:100%; aspect-ratio:1036/518; background:#f8fafc; border:1px solid rgba(0, 212, 255, 0.3); border-radius:12px; overflow:hidden; }
.mv-box model-viewer{ width:100%; height:100%; background:#f8fafc; }
.res-cell{ position:relative; width:100%; aspect-ratio:2/1; background:#f8fafc; border:1px solid rgba(0, 212, 255, 0.3); border-radius:12px; overflow:hidden; display:grid; place-items:center; }
.res-empty{ position:absolute; inset:0; display:grid; place-items:center; opacity:.55; font-size:12px; color:#64748b; }
.download-icon{ background:rgba(255, 255, 255, 0.9); border:1px solid rgba(0, 212, 255, 0.4); color:#1e293b; box-shadow:0 4px 12px rgba(0,0,0,0.15); }
.download-icon:hover{ background:rgba(0, 212, 255, 0.2); border-color:#0066ff; box-shadow:0 0 20px rgba(0, 102, 255, 0.4); transform:scale(1.05); }
}
.viewer-header{ position:absolute; top:8px; left:8px; right:8px; display:flex; gap:8px; align-items:center; z-index:2; }
.viewer-body{ height:100%; display:grid; grid-template-rows:auto auto; gap:12px; padding:36px 8px 8px 8px; overflow:auto; }
.res-grid{ display:grid; grid-template-columns:1fr 1fr; gap:8px; }
.res-img{ max-width:100%; max-height:100%; object-fit:contain; display:block; }
.download-icon{ position:absolute; bottom:16px; right:16px; width:44px; height:44px; border-radius:50%; display:grid; place-items:center; font-size:20px; cursor:pointer; z-index:3; transition:all 0.3s ease; }
/* Pagination controls */
.pager {
grid-column: 1 / -1;
justify-content: center;
align-items: center;
display: flex;
gap: 16px;
margin-top: 8px;
font-size: 13px;
text-align: center;
}
/* Dark mode pagination */
@media (prefers-color-scheme: dark) {
.pager {
color: #ccc;
}
.pager button {
padding: 4px 10px;
border-radius: 8px;
border: 1px solid rgba(0, 212, 255, 0.3);
background: rgba(0, 0, 0, 0.3);
color: #e8eaed;
cursor: pointer;
transition: all 0.3s ease;
}
.pager button:hover:not(:disabled) {
border-color: var(--tech-blue);
box-shadow: 0 0 8px rgba(0, 212, 255, 0.2);
}
.pager button:disabled {
opacity: 0.4;
cursor: not-allowed;
}
}
/* Light mode pagination */
@media (prefers-color-scheme: light) {
.pager {
color: #64748b;
}
.pager button {
padding: 4px 10px;
border-radius: 8px;
border: 1px solid rgba(0, 212, 255, 0.4);
background: rgba(255, 255, 255, 0.8);
color: #1e293b;
cursor: pointer;
transition: all 0.3s ease;
}
.pager button:hover:not(:disabled) {
border-color: #0066ff;
box-shadow: 0 0 8px rgba(0, 102, 255, 0.2);
}
.pager button:disabled {
opacity: 0.4;
cursor: not-allowed;
}
}
/* Intro card styles */
@media (prefers-color-scheme: dark) {
.intro-card {
background: linear-gradient(135deg, rgba(0, 212, 255, 0.1) 0%, rgba(0, 102, 255, 0.1) 100%);
border: 1px solid rgba(0, 212, 255, 0.2);
backdrop-filter: blur(10px);
}
.intro-title {
background: linear-gradient(45deg, var(--tech-blue), var(--tech-cyan), var(--tech-purple));
background-size: 400% 400%;
-webkit-background-clip: text;
background-clip: text;
color: transparent;
animation: techGradient 3s ease infinite;
text-shadow: 0 0 20px rgba(0, 212, 255, 0.3);
}
.intro-description {
color: #e0e0e0;
}
}
@media (prefers-color-scheme: light) {
.intro-card {
background: linear-gradient(135deg, rgba(0, 212, 255, 0.05) 0%, rgba(0, 102, 255, 0.05) 100%);
border: 1px solid rgba(0, 212, 255, 0.3);
backdrop-filter: blur(10px);
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.intro-title {
background: linear-gradient(45deg, #0066ff, #00d4ff, #00ffcc);
background-size: 400% 400%;
-webkit-background-clip: text;
background-clip: text;
color: transparent;
animation: techGradient 3s ease infinite;
text-shadow: 0 0 15px rgba(0, 102, 255, 0.2);
}
.intro-description {
color: #334155;
}
}
footer{
opacity:.55;
font-size:12px;
padding:12px 24px 24px;
text-align:center;
display:flex;
justify-content:center;
align-items:center;
width:100%;
}
</style>
</head>
<body>
<header>
<div class="row">
<button id="backBtn">← Back</button>
<h1 id="pageTitle">Depth Anything 3 Gallery</h1>
<span id="crumb" class="crumb"></span>
<input id="search" placeholder="Search…" />
</div>
<div class="muted" id="hint" style="text-align: center;">Level 1 shows groups only; click a group to browse scenes and previews.</div>
</header>
<main>
<!-- Tech intro card -->
<div class="intro-card" style="margin-bottom: 30px; padding: 25px; border-radius: 15px; text-align: center; max-width: 800px;">
<h2 class="intro-title" style="margin: 0 0 15px 0; font-size: 1.8em; font-weight: 700;">
🎯 Depth Anything 3 Gallery
</h2>
<p class="intro-description" style="margin: 0; font-size: 1.1em; line-height: 1.6;">
Explore 3D reconstructions and depth visualizations from Depth Anything 3.
Browse through groups of scenes, preview 3D models, and examine depth maps interactively.
</p>
</div>
<div id="level1" class="group-wrap" aria-live="polite">
<ul id="groupList" class="group-list"></ul>
<div id="groupEmpty" class="empty" style="display:none;">No available groups</div>
</div>
<div id="level2" style="display:none; width:100%;" aria-live="polite">
<div id="topPager" class="pager" style="margin-bottom: 16px;"></div>
<div id="grid" class="grid"></div>
<div id="sceneEmpty" class="empty" style="display:none;">No available scenes in this group</div>
</div>
</main>
<div id="overlay" class="overlay" role="dialog" aria-modal="true" aria-label="3D Preview">
<div class="viewer" id="viewer">
<div class="viewer-header">
<div id="viewerTitle" class="chip">Loading…</div>
<button id="toggleView" class="btn" title="Toggle between 3D-only and resource view">Resource View</button>
<button id="closeBtn" class="btn">Close</button>
</div>
<div id="downloadBtn" class="download-icon" title="Download GLB model">⬇</div>
<div class="viewer-body">
<div class="mv-box"><model-viewer id="mv"
src=""
ar
camera-controls
auto-rotate
interaction-prompt="auto"
shadow-intensity="0.7"
exposure="1.0"
alt="GLB Preview"></model-viewer></div>
<div class="res-grid" id="resGrid" hidden></div>
</div>
</div>
</div>
<footer>Depth Anything 3 Gallery. Copyright 2025 Depth Anything 3 authors.</footer>
<script>
const level1=document.getElementById('level1'),level2=document.getElementById('level2'),pageTitle=document.getElementById('pageTitle'),crumb=document.getElementById('crumb'),backBtn=document.getElementById('backBtn'),hint=document.getElementById('hint'),searchInput=document.getElementById('search'),groupList=document.getElementById('groupList'),groupEmpty=document.getElementById('groupEmpty'),topPager=document.getElementById('topPager'),grid=document.getElementById('grid'),sceneEmpty=document.getElementById('sceneEmpty'),overlay=document.getElementById('overlay'),viewer=document.getElementById('viewer'),mv=document.getElementById('mv'),viewerTitle=document.getElementById('viewerTitle'),downloadBtn=document.getElementById('downloadBtn'),toggleViewBtn=document.getElementById('toggleView'),closeBtn=document.getElementById('closeBtn'),resGrid=document.getElementById('resGrid');
let GROUPS=[],SCENES=[],currentGroup=null,currentScene=null,currentPage=1,currentScenePage=1;
const qs=()=>new URLSearchParams(location.search);
async function loadGroups(){const r=await fetch('/manifest.json',{cache:'no-store'});if(!r.ok)throw new Error(r.status+' '+r.statusText);const j=await r.json();GROUPS=j.groups||[];renderGroups(GROUPS);}
async function loadScenes(g){const r=await fetch('/manifest/'+encodeURIComponent(g)+'.json',{cache:'no-store'});if(!r.ok)throw new Error(r.status+' '+r.statusText);const j=await r.json();SCENES=j.items||[];const p=parseInt(qs().get('page'))||1;renderScenes(SCENES,p);}
function renderGroups(list){groupList.innerHTML='';const q=searchInput.value.trim().toLowerCase();const f=list.filter(g=>(g.title||g.id||'').toLowerCase().includes(q));if(!f.length){groupEmpty.style.display='';return;}groupEmpty.style.display='none';for(const g of f){const li=document.createElement('li');li.className='group-item';li.title=g.title||g.id;li.onclick=()=>enterLevel2(g.id,{push:true});const name=document.createElement('div');name.className='gname';name.textContent=g.title||g.id;li.appendChild(name);groupList.appendChild(li);}}
function renderScenes(list,page=1){topPager.innerHTML='';grid.innerHTML='';const q=searchInput.value.trim().toLowerCase();const f=list.filter(x=>(x.title||'').toLowerCase().includes(q)||(x.id||'').toLowerCase().includes(q));if(!f.length){sceneEmpty.style.display='';topPager.style.display='none';return;}sceneEmpty.style.display='none';topPager.style.display='flex';const perPage=16;const total=f.length;const totalPages=Math.max(1,Math.ceil(total/perPage));currentScenePage=page;const u=new URL(location.href);u.searchParams.set('page',page);history.replaceState(null,'',u);const subset=f.slice((page-1)*perPage,page*perPage);for(const i of subset){const c=document.createElement('div');c.className='card';c.title=i.title;const b=document.createElement('div');b.className='thumb-box';const img=document.createElement('img');img.className='thumb';img.loading='lazy';img.alt=i.title;img.src=i.thumbnail;b.appendChild(img);const m=document.createElement('div');m.className='meta';const t=document.createElement('div');t.className='title';t.textContent=i.title;const o=document.createElement('div');o.className='open';o.textContent='Preview';m.appendChild(t);m.appendChild(o);c.appendChild(b);c.appendChild(m);c.onclick=()=>openViewer(i,{push:true});grid.appendChild(c);}function buildPager(){const pg=document.createElement('div');pg.className='pager';const prev=document.createElement('button');prev.textContent='← Prev';prev.disabled=page<=1;prev.onclick=()=>renderScenes(list,page-1);const info=document.createElement('span');info.textContent=`${page} / ${totalPages}`;const next=document.createElement('button');next.textContent='Next →';next.disabled=page>=totalPages;next.onclick=()=>renderScenes(list,page+1);pg.appendChild(prev);pg.appendChild(info);pg.appendChild(next);return pg;}topPager.innerHTML='';topPager.appendChild(buildPager());grid.appendChild(buildPager());}
function enterLevel1({push=false}={}){currentGroup=null;pageTitle.textContent='Depth Anything 3 Gallery';crumb.textContent='';backBtn.style.display='none';hint.style.display='';level1.style.display='';level2.style.display='none';overlay.classList.remove('show');mv.src='';const u=new URL(location.href);u.searchParams.delete('group');u.searchParams.delete('id');u.searchParams.delete('page');push?history.pushState(null,'',u):history.replaceState(null,'',u);searchInput.value='';loadGroups().catch(e=>{groupList.innerHTML='';groupEmpty.style.display='';groupEmpty.textContent='Failed to load groups: '+e;});}
async function enterLevel2(g,{push=false}={}){currentGroup=g;pageTitle.textContent=g;crumb.textContent='(group)';backBtn.style.display='';hint.style.display='none';level1.style.display='none';level2.style.display='';overlay.classList.remove('show');mv.src='';const u=new URL(location.href);u.searchParams.set('group',g);u.searchParams.delete('id');push?history.pushState(null,'',u):history.replaceState(null,'',u);searchInput.value='';try{await loadScenes(g);const id=qs().get('id');if(id){const hit=SCENES.find(x=>x.id===id);if(hit)openViewer(hit,{push:false});}}catch(e){grid.innerHTML='';sceneEmpty.style.display='';sceneEmpty.textContent='Failed to load scenes: '+e;}}
function buildResGrid(i,page=1){
resGrid.innerHTML='';
const imgs=i.depth_images||[];
const perPage=4;
const total=imgs.length;
const totalPages=Math.max(1, Math.ceil(total/perPage));
currentPage=page;
const subset=imgs.slice((page-1)*perPage,(page-1)*perPage+perPage);
for(let k=0;k<4;k++){
const cell=document.createElement('div');
cell.className='res-cell';
if(subset[k]){
const im=document.createElement('img');
im.className='res-img';
im.src=subset[k];
im.alt=(i.title||'scene')+' depth '+(k+1+(page-1)*perPage);
im.loading='lazy';
cell.appendChild(im);
} else {
const ph=document.createElement('div');
ph.className='res-empty';
ph.textContent='N/A';
cell.appendChild(ph);
}
resGrid.appendChild(cell);
}
// pagination bar (always rebuilt)
const pager=document.createElement('div');
pager.className='pager';
const prev=document.createElement('button');
prev.textContent='← Prev';
prev.disabled=page<=1;
prev.onclick=()=>buildResGrid(i,page-1);
const info=document.createElement('span');
info.textContent=`${page} / ${totalPages}`;
const next=document.createElement('button');
next.textContent='Next →';
next.disabled=page>=totalPages;
next.onclick=()=>buildResGrid(i,page+1);
pager.appendChild(prev);
pager.appendChild(info);
pager.appendChild(next);
resGrid.appendChild(pager);
}
function openViewer(i,{push=false}={}){currentScene=i;viewerTitle.textContent=i.title;mv.src=i.model;overlay.classList.add('show');resGrid.hidden=true;toggleViewBtn.textContent='Resource View';viewer.style.blockSize='min(82vh,var(--maxH))';buildResGrid(i,1);downloadBtn.onclick=()=>{const a=document.createElement('a');a.href=i.model;a.download=i.title+'.glb';a.click();};if(push){const u=new URL(location.href);if(!u.searchParams.get('group'))u.searchParams.set('group',currentGroup||'');u.searchParams.set('id',i.id);history.pushState(null,'',u);}}
function toggleView(){const hidden=!resGrid.hidden;resGrid.hidden=hidden;toggleViewBtn.textContent=hidden?'Resource View':'3D Only';viewer.style.blockSize=hidden?'min(82vh,var(--maxH))':'min(92vh,900px)';}
function closeViewer(){const hasId=!!qs().get('id');if(hasId&&history.length>1){history.back();return;}const u=new URL(location.href);u.searchParams.delete('id');history.replaceState(null,'',u);overlay.classList.remove('show');mv.src='';}
overlay.onclick=e=>{if(e.target===overlay)closeViewer();};closeBtn.onclick=closeViewer;toggleViewBtn.onclick=toggleView;backBtn.onclick=()=>history.back();
searchInput.oninput=()=>{!qs().get('group')?renderGroups(GROUPS):renderScenes(SCENES,1);};
window.onpopstate=()=>routeFromURL();
async function routeFromURL(){if(location.pathname!="/")history.replaceState(null,'','/'+location.search);const g=qs().get('group');const id=qs().get('id');if(!g){enterLevel1({push:false});return;}await enterLevel2(g,{push:false});if(id){const hit=SCENES.find(x=>x.id===id);if(hit)openViewer(hit,{push:false});else{overlay.classList.remove('show');mv.src='';}}else{overlay.classList.remove('show');mv.src='';}}
routeFromURL();
</script>
</body>
</html>
"""
# ------------------------------ Utilities ------------------------------ #
IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".webp", ".bmp")
def _url_join(*parts: str) -> str:
norm = posixpath.join(*[p.replace("\\", "/") for p in parts])
segs = [s for s in norm.split("/") if s not in ("", ".")]
return "/".join(quote(s) for s in segs)
def _is_plain_name(name: str) -> bool:
return all(c not in name for c in ("/", "\\")) and name not in (".", "..")
def build_group_list(root_dir: str) -> dict:
groups = []
try:
for gname in sorted(os.listdir(root_dir)):
gpath = os.path.join(root_dir, gname)
if not os.path.isdir(gpath):
continue
has_scene = False
try:
for sname in os.listdir(gpath):
spath = os.path.join(gpath, sname)
if not os.path.isdir(spath):
continue
if os.path.exists(os.path.join(spath, "scene.glb")) and os.path.exists(
os.path.join(spath, "scene.jpg")
):
has_scene = True
break
except Exception:
pass
if has_scene:
groups.append({"id": gname, "title": gname})
except Exception as e:
print(f"[warn] build_group_list failed: {e}", file=sys.stderr)
return {"groups": groups}
def build_group_manifest(root_dir: str, group: str) -> dict:
items = []
gpath = os.path.join(root_dir, group)
try:
if not os.path.isdir(gpath):
return {"group": group, "items": []}
for sname in sorted(os.listdir(gpath)):
spath = os.path.join(gpath, sname)
if not os.path.isdir(spath):
continue
glb_fs = os.path.join(spath, "scene.glb")
jpg_fs = os.path.join(spath, "scene.jpg")
if not (os.path.exists(glb_fs) and os.path.exists(jpg_fs)):
continue
depth_images = []
dpath = os.path.join(spath, "depth_vis")
if os.path.isdir(dpath):
files = [
f for f in os.listdir(dpath) if os.path.splitext(f)[1].lower() in IMAGE_EXTS
]
for fn in sorted(files):
depth_images.append("/" + _url_join(group, sname, "depth_vis", fn))
items.append(
{
"id": sname,
"title": sname,
"model": "/" + _url_join(group, sname, "scene.glb"),
"thumbnail": "/" + _url_join(group, sname, "scene.jpg"),
"depth_images": depth_images,
}
)
except Exception as e:
print(f"[warn] build_group_manifest failed for {group}: {e}", file=sys.stderr)
return {"group": group, "items": items}
class GalleryHandler(SimpleHTTPRequestHandler):
def __init__(self, *args, directory=None, **kwargs):
super().__init__(*args, directory=directory, **kwargs)
def do_GET(self):
if self.path in ("/", "/index.html") or self.path.startswith("/?"):
content = HTML_PAGE.encode("utf-8")
self.send_response(HTTPStatus.OK)
self.send_header("Content-Type", "text/html; charset=utf-8")
self.send_header("Content-Length", str(len(content)))
self.send_header("Cache-Control", "no-store")
self.end_headers()
self.wfile.write(content)
return
if self.path == "/manifest.json":
data = json.dumps(
build_group_list(self.directory), ensure_ascii=False, indent=2
).encode("utf-8")
self.send_response(HTTPStatus.OK)
self.send_header("Content-Type", "application/json; charset=utf-8")
self.send_header("Content-Length", str(len(data)))
self.send_header("Cache-Control", "no-store")
self.end_headers()
self.wfile.write(data)
return
if self.path.startswith("/manifest/") and self.path.endswith(".json"):
group_enc = self.path[len("/manifest/") : -len(".json")]
try:
group = unquote(group_enc)
except Exception:
group = group_enc
if not _is_plain_name(group):
self.send_error(HTTPStatus.BAD_REQUEST, "Invalid group name")
return
data = json.dumps(
build_group_manifest(self.directory, group), ensure_ascii=False, indent=2
).encode("utf-8")
self.send_response(HTTPStatus.OK)
self.send_header("Content-Type", "application/json; charset=utf-8")
self.send_header("Content-Length", str(len(data)))
self.send_header("Cache-Control", "no-store")
self.end_headers()
self.wfile.write(data)
return
if self.path == "/favicon.ico":
self.send_response(HTTPStatus.NO_CONTENT)
self.end_headers()
return
return super().do_GET()
def list_directory(self, path):
self.send_error(HTTPStatus.NOT_FOUND, "Directory listing disabled")
return None
def gallery():
parser = argparse.ArgumentParser(
description="Depth Anything 3 Gallery Server (two-level, with pagination)"
)
parser.add_argument(
"-d", "--dir", required=True, help="Gallery root directory (two-level: group/scene)"
)
parser.add_argument("-p", "--port", type=int, default=8000, help="Port (default 8000)")
parser.add_argument("--host", default="127.0.0.1", help="Host address (default 127.0.0.1)")
parser.add_argument("--open", action="store_true", help="Open browser after launch")
args = parser.parse_args()
root_dir = os.path.abspath(args.dir)
if not os.path.isdir(root_dir):
print(f"[error] Directory not found: {root_dir}", file=sys.stderr)
sys.exit(1)
Handler = partial(GalleryHandler, directory=root_dir)
server = ThreadingHTTPServer((args.host, args.port), Handler)
addr = f"http://{args.host}:{args.port}/"
print(f"[info] Serving gallery from: {root_dir}")
print(f"[info] Open: {addr}")
if args.open:
try:
import webbrowser
webbrowser.open(addr)
except Exception as e:
print(f"[warn] Failed to open browser: {e}", file=sys.stderr)
try:
server.serve_forever()
except KeyboardInterrupt:
print("\n[info] Shutting down...")
finally:
server.server_close()
def main():
"""Main entry point for gallery server."""
mimetypes.add_type("model/gltf-binary", ".glb")
gallery()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,239 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unified Inference Service
Provides unified interface for local and remote inference
"""
from typing import Any, Dict, List, Optional, Union
import numpy as np
import requests
import typer
from ..api import DepthAnything3
class InferenceService:
"""Unified inference service class"""
def __init__(self, model_dir: str, device: str = "cuda"):
self.model_dir = model_dir
self.device = device
self.model = None
def load_model(self):
"""Load model"""
if self.model is None:
typer.echo(f"Loading model from {self.model_dir}...")
self.model = DepthAnything3.from_pretrained(self.model_dir).to(self.device)
return self.model
def run_local_inference(
self,
image_paths: List[str],
export_dir: str,
export_format: str = "mini_npz-glb",
process_res: int = 504,
process_res_method: str = "upper_bound_resize",
export_feat_layers: List[int] = None,
extrinsics: Optional[np.ndarray] = None,
intrinsics: Optional[np.ndarray] = None,
align_to_input_ext_scale: bool = True,
use_ray_pose: bool = False,
ref_view_strategy: str = "saddle_balanced",
conf_thresh_percentile: float = 40.0,
num_max_points: int = 1_000_000,
show_cameras: bool = True,
feat_vis_fps: int = 15,
) -> Any:
"""Run local inference"""
if export_feat_layers is None:
export_feat_layers = []
model = self.load_model()
# Prepare inference parameters
inference_kwargs = {
"image": image_paths,
"export_dir": export_dir,
"export_format": export_format,
"process_res": process_res,
"process_res_method": process_res_method,
"export_feat_layers": export_feat_layers,
"align_to_input_ext_scale": align_to_input_ext_scale,
"use_ray_pose": use_ray_pose,
"ref_view_strategy": ref_view_strategy,
"conf_thresh_percentile": conf_thresh_percentile,
"num_max_points": num_max_points,
"show_cameras": show_cameras,
"feat_vis_fps": feat_vis_fps,
}
# Add pose data (if exists)
if extrinsics is not None:
inference_kwargs["extrinsics"] = extrinsics
if intrinsics is not None:
inference_kwargs["intrinsics"] = intrinsics
# Run inference
typer.echo(f"Running inference on {len(image_paths)} images...")
prediction = model.inference(**inference_kwargs)
typer.echo(f"Results saved to {export_dir}")
typer.echo(f"Export format: {export_format}")
return prediction
def run_backend_inference(
self,
image_paths: List[str],
export_dir: str,
backend_url: str,
export_format: str = "mini_npz-glb",
process_res: int = 504,
process_res_method: str = "upper_bound_resize",
export_feat_layers: List[int] = None,
extrinsics: Optional[np.ndarray] = None,
intrinsics: Optional[np.ndarray] = None,
align_to_input_ext_scale: bool = True,
use_ray_pose: bool = False,
ref_view_strategy: str = "saddle_balanced",
conf_thresh_percentile: float = 40.0,
num_max_points: int = 1_000_000,
show_cameras: bool = True,
feat_vis_fps: int = 15,
) -> Dict[str, Any]:
"""Run backend inference"""
if export_feat_layers is None:
export_feat_layers = []
# Check backend status
if not self._check_backend_status(backend_url):
raise typer.BadParameter(f"Backend service is not running at {backend_url}")
# Prepare payload
payload = {
"image_paths": image_paths,
"export_dir": export_dir,
"export_format": export_format,
"process_res": process_res,
"process_res_method": process_res_method,
"export_feat_layers": export_feat_layers,
"align_to_input_ext_scale": align_to_input_ext_scale,
"use_ray_pose": use_ray_pose,
"ref_view_strategy": ref_view_strategy,
"conf_thresh_percentile": conf_thresh_percentile,
"num_max_points": num_max_points,
"show_cameras": show_cameras,
"feat_vis_fps": feat_vis_fps,
}
# Add pose data (if exists)
if extrinsics is not None:
payload["extrinsics"] = [ext.astype(np.float64).tolist() for ext in extrinsics]
if intrinsics is not None:
payload["intrinsics"] = [intr.astype(np.float64).tolist() for intr in intrinsics]
# Submit task
typer.echo("Submitting inference task to backend...")
try:
response = requests.post(f"{backend_url}/inference", json=payload, timeout=30)
response.raise_for_status()
result = response.json()
if result["success"]:
task_id = result["task_id"]
typer.echo("Task submitted successfully!")
typer.echo(f"Task ID: {task_id}")
typer.echo(f"Results will be saved to: {export_dir}")
typer.echo(f"Check backend logs for progress updates with task ID: {task_id}")
return result
else:
raise typer.BadParameter(
f"Backend inference submission failed: {result['message']}"
)
except requests.exceptions.RequestException as e:
raise typer.BadParameter(f"Backend inference submission failed: {e}")
def _check_backend_status(self, backend_url: str) -> bool:
"""Check backend status"""
try:
response = requests.get(f"{backend_url}/status", timeout=5)
return response.status_code == 200
except Exception:
return False
def run_inference(
image_paths: List[str],
export_dir: str,
model_dir: str,
device: str = "cuda",
backend_url: Optional[str] = None,
export_format: str = "mini_npz-glb",
process_res: int = 504,
process_res_method: str = "upper_bound_resize",
export_feat_layers: List[int] = None,
extrinsics: Optional[np.ndarray] = None,
intrinsics: Optional[np.ndarray] = None,
align_to_input_ext_scale: bool = True,
use_ray_pose: bool = False,
ref_view_strategy: str = "saddle_balanced",
conf_thresh_percentile: float = 40.0,
num_max_points: int = 1_000_000,
show_cameras: bool = True,
feat_vis_fps: int = 15,
) -> Union[Any, Dict[str, Any]]:
"""Unified inference interface"""
service = InferenceService(model_dir, device)
if backend_url:
return service.run_backend_inference(
image_paths=image_paths,
export_dir=export_dir,
backend_url=backend_url,
export_format=export_format,
process_res=process_res,
process_res_method=process_res_method,
export_feat_layers=export_feat_layers,
extrinsics=extrinsics,
intrinsics=intrinsics,
align_to_input_ext_scale=align_to_input_ext_scale,
use_ray_pose=use_ray_pose,
ref_view_strategy=ref_view_strategy,
conf_thresh_percentile=conf_thresh_percentile,
num_max_points=num_max_points,
show_cameras=show_cameras,
feat_vis_fps=feat_vis_fps,
)
else:
return service.run_local_inference(
image_paths=image_paths,
export_dir=export_dir,
export_format=export_format,
process_res=process_res,
process_res_method=process_res_method,
export_feat_layers=export_feat_layers,
extrinsics=extrinsics,
intrinsics=intrinsics,
align_to_input_ext_scale=align_to_input_ext_scale,
use_ray_pose=use_ray_pose,
ref_view_strategy=ref_view_strategy,
conf_thresh_percentile=conf_thresh_percentile,
num_max_points=num_max_points,
show_cameras=show_cameras,
feat_vis_fps=feat_vis_fps,
)

View File

@@ -0,0 +1,266 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Input Processing Service
Handles different types of inputs (image, images, colmap, video)
"""
import glob
import os
from typing import List, Tuple
import cv2
import numpy as np
import typer
from ..utils.read_write_model import read_model
class InputHandler:
"""Base input handler class"""
@staticmethod
def validate_path(path: str, path_type: str = "file") -> str:
"""Validate path"""
if not os.path.exists(path):
raise typer.BadParameter(f"{path_type} not found: {path}")
return path
@staticmethod
def handle_export_dir(export_dir: str, auto_cleanup: bool = False) -> str:
"""Handle export directory"""
if os.path.exists(export_dir):
if auto_cleanup:
typer.echo(f"Auto-cleaning existing export directory: {export_dir}")
import shutil
shutil.rmtree(export_dir)
os.makedirs(export_dir, exist_ok=True)
else:
typer.echo(f"Export directory '{export_dir}' already exists.")
if typer.confirm("Do you want to clean it and continue?"):
import shutil
shutil.rmtree(export_dir)
os.makedirs(export_dir, exist_ok=True)
typer.echo(f"Cleaned export directory: {export_dir}")
else:
typer.echo("Operation cancelled.")
raise typer.Exit(0)
else:
os.makedirs(export_dir, exist_ok=True)
return export_dir
class ImageHandler(InputHandler):
"""Single image handler"""
@staticmethod
def process(image_path: str) -> List[str]:
"""Process single image"""
InputHandler.validate_path(image_path, "Image file")
return [image_path]
class ImagesHandler(InputHandler):
"""Image directory handler"""
@staticmethod
def process(images_dir: str, image_extensions: str = "png,jpg,jpeg") -> List[str]:
"""Process image directory"""
InputHandler.validate_path(images_dir, "Images directory")
# Parse extensions
extensions = [ext.strip().lower() for ext in image_extensions.split(",")]
extensions = [ext if ext.startswith(".") else f".{ext}" for ext in extensions]
# Find image files
image_files = []
for ext in extensions:
pattern = f"*{ext}"
image_files.extend(glob.glob(os.path.join(images_dir, pattern)))
image_files.extend(glob.glob(os.path.join(images_dir, pattern.upper())))
image_files = sorted(list(set(image_files))) # Remove duplicates and sort
if not image_files:
raise typer.BadParameter(
f"No image files found in {images_dir} with extensions: {extensions}"
)
typer.echo(f"Found {len(image_files)} images to process")
return image_files
class ColmapHandler(InputHandler):
"""COLMAP data handler"""
@staticmethod
def process(
colmap_dir: str, sparse_subdir: str = ""
) -> Tuple[List[str], np.ndarray, np.ndarray]:
"""Process COLMAP data"""
InputHandler.validate_path(colmap_dir, "COLMAP directory")
# Build paths
images_dir = os.path.join(colmap_dir, "images")
if sparse_subdir:
sparse_dir = os.path.join(colmap_dir, "sparse", sparse_subdir)
else:
sparse_dir = os.path.join(colmap_dir, "sparse")
InputHandler.validate_path(images_dir, "Images directory")
InputHandler.validate_path(sparse_dir, "Sparse reconstruction directory")
# Load COLMAP data
typer.echo("Loading COLMAP reconstruction data...")
try:
cameras, images, points3D = read_model(sparse_dir)
typer.echo(
f"Loaded COLMAP data: {len(cameras)} cameras, {len(images)} images, "
f"{len(points3D)} 3D points."
)
# Get image files and pose data
image_files = []
extrinsics = []
intrinsics = []
for image_id, image_data in images.items():
image_name = image_data.name
image_path = os.path.join(images_dir, image_name)
if os.path.exists(image_path):
image_files.append(image_path)
# Get camera parameters
camera = cameras[image_data.camera_id]
# Convert quaternion to rotation matrix
R = image_data.qvec2rotmat()
t = image_data.tvec
# Create extrinsic matrix (world to camera)
extrinsic = np.eye(4)
extrinsic[:3, :3] = R
extrinsic[:3, 3] = t
extrinsics.append(extrinsic)
# Create intrinsic matrix
if camera.model == "PINHOLE":
fx, fy, cx, cy = camera.params
elif camera.model == "SIMPLE_PINHOLE":
f, cx, cy = camera.params
fx = fy = f
else:
# For other models, use basic pinhole approximation
fx = fy = camera.params[0] if len(camera.params) > 0 else 1000
cx = camera.width / 2
cy = camera.height / 2
intrinsic = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
intrinsics.append(intrinsic)
if not image_files:
raise typer.BadParameter("No valid images found in COLMAP data")
typer.echo(f"Found {len(image_files)} valid images with pose data")
return image_files, np.array(extrinsics), np.array(intrinsics)
except Exception as e:
raise typer.BadParameter(f"Failed to load COLMAP data: {e}")
class VideoHandler(InputHandler):
"""Video handler"""
@staticmethod
def process(video_path: str, output_dir: str, fps: float = 1.0) -> List[str]:
"""Process video, extract frames"""
InputHandler.validate_path(video_path, "Video file")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise typer.BadParameter(f"Cannot open video: {video_path}")
# Get video properties
video_fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / video_fps
# Calculate frame interval (ensure at least 1)
frame_interval = max(1, int(video_fps / fps))
actual_fps = video_fps / frame_interval
typer.echo(f"Video FPS: {video_fps:.2f}, Duration: {duration:.2f}s")
# Warn if requested FPS is higher than video FPS
if fps > video_fps:
typer.echo(
f"⚠️ Warning: Requested sampling FPS ({fps:.2f}) exceeds video FPS ({video_fps:.2f})", # noqa: E501
err=True,
)
typer.echo(
f"⚠️ Using maximum available FPS: {actual_fps:.2f} (extracting every frame)",
err=True,
)
typer.echo(f"Extracting frames at {actual_fps:.2f} FPS (every {frame_interval} frame(s))")
# Create output directory
frames_dir = os.path.join(output_dir, "input_images")
os.makedirs(frames_dir, exist_ok=True)
frame_count = 0
saved_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
if frame_count % frame_interval == 0:
frame_path = os.path.join(frames_dir, f"{saved_count:06d}.png")
cv2.imwrite(frame_path, frame)
saved_count += 1
frame_count += 1
cap.release()
typer.echo(f"Extracted {saved_count} frames to {frames_dir}")
# Get frame file list
frame_files = sorted(
[f for f in os.listdir(frames_dir) if f.endswith((".png", ".jpg", ".jpeg"))]
)
if not frame_files:
raise typer.BadParameter("No frames extracted from video")
return [os.path.join(frames_dir, f) for f in frame_files]
def parse_export_feat(export_feat_str: str) -> List[int]:
"""Parse export_feat parameter"""
if not export_feat_str:
return []
try:
return [int(x.strip()) for x in export_feat_str.split(",") if x.strip()]
except ValueError:
raise typer.BadParameter(
f"Invalid export_feat format: {export_feat_str}. "
"Use comma-separated integers like '0,1,2'"
)

View File

@@ -0,0 +1,45 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Optional
import numpy as np
import torch
@dataclass
class Gaussians:
"""3DGS parameters, all in world space"""
means: torch.Tensor # world points, "batch gaussian dim"
scales: torch.Tensor # scales_std, "batch gaussian 3"
rotations: torch.Tensor # world_quat_wxyz, "batch gaussian 4"
harmonics: torch.Tensor # world SH, "batch gaussian 3 d_sh"
opacities: torch.Tensor # opacity | opacity SH, "batch gaussian" | "batch gaussian 1 d_sh"
@dataclass
class Prediction:
depth: np.ndarray # N, H, W
is_metric: int
sky: np.ndarray | None = None # N, H, W
conf: np.ndarray | None = None # N, H, W
extrinsics: np.ndarray | None = None # N, 4, 4
intrinsics: np.ndarray | None = None # N, 3, 3
processed_images: np.ndarray | None = None # N, H, W, 3 - processed images for visualization
gaussians: Gaussians | None = None # 3D gaussians
aux: dict[str, Any] = None #
scale_factor: Optional[float] = None # metric scale

View File

@@ -0,0 +1,163 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Alignment utilities for depth estimation and metric scaling.
"""
from typing import Tuple
import torch
def least_squares_scale_scalar(
a: torch.Tensor, b: torch.Tensor, eps: float = 1e-12
) -> torch.Tensor:
"""
Compute least squares scale factor s such that a ≈ s * b.
Args:
a: First tensor
b: Second tensor
eps: Small epsilon for numerical stability
Returns:
Scalar tensor containing the scale factor
Raises:
ValueError: If tensors have mismatched shapes or devices
TypeError: If tensors are not floating point
"""
if a.shape != b.shape:
raise ValueError(f"Shape mismatch: {a.shape} vs {b.shape}")
if a.device != b.device:
raise ValueError(f"Device mismatch: {a.device} vs {b.device}")
if not a.is_floating_point() or not b.is_floating_point():
raise TypeError("Tensors must be floating point type")
# Compute dot products for least squares solution
num = torch.dot(a.reshape(-1), b.reshape(-1))
den = torch.dot(b.reshape(-1), b.reshape(-1)).clamp_min(eps)
return num / den
def compute_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor:
"""
Compute non-sky mask from sky prediction.
Args:
sky_prediction: Sky prediction tensor
threshold: Threshold for sky classification
Returns:
Boolean mask where True indicates non-sky regions
"""
return sky_prediction < threshold
def compute_alignment_mask(
depth_conf: torch.Tensor,
non_sky_mask: torch.Tensor,
depth: torch.Tensor,
metric_depth: torch.Tensor,
median_conf: torch.Tensor,
min_depth_threshold: float = 1e-3,
min_metric_depth_threshold: float = 1e-2,
) -> torch.Tensor:
"""
Compute mask for depth alignment based on confidence and depth thresholds.
Args:
depth_conf: Depth confidence tensor
non_sky_mask: Non-sky region mask
depth: Predicted depth tensor
metric_depth: Metric depth tensor
median_conf: Median confidence threshold
min_depth_threshold: Minimum depth threshold
min_metric_depth_threshold: Minimum metric depth threshold
Returns:
Boolean mask for valid alignment regions
"""
return (
(depth_conf >= median_conf)
& non_sky_mask
& (metric_depth > min_metric_depth_threshold)
& (depth > min_depth_threshold)
)
def sample_tensor_for_quantile(tensor: torch.Tensor, max_samples: int = 100000) -> torch.Tensor:
"""
Sample tensor elements for quantile computation to reduce memory usage.
Args:
tensor: Input tensor to sample
max_samples: Maximum number of samples to take
Returns:
Sampled tensor
"""
if tensor.numel() <= max_samples:
return tensor
idx = torch.randperm(tensor.numel(), device=tensor.device)[:max_samples]
return tensor.flatten()[idx]
def apply_metric_scaling(
depth: torch.Tensor, intrinsics: torch.Tensor, scale_factor: float = 300.0
) -> torch.Tensor:
"""
Apply metric scaling to depth based on camera intrinsics.
Args:
depth: Input depth tensor
intrinsics: Camera intrinsics tensor
scale_factor: Scaling factor for metric conversion
Returns:
Scaled depth tensor
"""
focal_length = (intrinsics[:, :, 0, 0] + intrinsics[:, :, 1, 1]) / 2
return depth * (focal_length[:, :, None, None] / scale_factor)
def set_sky_regions_to_max_depth(
depth: torch.Tensor,
depth_conf: torch.Tensor,
non_sky_mask: torch.Tensor,
max_depth: float = 200.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Set sky regions to maximum depth and high confidence.
Args:
depth: Depth tensor
depth_conf: Depth confidence tensor
non_sky_mask: Non-sky region mask
max_depth: Maximum depth value for sky regions
Returns:
Tuple of (updated_depth, updated_depth_conf)
"""
depth = depth.clone()
# Set sky regions to max depth and high confidence
depth[~non_sky_mask] = max_depth
if depth_conf is not None:
depth_conf = depth_conf.clone()
depth_conf[~non_sky_mask] = 1.0
return depth, depth_conf
else:
return depth, None

View File

@@ -0,0 +1,58 @@
import argparse
def parse_scalar(s):
if not isinstance(s, str):
return s
t = s.strip()
l = t.lower()
if l == "true":
return True
if l == "false":
return False
if l in ("none", "null"):
return None
try:
return int(t, 10)
except Exception:
pass
try:
return float(t)
except Exception:
return s
def fn_kv_csv(s: str) -> dict[str, dict[str, object]]:
"""
Parse a string of comma-separated triplets: fn:key:value
Returns:
dict[fn_name] -> dict[key] = parsed_value
Example:
"fn1:width:1920,fn1:height:1080,fn2:quality:0.8"
-> {"fn1": {"width": 1920, "height": 1080}, "fn2": {"quality": 0.8}}
"""
result: dict[str, dict[str, object]] = {}
if not s:
return result
for item in s.split(","):
if not item:
continue
parts = item.split(":", 2) # allow value to contain ":" beyond first two separators
if len(parts) < 3:
raise argparse.ArgumentTypeError(f"Bad item '{item}', expected FN:KEY:VALUE")
fn, key, raw_val = parts[0], parts[1], parts[2]
# If you need to allow colons in values, join leftover parts:
# fn, key, raw_val = parts[0], parts[1], ":".join(parts[2:])
if not fn:
raise argparse.ArgumentTypeError(f"Bad item '{item}': empty function name")
if not key:
raise argparse.ArgumentTypeError(f"Bad item '{item}': empty key")
val = parse_scalar(raw_val)
bucket = result.setdefault(fn, {})
bucket[key] = val
return result

View File

@@ -0,0 +1,479 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from einops import einsum, rearrange, reduce
try:
from scipy.spatial.transform import Rotation as R
except ImportError:
from depth_anything_3.utils.logger import logger
logger.warn("Dependency 'scipy' not found. Required for interpolating camera trajectory.")
from depth_anything_3.utils.geometry import as_homogeneous
@torch.no_grad()
def render_stabilization_path(poses, k_size=45):
"""Rendering stabilized camera path.
poses: [batch, 4, 4] or [batch, 3, 4],
return:
smooth path: [batch 4 4]"""
num_frames = poses.shape[0]
device = poses.device
dtype = poses.dtype
# Early exit for trivial cases
if num_frames <= 1:
return as_homogeneous(poses)
# Make k_size safe: positive odd and not larger than num_frames
# 1) Ensure odd
if k_size < 1:
k_size = 1
if k_size % 2 == 0:
k_size += 1
# 2) Cap to num_frames (keep odd)
max_odd = num_frames if (num_frames % 2 == 1) else (num_frames - 1)
if max_odd < 1:
max_odd = 1 # covers num_frames == 0 theoretically
k_size = min(k_size, max_odd)
# 3) enforce a minimum of 3 when possible (for better smoothing)
if num_frames >= 3 and k_size < 3:
k_size = 3
input_poses = []
for i in range(num_frames):
input_poses.append(
torch.cat([poses[i, :3, 0:1], poses[i, :3, 1:2], poses[i, :3, 3:4]], dim=-1)
)
input_poses = torch.stack(input_poses) # (num_frames, 3, 3)
# Prepare Gaussian kernel
gaussian_kernel = cv2.getGaussianKernel(ksize=k_size, sigma=-1).astype(np.float32).squeeze()
gaussian_kernel = torch.tensor(gaussian_kernel, dtype=dtype, device=device).view(1, 1, -1)
pad = k_size // 2
output_vectors = []
for idx in range(3): # For r1, r2, t
vec = (
input_poses[:, :, idx].T.unsqueeze(0).unsqueeze(0)
) # (1, 1, 3, num_frames) -> (1, 1, 3, num_frames)
# But actually, we want (batch=3, channel=1, width=num_frames)
# So:
vec = input_poses[:, :, idx].T.unsqueeze(1) # (3, 1, num_frames)
vec_padded = F.pad(vec, (pad, pad), mode="reflect")
filtered = F.conv1d(vec_padded, gaussian_kernel)
output_vectors.append(filtered.squeeze(1).T) # (num_frames, 3)
output_r1, output_r2, output_t = output_vectors # Each is (num_frames, 3)
# Normalize r1 and r2
output_r1 = output_r1 / output_r1.norm(dim=-1, keepdim=True)
output_r2 = output_r2 / output_r2.norm(dim=-1, keepdim=True)
output_poses = []
for i in range(num_frames):
output_r3 = torch.linalg.cross(output_r1[i], output_r2[i])
render_pose = torch.cat(
[
output_r1[i].unsqueeze(-1),
output_r2[i].unsqueeze(-1),
output_r3.unsqueeze(-1),
output_t[i].unsqueeze(-1),
],
dim=-1,
)
output_poses.append(render_pose[:3, :])
output_poses = as_homogeneous(torch.stack(output_poses, dim=0))
return output_poses
@torch.no_grad()
def render_wander_path(
cam2world: torch.Tensor,
intrinsic: torch.Tensor,
h: int,
w: int,
num_frames: int = 120,
max_disp: float = 48.0,
):
device, dtype = cam2world.device, cam2world.dtype
fx = intrinsic[0, 0] * w
r = max_disp / fx
th = torch.linspace(0, 2.0 * torch.pi, steps=num_frames, device=device, dtype=dtype)
x = r * torch.sin(th)
yz = r * torch.cos(th) / 3.0
T = torch.eye(4, device=device, dtype=dtype).unsqueeze(0).repeat(num_frames, 1, 1)
T[:, :3, 3] = torch.stack([x, yz, yz], dim=-1) * -1.0
c2ws = cam2world.unsqueeze(0) @ T
# Start at reference pose and end back at reference pose
c2ws = torch.cat([cam2world.unsqueeze(0), c2ws, cam2world.unsqueeze(0)], dim=0)
Ks = intrinsic.unsqueeze(0).repeat(c2ws.shape[0], 1, 1)
return c2ws, Ks
@torch.no_grad()
def render_dolly_zoom_path(
cam2world: torch.Tensor,
intrinsic: torch.Tensor,
h: int,
w: int,
num_frames: int = 120,
max_disp: float = 0.1,
D_focus: float = 10.0,
):
device, dtype = cam2world.device, cam2world.dtype
fx0, fy0 = intrinsic[0, 0] * w, intrinsic[1, 1] * h
t = torch.linspace(0.0, 2.0, steps=num_frames, device=device, dtype=dtype)
z = 0.5 * (1.0 - torch.cos(torch.pi * t)) * max_disp
T = torch.eye(4, device=device, dtype=dtype).unsqueeze(0).repeat(num_frames, 1, 1)
T[:, 2, 3] = -z
c2ws = cam2world.unsqueeze(0) @ T
Df = torch.as_tensor(D_focus, device=device, dtype=dtype)
scale = (Df / (Df + z)).clamp(min=1e-6)
Ks = intrinsic.unsqueeze(0).repeat(num_frames, 1, 1)
Ks[:, 0, 0] = (fx0 * scale) / w
Ks[:, 1, 1] = (fy0 * scale) / h
return c2ws, Ks
@torch.no_grad()
def interpolate_intrinsics(
initial: torch.Tensor, # "*#batch 3 3"
final: torch.Tensor, # "*#batch 3 3"
t: torch.Tensor, # " time_step"
) -> torch.Tensor: # "*batch time_step 3 3"
initial = rearrange(initial, "... i j -> ... () i j")
final = rearrange(final, "... i j -> ... () i j")
t = rearrange(t, "t -> t () ()")
return initial + (final - initial) * t
def intersect_rays(
a_origins: torch.Tensor, # "*#batch dim"
a_directions: torch.Tensor, # "*#batch dim"
b_origins: torch.Tensor, # "*#batch dim"
b_directions: torch.Tensor, # "*#batch dim"
) -> torch.Tensor: # "*batch dim"
"""Compute the least-squares intersection of rays. Uses the math from here:
https://math.stackexchange.com/a/1762491/286022
"""
# Broadcast and stack the tensors.
a_origins, a_directions, b_origins, b_directions = torch.broadcast_tensors(
a_origins, a_directions, b_origins, b_directions
)
origins = torch.stack((a_origins, b_origins), dim=-2)
directions = torch.stack((a_directions, b_directions), dim=-2)
# Compute n_i * n_i^T - eye(3) from the equation.
n = einsum(directions, directions, "... n i, ... n j -> ... n i j")
n = n - torch.eye(3, dtype=origins.dtype, device=origins.device)
# Compute the left-hand side of the equation.
lhs = reduce(n, "... n i j -> ... i j", "sum")
# Compute the right-hand side of the equation.
rhs = einsum(n, origins, "... n i j, ... n j -> ... n i")
rhs = reduce(rhs, "... n i -> ... i", "sum")
# Left-matrix-multiply both sides by the inverse of lhs to find p.
return torch.linalg.lstsq(lhs, rhs).solution
def normalize(a: torch.Tensor) -> torch.Tensor: # "*#batch dim" -> "*#batch dim"
return a / a.norm(dim=-1, keepdim=True)
def generate_coordinate_frame(
y: torch.Tensor, # "*#batch 3"
z: torch.Tensor, # "*#batch 3"
) -> torch.Tensor: # "*batch 3 3"
"""Generate a coordinate frame given perpendicular, unit-length Y and Z vectors."""
y, z = torch.broadcast_tensors(y, z)
return torch.stack([y.cross(z, dim=-1), y, z], dim=-1)
def generate_rotation_coordinate_frame(
a: torch.Tensor, # "*#batch 3"
b: torch.Tensor, # "*#batch 3"
eps: float = 1e-4,
) -> torch.Tensor: # "*batch 3 3"
"""Generate a coordinate frame where the Y direction is normal to the plane defined
by unit vectors a and b. The other axes are arbitrary."""
device = a.device
# Replace every entry in b that's parallel to the corresponding entry in a with an
# arbitrary vector.
b = b.detach().clone()
parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps
b[parallel] = torch.tensor([0, 0, 1], dtype=b.dtype, device=device)
parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps
b[parallel] = torch.tensor([0, 1, 0], dtype=b.dtype, device=device)
# Generate the coordinate frame. The initial cross product defines the plane.
return generate_coordinate_frame(normalize(torch.linalg.cross(a, b)), a)
def matrix_to_euler(
rotations: torch.Tensor, # "*batch 3 3"
pattern: str,
) -> torch.Tensor: # "*batch 3"
*batch, _, _ = rotations.shape
rotations = rotations.reshape(-1, 3, 3)
angles_np = R.from_matrix(rotations.detach().cpu().numpy()).as_euler(pattern)
rotations = torch.tensor(angles_np, dtype=rotations.dtype, device=rotations.device)
return rotations.reshape(*batch, 3)
def euler_to_matrix(
rotations: torch.Tensor, # "*batch 3"
pattern: str,
) -> torch.Tensor: # "*batch 3 3"
*batch, _ = rotations.shape
rotations = rotations.reshape(-1, 3)
matrix_np = R.from_euler(pattern, rotations.detach().cpu().numpy()).as_matrix()
rotations = torch.tensor(matrix_np, dtype=rotations.dtype, device=rotations.device)
return rotations.reshape(*batch, 3, 3)
def extrinsics_to_pivot_parameters(
extrinsics: torch.Tensor, # "*#batch 4 4"
pivot_coordinate_frame: torch.Tensor, # "*#batch 3 3"
pivot_point: torch.Tensor, # "*#batch 3"
) -> torch.Tensor: # "*batch 5"
"""Convert the extrinsics to a representation with 5 degrees of freedom:
1. Distance from pivot point in the "X" (look cross pivot axis) direction.
2. Distance from pivot point in the "Y" (pivot axis) direction.
3. Distance from pivot point in the Z (look) direction
4. Angle in plane
5. Twist (rotation not in plane)
"""
# The pivot coordinate frame's Z axis is normal to the plane.
pivot_axis = pivot_coordinate_frame[..., :, 1]
# Compute the translation elements of the pivot parametrization.
translation_frame = generate_coordinate_frame(pivot_axis, extrinsics[..., :3, 2])
origin = extrinsics[..., :3, 3]
delta = pivot_point - origin
translation = einsum(translation_frame, delta, "... i j, ... i -> ... j")
# Add the rotation elements of the pivot parametrization.
inverted = pivot_coordinate_frame.inverse() @ extrinsics[..., :3, :3]
y, _, z = matrix_to_euler(inverted, "YXZ").unbind(dim=-1)
return torch.cat([translation, y[..., None], z[..., None]], dim=-1)
def pivot_parameters_to_extrinsics(
parameters: torch.Tensor, # "*#batch 5"
pivot_coordinate_frame: torch.Tensor, # "*#batch 3 3"
pivot_point: torch.Tensor, # "*#batch 3"
) -> torch.Tensor: # "*batch 4 4"
translation, y, z = parameters.split((3, 1, 1), dim=-1)
euler = torch.cat((y, torch.zeros_like(y), z), dim=-1)
rotation = pivot_coordinate_frame @ euler_to_matrix(euler, "YXZ")
# The pivot coordinate frame's Z axis is normal to the plane.
pivot_axis = pivot_coordinate_frame[..., :, 1]
translation_frame = generate_coordinate_frame(pivot_axis, rotation[..., :3, 2])
delta = einsum(translation_frame, translation, "... i j, ... j -> ... i")
origin = pivot_point - delta
*batch, _ = origin.shape
extrinsics = torch.eye(4, dtype=parameters.dtype, device=parameters.device)
extrinsics = extrinsics.broadcast_to((*batch, 4, 4)).clone()
extrinsics[..., 3, 3] = 1
extrinsics[..., :3, :3] = rotation
extrinsics[..., :3, 3] = origin
return extrinsics
def interpolate_circular(
a: torch.Tensor, # "*#batch"
b: torch.Tensor, # "*#batch"
t: torch.Tensor, # "*#batch"
) -> torch.Tensor: # " *batch"
a, b, t = torch.broadcast_tensors(a, b, t)
tau = 2 * torch.pi
a = a % tau
b = b % tau
# Consider piecewise edge cases.
d = (b - a).abs()
a_left = a - tau
d_left = (b - a_left).abs()
a_right = a + tau
d_right = (b - a_right).abs()
use_d = (d < d_left) & (d < d_right)
use_d_left = (d_left < d_right) & (~use_d)
use_d_right = (~use_d) & (~use_d_left)
result = a + (b - a) * t
result[use_d_left] = (a_left + (b - a_left) * t)[use_d_left]
result[use_d_right] = (a_right + (b - a_right) * t)[use_d_right]
return result
def interpolate_pivot_parameters(
initial: torch.Tensor, # "*#batch 5"
final: torch.Tensor, # "*#batch 5"
t: torch.Tensor, # " time_step"
) -> torch.Tensor: # "*batch time_step 5"
initial = rearrange(initial, "... d -> ... () d")
final = rearrange(final, "... d -> ... () d")
t = rearrange(t, "t -> t ()")
ti, ri = initial.split((3, 2), dim=-1)
tf, rf = final.split((3, 2), dim=-1)
t_lerp = ti + (tf - ti) * t
r_lerp = interpolate_circular(ri, rf, t)
return torch.cat((t_lerp, r_lerp), dim=-1)
@torch.no_grad()
def interpolate_extrinsics(
initial: torch.Tensor, # "*#batch 4 4"
final: torch.Tensor, # "*#batch 4 4"
t: torch.Tensor, # " time_step"
eps: float = 1e-4,
) -> torch.Tensor: # "*batch time_step 4 4"
"""Interpolate extrinsics by rotating around their "focus point," which is the
least-squares intersection between the look vectors of the initial and final
extrinsics.
"""
initial = initial.type(torch.float64)
final = final.type(torch.float64)
t = t.type(torch.float64)
# Based on the dot product between the look vectors, pick from one of two cases:
# 1. Look vectors are parallel: interpolate about their origins' midpoint.
# 3. Look vectors aren't parallel: interpolate about their focus point.
initial_look = initial[..., :3, 2]
final_look = final[..., :3, 2]
dot_products = einsum(initial_look, final_look, "... i, ... i -> ...")
parallel_mask = (dot_products.abs() - 1).abs() < eps
# Pick focus points.
initial_origin = initial[..., :3, 3]
final_origin = final[..., :3, 3]
pivot_point = 0.5 * (initial_origin + final_origin)
pivot_point[~parallel_mask] = intersect_rays(
initial_origin[~parallel_mask],
initial_look[~parallel_mask],
final_origin[~parallel_mask],
final_look[~parallel_mask],
)
# Convert to pivot parameters.
pivot_frame = generate_rotation_coordinate_frame(initial_look, final_look, eps=eps)
initial_params = extrinsics_to_pivot_parameters(initial, pivot_frame, pivot_point)
final_params = extrinsics_to_pivot_parameters(final, pivot_frame, pivot_point)
# Interpolate the pivot parameters.
interpolated_params = interpolate_pivot_parameters(initial_params, final_params, t)
# Convert back.
return pivot_parameters_to_extrinsics(
interpolated_params.type(torch.float32),
rearrange(pivot_frame, "... i j -> ... () i j").type(torch.float32),
rearrange(pivot_point, "... xyz -> ... () xyz").type(torch.float32),
)
@torch.no_grad()
def generate_wobble_transformation(
radius: torch.Tensor, # "*#batch"
t: torch.Tensor, # " time_step"
num_rotations: int = 1,
scale_radius_with_t: bool = True,
) -> torch.Tensor: # "*batch time_step 4 4"]:
# Generate a translation in the image plane.
tf = torch.eye(4, dtype=torch.float32, device=t.device)
tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone()
radius = radius[..., None]
if scale_radius_with_t:
radius = radius * t
tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius
tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius
return tf
@torch.no_grad()
def render_wobble_inter_path(
cam2world: torch.Tensor, intr_normed: torch.Tensor, inter_len: int, n_skip: int = 3
):
"""
cam2world: [batch, 4, 4],
intr_normed: [batch, 3, 3]
"""
frame_per_round = n_skip * inter_len
num_rotations = 1
t = torch.linspace(0, 1, frame_per_round, dtype=torch.float32, device=cam2world.device)
# t = (torch.cos(torch.pi * (t + 1)) + 1) / 2
tgt_c2w_b = []
tgt_intr_b = []
for b_idx in range(cam2world.shape[0]):
tgt_c2w = []
tgt_intr = []
for cur_idx in range(0, cam2world.shape[1] - n_skip, n_skip):
origin_a = cam2world[b_idx, cur_idx, :3, 3]
origin_b = cam2world[b_idx, cur_idx + n_skip, :3, 3]
delta = (origin_a - origin_b).norm(dim=-1)
if cur_idx == 0:
delta_prev = delta
else:
delta = (delta_prev + delta) / 2
delta_prev = delta
tf = generate_wobble_transformation(
radius=delta * 0.5,
t=t,
num_rotations=num_rotations,
scale_radius_with_t=False,
)
cur_extrs = (
interpolate_extrinsics(
cam2world[b_idx, cur_idx],
cam2world[b_idx, cur_idx + n_skip],
t,
)
@ tf
)
tgt_c2w.append(cur_extrs[(0 if cur_idx == 0 else 1) :])
tgt_intr.append(
interpolate_intrinsics(
intr_normed[b_idx, cur_idx],
intr_normed[b_idx, cur_idx + n_skip],
t,
)[(0 if cur_idx == 0 else 1) :]
)
tgt_c2w_b.append(torch.cat(tgt_c2w))
tgt_intr_b.append(torch.cat(tgt_intr))
tgt_c2w = torch.stack(tgt_c2w_b) # b v 4 4
tgt_intr = torch.stack(tgt_intr_b) # b v 3 3
return tgt_c2w, tgt_intr

View File

@@ -0,0 +1,19 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
DEFAULT_MODEL = "depth-anything/DA3NESTED-GIANT-LARGE"
DEFAULT_EXPORT_DIR = "workspace/gallery/scene"
DEFAULT_GALLERY_DIR = "workspace/gallery"
DEFAULT_GRADIO_DIR = "workspace/gradio"
THRESH_FOR_REF_SELECTION = 3

View File

@@ -0,0 +1,59 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from depth_anything_3.specs import Prediction
from depth_anything_3.utils.export.gs import export_to_gs_ply, export_to_gs_video
from .colmap import export_to_colmap
from .depth_vis import export_to_depth_vis
from .feat_vis import export_to_feat_vis
from .glb import export_to_glb
from .npz import export_to_mini_npz, export_to_npz
def export(
prediction: Prediction,
export_format: str,
export_dir: str,
**kwargs,
):
if "-" in export_format:
export_formats = export_format.split("-")
for export_format in export_formats:
export(prediction, export_format, export_dir, **kwargs)
return # Prevent falling through to single-format handling
if export_format == "glb":
export_to_glb(prediction, export_dir, **kwargs.get(export_format, {}))
elif export_format == "mini_npz":
export_to_mini_npz(prediction, export_dir)
elif export_format == "npz":
export_to_npz(prediction, export_dir)
elif export_format == "feat_vis":
export_to_feat_vis(prediction, export_dir, **kwargs.get(export_format, {}))
elif export_format == "depth_vis":
export_to_depth_vis(prediction, export_dir)
elif export_format == "gs_ply":
export_to_gs_ply(prediction, export_dir, **kwargs.get(export_format, {}))
elif export_format == "gs_video":
export_to_gs_video(prediction, export_dir, **kwargs.get(export_format, {}))
elif export_format == "colmap":
export_to_colmap(prediction, export_dir, **kwargs.get(export_format, {}))
else:
raise ValueError(f"Unsupported export format: {export_format}")
__all__ = [
export,
]

View File

@@ -0,0 +1,150 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pycolmap
import cv2 as cv
import numpy as np
from PIL import Image
from depth_anything_3.specs import Prediction
from depth_anything_3.utils.logger import logger
from .glb import _depths_to_world_points_with_colors
def export_to_colmap(
prediction: Prediction,
export_dir: str,
image_paths: list[str],
conf_thresh_percentile: float = 40.0,
process_res_method: str = "upper_bound_resize",
) -> None:
# 1. Data preparation
conf_thresh = np.percentile(prediction.conf, conf_thresh_percentile)
points, colors = _depths_to_world_points_with_colors(
prediction.depth,
prediction.intrinsics,
prediction.extrinsics, # w2c
prediction.processed_images,
prediction.conf,
conf_thresh,
)
num_points = len(points)
logger.info(f"Exporting to COLMAP with {num_points} points")
num_frames = len(prediction.processed_images)
h, w = prediction.processed_images.shape[1:3]
points_xyf = _create_xyf(num_frames, h, w)
points_xyf = points_xyf[prediction.conf >= conf_thresh]
# 2. Set Reconstruction
reconstruction = pycolmap.Reconstruction()
point3d_ids = []
for vidx in range(num_points):
point3d_id = reconstruction.add_point3D(points[vidx], pycolmap.Track(), colors[vidx])
point3d_ids.append(point3d_id)
for fidx in range(num_frames):
orig_w, orig_h = Image.open(image_paths[fidx]).size
intrinsic = prediction.intrinsics[fidx]
if process_res_method.endswith("resize"):
intrinsic[:1] *= orig_w / w
intrinsic[1:2] *= orig_h / h
elif process_res_method == "crop":
raise NotImplementedError("COLMAP export for crop method is not implemented")
else:
raise ValueError(f"Unknown process_res_method: {process_res_method}")
pycolmap_intri = np.array(
[intrinsic[0, 0], intrinsic[1, 1], intrinsic[0, 2], intrinsic[1, 2]]
)
extrinsic = prediction.extrinsics[fidx]
cam_from_world = pycolmap.Rigid3d(pycolmap.Rotation3d(extrinsic[:3, :3]), extrinsic[:3, 3])
# set and add camera
camera = pycolmap.Camera()
camera.camera_id = fidx + 1
camera.model = pycolmap.CameraModelId.PINHOLE
camera.width = orig_w
camera.height = orig_h
camera.params = pycolmap_intri
reconstruction.add_camera(camera)
# set and add rig (from camera)
rig = pycolmap.Rig()
rig.rig_id = camera.camera_id
rig.add_ref_sensor(camera.sensor_id)
reconstruction.add_rig(rig)
# set image
image = pycolmap.Image()
image.image_id = fidx + 1
image.camera_id = camera.camera_id
# set and add frame (from image)
frame = pycolmap.Frame()
frame.frame_id = image.image_id
frame.rig_id = camera.camera_id
frame.add_data_id(image.data_id)
frame.rig_from_world = cam_from_world
reconstruction.add_frame(frame)
# set point2d and update track
point2d_list = []
points_in_frame = points_xyf[:, 2].astype(np.int32) == fidx
for vidx in np.where(points_in_frame)[0]:
point2d = points_xyf[vidx][:2]
point2d[0] *= orig_w / w
point2d[1] *= orig_h / h
point3d_id = point3d_ids[vidx]
point2d_list.append(pycolmap.Point2D(point2d, point3d_id))
reconstruction.point3D(point3d_id).track.add_element(
image.image_id, len(point2d_list) - 1
)
# set and add image
image.frame_id = image.image_id
image.name = os.path.basename(image_paths[fidx])
image.points2D = pycolmap.Point2DList(point2d_list)
reconstruction.add_image(image)
# 3. Export
reconstruction.write(export_dir)
def _create_xyf(num_frames, height, width):
"""
Creates a grid of pixel coordinates and frame indices (fidx) for all frames.
"""
# Create coordinate grids for a single frame
y_grid, x_grid = np.indices((height, width), dtype=np.int32)
x_grid = x_grid[np.newaxis, :, :]
y_grid = y_grid[np.newaxis, :, :]
# Broadcast to all frames
x_coords = np.broadcast_to(x_grid, (num_frames, height, width))
y_coords = np.broadcast_to(y_grid, (num_frames, height, width))
# Create frame indices and broadcast
f_idx = np.arange(num_frames, dtype=np.int32)[:, np.newaxis, np.newaxis]
f_coords = np.broadcast_to(f_idx, (num_frames, height, width))
# Stack coordinates and frame indices
points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1)
return points_xyf

View File

@@ -0,0 +1,41 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import imageio
import numpy as np
from depth_anything_3.specs import Prediction
from depth_anything_3.utils.visualize import visualize_depth
def export_to_depth_vis(
prediction: Prediction,
export_dir: str,
):
# Use prediction.processed_images, which is already processed image data
if prediction.processed_images is None:
raise ValueError("prediction.processed_images is required but not available")
images_u8 = prediction.processed_images # (N,H,W,3) uint8
os.makedirs(os.path.join(export_dir, "depth_vis"), exist_ok=True)
for idx in range(prediction.depth.shape[0]):
depth_vis = visualize_depth(prediction.depth[idx])
image_vis = images_u8[idx]
depth_vis = depth_vis.astype(np.uint8)
image_vis = image_vis.astype(np.uint8)
vis_image = np.concatenate([image_vis, depth_vis], axis=1)
save_path = os.path.join(export_dir, f"depth_vis/{idx:04d}.jpg")
imageio.imwrite(save_path, vis_image, quality=95)

View File

@@ -0,0 +1,65 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import imageio
import numpy as np
from tqdm.auto import tqdm
from depth_anything_3.utils.parallel_utils import async_call
from depth_anything_3.utils.pca_utils import PCARGBVisualizer
@async_call
def export_to_feat_vis(
prediction,
export_dir,
fps=15,
):
"""Export feature visualization with PCA.
Args:
prediction: Model prediction containing feature maps
export_dir: Directory to export results
fps: Frame rate for output video (default: 15)
"""
out_dir = os.path.join(export_dir, "feat_vis")
os.makedirs(out_dir, exist_ok=True)
images = prediction.processed_images
for k, v in prediction.aux.items():
if not k.startswith("feat_layer_"):
continue
os.makedirs(os.path.join(out_dir, k), exist_ok=True)
viz = PCARGBVisualizer(basis_mode="fixed", percentile_mode="global", clip_percent=10.0)
viz.fit_reference(v)
feats_vis = viz.transform_video(v)
for idx in tqdm(range(len(feats_vis))):
img = images[idx]
feat_vis = (feats_vis[idx] * 255).astype(np.uint8)
feat_vis = cv2.resize(
feat_vis, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST
)
save_path = os.path.join(out_dir, f"{k}/{idx:06d}.jpg")
save = np.concatenate([img, feat_vis], axis=1)
imageio.imwrite(save_path, save, quality=95)
cmd = (
"ffmpeg -loglevel error -hide_banner -y "
f"-framerate {fps} -start_number 0 "
f"-i {out_dir}/{k}/%06d.jpg "
f"-c:v libx264 -pix_fmt yuv420p "
f"{out_dir}/{k}.mp4"
)
os.system(cmd)

View File

@@ -0,0 +1,432 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import numpy as np
import trimesh
from depth_anything_3.specs import Prediction
from depth_anything_3.utils.logger import logger
from .depth_vis import export_to_depth_vis
def set_sky_depth(prediction: Prediction, sky_mask: np.ndarray, sky_depth_def: float = 98.0):
non_sky_mask = ~sky_mask
valid_depth = prediction.depth[non_sky_mask]
if valid_depth.size > 0:
max_depth = np.percentile(valid_depth, sky_depth_def)
prediction.depth[sky_mask] = max_depth
def get_conf_thresh(
prediction: Prediction,
sky_mask: np.ndarray,
conf_thresh: float,
conf_thresh_percentile: float = 10.0,
ensure_thresh_percentile: float = 90.0,
):
if sky_mask is not None and (~sky_mask).sum() > 10:
conf_pixels = prediction.conf[~sky_mask]
else:
conf_pixels = prediction.conf
lower = np.percentile(conf_pixels, conf_thresh_percentile)
upper = np.percentile(conf_pixels, ensure_thresh_percentile)
conf_thresh = min(max(conf_thresh, lower), upper)
return conf_thresh
def export_to_glb(
prediction: Prediction,
export_dir: str,
num_max_points: int = 1_000_000,
conf_thresh: float = 1.05,
filter_black_bg: bool = False,
filter_white_bg: bool = False,
conf_thresh_percentile: float = 40.0,
ensure_thresh_percentile: float = 90.0,
sky_depth_def: float = 98.0,
show_cameras: bool = True,
camera_size: float = 0.03,
export_depth_vis: bool = True,
) -> str:
"""Generate a 3D point cloud and camera wireframes and export them as a ``.glb`` file.
The function builds a point cloud from the predicted depth maps, aligns it to the
first camera in glTF coordinates (X-right, Y-up, Z-backward), optionally draws
camera wireframes, and writes the result to ``scene.glb``. Auxiliary assets such as
depth visualizations can also be generated alongside the main export.
Args:
prediction: Model prediction containing depth, confidence, intrinsics, extrinsics,
and pre-processed images.
export_dir: Output directory where the glTF assets will be written.
num_max_points: Maximum number of points retained after downsampling.
conf_thresh: Base confidence threshold used before percentile adjustments.
filter_black_bg: Mark near-black background pixels for removal during confidence filtering.
filter_white_bg: Mark near-white background pixels for removal during confidence filtering.
conf_thresh_percentile: Lower percentile used when adapting the confidence threshold.
ensure_thresh_percentile: Upper percentile clamp for the adaptive threshold.
sky_depth_def: Percentile used to fill sky pixels with plausible depth values.
show_cameras: Whether to render camera wireframes in the exported scene.
camera_size: Relative camera wireframe scale as a fraction of the scene diagonal.
export_depth_vis: Whether to export raster depth visualisations alongside the glTF.
Returns:
Path to the exported ``scene.glb`` file.
"""
# 1) Use prediction.processed_images, which is already processed image data
assert (
prediction.processed_images is not None
), "Export to GLB: prediction.processed_images is required but not available"
assert (
prediction.depth is not None
), "Export to GLB: prediction.depth is required but not available"
assert (
prediction.intrinsics is not None
), "Export to GLB: prediction.intrinsics is required but not available"
assert (
prediction.extrinsics is not None
), "Export to GLB: prediction.extrinsics is required but not available"
assert (
prediction.conf is not None
), "Export to GLB: prediction.conf is required but not available"
logger.info(f"conf_thresh_percentile: {conf_thresh_percentile}")
logger.info(f"num max points: {num_max_points}")
logger.info(f"Exporting to GLB with num_max_points: {num_max_points}")
if prediction.processed_images is None:
raise ValueError("prediction.processed_images is required but not available")
images_u8 = prediction.processed_images # (N,H,W,3) uint8
# 2) Sky processing (if sky_mask is provided)
if getattr(prediction, "sky_mask", None) is not None:
set_sky_depth(prediction, prediction.sky_mask, sky_depth_def)
# 3) Confidence threshold (if no conf, then no filtering)
if filter_black_bg:
prediction.conf[(prediction.processed_images < 16).all(axis=-1)] = 1.0
if filter_white_bg:
prediction.conf[(prediction.processed_images >= 240).all(axis=-1)] = 1.0
conf_thr = get_conf_thresh(
prediction,
getattr(prediction, "sky_mask", None),
conf_thresh,
conf_thresh_percentile,
ensure_thresh_percentile,
)
# 4) Back-project to world coordinates and get colors (world frame)
points, colors = _depths_to_world_points_with_colors(
prediction.depth,
prediction.intrinsics,
prediction.extrinsics, # w2c
images_u8,
prediction.conf,
conf_thr,
)
# 5) Based on first camera orientation + glTF axis system, center by point cloud,
# construct alignment transform, and apply to point cloud
A = _compute_alignment_transform_first_cam_glTF_center_by_points(
prediction.extrinsics[0], points
) # (4,4)
if points.shape[0] > 0:
points = trimesh.transform_points(points, A)
# 6) Clean + downsample
points, colors = _filter_and_downsample(points, colors, num_max_points)
# 7) Assemble scene (add point cloud first)
scene = trimesh.Scene()
if scene.metadata is None:
scene.metadata = {}
scene.metadata["hf_alignment"] = A # For camera wireframes and external reuse
if points.shape[0] > 0:
pc = trimesh.points.PointCloud(vertices=points, colors=colors)
scene.add_geometry(pc)
# 8) Draw cameras (wireframe pyramids), using the same transform A
if show_cameras and prediction.intrinsics is not None and prediction.extrinsics is not None:
scene_scale = _estimate_scene_scale(points, fallback=1.0)
H, W = prediction.depth.shape[1:]
_add_cameras_to_scene(
scene=scene,
K=prediction.intrinsics,
ext_w2c=prediction.extrinsics,
image_sizes=[(H, W)] * prediction.depth.shape[0],
scale=scene_scale * camera_size,
)
# 9) Export
os.makedirs(export_dir, exist_ok=True)
out_path = os.path.join(export_dir, "scene.glb")
scene.export(out_path)
if export_depth_vis:
export_to_depth_vis(prediction, export_dir)
os.system(f"cp -r {export_dir}/depth_vis/0000.jpg {export_dir}/scene.jpg")
return out_path
# =========================
# utilities
# =========================
def _as_homogeneous44(ext: np.ndarray) -> np.ndarray:
"""
Accept (4,4) or (3,4) extrinsic parameters, return (4,4) homogeneous matrix.
"""
if ext.shape == (4, 4):
return ext
if ext.shape == (3, 4):
H = np.eye(4, dtype=ext.dtype)
H[:3, :4] = ext
return H
raise ValueError(f"extrinsic must be (4,4) or (3,4), got {ext.shape}")
def _depths_to_world_points_with_colors(
depth: np.ndarray,
K: np.ndarray,
ext_w2c: np.ndarray,
images_u8: np.ndarray,
conf: np.ndarray | None,
conf_thr: float,
) -> tuple[np.ndarray, np.ndarray]:
"""
For each frame, transform (u,v,1) through K^{-1} to get rays,
multiply by depth to camera frame, then use (w2c)^{-1} to transform to world frame.
Simultaneously extract colors.
"""
N, H, W = depth.shape
us, vs = np.meshgrid(np.arange(W), np.arange(H))
ones = np.ones_like(us)
pix = np.stack([us, vs, ones], axis=-1).reshape(-1, 3) # (H*W,3)
pts_all, col_all = [], []
for i in range(N):
d = depth[i] # (H,W)
valid = np.isfinite(d) & (d > 0)
if conf is not None:
valid &= conf[i] >= conf_thr
if not np.any(valid):
continue
d_flat = d.reshape(-1)
vidx = np.flatnonzero(valid.reshape(-1))
K_inv = np.linalg.inv(K[i]) # (3,3)
c2w = np.linalg.inv(_as_homogeneous44(ext_w2c[i])) # (4,4)
rays = K_inv @ pix[vidx].T # (3,M)
Xc = rays * d_flat[vidx][None, :] # (3,M)
Xc_h = np.vstack([Xc, np.ones((1, Xc.shape[1]))])
Xw = (c2w @ Xc_h)[:3].T.astype(np.float32) # (M,3)
cols = images_u8[i].reshape(-1, 3)[vidx].astype(np.uint8) # (M,3)
pts_all.append(Xw)
col_all.append(cols)
if len(pts_all) == 0:
return np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.uint8)
return np.concatenate(pts_all, 0), np.concatenate(col_all, 0)
def _filter_and_downsample(points: np.ndarray, colors: np.ndarray, num_max: int):
if points.shape[0] == 0:
return points, colors
finite = np.isfinite(points).all(axis=1)
points, colors = points[finite], colors[finite]
if points.shape[0] > num_max:
idx = np.random.choice(points.shape[0], num_max, replace=False)
points, colors = points[idx], colors[idx]
return points, colors
def _estimate_scene_scale(points: np.ndarray, fallback: float = 1.0) -> float:
if points.shape[0] < 2:
return fallback
lo = np.percentile(points, 5, axis=0)
hi = np.percentile(points, 95, axis=0)
diag = np.linalg.norm(hi - lo)
return float(diag if np.isfinite(diag) and diag > 0 else fallback)
def _compute_alignment_transform_first_cam_glTF_center_by_points(
ext_w2c0: np.ndarray,
points_world: np.ndarray,
) -> np.ndarray:
"""Computes the transformation matrix to align the scene with glTF standards.
This function calculates a 4x4 homogeneous matrix that centers the scene's
point cloud and transforms its coordinate system from the computer vision (CV)
standard to the glTF standard.
The transformation process involves three main steps:
1. **Initial Alignment**: Orients the world coordinate system to match the
first camera's view (x-right, y-down, z-forward).
2. **Coordinate System Conversion**: Converts the CV camera frame to the
glTF frame (x-right, y-up, z-backward) by flipping the Y and Z axes.
3. **Centering**: Translates the entire scene so that the median of the
point cloud becomes the new origin (0,0,0).
Returns:
A 4x4 homogeneous transformation matrix (torch.Tensor or np.ndarray)
that applies these transformations. A: X' = A @ [X;1]
"""
w2c0 = _as_homogeneous44(ext_w2c0).astype(np.float64)
# CV -> glTF axis transformation
M = np.eye(4, dtype=np.float64)
M[1, 1] = -1.0 # flip Y
M[2, 2] = -1.0 # flip Z
# Don't center first
A_no_center = M @ w2c0
# Calculate point cloud center in new coordinate system (use median to resist outliers)
if points_world.shape[0] > 0:
pts_tmp = trimesh.transform_points(points_world, A_no_center)
center = np.median(pts_tmp, axis=0)
else:
center = np.zeros(3, dtype=np.float64)
T_center = np.eye(4, dtype=np.float64)
T_center[:3, 3] = -center
A = T_center @ A_no_center
return A
def _add_cameras_to_scene(
scene: trimesh.Scene,
K: np.ndarray,
ext_w2c: np.ndarray,
image_sizes: list[tuple[int, int]],
scale: float,
) -> None:
"""Draws camera frustums to visualize their position and orientation.
This function renders each camera as a wireframe pyramid, originating from
the camera's center and extending to the corners of its imaging plane.
It reads the 'hf_alignment' metadata from the scene to ensure the
wireframes are correctly aligned with the 3D point cloud.
"""
N = K.shape[0]
if N == 0:
return
# Alignment matrix consistent with point cloud (use identity matrix if missing)
A = None
try:
A = scene.metadata.get("hf_alignment", None) if scene.metadata else None
except Exception:
A = None
if A is None:
A = np.eye(4, dtype=np.float64)
for i in range(N):
H, W = image_sizes[i]
segs = _camera_frustum_lines(K[i], ext_w2c[i], W, H, scale) # (8,2,3) world frame
# Apply unified transformation
segs = trimesh.transform_points(segs.reshape(-1, 3), A).reshape(-1, 2, 3)
path = trimesh.load_path(segs)
color = _index_color_rgb(i, N)
if hasattr(path, "colors"):
path.colors = np.tile(color, (len(path.entities), 1))
scene.add_geometry(path)
def _camera_frustum_lines(
K: np.ndarray, ext_w2c: np.ndarray, W: int, H: int, scale: float
) -> np.ndarray:
corners = np.array(
[
[0, 0, 1.0],
[W - 1, 0, 1.0],
[W - 1, H - 1, 1.0],
[0, H - 1, 1.0],
],
dtype=float,
) # (4,3)
K_inv = np.linalg.inv(K)
c2w = np.linalg.inv(_as_homogeneous44(ext_w2c))
# camera center in world
Cw = (c2w @ np.array([0, 0, 0, 1.0]))[:3]
# rays -> z=1 plane points (camera frame)
rays = (K_inv @ corners.T).T
z = rays[:, 2:3]
z[z == 0] = 1.0
plane_cam = (rays / z) * scale # (4,3)
# to world
plane_w = []
for p in plane_cam:
pw = (c2w @ np.array([p[0], p[1], p[2], 1.0]))[:3]
plane_w.append(pw)
plane_w = np.stack(plane_w, 0) # (4,3)
segs = []
# center to corners
for k in range(4):
segs.append(np.stack([Cw, plane_w[k]], 0))
# rectangle edges
order = [0, 1, 2, 3, 0]
for a, b in zip(order[:-1], order[1:]):
segs.append(np.stack([plane_w[a], plane_w[b]], 0))
return np.stack(segs, 0) # (8,2,3)
def _index_color_rgb(i: int, n: int) -> np.ndarray:
h = (i + 0.5) / max(n, 1)
s, v = 0.85, 0.95
r, g, b = _hsv_to_rgb(h, s, v)
return (np.array([r, g, b]) * 255).astype(np.uint8)
def _hsv_to_rgb(h: float, s: float, v: float) -> tuple[float, float, float]:
i = int(h * 6.0)
f = h * 6.0 - i
p = v * (1.0 - s)
q = v * (1.0 - f * s)
t = v * (1.0 - (1.0 - f) * s)
i = i % 6
if i == 0:
r, g, b = v, t, p
elif i == 1:
r, g, b = q, v, p
elif i == 2:
r, g, b = p, v, t
elif i == 3:
r, g, b = p, q, v
elif i == 4:
r, g, b = t, p, v
else:
r, g, b = v, p, q
return r, g, b

View File

@@ -0,0 +1,154 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Literal, Optional
import moviepy.editor as mpy
import torch
from depth_anything_3.model.utils.gs_renderer import run_renderer_in_chunk_w_trj_mode
from depth_anything_3.specs import Prediction
from depth_anything_3.utils.gsply_helpers import save_gaussian_ply
from depth_anything_3.utils.layout_helpers import hcat, vcat
from depth_anything_3.utils.visualize import vis_depth_map_tensor
VIDEO_QUALITY_MAP = {
"low": {"crf": "28", "preset": "veryfast"},
"medium": {"crf": "23", "preset": "medium"},
"high": {"crf": "18", "preset": "slow"},
}
def export_to_gs_ply(
prediction: Prediction,
export_dir: str,
gs_views_interval: Optional[
int
] = 1, # export GS every N views, useful for extremely dense inputs
):
gs_world = prediction.gaussians
pred_depth = torch.from_numpy(prediction.depth).unsqueeze(-1).to(gs_world.means) # v h w 1
idx = 0
os.makedirs(os.path.join(export_dir, "gs_ply"), exist_ok=True)
save_path = os.path.join(export_dir, f"gs_ply/{idx:04d}.ply")
if gs_views_interval is None: # select around 12 views in total
gs_views_interval = max(pred_depth.shape[0] // 12, 1)
save_gaussian_ply(
gaussians=gs_world,
save_path=save_path,
ctx_depth=pred_depth,
shift_and_scale=False,
save_sh_dc_only=True,
gs_views_interval=gs_views_interval,
inv_opacity=True,
prune_by_depth_percent=0.9,
prune_border_gs=True,
match_3dgs_mcmc_dev=False,
)
def export_to_gs_video(
prediction: Prediction,
export_dir: str,
extrinsics: Optional[torch.Tensor] = None, # render views' world2cam, "b v 4 4"
intrinsics: Optional[torch.Tensor] = None, # render views' unnormed intrinsics, "b v 3 3"
out_image_hw: Optional[tuple[int, int]] = None, # render views' resolution, (h, w)
chunk_size: Optional[int] = 4,
trj_mode: Literal[
"original",
"smooth",
"interpolate",
"interpolate_smooth",
"wander",
"dolly_zoom",
"extend",
"wobble_inter",
] = "extend",
color_mode: Literal["RGB+D", "RGB+ED"] = "RGB+ED",
vis_depth: Optional[Literal["hcat", "vcat"]] = "hcat",
enable_tqdm: Optional[bool] = True,
output_name: Optional[str] = None,
video_quality: Literal["low", "medium", "high"] = "high",
) -> None:
gs_world = prediction.gaussians
# if target poses are not provided, render the (smooth/interpolate) input poses
if extrinsics is not None:
tgt_extrs = extrinsics
else:
tgt_extrs = torch.from_numpy(prediction.extrinsics).unsqueeze(0).to(gs_world.means)
if prediction.is_metric:
scale_factor = prediction.scale_factor
if scale_factor is not None:
tgt_extrs[:, :, :3, 3] /= scale_factor
tgt_intrs = (
intrinsics
if intrinsics is not None
else torch.from_numpy(prediction.intrinsics).unsqueeze(0).to(gs_world.means)
)
# if render resolution is not provided, render the input ones
if out_image_hw is not None:
H, W = out_image_hw
else:
H, W = prediction.depth.shape[-2:]
# if single views, render wander trj
if tgt_extrs.shape[1] <= 1:
trj_mode = "wander"
# trj_mode = "dolly_zoom"
color, depth = run_renderer_in_chunk_w_trj_mode(
gaussians=gs_world,
extrinsics=tgt_extrs,
intrinsics=tgt_intrs,
image_shape=(H, W),
chunk_size=chunk_size,
trj_mode=trj_mode,
use_sh=True,
color_mode=color_mode,
enable_tqdm=enable_tqdm,
)
# save as video
ffmpeg_params = [
"-crf",
VIDEO_QUALITY_MAP[video_quality]["crf"],
"-preset",
VIDEO_QUALITY_MAP[video_quality]["preset"],
"-pix_fmt",
"yuv420p",
] # best compatibility
os.makedirs(os.path.join(export_dir, "gs_video"), exist_ok=True)
for idx in range(color.shape[0]):
video_i = color[idx]
if vis_depth is not None:
depth_i = vis_depth_map_tensor(depth[0])
cat_fn = hcat if vis_depth == "hcat" else vcat
video_i = torch.stack([cat_fn(c, d) for c, d in zip(video_i, depth_i)])
frames = list(
(video_i.clamp(0, 1) * 255).byte().permute(0, 2, 3, 1).cpu().numpy()
) # T x H x W x C, uint8, numpy()
fps = 24
clip = mpy.ImageSequenceClip(frames, fps=fps)
output_name = f"{idx:04d}_{trj_mode}" if output_name is None else output_name
save_path = os.path.join(export_dir, f"gs_video/{output_name}.mp4")
# clip.write_videofile(save_path, codec="libx264", audio=False, bitrate="4000k")
clip.write_videofile(
save_path,
codec="libx264",
audio=False,
fps=fps,
ffmpeg_params=ffmpeg_params,
)
return

View File

@@ -0,0 +1,73 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from depth_anything_3.specs import Prediction
from depth_anything_3.utils.parallel_utils import async_call
@async_call
def export_to_npz(
prediction: Prediction,
export_dir: str,
):
output_file = os.path.join(export_dir, "exports", "npz", "results.npz")
os.makedirs(os.path.dirname(output_file), exist_ok=True)
# Use prediction.processed_images, which is already processed image data
if prediction.processed_images is None:
raise ValueError("prediction.processed_images is required but not available")
image = prediction.processed_images # (N,H,W,3) uint8
# Build save dict with only non-None values
save_dict = {
"image": image,
"depth": np.round(prediction.depth, 6),
}
if prediction.conf is not None:
save_dict["conf"] = np.round(prediction.conf, 2)
if prediction.extrinsics is not None:
save_dict["extrinsics"] = prediction.extrinsics
if prediction.intrinsics is not None:
save_dict["intrinsics"] = prediction.intrinsics
# aux = {k: np.round(v, 4) for k, v in prediction.aux.items()}
np.savez_compressed(output_file, **save_dict)
@async_call
def export_to_mini_npz(
prediction: Prediction,
export_dir: str,
):
output_file = os.path.join(export_dir, "exports", "mini_npz", "results.npz")
os.makedirs(os.path.dirname(output_file), exist_ok=True)
# Build save dict with only non-None values
save_dict = {
"depth": np.round(prediction.depth, 6),
}
if prediction.conf is not None:
save_dict["conf"] = np.round(prediction.conf, 2)
if prediction.extrinsics is not None:
save_dict["extrinsics"] = prediction.extrinsics
if prediction.intrinsics is not None:
save_dict["intrinsics"] = prediction.intrinsics
np.savez_compressed(output_file, **save_dict)

View File

@@ -0,0 +1,30 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
def _denorm_and_to_uint8(image_tensor: torch.Tensor) -> np.ndarray:
"""Denormalize to [0,255] and output (N, H, W, 3) uint8."""
resnet_mean = torch.tensor(
[0.485, 0.456, 0.406], dtype=image_tensor.dtype, device=image_tensor.device
)
resnet_std = torch.tensor(
[0.229, 0.224, 0.225], dtype=image_tensor.dtype, device=image_tensor.device
)
img = image_tensor * resnet_std[None, :, None, None] + resnet_mean[None, :, None, None]
img = torch.clamp(img, 0.0, 1.0)
img = (img.permute(0, 2, 3, 1).cpu().numpy() * 255.0).round().astype(np.uint8) # (N,H,W,3)
return img

View File

@@ -0,0 +1,498 @@
# flake8: noqa: F722
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from types import SimpleNamespace
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from einops import einsum
def as_homogeneous(ext):
"""
Accept (..., 3,4) or (..., 4,4) extrinsics, return (...,4,4) homogeneous matrix.
Supports torch.Tensor or np.ndarray.
"""
if isinstance(ext, torch.Tensor):
# If already in homogeneous form
if ext.shape[-2:] == (4, 4):
return ext
elif ext.shape[-2:] == (3, 4):
# Create a new homogeneous matrix
ones = torch.zeros_like(ext[..., :1, :4])
ones[..., 0, 3] = 1.0
return torch.cat([ext, ones], dim=-2)
else:
raise ValueError(f"Invalid shape for torch.Tensor: {ext.shape}")
elif isinstance(ext, np.ndarray):
if ext.shape[-2:] == (4, 4):
return ext
elif ext.shape[-2:] == (3, 4):
ones = np.zeros_like(ext[..., :1, :4])
ones[..., 0, 3] = 1.0
return np.concatenate([ext, ones], axis=-2)
else:
raise ValueError(f"Invalid shape for np.ndarray: {ext.shape}")
else:
raise TypeError("Input must be a torch.Tensor or np.ndarray.")
@torch.jit.script
def affine_inverse(A: torch.Tensor):
R = A[..., :3, :3] # ..., 3, 3
T = A[..., :3, 3:] # ..., 3, 1
P = A[..., 3:, :] # ..., 1, 4
return torch.cat([torch.cat([R.mT, -R.mT @ T], dim=-1), P], dim=-2)
def transpose_last_two_axes(arr):
"""
for np < 2
"""
if arr.ndim < 2:
return arr
axes = list(range(arr.ndim))
# swap the last two
axes[-2], axes[-1] = axes[-1], axes[-2]
return arr.transpose(axes)
def affine_inverse_np(A: np.ndarray):
R = A[..., :3, :3]
T = A[..., :3, 3:]
P = A[..., 3:, :]
return np.concatenate(
[
np.concatenate([transpose_last_two_axes(R), -transpose_last_two_axes(R) @ T], axis=-1),
P,
],
axis=-2,
)
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
"""
Quaternion Order: XYZW or say ijkr, scalar-last
Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part last,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
i, j, k, r = torch.unbind(quaternions, -1)
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part last, as tensor of shape (..., 4).
Quaternion Order: XYZW or say ijkr, scalar-last
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack(
[
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
# the candidate won't be picked.
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator)
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(
batch_dim + (4,)
)
# Convert from rijk to ijkr
out = out[..., [1, 2, 3, 0]]
out = standardize_quaternion(out)
return out
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
if torch.is_grad_enabled():
ret[positive_mask] = torch.sqrt(x[positive_mask])
else:
ret = torch.where(positive_mask, torch.sqrt(x), ret)
return ret
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part last,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
def sample_image_grid(
shape: tuple[int, ...],
device: torch.device = torch.device("cpu"),
) -> tuple[
torch.Tensor, # float coordinates (xy indexing), "*shape dim"
torch.Tensor, # integer indices (ij indexing), "*shape dim"
]:
"""Get normalized (range 0 to 1) coordinates and integer indices for an image."""
# Each entry is a pixel-wise integer coordinate. In the 2D case, each entry is a
# (row, col) coordinate.
indices = [torch.arange(length, device=device) for length in shape]
stacked_indices = torch.stack(torch.meshgrid(*indices, indexing="ij"), dim=-1)
# Each entry is a floating-point coordinate in the range (0, 1). In the 2D case,
# each entry is an (x, y) coordinate.
coordinates = [(idx + 0.5) / length for idx, length in zip(indices, shape)]
coordinates = reversed(coordinates)
coordinates = torch.stack(torch.meshgrid(*coordinates, indexing="xy"), dim=-1)
return coordinates, stacked_indices
def homogenize_points(points: torch.Tensor) -> torch.Tensor: # "*batch dim" # "*batch dim+1"
"""Convert batched points (xyz) to (xyz1)."""
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
def homogenize_vectors(vectors: torch.Tensor) -> torch.Tensor: # "*batch dim" # "*batch dim+1"
"""Convert batched vectors (xyz) to (xyz0)."""
return torch.cat([vectors, torch.zeros_like(vectors[..., :1])], dim=-1)
def transform_rigid(
homogeneous_coordinates: torch.Tensor, # "*#batch dim"
transformation: torch.Tensor, # "*#batch dim dim"
) -> torch.Tensor: # "*batch dim"
"""Apply a rigid-body transformation to points or vectors."""
return einsum(
transformation,
homogeneous_coordinates.to(transformation.dtype),
"... i j, ... j -> ... i",
)
def transform_cam2world(
homogeneous_coordinates: torch.Tensor, # "*#batch dim"
extrinsics: torch.Tensor, # "*#batch dim dim"
) -> torch.Tensor: # "*batch dim"
"""Transform points from 3D camera coordinates to 3D world coordinates."""
return transform_rigid(homogeneous_coordinates, extrinsics)
def unproject(
coordinates: torch.Tensor, # "*#batch dim"
z: torch.Tensor, # "*#batch"
intrinsics: torch.Tensor, # "*#batch dim+1 dim+1"
) -> torch.Tensor: # "*batch dim+1"
"""Unproject 2D camera coordinates with the given Z values."""
# Apply the inverse intrinsics to the coordinates.
coordinates = homogenize_points(coordinates)
ray_directions = einsum(
intrinsics.float().inverse().to(intrinsics),
coordinates.to(intrinsics.dtype),
"... i j, ... j -> ... i",
)
# Apply the supplied depth values.
return ray_directions * z[..., None]
def get_world_rays(
coordinates: torch.Tensor, # "*#batch dim"
extrinsics: torch.Tensor, # "*#batch dim+2 dim+2"
intrinsics: torch.Tensor, # "*#batch dim+1 dim+1"
) -> tuple[
torch.Tensor, # origins, "*batch dim+1"
torch.Tensor, # directions, "*batch dim+1"
]:
# Get camera-space ray directions.
directions = unproject(
coordinates,
torch.ones_like(coordinates[..., 0]),
intrinsics,
)
directions = directions / directions.norm(dim=-1, keepdim=True)
# Transform ray directions to world coordinates.
directions = homogenize_vectors(directions)
directions = transform_cam2world(directions, extrinsics)[..., :-1]
# Tile the ray origins to have the same shape as the ray directions.
origins = extrinsics[..., :-1, -1].broadcast_to(directions.shape)
return origins, directions
def get_fov(intrinsics: torch.Tensor) -> torch.Tensor: # "batch 3 3" -> "batch 2"
intrinsics_inv = intrinsics.float().inverse().to(intrinsics)
def process_vector(vector):
vector = torch.tensor(vector, dtype=intrinsics.dtype, device=intrinsics.device)
vector = einsum(intrinsics_inv, vector, "b i j, j -> b i")
return vector / vector.norm(dim=-1, keepdim=True)
left = process_vector([0, 0.5, 1])
right = process_vector([1, 0.5, 1])
top = process_vector([0.5, 0, 1])
bottom = process_vector([0.5, 1, 1])
fov_x = (left * right).sum(dim=-1).acos()
fov_y = (top * bottom).sum(dim=-1).acos()
return torch.stack((fov_x, fov_y), dim=-1)
def map_pdf_to_opacity(
pdf: torch.Tensor, # " *batch"
global_step: int = 0,
opacity_mapping: Optional[dict] = None,
) -> torch.Tensor: # " *batch"
# https://www.desmos.com/calculator/opvwti3ba9
# Figure out the exponent.
if opacity_mapping is not None:
cfg = SimpleNamespace(**opacity_mapping)
x = cfg.initial + min(global_step / cfg.warm_up, 1) * (cfg.final - cfg.initial)
else:
x = 0.0
exponent = 2**x
# Map the probability density to an opacity.
return 0.5 * (1 - (1 - pdf) ** exponent + pdf ** (1 / exponent))
def normalize_homogenous_points(points):
"""Normalize the point vectors"""
return points / points[..., -1:]
def inverse_intrinsic_matrix(ixts):
""" """
return torch.inverse(ixts)
def pixel_space_to_camera_space(pixel_space_points, depth, intrinsics):
"""
Convert pixel space points to camera space points.
Args:
pixel_space_points (torch.Tensor): Pixel space points with shape (h, w, 2)
depth (torch.Tensor): Depth map with shape (b, v, h, w, 1)
intrinsics (torch.Tensor): Camera intrinsics with shape (b, v, 3, 3)
Returns:
torch.Tensor: Camera space points with shape (b, v, h, w, 3).
"""
pixel_space_points = homogenize_points(pixel_space_points)
# camera_space_points = torch.einsum(
# "b v i j , h w j -> b v h w i", intrinsics.inverse(), pixel_space_points
# )
camera_space_points = torch.einsum(
"b v i j , h w j -> b v h w i", inverse_intrinsic_matrix(intrinsics), pixel_space_points
)
camera_space_points = camera_space_points * depth
return camera_space_points
def camera_space_to_world_space(camera_space_points, c2w):
"""
Convert camera space points to world space points.
Args:
camera_space_points (torch.Tensor): Camera space points with shape (b, v, h, w, 3)
c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v, 4, 4)
Returns:
torch.Tensor: World space points with shape (b, v, h, w, 3).
"""
camera_space_points = homogenize_points(camera_space_points)
world_space_points = torch.einsum("b v i j , b v h w j -> b v h w i", c2w, camera_space_points)
return world_space_points[..., :3]
def camera_space_to_pixel_space(camera_space_points, intrinsics):
"""
Convert camera space points to pixel space points.
Args:
camera_space_points (torch.Tensor): Camera space points with shape (b, v1, v2, h, w, 3)
c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v2, 3, 3)
Returns:
torch.Tensor: World space points with shape (b, v1, v2, h, w, 2).
"""
camera_space_points = normalize_homogenous_points(camera_space_points)
pixel_space_points = torch.einsum(
"b u i j , b v u h w j -> b v u h w i", intrinsics, camera_space_points
)
return pixel_space_points[..., :2]
def world_space_to_camera_space(world_space_points, c2w):
"""
Convert world space points to pixel space points.
Args:
world_space_points (torch.Tensor): World space points with shape (b, v1, h, w, 3)
c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v2, 4, 4)
Returns:
torch.Tensor: Camera space points with shape (b, v1, v2, h, w, 3).
"""
world_space_points = homogenize_points(world_space_points)
camera_space_points = torch.einsum(
"b u i j , b v h w j -> b v u h w i", c2w.inverse(), world_space_points
)
return camera_space_points[..., :3]
def unproject_depth(
depth, intrinsics, c2w=None, ixt_normalized=False, num_patches_x=None, num_patches_y=None
):
"""
Turn the depth map into a 3D point cloud in world space
Args:
depth: (b, v, h, w, 1)
intrinsics: (b, v, 3, 3)
c2w: (b, v, 4, 4)
Returns:
torch.Tensor: World space points with shape (b, v, h, w, 3).
"""
if c2w is None:
c2w = torch.eye(4, device=depth.device, dtype=depth.dtype)
c2w = c2w[None, None].repeat(depth.shape[0], depth.shape[1], 1, 1)
if not ixt_normalized:
# Compute indices of pixels
h, w = depth.shape[-3], depth.shape[-2]
x_grid, y_grid = torch.meshgrid(
torch.arange(w, device=depth.device, dtype=depth.dtype),
torch.arange(h, device=depth.device, dtype=depth.dtype),
indexing="xy",
) # (h, w), (h, w)
else:
# ixt_normalized: h=w=2.0. cx, cy, fx, fy are normalized according to h=w=2.0
assert num_patches_x is not None and num_patches_y is not None
dx = 1 / num_patches_x
dy = 1 / num_patches_y
max_y = 1 - dy
min_y = -max_y
max_x = 1 - dx
min_x = -max_x
grid_shift = 1.0
y_grid, x_grid = torch.meshgrid(
torch.linspace(
min_y + grid_shift,
max_y + grid_shift,
num_patches_y,
dtype=torch.float32,
device=depth.device,
),
torch.linspace(
min_x + grid_shift,
max_x + grid_shift,
num_patches_x,
dtype=torch.float32,
device=depth.device,
),
indexing="ij",
)
# Compute coordinates of pixels in camera space
pixel_space_points = torch.stack((x_grid, y_grid), dim=-1) # (..., h, w, 2)
camera_points = pixel_space_to_camera_space(
pixel_space_points, depth, intrinsics
) # (..., h, w, 3)
# Convert points to world space
world_points = camera_space_to_world_space(camera_points, c2w) # (..., h, w, 3)
return world_points

View File

@@ -0,0 +1,173 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Optional
import numpy as np
import torch
from einops import rearrange, repeat
from plyfile import PlyData, PlyElement
from torch import Tensor
from depth_anything_3.specs import Gaussians
def construct_list_of_attributes(num_rest: int) -> list[str]:
attributes = ["x", "y", "z", "nx", "ny", "nz"]
for i in range(3):
attributes.append(f"f_dc_{i}")
for i in range(num_rest):
attributes.append(f"f_rest_{i}")
attributes.append("opacity")
for i in range(3):
attributes.append(f"scale_{i}")
for i in range(4):
attributes.append(f"rot_{i}")
return attributes
def export_ply(
means: Tensor, # "gaussian 3"
scales: Tensor, # "gaussian 3"
rotations: Tensor, # "gaussian 4"
harmonics: Tensor, # "gaussian 3 d_sh"
opacities: Tensor, # "gaussian"
path: Path,
shift_and_scale: bool = False,
save_sh_dc_only: bool = True,
match_3dgs_mcmc_dev: Optional[bool] = False,
):
if shift_and_scale:
# Shift the scene so that the median Gaussian is at the origin.
means = means - means.median(dim=0).values
# Rescale the scene so that most Gaussians are within range [-1, 1].
scale_factor = means.abs().quantile(0.95, dim=0).max()
means = means / scale_factor
scales = scales / scale_factor
rotations = rotations.detach().cpu().numpy()
# Since current model use SH_degree = 4,
# which require large memory to store, we can only save the DC band to save memory.
f_dc = harmonics[..., 0]
f_rest = harmonics[..., 1:].flatten(start_dim=1)
if match_3dgs_mcmc_dev:
sh_degree = 3
n_rest = 3 * (sh_degree + 1) ** 2 - 3
f_rest = repeat(
torch.zeros_like(harmonics[..., :1]), "... i -> ... (n i)", n=(n_rest // 3)
).flatten(start_dim=1)
dtype_full = [
(attribute, "f4")
for attribute in construct_list_of_attributes(num_rest=n_rest)
if attribute not in ("nx", "ny", "nz")
]
else:
dtype_full = [
(attribute, "f4")
for attribute in construct_list_of_attributes(
0 if save_sh_dc_only else f_rest.shape[1]
)
]
elements = np.empty(means.shape[0], dtype=dtype_full)
attributes = [
means.detach().cpu().numpy(),
torch.zeros_like(means).detach().cpu().numpy(),
f_dc.detach().cpu().contiguous().numpy(),
f_rest.detach().cpu().contiguous().numpy(),
opacities[..., None].detach().cpu().numpy(),
scales.log().detach().cpu().numpy(),
rotations,
]
if match_3dgs_mcmc_dev:
attributes.pop(1) # dummy normal is not needed
elif save_sh_dc_only:
attributes.pop(3) # remove f_rest from attributes
attributes = np.concatenate(attributes, axis=1)
elements[:] = list(map(tuple, attributes))
path.parent.mkdir(exist_ok=True, parents=True)
PlyData([PlyElement.describe(elements, "vertex")]).write(path)
def inverse_sigmoid(x):
return torch.log(x / (1 - x))
def save_gaussian_ply(
gaussians: Gaussians,
save_path: str,
ctx_depth: torch.Tensor, # depth of input views; for getting shape and filtering, "v h w 1"
shift_and_scale: bool = False,
save_sh_dc_only: bool = True,
gs_views_interval: int = 1,
inv_opacity: Optional[bool] = True,
prune_by_depth_percent: Optional[float] = 1.0,
prune_border_gs: Optional[bool] = True,
match_3dgs_mcmc_dev: Optional[bool] = False,
):
b = gaussians.means.shape[0]
assert b == 1, "must set batch_size=1 when exporting 3D gaussians"
src_v, out_h, out_w, _ = ctx_depth.shape
# extract gs params
world_means = gaussians.means
world_shs = gaussians.harmonics
world_rotations = gaussians.rotations
gs_scales = gaussians.scales
gs_opacities = inverse_sigmoid(gaussians.opacities) if inv_opacity else gaussians.opacities
# Create a mask to filter the Gaussians.
# TODO: prune the sky region here
# throw away Gaussians at the borders, since they're generally of lower quality.
if prune_border_gs:
mask = torch.zeros_like(ctx_depth, dtype=torch.bool)
gstrim_h = int(8 / 256 * out_h)
gstrim_w = int(8 / 256 * out_w)
mask[:, gstrim_h:-gstrim_h, gstrim_w:-gstrim_w, :] = 1
else:
mask = torch.ones_like(ctx_depth, dtype=torch.bool)
# trim the far away point based on depth;
if prune_by_depth_percent is not None and prune_by_depth_percent < 1:
in_depths = ctx_depth
d_percentile = torch.quantile(
in_depths.view(in_depths.shape[0], -1), q=prune_by_depth_percent, dim=1
).view(-1, 1, 1)
d_mask = (in_depths[..., 0] <= d_percentile).unsqueeze(-1)
mask = mask & d_mask
mask = mask.squeeze(-1) # v h w
# helper fn, must place after mask
def trim_select_reshape(element):
selected_element = rearrange(
element[0], "(v h w) ... -> v h w ...", v=src_v, h=out_h, w=out_w
)
selected_element = selected_element[::gs_views_interval][mask[::gs_views_interval]]
return selected_element
export_ply(
means=trim_select_reshape(world_means),
scales=trim_select_reshape(gs_scales),
rotations=trim_select_reshape(world_rotations),
harmonics=trim_select_reshape(world_shs),
opacities=trim_select_reshape(gs_opacities),
path=Path(save_path),
shift_and_scale=shift_and_scale,
save_sh_dc_only=save_sh_dc_only,
match_3dgs_mcmc_dev=match_3dgs_mcmc_dev,
)

View File

@@ -0,0 +1,501 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Input processor for Depth Anything 3 (parallelized).
This version removes the square center-crop step for "*crop" methods (same as your note).
In addition, it parallelizes per-image preprocessing using the provided `parallel_execution`.
"""
from __future__ import annotations
from typing import Sequence
import cv2
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from depth_anything_3.utils.logger import logger
from depth_anything_3.utils.parallel_utils import parallel_execution
class InputProcessor:
"""Prepares a batch of images for model inference.
This processor converts a list of image file paths into a single, model-ready
tensor. The processing pipeline is executed in parallel across multiple workers
for efficiency.
Pipeline:
1) Load image and convert to RGB
2) Boundary resize (upper/lower bound, preserving aspect ratio)
3) Enforce divisibility by PATCH_SIZE:
- "*resize" methods: each dimension is rounded to nearest multiple
(may up/downscale a few px)
- "*crop" methods: each dimension is floored to nearest multiple via center crop
4) Convert to tensor and apply ImageNet normalization
5) Stack into (1, N, 3, H, W)
Parallelization:
- Each image is processed independently in a worker.
- Order of outputs matches the input order.
"""
NORMALIZE = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
PATCH_SIZE = 14
def __init__(self):
pass
# -----------------------------
# Public API
# -----------------------------
def __call__(
self,
image: list[np.ndarray | Image.Image | str],
extrinsics: np.ndarray | None = None,
intrinsics: np.ndarray | None = None,
process_res: int = 504,
process_res_method: str = "upper_bound_resize",
*,
num_workers: int = 8,
print_progress: bool = False,
sequential: bool | None = None,
desc: str | None = "Preprocess",
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""
Returns:
(tensor, extrinsics_list, intrinsics_list)
tensor shape: (1, N, 3, H, W)
"""
sequential = self._resolve_sequential(sequential, num_workers)
exts_list, ixts_list = self._validate_and_pack_meta(image, extrinsics, intrinsics)
results = self._run_parallel(
image=image,
exts_list=exts_list,
ixts_list=ixts_list,
process_res=process_res,
process_res_method=process_res_method,
num_workers=num_workers,
print_progress=print_progress,
sequential=sequential,
desc=desc,
)
proc_imgs, out_sizes, out_ixts, out_exts = self._unpack_results(results)
proc_imgs, out_sizes, out_ixts = self._unify_batch_shapes(proc_imgs, out_sizes, out_ixts)
batch_tensor = self._stack_batch(proc_imgs)
out_exts = (
torch.from_numpy(np.asarray(out_exts)).float()
if out_exts is not None and out_exts[0] is not None
else None
)
out_ixts = (
torch.from_numpy(np.asarray(out_ixts)).float()
if out_ixts is not None and out_ixts[0] is not None
else None
)
return (batch_tensor, out_exts, out_ixts)
# -----------------------------
# __call__ helpers
# -----------------------------
def _resolve_sequential(self, sequential: bool | None, num_workers: int) -> bool:
return (num_workers <= 1) if sequential is None else sequential
def _validate_and_pack_meta(
self,
images: list[np.ndarray | Image.Image | str],
extrinsics: np.ndarray | None,
intrinsics: np.ndarray | None,
) -> tuple[list[np.ndarray | None] | None, list[np.ndarray | None] | None]:
if extrinsics is not None and len(extrinsics) != len(images):
raise ValueError("Length of extrinsics must match images when provided.")
if intrinsics is not None and len(intrinsics) != len(images):
raise ValueError("Length of intrinsics must match images when provided.")
exts_list = [e for e in extrinsics] if extrinsics is not None else None
ixts_list = [k for k in intrinsics] if intrinsics is not None else None
return exts_list, ixts_list
def _run_parallel(
self,
*,
image: list[np.ndarray | Image.Image | str],
exts_list: list[np.ndarray | None] | None,
ixts_list: list[np.ndarray | None] | None,
process_res: int,
process_res_method: str,
num_workers: int,
print_progress: bool,
sequential: bool,
desc: str | None,
):
results = parallel_execution(
image,
exts_list,
ixts_list,
action=self._process_one, # (img, extrinsic, intrinsic, ...)
num_processes=num_workers,
print_progress=print_progress,
sequential=sequential,
desc=desc,
process_res=process_res,
process_res_method=process_res_method,
)
if not results:
raise RuntimeError(
"No preprocessing results returned. Check inputs and parallel_execution."
)
return results
def _unpack_results(self, results):
"""
results: List[Tuple[torch.Tensor, Tuple[H, W], Optional[np.ndarray], Optional[np.ndarray]]]
-> processed_images, out_sizes, out_intrinsics, out_extrinsics
"""
try:
processed_images, out_sizes, out_intrinsics, out_extrinsics = zip(*results)
except Exception as e:
raise RuntimeError(
"Unexpected results structure from parallel_execution: "
f"{type(results)} / sample: {results[0]}"
) from e
return list(processed_images), list(out_sizes), list(out_intrinsics), list(out_extrinsics)
def _unify_batch_shapes(
self,
processed_images: list[torch.Tensor],
out_sizes: list[tuple[int, int]],
out_intrinsics: list[np.ndarray | None],
) -> tuple[list[torch.Tensor], list[tuple[int, int]], list[np.ndarray | None]]:
"""Center-crop all tensors to the smallest H, W; adjust intrinsics' cx, cy accordingly."""
if len(set(out_sizes)) <= 1:
return processed_images, out_sizes, out_intrinsics
min_h = min(h for h, _ in out_sizes)
min_w = min(w for _, w in out_sizes)
logger.warn(
f"Images in batch have different sizes {out_sizes}; "
f"center-cropping all to smallest ({min_h},{min_w})"
)
center_crop = T.CenterCrop((min_h, min_w))
new_imgs, new_sizes, new_ixts = [], [], []
for img_t, (H, W), K in zip(processed_images, out_sizes, out_intrinsics):
crop_top = max(0, (H - min_h) // 2)
crop_left = max(0, (W - min_w) // 2)
new_imgs.append(center_crop(img_t))
new_sizes.append((min_h, min_w))
if K is None:
new_ixts.append(None)
else:
K_adj = K.copy()
K_adj[0, 2] -= crop_left
K_adj[1, 2] -= crop_top
new_ixts.append(K_adj)
return new_imgs, new_sizes, new_ixts
def _stack_batch(self, processed_images: list[torch.Tensor]) -> torch.Tensor:
return torch.stack(processed_images)
# -----------------------------
# Per-item worker
# -----------------------------
def _process_one(
self,
img: np.ndarray | Image.Image | str,
extrinsic: np.ndarray | None = None,
intrinsic: np.ndarray | None = None,
*,
process_res: int,
process_res_method: str,
) -> tuple[torch.Tensor, tuple[int, int], np.ndarray | None, np.ndarray | None]:
# Load & remember original size
pil_img = self._load_image(img)
orig_w, orig_h = pil_img.size
# Boundary resize
pil_img = self._resize_image(pil_img, process_res, process_res_method)
w, h = pil_img.size
intrinsic = self._resize_ixt(intrinsic, orig_w, orig_h, w, h)
# Enforce divisibility by PATCH_SIZE
if process_res_method.endswith("resize"):
pil_img = self._make_divisible_by_resize(pil_img, self.PATCH_SIZE)
new_w, new_h = pil_img.size
intrinsic = self._resize_ixt(intrinsic, w, h, new_w, new_h)
w, h = new_w, new_h
elif process_res_method.endswith("crop"):
pil_img = self._make_divisible_by_crop(pil_img, self.PATCH_SIZE)
new_w, new_h = pil_img.size
intrinsic = self._crop_ixt(intrinsic, w, h, new_w, new_h)
w, h = new_w, new_h
else:
raise ValueError(f"Unsupported process_res_method: {process_res_method}")
# Convert to tensor & normalize
img_tensor = self._normalize_image(pil_img)
_, H, W = img_tensor.shape
assert (W, H) == (w, h), "Tensor size mismatch with PIL image size after processing."
# Return: (img_tensor, (H, W), intrinsic, extrinsic)
return img_tensor, (H, W), intrinsic, extrinsic
# -----------------------------
# Intrinsics transforms
# -----------------------------
def _resize_ixt(
self,
intrinsic: np.ndarray | None,
orig_w: int,
orig_h: int,
w: int,
h: int,
) -> np.ndarray | None:
if intrinsic is None:
return None
K = intrinsic.copy()
# scale fx, cx by w ratio; fy, cy by h ratio
K[:1] *= w / float(orig_w)
K[1:2] *= h / float(orig_h)
return K
def _crop_ixt(
self,
intrinsic: np.ndarray | None,
orig_w: int,
orig_h: int,
w: int,
h: int,
) -> np.ndarray | None:
if intrinsic is None:
return None
K = intrinsic.copy()
crop_h = (orig_h - h) // 2
crop_w = (orig_w - w) // 2
K[0, 2] -= crop_w
K[1, 2] -= crop_h
return K
# -----------------------------
# I/O & normalization
# -----------------------------
def _load_image(self, img: np.ndarray | Image.Image | str) -> Image.Image:
if isinstance(img, str):
return Image.open(img).convert("RGB")
elif isinstance(img, np.ndarray):
# Assume HxWxC uint8/RGB
return Image.fromarray(img).convert("RGB")
elif isinstance(img, Image.Image):
return img.convert("RGB")
else:
raise ValueError(f"Unsupported image type: {type(img)}")
def _normalize_image(self, img: Image.Image) -> torch.Tensor:
img_tensor = T.ToTensor()(img)
return self.NORMALIZE(img_tensor)
# -----------------------------
# Boundary resizing
# -----------------------------
def _resize_image(self, img: Image.Image, target_size: int, method: str) -> Image.Image:
if method in ("upper_bound_resize", "upper_bound_crop"):
return self._resize_longest_side(img, target_size)
elif method in ("lower_bound_resize", "lower_bound_crop"):
return self._resize_shortest_side(img, target_size)
else:
raise ValueError(f"Unsupported resize method: {method}")
def _resize_longest_side(self, img: Image.Image, target_size: int) -> Image.Image:
w, h = img.size
longest = max(w, h)
if longest == target_size:
return img
scale = target_size / float(longest)
new_w = max(1, int(round(w * scale)))
new_h = max(1, int(round(h * scale)))
interpolation = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA
arr = cv2.resize(np.asarray(img), (new_w, new_h), interpolation=interpolation)
return Image.fromarray(arr)
def _resize_shortest_side(self, img: Image.Image, target_size: int) -> Image.Image:
w, h = img.size
shortest = min(w, h)
if shortest == target_size:
return img
scale = target_size / float(shortest)
new_w = max(1, int(round(w * scale)))
new_h = max(1, int(round(h * scale)))
interpolation = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA
arr = cv2.resize(np.asarray(img), (new_w, new_h), interpolation=interpolation)
return Image.fromarray(arr)
# -----------------------------
# Make divisible by PATCH_SIZE
# -----------------------------
def _make_divisible_by_crop(self, img: Image.Image, patch: int) -> Image.Image:
"""
Floor each dimension to the nearest multiple of PATCH_SIZE via center crop.
Example: 504x377 -> 504x364
"""
w, h = img.size
new_w = (w // patch) * patch
new_h = (h // patch) * patch
if new_w == w and new_h == h:
return img
left = (w - new_w) // 2
top = (h - new_h) // 2
return img.crop((left, top, left + new_w, top + new_h))
def _make_divisible_by_resize(self, img: Image.Image, patch: int) -> Image.Image:
"""
Round each dimension to nearest multiple of PATCH_SIZE via small resize.
"""
w, h = img.size
def nearest_multiple(x: int, p: int) -> int:
down = (x // p) * p
up = down + p
return up if abs(up - x) <= abs(x - down) else down
new_w = max(1, nearest_multiple(w, patch))
new_h = max(1, nearest_multiple(h, patch))
if new_w == w and new_h == h:
return img
upscale = (new_w > w) or (new_h > h)
interpolation = cv2.INTER_CUBIC if upscale else cv2.INTER_AREA
arr = cv2.resize(np.asarray(img), (new_w, new_h), interpolation=interpolation)
return Image.fromarray(arr)
# Backward compatibility alias
InputAdapter = InputProcessor
# ===========================
# Minimal test runner (parallel execution)
# ===========================
if __name__ == "__main__":
"""
Minimal test suite:
- Creates pairs of images so batch shapes match.
- Tests all four process_res_methods.
- Prints fx fy cx cy IN->OUT per image.
- Includes cases with K/E provided and with None.
"""
def fmt_k_line(K: np.ndarray | None) -> str:
if K is None:
return "None"
fx, fy, cx, cy = float(K[0, 0]), float(K[1, 1]), float(K[0, 2]), float(K[1, 2])
return f"fx={fx:.3f} fy={fy:.3f} cx={cx:.3f} cy={cy:.3f}"
def show_result(
tag: str,
tensor: torch.Tensor,
Ks_in: Sequence[np.ndarray | None] | None = None,
Ks_out: Sequence[np.ndarray | None] | None = None,
):
B, N, C, H, W = tensor.shape
print(f"[{tag}] shape={tuple(tensor.shape)} HxW=({H},{W}) div14=({H%14==0},{W%14==0})")
assert H % 14 == 0 and W % 14 == 0, f"{tag}: output size not divisible by 14!"
if Ks_in is not None or Ks_out is not None:
Ks_in = Ks_in or [None] * N
Ks_out = Ks_out or [None] * N
for i in range(N):
print(f" K[{i}]: {fmt_k_line(Ks_in[i])} -> {fmt_k_line(Ks_out[i])}")
proc = InputProcessor()
process_res = 504
methods = ["upper_bound_resize", "upper_bound_crop", "lower_bound_resize", "lower_bound_crop"]
# Example sizes (two orientations)
small_sizes = [(680, 1208), (1208, 680)]
large_sizes = [(1208, 680), (680, 1208)]
def make_K(w, h, fx=1200.0, fy=1100.0):
cx, cy = w / 2.0, h / 2.0
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
return K
def run_suite(suite_name: str, sizes: list[tuple[int, int]]):
print(f"\n===== {suite_name} =====")
for w, h in sizes:
img = Image.new("RGB", (w, h), color=(123, 222, 100))
batch_imgs = [img, img]
# intrinsics / extrinsics examples
Ks_in = [make_K(w, h), make_K(w, h)]
Es_in = [np.eye(4, dtype=np.float32), np.eye(4, dtype=np.float32)]
for m in methods:
tensor, Es_out, Ks_out = proc(
image=batch_imgs,
process_res=process_res,
process_res_method=m,
num_workers=8,
print_progress=False,
intrinsics=Ks_in, # test with non-None
extrinsics=Es_in,
)
show_result(f"{suite_name} size=({w},{h}) | {m}", tensor, Ks_in, Ks_out)
# Also test None path
tensor2, Es_out2, Ks_out2 = proc(
image=batch_imgs,
process_res=process_res,
process_res_method="upper_bound_resize",
num_workers=8,
intrinsics=None,
extrinsics=None,
)
show_result(
f"{suite_name} size=({w},{h}) | upper_bound_resize | no K/E",
tensor2,
None,
Ks_out2,
)
run_suite("SMALL", small_sizes)
run_suite("LARGE", large_sizes)
# Extra sanity for 504x376
print("\n===== EXTRA sanity for 504x376 =====")
img_example = Image.new("RGB", (504, 376), color=(10, 20, 30))
Ks_in_extra = [make_K(504, 376, fx=900.0, fy=900.0), make_K(504, 376, fx=900.0, fy=900.0)]
out_r, _, Ks_out_r = proc(
image=[img_example, img_example],
process_res=504,
process_res_method="upper_bound_resize",
num_workers=8,
intrinsics=Ks_in_extra,
)
out_c, _, Ks_out_c = proc(
image=[img_example, img_example],
process_res=504,
process_res_method="upper_bound_crop",
num_workers=8,
intrinsics=Ks_in_extra,
)
_, _, _, Hr, Wr = out_r.shape
_, _, _, Hc, Wc = out_c.shape
print(f"upper_bound_resize -> ({Hr},{Wr}) (rounded to nearest multiple of 14)")
show_result("Ks after upper_bound_resize", out_r, Ks_in_extra, Ks_out_r)
print(f"upper_bound_crop -> ({Hc},{Wc}) (floored to multiple of 14)")
show_result("Ks after upper_bound_crop", out_c, Ks_in_extra, Ks_out_c)

View File

@@ -0,0 +1,172 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Output processor for Depth Anything 3.
This module handles model output processing, including tensor-to-numpy conversion,
batch dimension removal, and Prediction object creation.
"""
from __future__ import annotations
import numpy as np
import torch
from addict import Dict as AddictDict
from depth_anything_3.specs import Prediction
class OutputProcessor:
"""
Output processor for converting model outputs to Prediction objects.
Handles tensor-to-numpy conversion, batch dimension removal,
and creates structured Prediction objects with proper data types.
"""
def __init__(self) -> None:
"""Initialize the output processor."""
def __call__(self, model_output: dict[str, torch.Tensor]) -> Prediction:
"""
Convert model output to Prediction object.
Args:
model_output: Model output dictionary containing depth, conf, extrinsics, intrinsics
Expected shapes: depth (B, N, 1, H, W), conf (B, N, 1, H, W),
extrinsics (B, N, 4, 4), intrinsics (B, N, 3, 3)
Returns:
Prediction: Object containing depth estimation results with shapes:
depth (N, H, W), conf (N, H, W), extrinsics (N, 4, 4), intrinsics (N, 3, 3)
"""
# Extract data from batch dimension (B=1, N=number of images)
depth = self._extract_depth(model_output)
conf = self._extract_conf(model_output)
extrinsics = self._extract_extrinsics(model_output)
intrinsics = self._extract_intrinsics(model_output)
sky = self._extract_sky(model_output)
aux = self._extract_aux(model_output)
gaussians = model_output.get("gaussians", None)
scale_factor = model_output.get("scale_factor", None)
return Prediction(
depth=depth,
sky=sky,
conf=conf,
extrinsics=extrinsics,
intrinsics=intrinsics,
is_metric=getattr(model_output, "is_metric", 0),
gaussians=gaussians,
aux=aux,
scale_factor=scale_factor,
)
def _extract_depth(self, model_output: dict[str, torch.Tensor]) -> np.ndarray:
"""
Extract depth tensor from model output and convert to numpy.
Args:
model_output: Model output dictionary
Returns:
Depth array with shape (N, H, W)
"""
depth = model_output["depth"].squeeze(0).squeeze(-1).cpu().numpy() # (N, H, W)
return depth
def _extract_conf(self, model_output: dict[str, torch.Tensor]) -> np.ndarray | None:
"""
Extract confidence tensor from model output and convert to numpy.
Args:
model_output: Model output dictionary
Returns:
Confidence array with shape (N, H, W) or None
"""
conf = model_output.get("depth_conf", None)
if conf is not None:
conf = conf.squeeze(0).cpu().numpy() # (N, H, W)
return conf
def _extract_extrinsics(self, model_output: dict[str, torch.Tensor]) -> np.ndarray | None:
"""
Extract extrinsics tensor from model output and convert to numpy.
Args:
model_output: Model output dictionary
Returns:
Extrinsics array with shape (N, 4, 4) or None
"""
extrinsics = model_output.get("extrinsics", None)
if extrinsics is not None:
extrinsics = extrinsics.squeeze(0).cpu().numpy() # (N, 4, 4)
return extrinsics
def _extract_intrinsics(self, model_output: dict[str, torch.Tensor]) -> np.ndarray | None:
"""
Extract intrinsics tensor from model output and convert to numpy.
Args:
model_output: Model output dictionary
Returns:
Intrinsics array with shape (N, 3, 3) or None
"""
intrinsics = model_output.get("intrinsics", None)
if intrinsics is not None:
intrinsics = intrinsics.squeeze(0).cpu().numpy() # (N, 3, 3)
return intrinsics
def _extract_sky(self, model_output: dict[str, torch.Tensor]) -> np.ndarray | None:
"""
Extract sky tensor from model output and convert to numpy.
Args:
model_output: Model output dictionary
Returns:
Sky mask array with shape (N, H, W) or None
"""
sky = model_output.get("sky", None)
if sky is not None:
sky = sky.squeeze(0).cpu().numpy() >= 0.5 # (N, H, W)
return sky
def _extract_aux(self, model_output: dict[str, torch.Tensor]) -> AddictDict:
"""
Extract auxiliary data from model output and convert to numpy.
Args:
model_output: Model output dictionary
Returns:
Dictionary containing auxiliary data
"""
aux = model_output.get("aux", None)
ret = AddictDict()
if aux is not None:
for k in aux.keys():
if isinstance(aux[k], torch.Tensor):
ret[k] = aux[k].squeeze(0).cpu().numpy()
else:
ret[k] = aux[k]
return ret
# Backward compatibility alias
OutputAdapter = OutputProcessor

View File

@@ -0,0 +1,216 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file contains useful layout utilities for images. They are:
- add_border: Add a border to an image.
- cat/hcat/vcat: Join images by arranging them in a line. If the images have different
sizes, they are aligned as specified (start, end, center). Allows you to specify a gap
between images.
Images are assumed to be float32 tensors with shape (channel, height, width).
"""
from typing import Any, Generator, Iterable, Literal, Union
import torch
from torch import Tensor
Alignment = Literal["start", "center", "end"]
Axis = Literal["horizontal", "vertical"]
Color = Union[
int,
float,
Iterable[int],
Iterable[float],
Tensor,
Tensor,
]
def _sanitize_color(color: Color) -> Tensor: # "#channel"
# Convert tensor to list (or individual item).
if isinstance(color, torch.Tensor):
color = color.tolist()
# Turn iterators and individual items into lists.
if isinstance(color, Iterable):
color = list(color)
else:
color = [color]
return torch.tensor(color, dtype=torch.float32)
def _intersperse(iterable: Iterable, delimiter: Any) -> Generator[Any, None, None]:
it = iter(iterable)
yield next(it)
for item in it:
yield delimiter
yield item
def _get_main_dim(main_axis: Axis) -> int:
return {
"horizontal": 2,
"vertical": 1,
}[main_axis]
def _get_cross_dim(main_axis: Axis) -> int:
return {
"horizontal": 1,
"vertical": 2,
}[main_axis]
def _compute_offset(base: int, overlay: int, align: Alignment) -> slice:
assert base >= overlay
offset = {
"start": 0,
"center": (base - overlay) // 2,
"end": base - overlay,
}[align]
return slice(offset, offset + overlay)
def overlay(
base: Tensor, # "channel base_height base_width"
overlay: Tensor, # "channel overlay_height overlay_width"
main_axis: Axis,
main_axis_alignment: Alignment,
cross_axis_alignment: Alignment,
) -> Tensor: # "channel base_height base_width"
# The overlay must be smaller than the base.
_, base_height, base_width = base.shape
_, overlay_height, overlay_width = overlay.shape
assert base_height >= overlay_height and base_width >= overlay_width
# Compute spacing on the main dimension.
main_dim = _get_main_dim(main_axis)
main_slice = _compute_offset(
base.shape[main_dim], overlay.shape[main_dim], main_axis_alignment
)
# Compute spacing on the cross dimension.
cross_dim = _get_cross_dim(main_axis)
cross_slice = _compute_offset(
base.shape[cross_dim], overlay.shape[cross_dim], cross_axis_alignment
)
# Combine the slices and paste the overlay onto the base accordingly.
selector = [..., None, None]
selector[main_dim] = main_slice
selector[cross_dim] = cross_slice
result = base.clone()
result[selector] = overlay
return result
def cat(
main_axis: Axis,
*images: Iterable[Tensor], # "channel _ _"
align: Alignment = "center",
gap: int = 8,
gap_color: Color = 1,
) -> Tensor: # "channel height width"
"""Arrange images in a line. The interface resembles a CSS div with flexbox."""
device = images[0].device
gap_color = _sanitize_color(gap_color).to(device)
# Find the maximum image side length in the cross axis dimension.
cross_dim = _get_cross_dim(main_axis)
cross_axis_length = max(image.shape[cross_dim] for image in images)
# Pad the images.
padded_images = []
for image in images:
# Create an empty image with the correct size.
padded_shape = list(image.shape)
padded_shape[cross_dim] = cross_axis_length
base = torch.ones(padded_shape, dtype=torch.float32, device=device)
base = base * gap_color[:, None, None]
padded_images.append(overlay(base, image, main_axis, "start", align))
# Intersperse separators if necessary.
if gap > 0:
# Generate a separator.
c, _, _ = images[0].shape
separator_size = [gap, gap]
separator_size[cross_dim - 1] = cross_axis_length
separator = torch.ones((c, *separator_size), dtype=torch.float32, device=device)
separator = separator * gap_color[:, None, None]
# Intersperse the separator between the images.
padded_images = list(_intersperse(padded_images, separator))
return torch.cat(padded_images, dim=_get_main_dim(main_axis))
def hcat(
*images: Iterable[Tensor], # "channel _ _"
align: Literal["start", "center", "end", "top", "bottom"] = "start",
gap: int = 8,
gap_color: Color = 1,
):
"""Shorthand for a horizontal linear concatenation."""
return cat(
"horizontal",
*images,
align={
"start": "start",
"center": "center",
"end": "end",
"top": "start",
"bottom": "end",
}[align],
gap=gap,
gap_color=gap_color,
)
def vcat(
*images: Iterable[Tensor], # "channel _ _"
align: Literal["start", "center", "end", "left", "right"] = "start",
gap: int = 8,
gap_color: Color = 1,
):
"""Shorthand for a horizontal linear concatenation."""
return cat(
"vertical",
*images,
align={
"start": "start",
"center": "center",
"end": "end",
"left": "start",
"right": "end",
}[align],
gap=gap,
gap_color=gap_color,
)
def add_border(
image: Tensor, # "channel height width"
border: int = 8,
color: Color = 1,
) -> Tensor: # "channel new_height new_width"
color = _sanitize_color(color).to(image)
c, h, w = image.shape
result = torch.empty(
(c, h + 2 * border, w + 2 * border), dtype=torch.float32, device=image.device
)
result[:] = color[:, None, None]
result[:, border : h + border, border : w + border] = image
return result

View File

@@ -0,0 +1,82 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
class Color:
RED = "\033[91m"
YELLOW = "\033[93m"
WHITE = "\033[97m"
GREEN = "\033[92m"
RESET = "\033[0m"
LOG_LEVELS = {"ERROR": 0, "WARN": 1, "INFO": 2, "DEBUG": 3}
COLOR_MAP = {"ERROR": Color.RED, "WARN": Color.YELLOW, "INFO": Color.WHITE, "DEBUG": Color.GREEN}
def get_env_log_level():
level = os.environ.get("DA3_LOG_LEVEL", "INFO").upper()
return LOG_LEVELS.get(level, LOG_LEVELS["INFO"])
class Logger:
def __init__(self):
self.level = get_env_log_level()
def log(self, level_str, *args, **kwargs):
level_key = level_str.split(":")[0].strip()
level_val = LOG_LEVELS.get(level_key)
if level_val is None:
raise ValueError(f"Unknown log level: {level_str}")
if self.level >= level_val:
color = COLOR_MAP[level_key]
msg = " ".join(str(arg) for arg in args)
# Align log level output in square brackets
# ERROR and DEBUG are 5 characters, INFO and WARN have an extra space for alignment
tag = level_key
if tag in ("INFO", "WARN"):
tag += " "
print(
f"{color}[{tag}] {msg}{Color.RESET}",
file=sys.stderr if level_key == "ERROR" else sys.stdout,
**kwargs,
)
def error(self, *args, **kwargs):
self.log("ERROR:", *args, **kwargs)
def warn(self, *args, **kwargs):
self.log("WARN:", *args, **kwargs)
def info(self, *args, **kwargs):
self.log("INFO:", *args, **kwargs)
def debug(self, *args, **kwargs):
self.log("DEBUG:", *args, **kwargs)
logger = Logger()
__all__ = ["logger"]
if __name__ == "__main__":
logger.info("This is an info message")
logger.warn("This is a warning message")
logger.error("This is an error message")
logger.debug("This is a debug message")

View File

@@ -0,0 +1,127 @@
"""
GPU memory utility helpers.
Shared cleanup and memory checking logic used by both the backend API and
the Gradio UI to keep memory-management behavior consistent.
"""
from __future__ import annotations
import gc
from typing import Any, Dict, Optional
import torch
def get_gpu_memory_info() -> Optional[Dict[str, Any]]:
"""Return a snapshot of current GPU memory usage or None if CUDA not available.
Keys in returned dict: total_gb, allocated_gb, reserved_gb, free_gb, utilization
"""
if not torch.cuda.is_available():
return None
try:
device = torch.cuda.current_device()
total_memory = torch.cuda.get_device_properties(device).total_memory
allocated_memory = torch.cuda.memory_allocated(device)
reserved_memory = torch.cuda.memory_reserved(device)
free_memory = total_memory - reserved_memory
return {
"total_gb": total_memory / 1024 ** 3,
"allocated_gb": allocated_memory / 1024 ** 3,
"reserved_gb": reserved_memory / 1024 ** 3,
"free_gb": free_memory / 1024 ** 3,
"utilization": (reserved_memory / total_memory) * 100,
}
except Exception:
return None
def cleanup_cuda_memory() -> None:
"""Perform a robust GPU cleanup sequence.
This includes synchronizing, emptying caches, collecting IPC handles and
running the Python garbage collector. Use this instead of a raw
``torch.cuda.empty_cache()`` where you need reliable freeing of GPU memory
between model loads or in error handling paths.
"""
try:
if torch.cuda.is_available():
mem_before = get_gpu_memory_info()
torch.cuda.synchronize()
torch.cuda.empty_cache()
# Collect cross-process cuda resources
try:
torch.cuda.ipc_collect()
except Exception:
# Older PyTorch versions or non-cuda devices may not support
# ipc_collect (no-op if not available)
pass
gc.collect()
mem_after = get_gpu_memory_info()
if mem_before and mem_after:
freed = mem_before["reserved_gb"] - mem_after["reserved_gb"]
print(
f"CUDA cleanup: freed {freed:.2f}GB, "
f"available: {mem_after['free_gb']:.2f}GB/{mem_after['total_gb']:.2f}GB"
)
else:
print("CUDA memory cleanup completed")
except Exception as e:
print(f"Warning: CUDA cleanup failed: {e}")
def check_memory_availability(required_gb: float = 2.0) -> tuple[bool, str]:
"""Return whether at least ``required_gb`` seems available on the current GPU.
The returned tuple is (is_available, message) with a human-friendly message.
"""
try:
if not torch.cuda.is_available():
return False, "CUDA is not available"
mem_info = get_gpu_memory_info()
if mem_info is None:
return True, "Cannot check memory, proceeding anyway"
if mem_info["free_gb"] < required_gb:
return (
False,
(
f"Insufficient GPU memory: {mem_info['free_gb']:.2f}GB available, "
f"{required_gb:.2f}GB required. Total: {mem_info['total_gb']:.2f}GB, "
f"Used: {mem_info['reserved_gb']:.2f}GB ({mem_info['utilization']:.1f}%)"
),
)
return (
True,
(
f"Memory check passed: {mem_info['free_gb']:.2f}GB available, "
f"{required_gb:.2f}GB required"
),
)
except Exception as e:
return True, f"Memory check failed: {e}, proceeding anyway"
def estimate_memory_requirement(num_images: int, process_res: int) -> float:
"""Heuristic estimate for memory usage (GB) based on image count and resolution.
This mirrors the simple policy used by the backend service so other code
(e.g., Gradio UI) can make consistent decisions when checking available
memory before loading a model or running inference.
Args:
num_images: Number of images to process.
process_res: Processing resolution.
Returns:
Estimated memory requirement in GB.
"""
base_memory = 2.0
per_image_memory = (process_res / 504) ** 2 * 0.5
total_memory = base_memory + (num_images * per_image_memory * 0.1)
return total_memory

View File

@@ -0,0 +1,149 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Model loading and state dict conversion utilities.
"""
from typing import Dict, Tuple
import torch
from depth_anything_3.utils.logger import logger
def convert_general_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert general model state dict to match current model architecture.
Args:
state_dict: Original state dictionary
Returns:
Converted state dictionary
"""
# Replace module prefixes
state_dict = {k.replace("module.", "model."): v for k, v in state_dict.items()}
state_dict = {k.replace(".net.", ".backbone."): v for k, v in state_dict.items()}
# Remove camera token if present
if "model.backbone.pretrained.camera_token" in state_dict:
del state_dict["model.backbone.pretrained.camera_token"]
# Replace camera token naming
state_dict = {
k.replace(".camera_token_extra", ".camera_token"): v for k, v in state_dict.items()
}
# Replace head naming
state_dict = {
k.replace("model.all_heads.camera_cond_head", "model.cam_enc"): v
for k, v in state_dict.items()
}
state_dict = {
k.replace("model.all_heads.camera_head", "model.cam_dec"): v for k, v in state_dict.items()
}
state_dict = {k.replace(".more_mlps.", ".backbone."): v for k, v in state_dict.items()}
state_dict = {k.replace(".fc_rot.", ".fc_qvec."): v for k, v in state_dict.items()}
state_dict = {
k.replace("model.all_heads.head", "model.head"): v for k, v in state_dict.items()
}
# Replace output naming
state_dict = {
k.replace("output_conv2_additional.sky_mask", "sky_output_conv2"): v
for k, v in state_dict.items()
}
state_dict = {k.replace("_ray.", "_aux."): v for k, v in state_dict.items()}
# Update GS-DPT head naming and value
state_dict = {k.replace("gaussian_param_head.", "gs_head."): v for k, v in state_dict.items()}
return state_dict
def convert_metric_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert metric model state dict to match current model architecture.
Args:
state_dict: Original metric state dictionary
Returns:
Converted state dictionary
"""
# Add module prefix for metric models
state_dict = {"module." + k: v for k, v in state_dict.items()}
return convert_general_state_dict(state_dict)
def load_pretrained_weights(model, model_path: str, is_metric: bool = False) -> Tuple[list, list]:
"""
Load pretrained weights for a single model.
Args:
model: Model instance to load weights into
model_path: Path to the pretrained weights
is_metric: Whether this is a metric model
Returns:
Tuple of (missed_keys, unexpected_keys)
"""
state_dict = torch.load(model_path, map_location="cpu")
if is_metric:
state_dict = convert_metric_state_dict(state_dict)
else:
state_dict = convert_general_state_dict(state_dict)
missed, unexpected = model.load_state_dict(state_dict, strict=False)
logger.info("Missed keys:", missed)
logger.info("Unexpected keys:", unexpected)
return missed, unexpected
def load_pretrained_nested_weights(
model, main_model_path: str, metric_model_path: str
) -> Tuple[list, list]:
"""
Load pretrained weights for a nested model with both main and metric branches.
Args:
model: Nested model instance
main_model_path: Path to main model weights
metric_model_path: Path to metric model weights
Returns:
Tuple of (missed_keys, unexpected_keys)
"""
# Load main model weights
state_dict0 = torch.load(main_model_path, map_location="cpu")
state_dict0 = convert_general_state_dict(state_dict0)
state_dict0 = {k.replace("model.", "model.da3."): v for k, v in state_dict0.items()}
# Load metric model weights
state_dict1 = torch.load(metric_model_path, map_location="cpu")
state_dict1 = convert_metric_state_dict(state_dict1)
state_dict1 = {k.replace("model.", "model.da3_metric."): v for k, v in state_dict1.items()}
# Combine state dictionaries
combined_state_dict = state_dict0.copy()
combined_state_dict.update(state_dict1)
missed, unexpected = model.load_state_dict(combined_state_dict, strict=False)
print("Missed keys:", missed)
print("Unexpected keys:", unexpected)
return missed, unexpected

View File

@@ -0,0 +1,133 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from functools import wraps
from multiprocessing.pool import ThreadPool
from threading import Thread
from typing import Callable, Dict, List
import imageio
from tqdm import tqdm
def async_call_func(func):
@wraps(func)
async def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
# Use run_in_executor to run the blocking function in a separate thread
return await loop.run_in_executor(None, func, *args, **kwargs)
return wrapper
slice_func = lambda chunk_index, chunk_dim, chunk_size: [slice(None)] * chunk_dim + [
slice(chunk_index, chunk_index + chunk_size)
]
def async_call(fn):
def wrapper(*args, **kwargs):
Thread(target=fn, args=args, kwargs=kwargs).start()
return wrapper
def _save_image_impl(save_img, save_path):
"""Common implementation for saving images synchronously or asynchronously"""
os.makedirs(os.path.dirname(save_path), exist_ok=True)
imageio.imwrite(save_path, save_img)
@async_call
def save_image_async(save_img, save_path):
"""Save image asynchronously"""
_save_image_impl(save_img, save_path)
def save_image(save_img, save_path):
"""Save image synchronously"""
_save_image_impl(save_img, save_path)
def parallel_execution(
*args,
action: Callable,
num_processes=32,
print_progress=False,
sequential=False,
async_return=False,
desc=None,
**kwargs,
):
# Partially copy from EasyVolumetricVideo (parallel_execution)
# NOTE: we expect first arg / or kwargs to be distributed
# NOTE: print_progress arg is reserved.
# `*args` packs all positional arguments passed to the function into a tuple
args = list(args)
def get_length(args: List, kwargs: Dict):
for a in args:
if isinstance(a, list):
return len(a)
for v in kwargs.values():
if isinstance(v, list):
return len(v)
raise NotImplementedError
def get_action_args(length: int, args: List, kwargs: Dict, i: int):
action_args = [
(arg[i] if isinstance(arg, list) and len(arg) == length else arg) for arg in args
]
# TODO: Support all types of iterable
action_kwargs = {
key: (
kwargs[key][i]
if isinstance(kwargs[key], list) and len(kwargs[key]) == length
else kwargs[key]
)
for key in kwargs
}
return action_args, action_kwargs
if not sequential:
# Create ThreadPool
pool = ThreadPool(processes=num_processes)
# Spawn threads
results = []
asyncs = []
length = get_length(args, kwargs)
for i in range(length):
action_args, action_kwargs = get_action_args(length, args, kwargs, i)
async_result = pool.apply_async(action, action_args, action_kwargs)
asyncs.append(async_result)
# Join threads and get return values
if not async_return:
for async_result in tqdm(asyncs, desc=desc, disable=not print_progress):
results.append(async_result.get()) # will sync the corresponding thread
pool.close()
pool.join()
return results
else:
return pool
else:
results = []
length = get_length(args, kwargs)
for i in tqdm(range(length), desc=desc, disable=not print_progress):
action_args, action_kwargs = get_action_args(length, args, kwargs, i)
async_result = action(*action_args, **action_kwargs)
results.append(async_result)
return results

View File

@@ -0,0 +1,284 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PCA utilities for feature visualization and dimensionality reduction (video-friendly).
- Support frame-by-frame: transform_frame / transform_video
- Support one-time global PCA fitting and reuse (mean, V3) for stable colors
- Support Procrustes alignment (solving principal component order/sign/rotation jumps)
- Support global fixed or temporal EMA for percentiles (time dimension only, no spatial)
"""
import numpy as np
import torch
def pca_to_rgb_4d_bf16_percentile(
x_np: np.ndarray,
device=None,
q_oversample: int = 6,
clip_percent: float = 10.0, # Percentage to clip from top and bottom (0~49.9)
return_uint8: bool = False,
enable_autocast_bf16: bool = True,
):
"""
Reduce numpy array of shape (49, 27, 36, 3072) to 3D via PCA and visualize as (49, 27, 36, 3).
- PCA uses torch.pca_lowrank (randomized SVD), defaults to GPU.
- Uses CUDA bf16 autocast in computation (if available),
then per-channel percentile clipping and normalization.
- Default removes 5% outliers from top and bottom (adjustable via clip_percent) to
improve visualization contrast.
Parameters
----------
x_np : np.ndarray
Shape must be (49, 27, 36, 3072). dtype recommended float32/float64.
device : str | None
Specify 'cuda' or 'cpu'. Auto-select if None (prefer cuda).
q_oversample : int
Oversampling q for pca_lowrank, must be >= 3.
Slightly larger than target dim (3) is more stable, default 6.
clip_percent : float
Percentage to clip from top and bottom (0~49.9),
e.g. 5.0 means clip lowest 5% and highest 5% per channel.
return_uint8 : bool
True returns uint8(0~255), otherwise returns float32(0~1).
enable_autocast_bf16 : bool
Enable bf16 autocast on CUDA.
Returns
-------
np.ndarray
Array of shape (49, 27, 36, 3), float32[0,1] or uint8[0,255].
"""
assert (
x_np.ndim == 4
) # and x_np.shape[-1] == 3072, f"expect (49,27,36,3072), got {x_np.shape}"
B1, B2, B3, D = x_np.shape
N = B1 * B2 * B3
# Device selection
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Convert input to torch, unified float32
X = torch.from_numpy(x_np.reshape(N, D)).to(device=device, dtype=torch.float32)
# Parameter and safety checks
k = 3
q = max(int(q_oversample), k)
clip_percent = float(clip_percent)
if not (0.0 <= clip_percent < 50.0):
raise ValueError(
"clip_percent must be in [0, 50), e.g. 5.0 means clip 5% from top and bottom"
)
low = clip_percent / 100.0
high = 1.0 - low
with torch.no_grad():
# Zero mean
mean = X.mean(dim=0, keepdim=True)
Xc = X - mean
# Main computation: PCA + projection, try to use bf16
# (auto-fallback if operator not supported)
device.startswith("cuda") and enable_autocast_bf16
U, S, V = torch.pca_lowrank(Xc, q=q, center=False) # V: (D, q)
V3 = V[:, :k] # (3072, 3)
PCs = Xc @ V3 # (N, 3)
# === Per-channel percentile clipping and normalization to [0,1] ===
# Vectorized one-time calculation of low/high percentiles for each channel
qs = torch.tensor([low, high], device=PCs.device, dtype=PCs.dtype)
qvals = torch.quantile(PCs, q=qs, dim=0) # Shape (2, 3)
lo = qvals[0] # (3,)
hi = qvals[1] # (3,)
# Avoid degenerate case where hi==lo
denom = torch.clamp(hi - lo, min=1e-8)
# Broadcast clipping + normalization
PCs = torch.clamp(PCs, lo, hi)
PCs = (PCs - lo) / denom # (N, 3) in [0,1]
# Restore 4D
PCs = PCs.reshape(B1, B2, B3, k)
# Output
if return_uint8:
out = (PCs * 255.0).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
else:
out = PCs.clamp(0, 1).to(torch.float32).cpu().numpy()
return out
class PCARGBVisualizer:
"""
Stable PCA→RGB for video features shaped (T, H, W, D) or a single frame (H, W, D).
- Global mean/V3 reference for stable colors
- Per-frame PCA with Procrustes alignment to V3_ref (basis_mode='procrustes')
- Percentile normalization with global or EMA stats (time-only, no spatial smoothing)
"""
def __init__(
self,
device=None,
q_oversample: int = 16,
clip_percent: float = 10.0,
return_uint8: bool = False,
enable_autocast_bf16: bool = True,
basis_mode: str = "procrustes", # 'fixed' | 'procrustes'
percentile_mode: str = "ema", # 'global' | 'ema'
ema_alpha: float = 0.1,
denom_eps: float = 1e-4,
):
assert 0.0 <= clip_percent < 50.0
assert basis_mode in ("fixed", "procrustes")
assert percentile_mode in ("global", "ema")
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.q = max(int(q_oversample), 6)
self.clip_percent = float(clip_percent)
self.return_uint8 = return_uint8
self.enable_autocast_bf16 = enable_autocast_bf16
self.basis_mode = basis_mode
self.percentile_mode = percentile_mode
self.ema_alpha = float(ema_alpha)
self.denom_eps = float(denom_eps)
# reference state
self.mean_ref = None # (1, D)
self.V3_ref = None # (D, 3)
self.lo_ref = None # (3,)
self.hi_ref = None # (3,)
@torch.no_grad()
def fit_reference(self, frames):
"""
Fit global mean/V3 and initialize percentiles from a reference set.
frames: ndarray (T,H,W,D) or list of (H,W,D)
"""
if isinstance(frames, np.ndarray):
if frames.ndim != 4:
raise ValueError("fit_reference expects (T,H,W,D) ndarray.")
T, H, W, D = frames.shape
X = torch.from_numpy(frames.reshape(T * H * W, D))
else: # list of (H,W,D)
xs = [torch.from_numpy(x.reshape(-1, x.shape[-1])) for x in frames]
D = xs[0].shape[-1]
X = torch.cat(xs, dim=0)
X = X.to(self.device, dtype=torch.float32)
X = torch.nan_to_num(X, nan=0.0, posinf=1e6, neginf=-1e6)
mean = X.mean(0, keepdim=True)
Xc = X - mean
U, S, V = torch.pca_lowrank(Xc, q=max(self.q, 8), center=False)
V3 = V[:, :3] # (D,3)
PCs = Xc @ V3
low = self.clip_percent / 100.0
high = 1.0 - low
qs = torch.tensor([low, high], device=PCs.device, dtype=PCs.dtype)
qvals = torch.quantile(PCs, q=qs, dim=0)
lo, hi = qvals[0], qvals[1]
self.mean_ref = mean
self.V3_ref = V3
if self.percentile_mode == "global":
self.lo_ref, self.hi_ref = lo, hi
else:
self.lo_ref = lo.clone()
self.hi_ref = hi.clone()
@torch.no_grad()
def _project_with_stable_colors(self, X: torch.Tensor) -> torch.Tensor:
"""
X: (N,D) where N = H*W
Returns PCs_raw: (N,3) using stable basis (fixed or Procrustes-aligned)
"""
assert self.mean_ref is not None and self.V3_ref is not None, "Call fit_reference() first."
X = torch.nan_to_num(X, nan=0.0, posinf=1e6, neginf=-1e6)
Xc = X - self.mean_ref
if self.basis_mode == "fixed":
V3_used = self.V3_ref
else:
U, S, V = torch.pca_lowrank(Xc, q=max(self.q, 6), center=False)
V3 = V[:, :3] # (D,3)
M = V3.T @ self.V3_ref
Uo, So, Vh = torch.linalg.svd(M)
R = Uo @ Vh
V3_used = V3 @ R
# Optional polarity fix via anchor
a = self.V3_ref.mean(0, keepdim=True)
sign = torch.sign((V3_used * a).sum(0, keepdim=True)).clamp(min=-1)
V3_used = V3_used * sign
return Xc @ V3_used
@torch.no_grad()
def _normalize_rgb(self, PCs_raw: torch.Tensor) -> torch.Tensor:
assert self.lo_ref is not None and self.hi_ref is not None
if self.percentile_mode == "global":
lo, hi = self.lo_ref, self.hi_ref
else:
low = self.clip_percent / 100.0
high = 1.0 - low
qs = torch.tensor([low, high], device=PCs_raw.device, dtype=PCs_raw.dtype)
qvals = torch.quantile(PCs_raw, q=qs, dim=0)
lo_now, hi_now = qvals[0], qvals[1]
a = self.ema_alpha
self.lo_ref = (1 - a) * self.lo_ref + a * lo_now
self.hi_ref = (1 - a) * self.hi_ref + a * hi_now
lo, hi = self.lo_ref, self.hi_ref
denom = torch.clamp(hi - lo, min=self.denom_eps)
PCs = torch.clamp(PCs_raw, lo, hi)
PCs = (PCs - lo) / denom
return PCs.clamp_(0, 1)
@torch.no_grad()
def transform_frame(self, frame: np.ndarray) -> np.ndarray:
"""
frame: (H,W,D) -> (H,W,3)
"""
if frame.ndim != 3:
raise ValueError("transform_frame expects (H,W,D).")
H, W, D = frame.shape
X = torch.from_numpy(frame.reshape(H * W, D)).to(self.device, dtype=torch.float32)
PCs_raw = self._project_with_stable_colors(X)
PCs = self._normalize_rgb(PCs_raw).reshape(H, W, 3)
if self.return_uint8:
return (PCs * 255.0).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
return PCs.to(torch.float32).cpu().numpy()
@torch.no_grad()
def transform_video(self, frames) -> np.ndarray:
"""
frames: (T,H,W,D) or list of (H,W,D)
returns: (T,H,W,3)
"""
outs = []
if isinstance(frames, np.ndarray):
if frames.ndim != 4:
raise ValueError("transform_video expects (T,H,W,D).")
T, H, W, D = frames.shape
for t in range(T):
outs.append(self.transform_frame(frames[t]))
else:
for f in frames:
outs.append(self.transform_frame(f))
return np.stack(outs, axis=0)

View File

@@ -0,0 +1,347 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import numpy as np
import torch
from evo.core.trajectory import PosePath3D
from depth_anything_3.utils.geometry import affine_inverse, affine_inverse_np
def batch_apply_alignment_to_enc(
rots: torch.Tensor, trans: torch.Tensor, scales: torch.Tensor, enc_list: List[torch.Tensor]
):
pass
def batch_apply_alignment_to_ext(
rots: torch.Tensor, trans: torch.Tensor, scales: torch.Tensor, ext: torch.Tensor
):
device, _ = ext.device, ext.dtype
if ext.shape[-2:] == (3, 4):
pad = torch.zeros((*ext.shape[:-2], 4, 4), dtype=ext.dtype, device=device)
pad[..., :3, :4] = ext
pad[..., 3, 3] = 1.0
ext = pad
pose_est = affine_inverse(ext)
pose_new_align_rot = rots[:, None] @ pose_est[..., :3, :3]
pose_new_align_trans = (
scales[:, None, None] * (rots[:, None] @ pose_est[..., :3, 3:])[..., 0] + trans[:, None]
)
pose_new_align = torch.zeros_like(ext)
pose_new_align[..., :3, :3] = pose_new_align_rot
pose_new_align[..., :3, 3] = pose_new_align_trans
pose_new_align[..., 3, 3] = 1.0
return affine_inverse(pose_new_align)[:, :3]
def batch_align_poses_umeyama(ext_ref: torch.Tensor, ext_est: torch.Tensor):
device, dtype = ext_ref.device, ext_ref.dtype
assert ext_ref.dtype in [torch.float32, torch.float64]
assert ext_est.dtype in [torch.float32, torch.float64]
assert ext_ref.requires_grad is False
assert ext_est.requires_grad is False
rots, trans, scales = [], [], []
for b in range(ext_ref.shape[0]):
r, t, s = align_poses_umeyama(ext_ref[b].cpu().numpy(), ext_est[b].cpu().numpy())
rots.append(torch.from_numpy(r).to(device=device, dtype=dtype))
trans.append(torch.from_numpy(t).to(device=device, dtype=dtype))
scales.append(torch.tensor(s, device=device, dtype=dtype))
return torch.stack(rots), torch.stack(trans), torch.stack(scales)
# Dependencies: affine_inverse_np, PosePath3D (maintain consistency with your existing project)
def _to44(ext):
if ext.shape[1] == 3:
out = np.eye(4)[None].repeat(len(ext), 0)
out[:, :3, :4] = ext
return out
return ext
def _poses_from_ext(ext_ref, ext_est):
ext_ref = _to44(ext_ref)
ext_est = _to44(ext_est)
pose_ref = affine_inverse_np(ext_ref)
pose_est = affine_inverse_np(ext_est)
return pose_ref, pose_est
def _umeyama_sim3_from_paths(pose_ref, pose_est):
path_ref = PosePath3D(poses_se3=pose_ref.copy())
path_est = PosePath3D(poses_se3=pose_est.copy())
r, t, s = path_est.align(path_ref, correct_scale=True)
pose_est_aligned = np.stack(path_est.poses_se3)
return r, t, s, pose_est_aligned
def _apply_sim3_to_poses(poses, r, t, s):
out = poses.copy()
Ri = poses[:, :3, :3]
ti = poses[:, :3, 3]
out[:, :3, :3] = r @ Ri
out[:, :3, 3] = (r @ (s * ti.T)).T + t
return out
def _median_nn_thresh(pose_ref, pose_est_aligned):
P_ref = pose_ref[:, :3, 3]
P_est = pose_est_aligned[:, :3, 3]
dists = []
for p in P_est:
dd = np.linalg.norm(P_ref - p[None, :], axis=1)
dists.append(dd.min())
return float(np.median(dists)) if dists else 0.0
def _ransac_align_sim3(
pose_ref, pose_est, sub_n=None, inlier_thresh=None, max_iters=10, random_state=None
):
rng = np.random.default_rng(random_state)
N = pose_ref.shape[0]
idx_all = np.arange(N)
if sub_n is None:
sub_n = max(3, (N + 1) // 2)
else:
sub_n = max(3, min(sub_n, N))
# Pre-alignment + default threshold
r0, t0, s0, pose_est0 = _umeyama_sim3_from_paths(pose_ref, pose_est)
if inlier_thresh is None:
inlier_thresh = _median_nn_thresh(pose_ref, pose_est0)
P_ref_all = pose_ref[:, :3, 3]
best_model = (r0, t0, s0)
best_inliers = None
best_score = (-1, np.inf) # (num_inliers, mean_err)
for _ in range(max_iters):
sample = rng.choice(idx_all, size=sub_n, replace=False)
try:
r, t, s, _ = _umeyama_sim3_from_paths(pose_ref[sample], pose_est[sample])
except Exception:
continue
pose_h = _apply_sim3_to_poses(pose_est, r, t, s)
P_h = pose_h[:, :3, 3]
errs = np.linalg.norm(P_h - P_ref_all, axis=1) # Match by same index
inliers = errs <= inlier_thresh
k = int(inliers.sum())
mean_err = float(errs[inliers].mean()) if k > 0 else np.inf
if (k > best_score[0]) or (k == best_score[0] and mean_err < best_score[1]):
best_score = (k, mean_err)
best_model = (r, t, s)
best_inliers = inliers
# Fit again with best inliers
if best_inliers is not None and best_inliers.sum() >= 3:
r, t, s, _ = _umeyama_sim3_from_paths(pose_ref[best_inliers], pose_est[best_inliers])
else:
r, t, s = best_model
return r, t, s
def align_poses_umeyama(
ext_ref: np.ndarray,
ext_est: np.ndarray,
return_aligned=False,
ransac=False,
sub_n=None,
inlier_thresh=None,
ransac_max_iters=10,
random_state=None,
):
"""
Align estimated trajectory to reference using Umeyama Sim(3).
Default no RANSAC; if ransac=True, use RANSAC (max iterations default 10).
- sub_n defaults to half the number of frames (rounded up, at least 3)
- inlier_thresh defaults to median of "distance from each estimated pose to
nearest reference pose after pre-alignment"
Returns rotation (3x3), translation (3,), scale; optionally returns aligned extrinsics (4x4).
"""
pose_ref, pose_est = _poses_from_ext(ext_ref, ext_est)
if not ransac:
r, t, s, pose_est_aligned = _umeyama_sim3_from_paths(pose_ref, pose_est)
else:
r, t, s = _ransac_align_sim3(
pose_ref,
pose_est,
sub_n=sub_n,
inlier_thresh=inlier_thresh,
max_iters=ransac_max_iters,
random_state=random_state,
)
pose_est_aligned = _apply_sim3_to_poses(pose_est, r, t, s)
if return_aligned:
ext_est_aligned = affine_inverse_np(pose_est_aligned)
return r, t, s, ext_est_aligned
return r, t, s
# def align_poses_umeyama(ext_ref: np.ndarray, ext_est: np.ndarray, return_aligned=False):
# """
# Align estimated trajectory to reference trajectory using Umeyama Sim(3)
# alignment (via evo PosePath3D). # noqa
# Returns rotation, translation, and scale.
# """
# # If input extrinsics are 3x4, convert to 4x4 by padding
# if ext_ref.shape[1] == 3:
# ext_ref_ = np.eye(4)[None].repeat(len(ext_ref), 0)
# ext_ref_[:, :3] = ext_ref
# ext_ref = ext_ref_
# if ext_est.shape[1] == 3:
# ext_est_ = np.eye(4)[None].repeat(len(ext_est), 0)
# ext_est_[:, :3] = ext_est
# ext_est = ext_est_
# # Convert to camera poses (inverse extrinsics)
# pose_ref = affine_inverse_np(ext_ref)
# pose_est = affine_inverse_np(ext_est)
# # Create evo PosePath3D objects
# path_ref = PosePath3D(poses_se3=pose_ref)
# path_est = PosePath3D(poses_se3=pose_est)
# r, t, s = path_est.align(path_ref, correct_scale=True)
# if return_aligned:
# return r, t, s, affine_inverse_np(np.stack(path_est.poses_se3))
# else:
# return r, t, s
def apply_umeyama_alignment_to_ext(
rot: np.ndarray, # (3,3)
trans: np.ndarray, # (3,) or (1,3)
scale: float,
ext_est: np.ndarray, # (...,4,4) or (...,3,4)
) -> np.ndarray:
"""
Apply Sim(3) (R, t, s) to a batch of world-to-camera extrinsics ext_est.
Returns the aligned extrinsics, with the same shape as input.
"""
# Allow 3x4 extrinsics: pad to 4x4
if ext_est.shape[-2:] == (3, 4):
pad = np.zeros((*ext_est.shape[:-2], 4, 4), dtype=ext_est.dtype)
pad[..., :3, :4] = ext_est
pad[..., 3, 3] = 1.0
ext_est = pad
# Convert world-to-camera to camera-to-world
pose_est = affine_inverse_np(ext_est) # (...,4,4)
R_e = pose_est[..., :3, :3] # (...,3,3)
t_e = pose_est[..., :3, 3] # (...,3)
# Apply Sim(3) transformation
R_a = np.einsum("ij,...jk->...ik", rot, R_e) # (...,3,3)
t_a = scale * np.einsum("ij,...j->...i", rot, t_e) + trans # (...,3)
# Assemble the transformed pose
pose_a = np.zeros_like(pose_est)
pose_a[..., :3, :3] = R_a
pose_a[..., :3, 3] = t_a
pose_a[..., 3, 3] = 1.0
# Convert back to world-to-camera
return affine_inverse_np(pose_a)
def transform_points_sim3(points, rot, trans, scale, inverse=False):
"""
Sim(3) transform point cloud
points: (N, 3)
rot: (3, 3)
trans: (3,) or (1, 3)
scale: float
inverse: Whether to do inverse transform (ref->est)
Returns: (N, 3)
"""
if not inverse:
# Forward: est -> ref
return scale * (points @ rot.T) + trans
else:
# Inverse: ref -> est
return ((points - trans) @ rot) / scale
def _rand_rot():
u1, u2, u3 = np.random.rand(3)
q = np.array(
[
np.sqrt(1 - u1) * np.sin(2 * np.math.pi * u2),
np.sqrt(1 - u1) * np.cos(2 * np.math.pi * u2),
np.sqrt(u1) * np.sin(2 * np.math.pi * u3),
np.sqrt(u1) * np.cos(2 * np.math.pi * u3),
]
)
w, x, y, z = q
return np.array(
[
[1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)],
[2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)],
[2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)],
]
)
def _rand_pose():
R, t = _rand_rot(), np.random.randn(3)
P = np.eye(4)
P[:3, :3] = R
P[:3, 3] = t
return P
if __name__ == "__main__":
np.random.seed(42)
# 1. Randomly generate reference trajectory and Sim(3)
N = 8
pose_ref = np.stack([_rand_pose() for _ in range(N)]) # (N,4,4) cam→world
rot_gt = _rand_rot()
scale_gt = 2.3
trans_gt = np.random.randn(3)
# 2. Generate estimated trajectory (apply Sim(3))
pose_est = np.zeros_like(pose_ref)
for i in range(N):
R = pose_ref[i][:3, :3]
t = pose_ref[i][:3, 3]
pose_est[i][:3, :3] = rot_gt @ R
pose_est[i][:3, 3] = scale_gt * (rot_gt @ t) + trans_gt
pose_est[i][3, 3] = 1.0
# 3. Get extrinsics (world->cam)
ext_ref = affine_inverse_np(pose_ref)
ext_est = affine_inverse_np(pose_est)
# 4. Use umeyama alignment, estimate Sim(3)
r_est, t_est, s_est = align_poses_umeyama(ext_ref, ext_est)
print("GT scale:", scale_gt, "Estimated:", s_est)
print("GT trans:", trans_gt, "Estimated:", t_est)
print("GT rot:\n", rot_gt, "\nEstimated:\n", r_est)
# 5. Random point cloud, in ref frame
num_points = 100
points_ref = np.random.randn(num_points, 3)
# 6. Use GT Sim(3) inverse transform to est frame
points_est = transform_points_sim3(points_ref, rot_gt, trans_gt, scale_gt, inverse=True)
# 7. Use estimated Sim(3) forward transform back to ref frame
points_ref_recovered = transform_points_sim3(points_est, r_est, t_est, s_est, inverse=False)
# 8. Check error
err = np.abs(points_ref_recovered - points_ref)
print("Point cloud sim3 transform error (mean abs):", err.mean())
print("Point cloud sim3 transform error (max abs):", err.max())
assert err.mean() < 1e-6, "Mean sim3 transform error too large!"
assert err.max() < 1e-5, "Max sim3 transform error too large!"
print("Sim(3) point cloud transform & alignment test passed!")

View File

@@ -0,0 +1,523 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from einops import repeat
from .geometry import unproject_depth
def compute_optimal_rotation_intrinsics_batch(
rays_origin, rays_target, z_threshold=1e-4, reproj_threshold=0.2, weights=None,
n_sample = None,
n_iter=100,
num_sample_for_ransac=8,
rand_sample_iters_idx=None,
):
"""
Args:
rays_origin (torch.Tensor): (B, N, 3)
rays_target (torch.Tensor): (B, N, 3)
z_threshold (float): Threshold for z value to be considered valid.
Returns:
R (torch.tensor): (3, 3)
focal_length (torch.tensor): (2,)
principal_point (torch.tensor): (2,)
"""
device = rays_origin.device
B, N, _ = rays_origin.shape
z_mask = torch.logical_and(
torch.abs(rays_target[:, :, 2]) > z_threshold, torch.abs(rays_origin[:, :, 2]) > z_threshold
) # (B, N, 1)
rays_origin = rays_origin.clone()
rays_target = rays_target.clone()
rays_origin[:, :, 0][z_mask] /= rays_origin[:, :, 2][z_mask]
rays_origin[:, :, 1][z_mask] /= rays_origin[:, :, 2][z_mask]
rays_target[:, :, 0][z_mask] /= rays_target[:, :, 2][z_mask]
rays_target[:, :, 1][z_mask] /= rays_target[:, :, 2][z_mask]
rays_origin = rays_origin[:, :, :2]
rays_target = rays_target[:, :, :2]
assert weights is not None, "weights must be provided"
weights[~z_mask] = 0
A_list = []
max_chunk_size = 2
for i in range(0, rays_origin.shape[0], max_chunk_size):
A = ransac_find_homography_weighted_fast_batch(
rays_origin[i:i+max_chunk_size],
rays_target[i:i+max_chunk_size],
weights[i:i+max_chunk_size],
n_iter=n_iter,
n_sample = n_sample,
num_sample_for_ransac=num_sample_for_ransac,
reproj_threshold=reproj_threshold,
rand_sample_iters_idx=rand_sample_iters_idx,
max_inlier_num=8000,
)
A = A.to(device)
A_need_inv_mask = torch.linalg.det(A) < 0
A[A_need_inv_mask] = -A[A_need_inv_mask]
A_list.append(A)
A = torch.cat(A_list, dim=0)
R_list = []
f_list = []
pp_list = []
for i in range(A.shape[0]):
R, L = ql_decomposition(A[i])
L = L / L[2][2]
f = torch.stack((L[0][0], L[1][1]))
pp = torch.stack((L[2][0], L[2][1]))
R_list.append(R)
f_list.append(f)
pp_list.append(pp)
R = torch.stack(R_list)
f = torch.stack(f_list)
pp = torch.stack(pp_list)
return R, f, pp
# https://www.reddit.com/r/learnmath/comments/v1crd7/linear_algebra_qr_to_ql_decomposition/
def ql_decomposition(A):
P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device).float()
A_tilde = torch.matmul(A, P)
Q_tilde, R_tilde = torch.linalg.qr(A_tilde)
Q = torch.matmul(Q_tilde, P)
L = torch.matmul(torch.matmul(P, R_tilde), P)
d = torch.diag(L)
Q[:, 0] *= torch.sign(d[0])
Q[:, 1] *= torch.sign(d[1])
Q[:, 2] *= torch.sign(d[2])
L[0] *= torch.sign(d[0])
L[1] *= torch.sign(d[1])
L[2] *= torch.sign(d[2])
return Q, L
def find_homography_least_squares_weighted_torch(src_pts, dst_pts, confident_weight):
"""
src_pts: (N,2) source points (torch.Tensor, float32/float64)
dst_pts: (N,2) target points (torch.Tensor, float32/float64)
confident_weight: (N,) weights (torch.Tensor)
Returns: (3,3) homography matrix H (torch.Tensor)
"""
assert src_pts.shape == dst_pts.shape
N = src_pts.shape[0]
if N < 4:
raise ValueError("At least 4 points are required to compute homography.")
assert confident_weight.shape == (N,)
w = confident_weight.sqrt().unsqueeze(1) # (N,1)
x = src_pts[:, 0:1] # (N,1)
y = src_pts[:, 1:2] # (N,1)
u = dst_pts[:, 0:1]
v = dst_pts[:, 1:2]
zeros = torch.zeros_like(x)
# Construct A matrix (2N, 9)
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=1)
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=1)
A = torch.cat([A1, A2], dim=0) # (2N, 9)
# SVD
# Note: torch.linalg.svd returns U, S, Vh, where Vh is the transpose of V
_, _, Vh = torch.linalg.svd(A)
H = Vh[-1].reshape(3, 3)
H = H / H[-1, -1]
return H
def ransac_find_homography_weighted(
src_pts,
dst_pts,
confident_weight,
n_iter=100,
sample_ratio=0.2,
reproj_threshold=3.0,
num_sample_for_ransac=16,
random_seed=None,
):
"""
RANSAC version of weighted Homography estimation.
Sample 4 points from the top 50% weighted points each time.
reproj_threshold: points with reprojection error less than this value are inliers
Returns: best_H
"""
if random_seed is not None:
torch.manual_seed(random_seed)
N = src_pts.shape[0]
assert N >= 4
# 1. Select top 50% weighted points
sorted_idx = torch.argsort(confident_weight, descending=True)
n_sample = max(num_sample_for_ransac, int(N * sample_ratio))
candidate_idx = sorted_idx[:n_sample]
best_inlier_mask = None
best_score = 0
for _ in range(n_iter):
# 2. Randomly sample 4 points
idx = candidate_idx[torch.randperm(n_sample)[:num_sample_for_ransac]]
# 3. Compute Homography
try:
H = find_homography_least_squares_weighted_torch(
src_pts[idx], dst_pts[idx], confident_weight[idx]
)
except Exception:
H = torch.eye(3, dtype=src_pts.dtype, device=src_pts.device)
# 4. Compute reprojection error for all points
src_homo = torch.cat(
[src_pts, torch.ones(N, 1, dtype=src_pts.dtype, device=src_pts.device)], dim=1
)
proj = (H @ src_homo.T).T
proj = proj[:, :2] / proj[:, 2:3]
error = ((proj - dst_pts) ** 2).sum(dim=1).sqrt() # Euclidean distance
inlier_mask = error < reproj_threshold
total_score = (inlier_mask * confident_weight).sum().item()
n_inlier = inlier_mask.sum().item()
if n_inlier < 4:
continue # At least 4 inliers required for fitting
if total_score > best_score:
best_score = total_score
best_inlier_mask = inlier_mask
# 5. Refit Homography using inliers
H_inlier = find_homography_least_squares_weighted_torch(
src_pts[best_inlier_mask], dst_pts[best_inlier_mask], confident_weight[best_inlier_mask]
)
return H_inlier
def find_homography_least_squares_weighted_torch_batch(
src_pts_batch, dst_pts_batch, confident_weight_batch
):
"""
Batch version of weighted least squares Homography
src_pts_batch: (B, K, 2)
dst_pts_batch: (B, K, 2)
confident_weight_batch: (B, K)
Returns: (B, 3, 3)
"""
B, K, _ = src_pts_batch.shape
w = confident_weight_batch.sqrt().unsqueeze(2) # (B,K,1)
x = src_pts_batch[:, :, 0:1]
y = src_pts_batch[:, :, 1:2]
u = dst_pts_batch[:, :, 0:1]
v = dst_pts_batch[:, :, 1:2]
zeros = torch.zeros_like(x)
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=2)
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=2)
A = torch.cat([A1, A2], dim=1) # (B, 2K, 9)
# SVD: torch.linalg.svd supports batch
_, _, Vh = torch.linalg.svd(A)
H = Vh[:, -1].reshape(B, 3, 3)
H = H / H[:, 2:3, 2:3]
return H
def ransac_find_homography_weighted_fast(
src_pts,
dst_pts,
confident_weight,
n_sample,
n_iter=100,
reproj_threshold=3.0,
num_sample_for_ransac=8,
random_seed=None,
rand_sample_iters_idx=None,
):
"""
Batch version of RANSAC weighted Homography estimation.
Returns: H_inlier
"""
if random_seed is not None:
torch.manual_seed(random_seed)
N = src_pts.shape[0]
device = src_pts.device
assert N >= 4
# 1. Select top weighted points by sample_ratio
sorted_idx = torch.argsort(confident_weight, descending=True)
candidate_idx = sorted_idx[:n_sample] # (n_sample,)
if rand_sample_iters_idx is None:
rand_sample_iters_idx = torch.stack(
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)],
dim=0,
) # (n_iter, num_sample_for_ransac)
# 2. Generate all sampling groups at once
# shape: (n_iter, num_sample_for_ransac)
rand_idx = candidate_idx[rand_sample_iters_idx] # (n_iter, num_sample_for_ransac)
# 3. Construct batch input
src_pts_batch = src_pts[rand_idx] # (n_iter, num_sample_for_ransac, 2)
dst_pts_batch = dst_pts[rand_idx] # (n_iter, num_sample_for_ransac, 2)
confident_weight_batch = confident_weight[rand_idx] # (n_iter, num_sample_for_ransac)
# 4. Batch fit Homography
H_batch = find_homography_least_squares_weighted_torch_batch(
src_pts_batch, dst_pts_batch, confident_weight_batch
) # (n_iter, 3, 3)
# 5. Batch evaluate inliers for all H
src_homo = torch.cat(
[src_pts, torch.ones(N, 1, dtype=src_pts.dtype, device=src_pts.device)], dim=1
) # (N,3)
src_homo_expand = src_homo.unsqueeze(0).expand(n_iter, N, 3) # (n_iter, N, 3)
dst_pts_expand = dst_pts.unsqueeze(0).expand(n_iter, N, 2) # (n_iter, N, 2)
confident_weight_expand = confident_weight.unsqueeze(0).expand(n_iter, N) # (n_iter, N)
# H_batch: (n_iter, 3, 3)
proj = torch.bmm(src_homo_expand, H_batch.transpose(1, 2)) # (n_iter, N, 3)
proj_xy = proj[:, :, :2] / proj[:, :, 2:3] # (n_iter, N, 2)
error = ((proj_xy - dst_pts_expand) ** 2).sum(dim=2).sqrt() # (n_iter, N)
inlier_mask = error < reproj_threshold # (n_iter, N)
total_score = (inlier_mask * confident_weight_expand).sum(dim=1) # (n_iter,)
# 6. Select the sampling group with the highest score
best_idx = torch.argmax(total_score)
best_inlier_mask = inlier_mask[best_idx] # (N,)
inlier_src_pts = src_pts[best_inlier_mask]
inlier_dst_pts = dst_pts[best_inlier_mask]
inlier_confident_weight = confident_weight[best_inlier_mask]
max_inlier_num = 10000
sorted_idx = torch.argsort(inlier_confident_weight, descending=True)
# method 1: sort according to confident_weight, and only keep max_inlier_num pts
# sorted_idx = sorted_idx[:max_inlier_num]
# method 2: random choose max_inlier_num pts
sorted_idx = sorted_idx[torch.randperm(len(sorted_idx))[:max_inlier_num]]
inlier_src_pts = inlier_src_pts[sorted_idx]
inlier_dst_pts = inlier_dst_pts[sorted_idx]
inlier_confident_weight = inlier_confident_weight[sorted_idx]
# 7. Refit Homography using inliers
H_inlier = find_homography_least_squares_weighted_torch(
inlier_src_pts, inlier_dst_pts, inlier_confident_weight
)
return H_inlier
def ransac_find_homography_weighted_fast_batch(
src_pts, # (B, N, 3)
dst_pts, # (B, N, 2)
confident_weight, # (B, N)
n_sample,
n_iter=100,
reproj_threshold=3.0,
num_sample_for_ransac=8,
max_inlier_num=10000,
random_seed=None,
rand_sample_iters_idx=None,
):
"""
Batch version of RANSAC weighted Homography estimation (supports batch).
Input:
src_pts: (B, N, 2)
dst_pts: (B, N, 2)
confident_weight: (B, N)
Returns:
H_inlier: (B, 3, 3)
"""
if random_seed is not None:
torch.manual_seed(random_seed)
B, N, _ = src_pts.shape
assert N >= 4
device = src_pts.device
# 1. Select top weighted points by sample_ratio
sorted_idx = torch.argsort(confident_weight, descending=True, dim=1) # (B, N)
candidate_idx = sorted_idx[:, :n_sample] # (B, n_sample)
# 2. Generate all sampling groups at once
# rand_idx: (B, n_iter, num_sample_for_ransac)
if rand_sample_iters_idx is None:
rand_sample_iters_idx = torch.stack(
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)],
dim=0,
) # (n_iter, num_sample_for_ransac)
rand_idx = candidate_idx[:, rand_sample_iters_idx] # (B, n_iter, num_sample_for_ransac)
# 3. Construct batch input
# Indexing method below: (B, n_iter, num_sample_for_ransac, ...)
b_idx = torch.arange(B, device=device).view(B, 1, 1).expand(B, n_iter, num_sample_for_ransac)
src_pts_batch = src_pts[b_idx, rand_idx] # (B, n_iter, num_sample_for_ransac, 2)
dst_pts_batch = dst_pts[b_idx, rand_idx] # (B, n_iter, num_sample_for_ransac, 2)
confident_weight_batch = confident_weight[b_idx, rand_idx] # (B, n_iter, num_sample_for_ransac)
# 4. Batch fit Homography
# Need to implement batch version that supports (B, n_iter, num_sample_for_ransac, ...) input
# Output H_batch: (B, n_iter, 3, 3)
cB, cN = src_pts_batch.shape[:2]
H_batch = find_homography_least_squares_weighted_torch_batch(
src_pts_batch.flatten(0, 1), dst_pts_batch.flatten(0, 1), confident_weight_batch.flatten(0, 1)
) # (B, n_iter, 3, 3)
H_batch = H_batch.unflatten(0, (cB, cN))
# 5. Batch evaluate inliers for all H
src_homo = torch.cat(
[src_pts, torch.ones(B, N, 1, dtype=src_pts.dtype, device=src_pts.device)], dim=2
) # (B, N, 3)
src_homo_expand = src_homo.unsqueeze(1).expand(B, n_iter, N, 3) # (B, n_iter, N, 3)
dst_pts_expand = dst_pts.unsqueeze(1).expand(B, n_iter, N, 2) # (B, n_iter, N, 2)
confident_weight_expand = confident_weight.unsqueeze(1).expand(B, n_iter, N) # (B, n_iter, N)
# H_batch: (B, n_iter, 3, 3)
# Need to reshape H_batch to (B*n_iter, 3, 3), src_homo_expand to (B*n_iter, N, 3)
H_batch_flat = H_batch.reshape(-1, 3, 3)
src_homo_expand_flat = src_homo_expand.reshape(-1, N, 3)
proj = torch.bmm(src_homo_expand_flat, H_batch_flat.transpose(1, 2)) # (B*n_iter, N, 3)
proj_xy = proj[:, :, :2] / proj[:, :, 2:3] # (B*n_iter, N, 2)
proj_xy = proj_xy.reshape(B, n_iter, N, 2)
error = ((proj_xy - dst_pts_expand) ** 2).sum(dim=3).sqrt() # (B, n_iter, N)
inlier_mask = error < reproj_threshold # (B, n_iter, N)
total_score = (inlier_mask * confident_weight_expand).sum(dim=2) # (B, n_iter)
# 6. Select the sampling group with the highest score
best_idx = torch.argmax(total_score, dim=1) # (B,)
best_inlier_mask = inlier_mask[torch.arange(B, device=device), best_idx] # (B, N)
# 7. Refit Homography using inliers
H_inlier_list = []
for b in range(B):
mask = best_inlier_mask[b]
inlier_src_pts = src_pts[b][mask] # (?, 3)
inlier_dst_pts = dst_pts[b][mask] # (?, 2)
inlier_confident_weight = confident_weight[b][mask] # (?)
sorted_idx = torch.argsort(inlier_confident_weight, descending=True)
# # method 1: sort according to confident_weight, and only keep max_inlier_num pts
# sorted_idx = sorted_idx[:max_inlier_num]
# method 2: random choose max_inlier_num pts
if len(sorted_idx) > max_inlier_num:
# random choose from first 95% confident pts
keep_len = max(int(len(sorted_idx) * 0.95), max_inlier_num)
sorted_idx = sorted_idx[:keep_len]
perm = torch.randperm(len(sorted_idx), device=device)[:max_inlier_num]
sorted_idx = sorted_idx[perm]
inlier_src_pts = inlier_src_pts[sorted_idx]
inlier_dst_pts = inlier_dst_pts[sorted_idx]
inlier_confident_weight = inlier_confident_weight[sorted_idx]
H_inlier = find_homography_least_squares_weighted_torch(
inlier_src_pts, inlier_dst_pts, inlier_confident_weight
) # (3, 3)
H_inlier_list.append(H_inlier)
H_inlier = torch.stack(H_inlier_list, dim=0) # (B, 3, 3)
return H_inlier
def get_params_for_ransac(N, device):
n_iter=100
sample_ratio=0.3
num_sample_for_ransac=8
n_sample = max(num_sample_for_ransac, int(N * sample_ratio))
rand_sample_iters_idx = torch.stack(
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)],
dim=0,
) # (n_iter, num_sample_for_ransac)
return n_iter, num_sample_for_ransac, n_sample, rand_sample_iters_idx
def camray_to_caminfo(camray, confidence=None, reproj_threshold=0.2, training=False):
"""
Args:
camray: (B, S, num_patches_y, num_patches_x, 6)
confidence: (B, S, num_patches_y, num_patches_x)
Returns:
R: (B, S, 3, 3)
T: (B, S, 3)
focal_lengths: (B, S, 2)
principal_points: (B, S, 2)
"""
if confidence is None:
confidence = torch.ones_like(camray[:, :, :, :, 0])
B, S, num_patches_y, num_patches_x, _ = camray.shape
# identity K, assume imw=imh=2.0
I_K = torch.eye(3, dtype=camray.dtype, device=camray.device)
I_K[0, 2] = 1.0
I_K[1, 2] = 1.0
# repeat I_K to match camray
I_K = I_K.unsqueeze(0).unsqueeze(0).expand(B, S, -1, -1)
cam_plane_depth = torch.ones(
B, S, num_patches_y, num_patches_x, 1, dtype=camray.dtype, device=camray.device
)
I_cam_plane_unproj = unproject_depth(
cam_plane_depth,
I_K,
c2w=None,
ixt_normalized=True,
num_patches_x=num_patches_x,
num_patches_y=num_patches_y,
) # (B, S, num_patches_y, num_patches_x, 3)
camray = camray.flatten(0, 1).flatten(1, 2) # (B*S, num_patches_y*num_patches_x, 6)
I_cam_plane_unproj = I_cam_plane_unproj.flatten(0, 1).flatten(
1, 2
) # (B*S, num_patches_y*num_patches_x, 3)
confidence = confidence.flatten(0, 1).flatten(1, 2) # (B*S, num_patches_y*num_patches_x)
# Compute optimal rotation to align rays
N = camray.shape[-2]
device = camray.device
n_iter, num_sample_for_ransac, n_sample, rand_sample_iters_idx = get_params_for_ransac(N, device)
# Use batch processing (confidence is guaranteed to be not None at this point)
if training:
camray = camray.clone().detach()
I_cam_plane_unproj = I_cam_plane_unproj.clone().detach()
confidence = confidence.clone().detach()
R, focal_lengths, principal_points = compute_optimal_rotation_intrinsics_batch(
I_cam_plane_unproj,
camray[:, :, :3],
reproj_threshold=reproj_threshold,
weights=confidence,
n_sample = n_sample,
n_iter=n_iter,
num_sample_for_ransac=num_sample_for_ransac,
rand_sample_iters_idx=rand_sample_iters_idx,
)
T = torch.sum(camray[:, :, 3:] * confidence.unsqueeze(-1), dim=1) / torch.sum(
confidence, dim=-1, keepdim=True
)
R = R.reshape(B, S, 3, 3)
T = T.reshape(B, S, 3)
focal_lengths = focal_lengths.reshape(B, S, 2)
principal_points = principal_points.reshape(B, S, 2)
return R, T, 1.0 / focal_lengths, principal_points + 1.0
def get_extrinsic_from_camray(camray, conf, patch_size_y, patch_size_x, training=False):
pred_R, pred_T, pred_focal_lengths, pred_principal_points = camray_to_caminfo(
camray, confidence=conf.squeeze(-1), training=training
)
pred_extrinsic = torch.cat(
[
torch.cat([pred_R, pred_T.unsqueeze(-1)], dim=-1),
repeat(
torch.tensor([0, 0, 0, 1], dtype=pred_R.dtype, device=pred_R.device),
"c -> b s 1 c",
b=pred_R.shape[0],
s=pred_R.shape[1],
),
],
dim=-2,
) # B, S, 4, 4
return pred_extrinsic, pred_focal_lengths, pred_principal_points

View File

@@ -0,0 +1,585 @@
# Copyright (c), ETH Zurich and UNC Chapel Hill.
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
# All rights reserved.
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 11/05/2025
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
# its contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
import argparse
import collections
import os
import struct
import numpy as np
CameraModel = collections.namedtuple("CameraModel", ["model_id", "model_name", "num_params"])
Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"])
BaseImage = collections.namedtuple(
"Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]
)
Point3D = collections.namedtuple(
"Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]
)
class Image(BaseImage):
def qvec2rotmat(self):
return qvec2rotmat(self.qvec)
CAMERA_MODELS = {
CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
CameraModel(model_id=3, model_name="RADIAL", num_params=5),
CameraModel(model_id=4, model_name="OPENCV", num_params=8),
CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
CameraModel(model_id=7, model_name="FOV", num_params=5),
CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12),
}
CAMERA_MODEL_IDS = {camera_model.model_id: camera_model for camera_model in CAMERA_MODELS}
CAMERA_MODEL_NAMES = {camera_model.model_name: camera_model for camera_model in CAMERA_MODELS}
def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
"""Read and unpack the next bytes from a binary file.
:param fid:
:param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
:param endian_character: Any of {@, =, <, >, !}
:return: Tuple of read and unpacked values.
"""
data = fid.read(num_bytes)
return struct.unpack(endian_character + format_char_sequence, data)
def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
"""pack and write to a binary file.
:param fid:
:param data: data to send, if multiple elements are sent at the same time,
they should be encapsuled either in a list or a tuple
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
should be the same length as the data list or tuple
:param endian_character: Any of {@, =, <, >, !}
"""
if isinstance(data, (list, tuple)):
bytes = struct.pack(endian_character + format_char_sequence, *data)
else:
bytes = struct.pack(endian_character + format_char_sequence, data)
fid.write(bytes)
def read_cameras_text(path):
"""
see: src/colmap/scene/reconstruction.cc
void Reconstruction::WriteCamerasText(const std::string& path)
void Reconstruction::ReadCamerasText(const std::string& path)
"""
cameras = {}
with open(path) as fid:
while True:
line = fid.readline()
if not line:
break
line = line.strip()
if len(line) > 0 and line[0] != "#":
elems = line.split()
camera_id = int(elems[0])
model = elems[1]
width = int(elems[2])
height = int(elems[3])
params = np.array(tuple(map(float, elems[4:])))
cameras[camera_id] = Camera(
id=camera_id,
model=model,
width=width,
height=height,
params=params,
)
return cameras
def read_cameras_binary(path_to_model_file):
"""
see: src/colmap/scene/reconstruction.cc
void Reconstruction::WriteCamerasBinary(const std::string& path)
void Reconstruction::ReadCamerasBinary(const std::string& path)
"""
cameras = {}
with open(path_to_model_file, "rb") as fid:
num_cameras = read_next_bytes(fid, 8, "Q")[0]
for _ in range(num_cameras):
camera_properties = read_next_bytes(fid, num_bytes=24, format_char_sequence="iiQQ")
camera_id = camera_properties[0]
model_id = camera_properties[1]
model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
width = camera_properties[2]
height = camera_properties[3]
num_params = CAMERA_MODEL_IDS[model_id].num_params
params = read_next_bytes(
fid,
num_bytes=8 * num_params,
format_char_sequence="d" * num_params,
)
cameras[camera_id] = Camera(
id=camera_id,
model=model_name,
width=width,
height=height,
params=np.array(params),
)
assert len(cameras) == num_cameras
return cameras
def write_cameras_text(cameras, path):
"""
see: src/colmap/scene/reconstruction.cc
void Reconstruction::WriteCamerasText(const std::string& path)
void Reconstruction::ReadCamerasText(const std::string& path)
"""
HEADER = (
"# Camera list with one line of data per camera:\n"
+ "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n"
+ f"# Number of cameras: {len(cameras)}\n"
)
with open(path, "w") as fid:
fid.write(HEADER)
for _, cam in cameras.items():
to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
line = " ".join([str(elem) for elem in to_write])
fid.write(line + "\n")
def write_cameras_binary(cameras, path_to_model_file):
"""
see: src/colmap/scene/reconstruction.cc
void Reconstruction::WriteCamerasBinary(const std::string& path)
void Reconstruction::ReadCamerasBinary(const std::string& path)
"""
with open(path_to_model_file, "wb") as fid:
write_next_bytes(fid, len(cameras), "Q")
for _, cam in cameras.items():
model_id = CAMERA_MODEL_NAMES[cam.model].model_id
camera_properties = [cam.id, model_id, cam.width, cam.height]
write_next_bytes(fid, camera_properties, "iiQQ")
for p in cam.params:
write_next_bytes(fid, float(p), "d")
return cameras
def read_images_text(path):
"""
see: src/colmap/scene/reconstruction.cc
void Reconstruction::ReadImagesText(const std::string& path)
void Reconstruction::WriteImagesText(const std::string& path)
"""
images = {}
with open(path) as fid:
while True:
line = fid.readline()
if not line:
break
line = line.strip()
if len(line) > 0 and line[0] != "#":
elems = line.split()
image_id = int(elems[0])
qvec = np.array(tuple(map(float, elems[1:5])))
tvec = np.array(tuple(map(float, elems[5:8])))
camera_id = int(elems[8])
image_name = elems[9]
elems = fid.readline().split()
xys = np.column_stack(
[
tuple(map(float, elems[0::3])),
tuple(map(float, elems[1::3])),
]
)
point3D_ids = np.array(tuple(map(int, elems[2::3])))
images[image_id] = Image(
id=image_id,
qvec=qvec,
tvec=tvec,
camera_id=camera_id,
name=image_name,
xys=xys,
point3D_ids=point3D_ids,
)
return images
def read_images_binary(path_to_model_file):
"""
see: src/colmap/scene/reconstruction.cc
void Reconstruction::ReadImagesBinary(const std::string& path)
void Reconstruction::WriteImagesBinary(const std::string& path)
"""
images = {}
with open(path_to_model_file, "rb") as fid:
num_reg_images = read_next_bytes(fid, 8, "Q")[0]
for _ in range(num_reg_images):
binary_image_properties = read_next_bytes(
fid, num_bytes=64, format_char_sequence="idddddddi"
)
image_id = binary_image_properties[0]
qvec = np.array(binary_image_properties[1:5])
tvec = np.array(binary_image_properties[5:8])
camera_id = binary_image_properties[8]
binary_image_name = b""
current_char = read_next_bytes(fid, 1, "c")[0]
while current_char != b"\x00": # look for the ASCII 0 entry
binary_image_name += current_char
current_char = read_next_bytes(fid, 1, "c")[0]
image_name = binary_image_name.decode("utf-8")
num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[0]
x_y_id_s = read_next_bytes(
fid,
num_bytes=24 * num_points2D,
format_char_sequence="ddq" * num_points2D,
)
xys = np.column_stack(
[
tuple(map(float, x_y_id_s[0::3])),
tuple(map(float, x_y_id_s[1::3])),
]
)
point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
images[image_id] = Image(
id=image_id,
qvec=qvec,
tvec=tvec,
camera_id=camera_id,
name=image_name,
xys=xys,
point3D_ids=point3D_ids,
)
return images
def write_images_text(images, path):
"""
see: src/colmap/scene/reconstruction.cc
void Reconstruction::ReadImagesText(const std::string& path)
void Reconstruction::WriteImagesText(const std::string& path)
"""
if len(images) == 0:
mean_observations = 0
else:
mean_observations = sum((len(img.point3D_ids) for _, img in images.items())) / len(images)
HEADER = (
"# Image list with two lines of data per image:\n"
+ "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n"
+ "# POINTS2D[] as (X, Y, POINT3D_ID)\n"
+ "# Number of images: {}, mean observations per image: {}\n".format(
len(images), mean_observations
)
)
with open(path, "w") as fid:
fid.write(HEADER)
for _, img in images.items():
image_header = [
img.id,
*img.qvec,
*img.tvec,
img.camera_id,
img.name,
]
first_line = " ".join(map(str, image_header))
fid.write(first_line + "\n")
points_strings = []
for xy, point3D_id in zip(img.xys, img.point3D_ids):
points_strings.append(" ".join(map(str, [*xy, point3D_id])))
fid.write(" ".join(points_strings) + "\n")
def write_images_binary(images, path_to_model_file):
"""
see: src/colmap/scene/reconstruction.cc
void Reconstruction::ReadImagesBinary(const std::string& path)
void Reconstruction::WriteImagesBinary(const std::string& path)
"""
with open(path_to_model_file, "wb") as fid:
write_next_bytes(fid, len(images), "Q")
for _, img in images.items():
write_next_bytes(fid, img.id, "i")
write_next_bytes(fid, img.qvec.tolist(), "dddd")
write_next_bytes(fid, img.tvec.tolist(), "ddd")
write_next_bytes(fid, img.camera_id, "i")
for char in img.name:
write_next_bytes(fid, char.encode("utf-8"), "c")
write_next_bytes(fid, b"\x00", "c")
write_next_bytes(fid, len(img.point3D_ids), "Q")
for xy, p3d_id in zip(img.xys, img.point3D_ids):
write_next_bytes(fid, [*xy, p3d_id], "ddq")
def read_points3D_text(path):
"""
see: src/colmap/scene/reconstruction.cc
void Reconstruction::ReadPoints3DText(const std::string& path)
void Reconstruction::WritePoints3DText(const std::string& path)
"""
points3D = {}
with open(path) as fid:
while True:
line = fid.readline()
if not line:
break
line = line.strip()
if len(line) > 0 and line[0] != "#":
elems = line.split()
point3D_id = int(elems[0])
xyz = np.array(tuple(map(float, elems[1:4])))
rgb = np.array(tuple(map(int, elems[4:7])))
error = float(elems[7])
image_ids = np.array(tuple(map(int, elems[8::2])))
point2D_idxs = np.array(tuple(map(int, elems[9::2])))
points3D[point3D_id] = Point3D(
id=point3D_id,
xyz=xyz,
rgb=rgb,
error=error,
image_ids=image_ids,
point2D_idxs=point2D_idxs,
)
return points3D
def read_points3D_binary(path_to_model_file):
"""
see: src/colmap/scene/reconstruction.cc
void Reconstruction::ReadPoints3DBinary(const std::string& path)
void Reconstruction::WritePoints3DBinary(const std::string& path)
"""
points3D = {}
with open(path_to_model_file, "rb") as fid:
num_points = read_next_bytes(fid, 8, "Q")[0]
for _ in range(num_points):
binary_point_line_properties = read_next_bytes(
fid, num_bytes=43, format_char_sequence="QdddBBBd"
)
point3D_id = binary_point_line_properties[0]
xyz = np.array(binary_point_line_properties[1:4])
rgb = np.array(binary_point_line_properties[4:7])
error = np.array(binary_point_line_properties[7])
track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[0]
track_elems = read_next_bytes(
fid,
num_bytes=8 * track_length,
format_char_sequence="ii" * track_length,
)
image_ids = np.array(tuple(map(int, track_elems[0::2])))
point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
points3D[point3D_id] = Point3D(
id=point3D_id,
xyz=xyz,
rgb=rgb,
error=error,
image_ids=image_ids,
point2D_idxs=point2D_idxs,
)
return points3D
def write_points3D_text(points3D, path):
"""
see: src/colmap/scene/reconstruction.cc
void Reconstruction::ReadPoints3DText(const std::string& path)
void Reconstruction::WritePoints3DText(const std::string& path)
"""
if len(points3D) == 0:
mean_track_length = 0
else:
mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items())) / len(points3D)
HEADER = (
"# 3D point list with one line of data per point:\n"
+ "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n"
+ "# Number of points: {}, mean track length: {}\n".format(
len(points3D), mean_track_length
)
)
with open(path, "w") as fid:
fid.write(HEADER)
for _, pt in points3D.items():
point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
fid.write(" ".join(map(str, point_header)) + " ")
track_strings = []
for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
track_strings.append(" ".join(map(str, [image_id, point2D])))
fid.write(" ".join(track_strings) + "\n")
def write_points3D_binary(points3D, path_to_model_file):
"""
see: src/colmap/scene/reconstruction.cc
void Reconstruction::ReadPoints3DBinary(const std::string& path)
void Reconstruction::WritePoints3DBinary(const std::string& path)
"""
with open(path_to_model_file, "wb") as fid:
write_next_bytes(fid, len(points3D), "Q")
for _, pt in points3D.items():
write_next_bytes(fid, pt.id, "Q")
write_next_bytes(fid, pt.xyz.tolist(), "ddd")
write_next_bytes(fid, pt.rgb.tolist(), "BBB")
write_next_bytes(fid, pt.error, "d")
track_length = pt.image_ids.shape[0]
write_next_bytes(fid, track_length, "Q")
for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
write_next_bytes(fid, [image_id, point2D_id], "ii")
def detect_model_format(path, ext):
if (
os.path.isfile(os.path.join(path, "cameras" + ext))
and os.path.isfile(os.path.join(path, "images" + ext))
and os.path.isfile(os.path.join(path, "points3D" + ext))
):
print("Detected model format: '" + ext + "'")
return True
return False
def read_model(path, ext=""):
# try to detect the extension automatically
if ext == "":
if detect_model_format(path, ".bin"):
ext = ".bin"
elif detect_model_format(path, ".txt"):
ext = ".txt"
else:
print("Provide model format: '.bin' or '.txt'")
return
if ext == ".txt":
cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
images = read_images_text(os.path.join(path, "images" + ext))
points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
else:
cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
images = read_images_binary(os.path.join(path, "images" + ext))
points3D = read_points3D_binary(os.path.join(path, "points3D") + ext)
return cameras, images, points3D
def write_model(cameras, images, points3D, path, ext=".bin"):
if ext == ".txt":
write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
write_images_text(images, os.path.join(path, "images" + ext))
write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
else:
write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
write_images_binary(images, os.path.join(path, "images" + ext))
write_points3D_binary(points3D, os.path.join(path, "points3D") + ext)
return cameras, images, points3D
def qvec2rotmat(qvec):
return np.array(
[
[
1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
],
[
2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
],
[
2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
],
]
)
def rotmat2qvec(R):
Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
K = (
np.array(
[
[Rxx - Ryy - Rzz, 0, 0, 0],
[Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
[Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
[Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz],
]
)
/ 3.0
)
eigvals, eigvecs = np.linalg.eigh(K)
qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
if qvec[0] < 0:
qvec *= -1
return qvec
def main():
parser = argparse.ArgumentParser(description="Read and write COLMAP binary and text models")
parser.add_argument("--input_model", help="path to input model folder")
parser.add_argument(
"--input_format",
choices=[".bin", ".txt"],
help="input model format",
default="",
)
parser.add_argument("--output_model", help="path to output model folder")
parser.add_argument(
"--output_format",
choices=[".bin", ".txt"],
help="output model format",
default=".txt",
)
args = parser.parse_args()
cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format)
print("num_cameras:", len(cameras))
print("num_images:", len(images))
print("num_points3D:", len(points3D))
if args.output_model is not None:
write_model(
cameras,
images,
points3D,
path=args.output_model,
ext=args.output_format,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,36 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from addict import Dict
class Registry(Dict[str, Any]):
def __init__(self):
super().__init__()
self._map = Dict({})
def register(self, name=None):
def decorator(cls):
key = name or cls.__name__
self._map[key] = cls
return cls
return decorator
def get(self, name):
return self._map[name]
def all(self):
return self._map

View File

@@ -0,0 +1,88 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from math import isqrt
import torch
from einops import einsum
try:
from e3nn.o3 import matrix_to_angles, wigner_D
except ImportError:
from depth_anything_3.utils.logger import logger
logger.warn("Dependency 'e3nn' not found. Required for rotating the camera space SH coeff")
def project_to_so3_strict(M: torch.Tensor) -> torch.Tensor:
if M.shape[-2:] != (3, 3):
raise ValueError("Input must be a batch of 3x3 matrices (i.e., shape [..., 3, 3]).")
# 1. Compute SVD
U, S, Vh = torch.linalg.svd(M)
V = Vh.mH
# 2. Handle reflection case (det = -1)
det_U = torch.det(U)
det_V = torch.det(V)
is_reflection = (det_U * det_V) < 0
correction_sign = torch.where(
is_reflection[..., None],
torch.tensor([1, 1, -1.0], device=M.device, dtype=M.dtype),
torch.tensor([1, 1, 1.0], device=M.device, dtype=M.dtype),
)
correction_matrix = torch.diag_embed(correction_sign)
U_corrected = U @ correction_matrix
R_so3_initial = U_corrected @ V.transpose(-2, -1)
# 3. Explicitly ensure determinant is 1 (or extremely close)
current_det = torch.det(R_so3_initial)
det_correction_factor = torch.pow(current_det, -1 / 3)[..., None, None]
R_so3_final = R_so3_initial * det_correction_factor
return R_so3_final
def rotate_sh(
sh_coefficients: torch.Tensor, # "*#batch n"
rotations: torch.Tensor, # "*#batch 3 3"
) -> torch.Tensor: # "*batch n"
# https://github.com/graphdeco-inria/gaussian-splatting/issues/176#issuecomment-2452412653
device = sh_coefficients.device
dtype = sh_coefficients.dtype
*_, n = sh_coefficients.shape
with torch.autocast(device_type=rotations.device.type, enabled=False):
rotations_float32 = rotations.to(torch.float32)
# switch axes: yzx -> xyz
P = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]]).unsqueeze(0).to(rotations_float32)
permuted_rotations = torch.linalg.inv(P) @ rotations_float32 @ P
# ensure rotation has det == 1 in float32 type
permuted_rotations_so3 = project_to_so3_strict(permuted_rotations)
alpha, beta, gamma = matrix_to_angles(permuted_rotations_so3)
result = []
for degree in range(isqrt(n)):
with torch.device(device):
sh_rotations = wigner_D(degree, alpha, -beta, gamma).type(dtype)
sh_rotated = einsum(
sh_rotations,
sh_coefficients[..., degree**2 : (degree + 1) ** 2],
"... i j, ... j -> ... i",
)
result.append(sh_rotated)
return torch.cat(result, dim=-1)

View File

@@ -0,0 +1,120 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import matplotlib
import numpy as np
import torch
from einops import rearrange
from depth_anything_3.utils.logger import logger
def visualize_depth(
depth: np.ndarray,
depth_min=None,
depth_max=None,
percentile=2,
ret_minmax=False,
ret_type=np.uint8,
cmap="Spectral",
):
"""
Visualize a depth map using a colormap.
Args:
depth: Input depth map array
depth_min: Minimum depth value for normalization. If None, uses percentile
depth_max: Maximum depth value for normalization. If None, uses percentile
percentile: Percentile for min/max computation if not provided
ret_minmax: Whether to return min/max depth values
ret_type: Return array type (uint8 or float)
cmap: Matplotlib colormap name to use
Returns:
Colored depth visualization as numpy array
If ret_minmax=True, also returns depth_min and depth_max
"""
depth = depth.copy()
depth.copy()
valid_mask = depth > 0
depth[valid_mask] = 1 / depth[valid_mask]
if depth_min is None:
if valid_mask.sum() <= 10:
depth_min = 0
else:
depth_min = np.percentile(depth[valid_mask], percentile)
if depth_max is None:
if valid_mask.sum() <= 10:
depth_max = 0
else:
depth_max = np.percentile(depth[valid_mask], 100 - percentile)
if depth_min == depth_max:
depth_min = depth_min - 1e-6
depth_max = depth_max + 1e-6
cm = matplotlib.colormaps[cmap]
depth = ((depth - depth_min) / (depth_max - depth_min)).clip(0, 1)
depth = 1 - depth
img_colored_np = cm(depth[None], bytes=False)[:, :, :, 0:3] # value from 0 to 1
if ret_type == np.uint8:
img_colored_np = (img_colored_np[0] * 255.0).astype(np.uint8)
elif ret_type == np.float32 or ret_type == np.float64:
img_colored_np = img_colored_np[0]
else:
raise ValueError(f"Invalid return type: {ret_type}")
if ret_minmax:
return img_colored_np, depth_min, depth_max
else:
return img_colored_np
# GS video rendering visulization function, since it operates in Tensor space...
def vis_depth_map_tensor(
result: torch.Tensor, # "*batch height width"
color_map: str = "Spectral",
) -> torch.Tensor: # "*batch 3 height with"
"""
Color-map the depth map.
"""
far = result.reshape(-1)[:16_000_000].float().quantile(0.99).log().to(result)
try:
near = result[result > 0][:16_000_000].float().quantile(0.01).log().to(result)
except (RuntimeError, ValueError) as e:
logger.error(f"No valid depth values found. Reason: {e}")
near = torch.zeros_like(far)
result = result.log()
result = (result - near) / (far - near)
return apply_color_map_to_image(result, color_map)
def apply_color_map(
x: torch.Tensor, # " *batch"
color_map: str = "inferno",
) -> torch.Tensor: # "*batch 3"
cmap = matplotlib.cm.get_cmap(color_map)
# Convert to NumPy so that Matplotlib color maps can be used.
mapped = cmap(x.float().detach().clip(min=0, max=1).cpu().numpy())[..., :3]
# Convert back to the original format.
return torch.tensor(mapped, device=x.device, dtype=torch.float32)
def apply_color_map_to_image(
image: torch.Tensor, # "*batch height width"
color_map: str = "inferno",
) -> torch.Tensor: # "*batch 3 height with"
image = apply_color_map(image, color_map)
return rearrange(image, "... h w c -> ... c h w")

View File

@@ -0,0 +1,35 @@
《Depth Anything V3》
网址https://github.com/ByteDance-Seed/Depth-Anything-3
# 1. 配置
gh repo clone ByteDance-Seed/Depth-Anything-3
cd Depth-Anything-3
conda create -n da3 python=3.12 -y
conda activate da3
conda install -c nvidia cudatoolkit=11.8
conda install -c nvidia cuda-nvcc=11.8
conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit
# 0. 请确认 11.8 编译器确实存在于您的 Conda 环境文件夹中: If this prints a path (e.g., /home/.../envs/REG/bin/nvcc), proceed to step 1
ls $CONDA_PREFIX/bin/nvcc
# 1. Force the shell to prioritize your Conda bin directory
export PATH=$CONDA_PREFIX/bin:$PATH
# 2. Explicitly tell the build script where CUDA is located
export CUDA_HOME=$CONDA_PREFIX
# 3. Check the version again - it MUST say 11.8 now
nvcc -V
# cuda version need to be same with `nvcc --version` for gsplat
pip install torch==2.7.1 torchvision==0.22.1 --index-url https://download.pytorch.org/whl/cu118
pip install -e .
pip install --no-build-isolation 'git+https://github.com/nerfstudio-project/gsplat.git@0b4dddf04cb687367602c01196913cde6a743d70'
# 2. 下载深度预训练模型
mkdir checkpoints
mkdir checkpoints/DA3MONO-LARGE
https://huggingface.co/depth-anything/DA3MONO-LARGE
mkdir checkpoints/DA3METRIC-LARGE
mkdir checkpoints/DA3-SMALL
mkdir checkpoints/DA3-BASE
mkdir checkpoints/DA3-LARGE
mkdir checkpoints/DA3-GIANT
mkdir checkpoints/DA3NESTED-GIANT-LARGE